pytorch之诗词生成3--utils

先上代码:


import numpy as np
import settingsdef generate_random_poetry(tokenizer, model, s=''):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: 一个字符串,表示一首古诗"""# 将初始字符串转成tokentoken_ids = tokenizer.encode(s)# 去掉结束标记[SEP]token_ids = token_ids[:-1]while len(token_ids) < settings.MAX_LEN:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# print(_probas)# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)if target == 3:breakreturn tokenizer.decode(token_ids)def generate_acrostic(tokenizer, model, head):"""随机生成一首藏头诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param head: 藏头诗的头:return: 一个字符串,表示一首古诗"""# 使用空串初始化token_ids,加入[CLS]token_ids = tokenizer.encode('')token_ids = token_ids[:-1]# 标点符号,这里简单的只把逗号和句号作为标点punctuations = [',', '。']punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}# 缓存生成的诗的listpoetry = []# 对于藏头诗中的每一个字,都生成一个短句for ch in head:# 先记录下这个字poetry.append(ch)# 将藏头诗的字符转成token idtoken_id = tokenizer.token_to_id(ch)# 加入到列表中去token_ids.append(token_id)# 开始生成一个短句while True:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)# 只有不是特殊字符时,才保存到poetry里面去if target > 3:poetry.append(tokenizer.id_to_token(target))if target in punctuation_ids:breakreturn ''.join(poetry)

我们首先看第一个函数generate_random_poetry,用于随机生成古诗。

def generate_random_poetry(tokenizer, model, s=''):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: 一个字符串,表示一首古诗"""

 tokenizer是一个分词器,用于将文本转化为模型可以理解的token序列。

# 将初始字符串转成tokentoken_ids = tokenizer.encode(s)# 去掉结束标记[SEP]token_ids = token_ids[:-1]

是将我们传入的字符串转化为id序列。并使用切片操作切除token序列的最后一个元素。这个元素是用来标记结束标志[SEP]的。

我们接着看:

 while len(token_ids) < settings.MAX_LEN:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output

这里我们就可以看出为什么我们要对最后一个 token进行切片处理了,是为了控制预测的长度。

之后我们进行预测,通过model模型对当前的token_ids序列进行预测,model接受一个形状为(batch_size,sequence_length)的输入,这里batch_size为1,sequence_length为token_ids的长度。因此,我们将token_ids包装成一个形状为(1,sequence_length)的numpy数组,并将其传递给model进行预测。

输出的结果是一个3D数组,形状为(1,sequence_length,vocab_size)。其中vocab_size是词汇表,得到的是对下一个词的预测。

然后通过output.numpy()将输出转化为numpy数组,由于我们只关注当前序列的最后一个token的预测结果,所以使用[0,-1,3:]来获取这部分概率分布。具体的,[0,-1]表示获取第一个样例的最后一个token的预测,而[3:]表示去除前三个token,即[PAD],[UNK]和[CLS]。

最后,通过del output释放output变量所占用的内存空间,这样做是为了及时释放内存,避免内存占用过多。

解释一下四个特殊的token在自然语言处理任务中的不同作用。

  • [PAD](填充):在序列中用于填充长度不足的部分,使得序列长度统一,在很多深度学习模型中,要求输入的序列长度是相同的,因此当序列长度不足时,可以使用[PAD]进行填充,使得序列达到统一的长度。
  • [UNK](未知):用于表示模型未见过的或者无法识别的词汇,当模型在处理文本时遇到不在词汇表中的词汇,就会用[UNK]表示,这样可以避免模型对于未知词汇的处理错误。
  • [CLS](分类):在自然语言处理中的一些任务中,例如文本分类,文本语义相似度等,[CLS]被用作特殊的起始标记。模型会将[CLS]作为整个序列的表示,并用于后续任务的处理。例如:对于文本分类任务,模型可以根据[CLS]表示进行预测分类标签。
  • [SEP](分隔符):用于表示序列的分隔,常见于处理多个句子或多个文本之间的任务。他可以用来分隔句子,句子对,文本片段等,在一些任务中,如文本对分类,问答系统等,输入可能包含多个句子或文本片段,模型可以通过识别[SEP]来理解不同句子之间的关系。另外,[SEP]还可以用于标记序列的结束,在一些 情况下,序列的长度是固定的,使用[SEP]标记还可以表示序列的结束。

继续看我们的代码:

 # 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)

 首先,通过argsort()函数对得到的所有outputs进行从小到大的排序,然后经过[::-1]将排序结果进行反转,这样就变成了从大到小的排序,[:100]表示只选择排名前100的token。返回的是_probas中元素从小到大排序的索引数组。

得到索引数组之后,我们通过p=_probas[p_args],根据排名靠前的token的索引p_args,从预测的概率分布_probas中获取对应的概率值,这样得到的p就是排列后的概率顺序(概率值从大到小)。

之后对概率进行归一化,将概率值除以概率之和,使其概率和为1。

target_index = np.random.choice(len(p), p=p)根据参数p(给定的概率分布)在长度为len(p)的范围内进行随机选择,并返回选择的索引target_index。(显然,元素对应的p值越大,被选择的概率就越高)。

后面的target = p_args[target_index] + 3,其中+3是必不可少的,我刚刚因为没有+3就运行报错了,因为我们得到的target是对应于token_ids的索引,而token_ids的前三项是特殊字符,为了和实际中的模型相对应,我们需要将预测的token编号做一些偏离处理。

最后我们将我们的预测列表中加上我们预测的词的索引值,也就是target。

看着一段:

 if target == 3:break
return tokenizer.decode(token_ids)

如果生成的target的值是3,表示生成的token是特殊的[CLS]标记,这意味着生成的序列已结束。
返回我们解码之后的结果也就是生成的诗词。等于3的情况是存在的,3意味着结束标志SPE,当句子的长度合适时,其预测的下一个词很可能是3,我们规定3也属于可预测范围。

看生成藏头诗的代码:

def generate_acrostic(tokenizer, model, head):"""随机生成一首藏头诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param head: 藏头诗的头:return: 一个字符串,表示一首古诗"""# 使用空串初始化token_ids,加入[CLS]token_ids = tokenizer.encode('')token_ids = token_ids[:-1]# 标点符号,这里简单的只把逗号和句号作为标点punctuations = [',', '。']punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}# 缓存生成的诗的listpoetry = []# 对于藏头诗中的每一个字,都生成一个短句for ch in head:# 先记录下这个字poetry.append(ch)# 将藏头诗的字符转成token idtoken_id = tokenizer.token_to_id(ch)# 加入到列表中去token_ids.append(token_id)# 开始生成一个短句while True:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)# 只有不是特殊字符时,才保存到poetry里面去if target > 3:poetry.append(tokenizer.id_to_token(target))if target in punctuation_ids:breakreturn ''.join(poetry)

我们看一下tokenizer.id_to_token(3,2,1,0)

这段代码和上面的随机生成古诗类似,不同的是:

上面的是完整生成一首古诗,其中就算遇到“,”或者“。”不会中断,直到遇到结束符号[SEP]才会中止。

而我们生成的藏头诗是对每一个字进行生成,这里遇到“,”,“。”就进行下一个字的继续生成。然后将每个字的自动生成进行拼接。得到完整的诗句。(当然,遇到[SEP]也是进行终止的,保证了句式的协调)。

return ''.join(poetry)

作用是将我们的peotry中的内容(短句),连接成一个字符串。并将其作为函数值返回。

如果不加join函数的话,输出是一个列表:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 安装/升级 svn

文章目录 下载最新版本安装包安装 下载最新版本安装包 wget https://dlcdn.apache.org/subversion/subversion-1.14.3.tar.gz tar -zxf subversion-1.14.3.tar.gz cd subversion-1.14.3 安装 ./configure 报错&#xff0c;提示缺少 apr-util 库&#xff0c;有的环境可能 apr 库…

人工智能|机器学习——CURE聚类算法(层次聚类)

1.CURE聚类概述 绝大多数聚类算法或者擅长处理球形和相似大小的聚类&#xff0e;或者在存在孤立点时变得比较脆弱。CURE采用了一种新颖的层次聚类算法&#xff0e;该算法选择基于质心和基于代表对象方法之间的中间策略。它不同于单个质心或对象来代表一个类&#xff0c;而是选择…

大话设计模式——6.工厂方法模式(Factory Method Pattern)

1.介绍 工厂方法模式也称工厂模式&#xff0c;是简单工厂模式的进一步抽象。定义一个用于创建对象的接口&#xff0c;使一个类的实例化延迟到其子类&#xff0c;让子类决定实例化哪个类。通过工厂父类定义负责创建产品的公共接口&#xff0c;通过子类确定所需要创建的类型。 属…

《ElementPlus 与 ElementUI 差异集合》el-input 多包裹一层 el-input__wrapper

差异 element-ui el-input 中&#xff0c;<div class"el-input"> 下一级就是 <input> 标签 &#xff1b;element-plus el-input中&#xff0c;<div class"el-input"> 和 <input> 标签之间多了一层 <div class"el-input__…

Nginx、LVS、HAProxy工作原理和负载均衡架构

当前大多数的互联网系统都使用了服务器集群技术&#xff0c;集群是将相同服务部署在多台服务器上构成一个集群整体对外提供服务&#xff0c;这些集群可以是 Web 应用服务器集群&#xff0c;也可以是数据库服务器集群&#xff0c;还可以是分布式缓存服务器集群等等。 在实际应用…

提升零售行业竞争力的信息抽取技术应用与实践

一、引言 在当今快速发展的零售行业中&#xff0c;沃尔玛、家乐福等大型连锁超市为消费者提供了丰富的日常食品和日用品。为了进一步提升客户体验和优化库存管理&#xff0c;这些零售巨头纷纷开始探索和应用先进的信息抽取技术。 本文将深入探讨一个成功的信息抽取项目&#…

使用HttpRequest工具类调用第三方URL传入普通以及文件参数并转换MultipartFile成File

使用HttpRequest工具类调用第三方URL传入普通以及文件参数 一、依赖及配置二、代码1、模拟第三方服务2、调用服务3、效果实现 一、依赖及配置 <!--工具依赖--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId&g…

嵌入式单片机学习思路感想分享

今天看到了一个提问,原话如下: 曾经干了8年单片机工程师,对工程师从入门,到入行,再到普通,再到高级,整个路径还算清晰,比如什么阶段,会碰到什么瓶颈,怎么突破,我都经历过。 这个同学,有个典型的问题,就是学得太多且杂了,估计稍微复杂点的项目,做不出来。 现在…

Spring Boot 中@Scheduled是单线程还是多线程?

在开发Spring Boot应用程序时&#xff0c;定时任务是一项常见的需求。Spring Boot提供了Scheduled注解&#xff0c;可用于将方法标记为定时任务&#xff0c;并在预定的时间间隔内执行。那么Scheduled注解的执行方式是单线程执行&#xff0c;还是多线程执行&#xff1f;Schedule…

docker小白第十二天

docker小白第十二天 docker network简介 docker不启动时默认的网络情况。 # 停止docker服务 systemctl stop docker.socket systemctl stop docker # 查看docker镜像 docker images输入查看docker镜像命令后&#xff0c;显示未连接到docker服务器 docker启动时网络情况 sy…

使用BBDown下载bilibili视频的方法

一款命令行式哔哩哔哩下载器. Bilibili Downloader. 下载地址 https://github.com/nilaoda/BBDown 功能 番剧下载(Web|TV|App) 课程下载(Web) 普通内容下载(Web|TV|App) 合集/列表/收藏夹/个人空间解析 多分P自动下载 选择指定分P进行下载 选择指定清晰度进行下载 下载外挂字幕…

html--bug

文章目录 html html <!DOCTYPE html> <html><head><meta charset"UTF-8"><title>老师</title><style>body {background-color: #008000;margin: 0px;cursor: none;overflow: hidden;}</style></head><bod…

88. 合并两个有序数组 (Swift版本)

题目 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数组同样按 非递减顺序 排列。 注意&#xff1a;最终&#xff0c;合并…

深度解析Java JDK 1.8中Stream流的源码实现:带你探寻数据流的奥秘

文章目录 一、 Stream流概述1.1 什么是Stream流&#xff0c;以及它的主要特点和优势1.2 Stream流的基本操作&#xff1a;过滤、映射、排序等 二、 Stream流源码解析2.1 接口和基本概念2.2 创建流2.3 源码分析2.3.1 流的起始2.3.2 流的初始2.3.3 认识BaseStream2.3.4 Stream接口…

java-ssm-jsp基于java的餐厅点餐系统的设计与实现

java-ssm-jsp基于java的餐厅点餐系统的设计与实现 获取源码——》公主号&#xff1a;计算机专业毕设大全

RabbitMQ:1.概述及安装

概述 AMQP协议 MQ Message Queue&#xff08;消息队列&#xff09;是在消息的传输过程中保存消息的容器&#xff0c;多用于系统之间的异步通信 AMQP Advanced Message Queuing Protocol(高级消息队列协议)是一个网络协议&#xff0c;2006年AMQP规范发布【类比HTTP】 专门为消…

基于单片机的公交车IC卡操作系统的设计

目 录 摘 要 III Abstract IV 前 言 1 第一章 绪论 2 1.1 设计的背景和意义 2 1.2 设计的现状和发展 2 1.3 设计的目的与意义 2 第二章 总体设计 4 2.1 总体方案的设计与实现 4 2.1.1 主要设计的内容 4 2.1.2 系统的总体设计 4 2.2 系统方案论证 5 2.2.1 单片机的选择 6 2.2.2…

Ubuntu虚拟机的IP总频繁变化,导致Xshell断开连接

文章目录 一、IP变化的原因二、解决方法&#xff1a;固定IP三、参考文章 一、IP变化的原因 1.DHCP协议 虚拟机系统(Ubuntu、CentOS、UOS等Linux系统)启动后&#xff0c;加入本地局域网网络时&#xff0c;会向本地网络申请租约一个IP地址&#xff0c;租约时长不定。我这里租约时…

独立开发的轻量级简洁开源论坛BBS PHP源码 – 2023新版发布

最新的轻量级开源论坛php源码发布啦&#xff01;这是一款独立开发的论坛系统&#xff0c;可以帮助你快速地开发出你想要的网站。 如果你是PHP初学者&#xff0c;这款论坛系统非常适合你入门学习。不过&#xff0c;需要注意的是&#xff0c;由于它并没有进行商业化改造&#xf…

如何正确地设置Outlook SMTP发送电子邮件?

Outlook SMTP发送邮件配置方法&#xff1f;Outlook怎么开启SMTP&#xff1f; 在使用Outlook发送邮件时&#xff0c;正确设置SMTP服务器是确保邮件能够顺利发送的关键步骤。接下来&#xff0c;就让AokSend一起探讨如何正确地设置Outlook SMTP发送电子邮件吧&#xff01; Outlo…