RNN预测下一句文本简单示例

根据句子前半句的内容推理出后半部分的内容,这样的任务可以使用循环的方式来实现。

RNN(Recurrent Neural Network,循环神经网络)是一种用于处理序列数据的强大神经网络模型。与传统的前馈神经网络不同,RNN能够通过其循环结构捕获序列内部的时间依赖性或顺序信息。

在RNN中,每个时间步(timestep)的隐藏状态不仅取决于当前输入,还与上一时间步的隐藏状态有关。这种递归特性使得网络能记忆过去的信息,并将其与当前输入相结合以做出决策或生成输出。

由于存在“梯度消失”和“梯度爆炸”的问题,在长序列建模时原始RNN可能效果不佳。因此,发展出了更复杂的变体,如LSTM(Long Short-Term Memory)和GRU(Gated Recurrent Units),它们通过门控机制更好地保留长期依赖信息。这些改进后的循环神经网络广泛应用于语音识别、自然语言处理(NLP)、机器翻译、视频分析等多种领域。

training_file = 'wordstest.txt' 在里面随便写入一些文章,当做数据,

具体代码如下,写了注释

import torch
import torch.nn.functional as F
import time
import random
import numpy as np
from collections import Counter# 确保每次结果可复现
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def elapsed(sec):if sec<60:return str(sec) + " sec"elif sec<(60*60):return str(sec/60) + " min"else:return str(sec/(60*60)) + " hr"#中文多文件
def readalltxt(txt_files):labels = []for txt_file in txt_files:target = get_ch_lable(txt_file)labels.append(target)return labelsdef get_ch_lable(txt_file):"""读取数据:param txt_file::return:"""labels = ""with open(txt_file, 'rb') as f:for label in f:labels += label.decode('utf-8')#labels += label.decode('gb2312')return labelsdef get_ch_lable_v(txt_file, word_num_map, txt_label=None):"""字符转向量:param txt_file::param word_num_map::param txt_label::return:"""words_size = len(word_num_map)to_num = lambda word: word_num_map.get(word, words_size)if txt_file != None:txt_label = get_ch_lable(txt_file)labels_vector = list(map(to_num, txt_label))return labels_vector# 文本预处理,生成词向量
training_file = 'wordstest.txt'
training_data = get_ch_lable(training_file)
print("Loaded training data...")
print('样本长度:', len(training_data))
counter = Counter(training_data)
words = sorted(counter)
words_size= len(words)
word_num_map = dict(zip(words, range(words_size)))  # 给每个字构建索引,通过索引来处理计算预测每一个字
print('字表大小:', words_size)
wordlabel = get_ch_lable_v(training_file, word_num_map)'''
GRU 构建 RNN 模型
1、将输入的文字索引转为词嵌入
2、将词嵌入结果输入用 GRU 所形成的网络层
3、对步骤 2 的输出结果做全连接处理,得到维度为【字表长度】的预测结果,这个结果代表的是每个文字的频率
'''
class GRURNN(torch.nn.Module):def __init__(self, word_size, embed_dim,hidden_dim, output_size, num_layers):super(GRURNN, self).__init__()self.num_layers = num_layersself.hidden_dim = hidden_dimself.embed = torch.nn.Embedding(word_size, embed_dim)self.gru = torch.nn.GRU(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers, bidirectional=True)# bidirectional=True 代表网络是双向的,从前往后,从后往前# hidden_dim*2 代表包含了两个维度的层数# 全连接层(线性层),它将接收前面双向GRU输出的隐藏状态作为输入self.fc = torch.nn.Linear(hidden_dim*2, output_size)def forward(self, features, hidden):embedded = self.embed(features.view(1, -1))output, hidden = self.gru(embedded.view(1, 1, -1), hidden)output = self.fc(output.view(1, -1))return output, hiddendef init_zero_state(self):"""一个初始化隐藏状态的方法,主要用于循环神经网络(RNN)类的实例。这个方法的作用是为RNN创建一组全零初始隐藏状态。self.num_layers * 2: 表示双向RNN时的层数(如果模型是双向的,即参数bidirectional=True),因为每个方向都会有一个隐藏层,所以总共有num_layers * 2个隐藏层。1: 表示批量大小(batch size),在这里初始化的是单个样本的隐藏状态,因此设置为1。若需要处理批量数据,则应根据实际批量大小调整。self.hidden_dim: 表示隐藏层的维度(hidden dimension),也就是每个隐藏单元的特征数量。:return:"""init_hidden = torch.zeros(self.num_layers * 2, 1, self.hidden_dim).to(DEVICE)return init_hiddenEMBEDDING_DIM = 10  # 向量的维度或者说长度
HIDDEN_DIM = 20  # 每一个隐藏层的神经元数量
NUM_LAYERS = 1  # 隐藏层数量model = GRURNN(words_size, EMBEDDING_DIM, HIDDEN_DIM, words_size, NUM_LAYERS)
model = model.to(DEVICE)  # 将模型移动到指定设备上进行计算
# model.parameters():获取模型中所有需要优化的参数。
# Adam:是优化算法的一种,它基于梯度下降法,并结合了动量项(Momentum)和自适应学习率调整策略(RMSProp)。Adam通常在很多深度学习任务中表现良好,因为它能够自动调整学习率并减少对初始化学习率的敏感性。
# lr=0.005:表示设置学习率为0.005,这是Adam算法中的一个重要超参数,决定了每次更新参数时步伐的大小。
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)def evaluate(model, prime_str, predict_len, temperature=0.8):"""评估函数:param model::param prime_str: 一个表示起始序列的整数列表,每个整数代表词汇表中的索引:param predict_len: 指定要预测的字符或单词数量:param temperature: 控制生成文本时随机性的一个超参数,较小的值会让模型更倾向于生成概率最高的结果,较大的值则会增加多样性:return:"""hidden = model.init_zero_state().to(DEVICE)predicted = ''# 处理输入语义# 将生成的字符添加到预测结果字符串 predicted 中for p in range(len(prime_str) - 1):_, hidden = model(prime_str[p], hidden)predicted += words[prime_str[p]]# 用最后一个输入字符开始进行预测inp = prime_str[-1]predicted += words[inp]for p in range(predict_len):output, hidden = model(inp, hidden)#从多项式分布中采样# 将模型输出转换为分布形式,通过除以温度 temperature 并求指数得到softmax分布output_dist = output.data.view(-1).div(temperature).exp()# 根据调整后的分布采样下一个字符的索引inp = torch.multinomial(output_dist, 1)[0]predicted += words[inp]return predicted#定义参数训练模型
training_iters = 5000
display_step = 1000
n_input = 4
step = 0
offset = random.randint(0, n_input+1)  # 每次迭代结束时,将偏移值向后移动 n_input+1 个距离,保证输入样本的相对均匀
end_offset = n_input + 1while step < training_iters:start_time = time.time()# 随机取一个位置偏移if offset > (len(training_data) - end_offset):offset = random.randint(0, n_input+1)# 取出偏移量为 4 的数据长度,因为文本时序列数据inwords = wordlabel[offset:offset + n_input]# [n_input, -1, 1] 表示重塑后的三维形状:# 第一维是序列长度(即每个序列有 n_input 个元素),# 第二维 -1 表示自动计算以适应原始数据大小,# 第三维为通道数(这里设为1,通常用于表示一维特征)inwords = np.reshape(np.array(inwords), [n_input, -1,  1])# 编码out_onehot = wordlabel[offset+1:offset+n_input+1]# 初始化隐藏层hidden = model.init_zero_state()'''模型完成一次前向传播计算并得到损失(loss)后,在反向传播(backpropagation)之前,需要调用这个函数来清零所有可训练参数的梯度。在开始新一轮的前向传播和反向传播之前,使用 optimizer.zero_grad() 来清零所有参数的梯度是至关重要的,确保每次优化步骤只基于当前批次数据计算出的梯度来进行参数更新'''optimizer.zero_grad()'''模型训练'''loss = 0.# 将输入数据 inwords 和目标数据 out_onehot 转换为PyTorch张量inputs, targets = torch.LongTensor(inwords).to(DEVICE), torch.LongTensor(out_onehot).to(DEVICE)for c in range(n_input):# 当前时间步的输入和前一时间步的隐藏状态运行模型,得到输出 (outputs) 和新的隐藏状态 (hidden)。outputs, hidden = model(inputs[c], hidden)# 计算当前时间步的交叉熵损失(Cross Entropy Loss),将模型预测的输出与实际的目标标签比较loss += F.cross_entropy(outputs, targets[c].view(1))# 所有时间步完成后,平均损失值loss /= n_input# 反向传播计算梯度:调用 .backward() 函数来计算关于损失函数关于模型参数的梯度loss.backward()# 使用优化器(在这里是 optimizer)根据计算出的梯度更新模型参数optimizer.step()#输出日志# with torch.set_grad_enabled(False): 这一上下文管理器用于在计算过程中暂时禁用梯度计算。这样,在打印损失、评估模型性能等操作时,# 不会占用额外的内存来存储中间计算的梯度,同时避免不必要的反向传播计算。with torch.set_grad_enabled(False):if (step+1) % display_step == 0:print(f'Time elapsed: {(time.time() - start_time)/60:.4f} min')print(f'step {step+1} | Loss {loss.item():.2f}\n\n')# torch.no_grad() 上下文管理器再次禁用梯度计算,以便于高效地进行模型评估,并且不影响之前或之后的梯度计算状态。with torch.no_grad():print(evaluate(model, inputs, 32), '\n')print(50*'=')step += 1offset += (n_input+1)#中间隔了一个,作为预测print("Finished!")# 使用模型
while True:prompt = "请输入几个字,最好是%s个: " % n_inputsentence = input(prompt)inputword = sentence.strip()try:inputword = get_ch_lable_v(None, word_num_map, inputword)keys = np.reshape(np.array(inputword), [len(inputword), -1, 1])'''调用 model.eval() 方法将模型设置为评估模式。在评估模式下,模型中的批量归一化层(如果有)会使用经过训练时平均的移动统计量,并且不会更新模型参数(梯度计算被禁用)。接下来,通过 with torch.no_grad(): 语句创建了一个临时上下文,在此上下文中执行所有操作时都不会累积梯度。这对于生成任务非常关键,因为在这种情况下我们并不关心反向传播以更新模型权重,而是要利用当前模型状态来生成文本。'''model.eval()with torch.no_grad():sentence = evaluate(model, torch.LongTensor(keys).to(DEVICE), 32)print(sentence)except:print("该字我还没学会")

运行结果类似:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/652180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

32GPIO输入LED闪烁蜂鸣器

一.GPIO简介 所有的GPIO都挂载到APB2上&#xff0c;每个GPIO有&#xff11;&#xff16;个引脚 内核可以通过APB&#xff12;对寄存器进行读写&#xff0c;寄存器都是32位的&#xff0c;但每个引脚端口只有&#xff11;&#xff16;位 驱动器用于增加信号的驱动能力 二.具体…

【Go】深入理解 Go map:赋值和扩容迁移 ①

文章目录 map底层实现hmapbmap map hash冲突了怎么办&#xff1f; map扩容触发扩容时机扩容小结为什么map扩容选择增量&#xff08;渐进式扩容&#xff09;&#xff1f;迁移是逐步进行的。那如果在途中又要扩容了&#xff0c;怎么办&#xff1f; map翻倍扩容原理 map写入数据内…

数据库查询3

目录 1. 多表查询 1.1.1 介绍 1.1.2 分类 1.2 内连接 1.3 外连接 1.4 子查询 1.4.1 介绍 1.4.2 标量子查询 1.4.3 列子查询 1.4.4 行子查询 1.4.5 表子查询 2. 事务 2.1 操作 2.2 四大特性 数据库总结2 数据库总结1 1. 多表查询 1.1.1 介绍 多表查询&#xff…

研发日记,Matlab/Simulink避坑指南(七)——数据溢出钳位Bug

文章目录 前言 背景介绍 问题描述 分析排查 解决方案 总结归纳 前言 见《研发日记&#xff0c;Matlab/Simulink避坑指南(二)——非对称数据溢出Bug》 见《研发日记&#xff0c;Matlab/Simulink避坑指南(三)——向上取整Bug》 见《研发日记&#xff0c;Matlab/Simulink避坑…

C语言第十一弹---函数(下)

​ ✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 函数 1、嵌套调用和链式访问 1.1、嵌套调用 1.2、链式访问 2、函数的声明和定义 2.1、单个文件 2.2、多个文件 2.3、static 和 extern 2.3.1、static…

【自然语言处理】【深度学习】文本向量化、one-hot、word embedding编码

因为文本不能够直接被模型计算&#xff0c;所以需要将其转化为向量 把文本转化为向量有两种方式&#xff1a; 转化为one-hot编码转化为word embedding 一、one-hot 编码 在one-hot编码中&#xff0c;每一个token使用一个长度为N的向量表示&#xff0c;N表示词典的数量。 即&…

dos攻击与ddos攻击的区别

①DOS攻击&#xff1a; DOS&#xff1a;中文名称是拒绝服务&#xff0c;一切能引起DOS行为的攻击都被称为dos攻击。该攻击的效果是使得计算机或网络无法提供正常的服务。常见的DOS攻击有针对计算机网络带宽和连通性的攻击。 DOS是单机于单机之间的攻击。 DOS攻击的原理&#…

【GitHub项目推荐--常见的国内镜像】【转载】

由于国内网络原因&#xff0c;下载依赖包或者软件&#xff0c;对于不少互联网从业者来说&#xff0c;都有不小的挑战&#xff0c;时间浪费在这上边&#xff0c;实在可惜。这个项目介绍了常见依赖&#xff0c;软件的国内镜像&#xff0c;助力大家畅爽编码。 这是一个归纳梳理类…

C# 将HTML网页、HTML字符串转换为PDF

将HTML转换为PDF可实现格式保留、可靠打印、文档归档等多种用途&#xff0c;满足不同领域和情境下的需求。本文将通过以下两个示例&#xff0c;演示如何使用第三方库Spire.PDF for .NET和QT插件在C# 中将Html 网页&#xff08;URL&#xff09;或HTML字符串转为PDF文件。 HTML转…

【C语言/数据结构】排序(选择排序,推排序,冒泡排序)

&#x1f308;个人主页&#xff1a;秦jh__https://blog.csdn.net/qinjh_?spm1010.2135.3001.5343&#x1f525; 系列专栏&#xff1a;《数据结构》https://blog.csdn.net/qinjh_/category_12536791.html?spm1001.2014.3001.5482 ​​​​ 目录 选择排序 选择排序 ​编辑…

js实现动漫拼图2.0版

比较与1.0版&#xff0c;2.0版就更像与华容道类似的拼图游戏&#xff0c;从头到尾都只能控制白色块移动&#xff0c;而且打乱拼图和求助的实现与1.0都不相同 文章目录 1 实现效果2 实现思路2.1 打乱拼图2.2 求助功能2.3 判赢 3 代码实现 js实现动漫拼图1.0版 https://blog.csdn…

python222网站实战(SpringBoot+SpringSecurity+MybatisPlus+thymeleaf+layui)-菜单管理实现

锋哥原创的SpringbootLayui python222网站实战&#xff1a; python222网站实战课程视频教程&#xff08;SpringBootPython爬虫实战&#xff09; ( 火爆连载更新中... )_哔哩哔哩_bilibilipython222网站实战课程视频教程&#xff08;SpringBootPython爬虫实战&#xff09; ( 火…

cmake-find_package链接第三方库

文章目录 基本调用形式和模块模式使用方式 之前我们是使用了绝对路径来链接OpenCV第三方库&#xff0c;但是现在很多库一般会自己写一些cmake文件提供给用户&#xff0c;用户可以直接使用其中的内置变量即可。使用的命令就是find_package。 基本调用形式和模块模式 find_packa…

【RTP】webrtc 学习2: webrtc对h264的rtp打包

切片只是拷贝帧的split的各个部分到新的rtp 包的封装中。并没有在rtp包本身标记是否为关键帧FU-A 切片 输入的H.264 数据进行split :SplitNalu SplitNalu : 按照最大1200字节进行切分 切分后会返回一个数组 对于FU-A :split的数据总大小是 去掉一个字节的nalu header size …

实战 | OpenCV+OCR实现弧形文字识别实例(详细步骤 + 源码)

导 读 本文主要介绍基于OpenCV+OCR实现弧形文字识别实例,并给详细步骤和代码。源码在文末。 背景介绍 测试图如下,目标是正确识别图中的字符。图片来源: https://www.51halcon.com/forum.php?mod=viewthread&tid=6712 同样,论坛中已经给出了Halcon实现代码,…

1948-2022年金融许可信息明细数据

1948-2022年金融许可信息明细数据 1、时间&#xff1a;1948-2022年 2、来源&#xff1a;银监会&#xff08;银监会许可证发布系统&#xff09; 3、指标&#xff1a;来源表、机构编码、机构名称、所属银行、机构类型、业务范围、机构住所、地理坐标、行政区划代码、所属区县、…

【计算机网络】深入掌握计算机网络的核心要点(面试专用)

写在前面 前言四层模型网络地址管理Linux下设置ipARP请求包总结 前言 计算机网络是指将分散的计算机设备通过通信线路连接起来&#xff0c;形成一个统一的网络。为了使得各个计算机之间能够相互通信&#xff0c;需要遵循一定的协议和规范。OSI参考模型和TCP/IP参考模型是计算机…

(南京观海微电子)——OLED驱动与调试

一、OLED DDIC分类 OLED DDIC的技术方向可以分为3类&#xff1a;带Ram【内存】的IC、Ram-less IC和TDDI【显示&触控集成的IC】 1、带Ram的OLED DDIC OLED DDIC有两个Ram&#xff0c;分别是Demura Ram和Display Ram。 1、带Ram的OLED DDIC 1-1&#xff09;Demura Ram&a…

一张图文深入了解信息量概念

通信原理第10页最后一段&#xff1a; 概率论告诉我们&#xff0c;事件的不确定程度可以用其出现的概率来描述。因此&#xff0c;消息中包含的信息量与消息发生的概率密切相关。消息出现的概率越小&#xff0c;则消息中包含的信息量就越大。 这句话怎么理解呢&#xff1f; 比如…

安利6款免费又高清的视频转GIF方法,值得收藏

前言 平时我们在聊天的时候会发的很多有趣表情包&#xff0c;其实有些就是视频里面的画面&#xff0c;觉得好玩有趣就被网友转换成了GIF&#xff0c;聊天的时候就可以用这些表情包来代表当时的心情。 如何将视频转成GIF动图&#xff1f;对于还不知道怎么将视频转成GIF的朋友&a…