【深度学习笔记】6_5 RNN的pytorch实现

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.5 循环神经网络的简洁实现

本节将使用PyTorch来更简洁地实现基于循环神经网络的语言模型。首先,我们读取周杰伦专辑歌词数据集。

import time
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.5.1 定义模型

PyTorch中的nn模块提供了循环神经网络的实现。下面构造一个含单隐藏层、隐藏单元个数为256的循环神经网络层rnn_layer

num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)

与上一节中实现的循环神经网络不同,这里rnn_layer的输入形状为(时间步数, 批量大小, 输入个数)。其中输入个数即one-hot向量长度(词典大小)。此外,rnn_layer作为nn.RNN实例,在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输入。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数)。而nn.RNN实例在前向计算返回的隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每一层的隐藏状态都会记录在该变量中;对于像长短期记忆(LSTM),隐藏状态是一个元组(h, c),即hidden state和cell state。我们会在本章的后面介绍长短期记忆和深度循环神经网络。关于循环神经网络(以LSTM为例)的输出,可以参考下图(图片来源)。

在这里插入图片描述

循环神经网络(以LSTM为例)的输出

来看看我们的例子,输出形状为(时间步数, 批量大小, 隐藏单元个数),隐藏状态h的形状为(层数, 批量大小, 隐藏单元个数)。

num_steps = 35
batch_size = 2
state = None
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new[0].shape)

输出:

torch.Size([35, 2, 256]) 1 torch.Size([2, 256])

如果rnn_layernn.LSTM实例,那么上面的输出是什么?

接下来我们继承Module类来定义一个完整的循环神经网络。它首先将输入数据使用one-hot向量表示后输入到rnn_layer中,然后使用全连接输出层得到输出。输出个数等于词典大小vocab_size

# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):def __init__(self, rnn_layer, vocab_size):super(RNNModel, self).__init__()self.rnn = rnn_layerself.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) self.vocab_size = vocab_sizeself.dense = nn.Linear(self.hidden_size, vocab_size)self.state = Nonedef forward(self, inputs, state): # inputs: (batch, seq_len)# 获取one-hot向量表示X = d2l.to_onehot(inputs, self.vocab_size) # X是个listY, self.state = self.rnn(torch.stack(X), state)# 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出# 形状为(num_steps * batch_size, vocab_size)output = self.dense(Y.view(-1, Y.shape[-1]))return output, self.state

6.5.2 训练模型

同上一节一样,下面定义一个预测函数。这里的实现区别在于前向计算和初始化隐藏状态的函数接口。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,char_to_idx):state = Noneoutput = [char_to_idx[prefix[0]]] # output会记录prefix加上输出for t in range(num_chars + len(prefix) - 1):X = torch.tensor([output[-1]], device=device).view(1, 1)if state is not None:if isinstance(state, tuple): # LSTM, state:(h, c)  state = (state[0].to(device), state[1].to(device))else:   state = state.to(device)(Y, state) = model(X, state)if t < len(prefix) - 1:output.append(char_to_idx[prefix[t + 1]])else:output.append(int(Y.argmax(dim=1).item()))return ''.join([idx_to_char[i] for i in output])

让我们使用权重为随机值的模型来预测一次。

model = RNNModel(rnn_layer, vocab_size).to(device)
predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)

输出:

'分开戏想暖迎凉想征凉征征'

接下来实现训练函数。算法同上一节的一样,但这里只使用了相邻采样来读取数据。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes):loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)model.to(device)state = Nonefor epoch in range(num_epochs):l_sum, n, start = 0.0, 0, time.time()data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样for X, Y in data_iter:if state is not None:# 使用detach函数从计算图分离隐藏状态, 这是为了# 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)if isinstance (state, tuple): # LSTM, state:(h, c)  state = (state[0].detach(), state[1].detach())else:   state = state.detach()(output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size)# Y的形状是(batch_size, num_steps),转置后再变成长度为# batch * num_steps 的向量,这样跟输出的行一一对应y = torch.transpose(Y, 0, 1).contiguous().view(-1)l = loss(output, y.long())optimizer.zero_grad()l.backward()# 梯度裁剪d2l.grad_clipping(model.parameters(), clipping_theta, device)optimizer.step()l_sum += l.item() * y.shape[0]n += y.shape[0]try:perplexity = math.exp(l_sum / n)except OverflowError:perplexity = float('inf')if (epoch + 1) % pred_period == 0:print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, perplexity, time.time() - start))for prefix in prefixes:print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char,char_to_idx))

使用和上一节实验中一样的超参数(除了学习率)来训练模型。

num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

输出:

epoch 50, perplexity 10.658418, time 0.05 sec- 分开始我妈  想要你 我不多 让我心到的 我妈妈 我不能再想 我不多再想 我不要再想 我不多再想 我不要- 不分开 我想要你不你 我 你不要 让我心到的 我妈人 可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的
epoch 100, perplexity 1.308539, time 0.05 sec- 分开不会痛 不要 你在黑色幽默 开始了美丽全脸的梦滴 闪烁成回忆 伤人的美丽 你的完美主义 太彻底 让我- 不分开不是我不要再想你 我不能这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀 你 在我
epoch 150, perplexity 1.070370, time 0.05 sec- 分开不能去河南嵩山 学少林跟武当 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 习武之人切记 仁者无敌- 不分开 在我会想通 是谁开没有全有开始 他心今天 一切人看 我 一口令秋软语的姑娘缓缓走过外滩 消失的 旧
epoch 200, perplexity 1.034663, time 0.05 sec- 分开不能去吗周杰伦 才离 没要你在一场悲剧 我的完美主义 太彻底 分手的话像语言暴力 我已无能为力再提起- 不分开 让我面到你 爱情来的太快就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不
epoch 250, perplexity 1.021437, time 0.05 sec- 分开 我我外的家边 你知道这 我爱不看的太  我想一个又重来不以 迷已文一只剩下回忆 让我叫带你 你你的- 不分开 我我想想和 是你听没不  我不能不想  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 

小结

  • PyTorch的nn模块提供了循环神经网络层的实现。
  • PyTorch的nn.RNN实例在前向计算后会分别返回输出和隐藏状态。该前向计算并不涉及输出层计算。

注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733534.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python操作Redis 各种数据类型

本文将深入探讨如何使用Python操作Redis&#xff0c;覆盖从基础数据类型到高级功能的广泛主题。无论是字符串、列表、散列、集合还是有序集合&#xff0c;我们将一一解析&#xff0c;同时提供丰富的代码示例帮助读者更好地理解和应用。除此之外&#xff0c;本文还将介绍Redis的…

【20240309】WORD宏设置批量修改全部表格格式

WORD宏设置批量修改全部表格格式 引言1. 设置表格文字样式2. 设置表格边框样式3. 设置所有表格边框样式为075pt4. 删除行参考 引言 这两周已经彻底变为office工程师了&#xff0c;更准确一点应该是Word工程师&#xff0c;一篇文档动不动就成百上千页&#xff0c;表格图片也是上…

STM32之串口中断接收UART_Start_Receive_IT

网上搜索了好多&#xff0c;都是说主函数增加UART_Receive_IT()函数来着&#xff0c;实际正确的是UART_Start_Receive_IT()函数。 —————————————————— 参考时间&#xff1a;2024年3月9日 Cube版本&#xff1a;STM32CubeMX 6.8.1版本 参考芯片&#xff1a…

Svg Flow Editor 原生svg流程图编辑器(二)

系列文章 Svg Flow Editor 原生svg流程图编辑器&#xff08;一&#xff09; 说明 这项目也是我第一次写TS代码哈&#xff0c;现在还被绕在类型中头昏脑胀&#xff0c;更新可能会慢点&#xff0c;大家见谅~ 目前实现的功能&#xff1a;1. 元件的创建、移动、形变&#xff1b;2…

【C语言】字符指针

在指针的类型中我们知道有一种指针类型为字符指针char* 一般使用&#xff1a; int main() { char ch w; char *pc &ch; *pc w; return 0; } 还有一种使用方式&#xff0c;如下&#xff1a; int main() { const char* pstr "hello bit.";//这⾥是把⼀个字…

plantUML使用指南之序列图

文章目录 前言一、序列图1.1 语法规则1.1.1 参与者1.1.2 生命线1.1.3 消息1.1.4 自动编号1.1.5 注释1.1.6 其它1.1.7 例子 1.2 如何画好 参考 前言 在软件开发、系统设计和架构文档编写过程中&#xff0c;图形化建模工具扮演着重要的角色。而 PlantUML 作为一种强大且简洁的开…

【stm32 外部中断】

中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行 中断优先级&#xff1a;当有多个中…

LoadBalancer (本地负载均衡)

1.loadbalancer本地负载均衡客户端 VS Nginx服务端负载均衡区别 Nginx是服务器负载均衡&#xff0c;客户端所有请求都会交给nginx&#xff0c;然后由nginx实现转发请求&#xff0c;即负载均衡是由服务端实现的。 loadbalancer本地负载均衡&#xff0c;在调用微服务接口时候&a…

考研复习C语言初阶(4)+标记和BFS展开的扫雷游戏

目录 1. 一维数组的创建和初始化。 1.1 数组的创建 1.2 数组的初始化 1.3 一维数组的使用 1.4 一维数组在内存中的存储 2. 二维数组的创建和初始化 2.1 二维数组的创建 2.2 二维数组的初始化 2.3 二维数组的使用 2.4 二维数组在内存中的存储 3. 数组越界 4. 冒泡…

【Java JVM】Class 文件的加载

Java 虚拟机把描述类的数据从 Class 文件加载到内存, 并对数据进行校验, 转换解析和初始化, 最终形成可以被虚拟机直接使用的 Java 类型, 这个过程被称作虚拟机的类加载机制。 与那些在编译时需要进行连接的语言不同, 在 Java 语言里面, 类的加载, 连接和初始化过程都是在程序…

解决阿里云服务器开启frp服务端,内网服务器开启frp客户端却连接不上的问题

解决方法&#xff1a; 把阿里云自带的Alibabxxxxxxxlinux系统 换成centos 7系统&#xff01;&#xff01;&#xff01;&#xff01; 说一下我的过程和问题&#xff1a;由于我们内网的服务器在校外是不能连接的&#xff0c;因此我弄了个阿里云服务器做内网穿透&#xff0c;所谓…

大模型学习过程记录

一、基础知识 自然语言处理&#xff1a;能够让计算理解人类的语言。 检测计算机是否智能化的方法&#xff1a;图灵测试 自然语言处理相关基础点&#xff1a; 基础点1——词表示问题&#xff1a; 1、词表示&#xff1a;把自然语言中最基本的语言单位——词&#xff0c;将它转…

你应该打好你的日志,起码避免被甩锅

大家好&#xff0c;我是蓝胖子,相信大家或多或少都有这样的经历&#xff0c;当你负责的功能出现线上问题时&#xff0c;领导第一时间便是找到你询问原因&#xff0c;然而有时问题的根因或许不在你这儿&#xff0c;只是这个功能或许依赖了第三方或者内部其他部门&#xff0c;这个…

【Unity InputSystem】实用指南:在PC端(鼠标与键盘)、手机端(触摸屏)、主机手柄上同步实现角色移动与跳跃功能

前引 随着Unity的不断发展&#xff0c;开发者对于项目的输入系统要求也日益提高。在进行多平台适配和跨平台移植时&#xff0c;常常需要改变输入系统&#xff0c;这给开发者带来了不少困扰。而Unity官方推出的InputSystem插件&#xff0c;则是为了解决这一问题而推出的全新输入…

Linux内存管理--系列文章壹

一、引子 作者、我在上班闲着没事的时候&#xff0c;看了一些关于Linux内存管理和程序装载、链接的文章&#xff0c;然后自己就总结出了一些东西。 本系列文章一方面将资料中的长篇大论总结到最少、以方便可以直接找到答案&#xff0c;一方面也是方便面试的时候可以吹牛逼。 L…

【Docker】golang使用DockerFile正确食用指南

【Docker】golang使用DockerFile正确食用指南 大家好 我是寸铁&#x1f44a; 总结了一篇golang使用DockerFile正确食用指南✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 问题背景 今天寸铁想让编写好的go程序在docker上面跑&#xff0c;要想实现这样的效果&#xff0c;就需要用…

小程序 van-field label和输入框改成上下布局

在组件上面加个样式就行&#xff1a;custom-style"display:block;" <van-field label"备注说明" type"textarea" clearable title-width"100px" custom-style"display:block;" placeholder"请输入" /> …

大载重无人机基础技术,研发一款50KG负重六旋翼无人机技术及成本分析

六旋翼无人机是一种多旋翼无人机&#xff0c;具有六个旋翼&#xff0c;通常呈“X”形布局。它采用电动串列式结构&#xff0c;具有垂直起降、悬停、前飞、后飞、侧飞、俯仰、翻滚等多种飞行动作的能力。六旋翼无人机通常被用于航拍、农业植保、环境监测、地形测绘等领域。 六旋…

Day34-Linux网络管理4

Day34-Linux网络管理4 1. IP地址分类与子网划分基础1.1 什么是IP地址1.2 十进制与二进制的转换1.3 IP地址的分类1.4 私网地址和局域网地址 2. 通信类型3. 子网划分讲解3.1 为什么要划分子网&#xff1f;3.2 什么是子网划分&#xff1f;3.3 子网划分的作用&#xff1f;3.4 子网划…

云计算项目十一:构建完整的日志分析平台

检查k8s集群环境&#xff0c;master主机操作&#xff0c;确定是ready 启动harbor [rootharbor ~]# cd /usr/local/harbor [rootharbor harbor]# /usr/local/bin/docker-compose up -d 检查head插件是否启动&#xff0c;如果没有&#xff0c;需要启动 [rootes-0001 ~]# system…