8.6 循环神经网络的简洁实现

在这里插入图片描述
每个步长共用参数

加载数据

虽然 8.5节 对了解循环神经网络的实现方式具有指导意义,但并不方便。 本节将展示如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义(这将在 9.3节中介绍)。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

  1. 设置一个单层 256个隐藏单元的神经网络 rnn
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

单隐藏层, 256个隐藏单元

隐藏层数(Number of Hidden Layers):
定义:隐藏层数指的是神经网络中介于输入层和输出层之间的层级数量。
作用:增加隐藏层数可以帮助神经网络学习更复杂的模式和特征,提高模型的表示能力。
影响:过多的隐藏层可能导致过拟合,而过少的隐藏层可能导致欠拟合。
选择:隐藏层数的选择通常需要根据具体问题的复杂度和数据集的特征来决定,可以通过交叉验证等方法来选择合适的层数。
隐藏单元数(Number of Hidden Units):
定义:隐藏单元数指的是每个隐藏层中包含的神经元数量。
作用:增加隐藏单元数可以增加神经网络的拟合能力,使其能够更好地学习数据的复杂关系。
影响:隐藏单元数的增加会增加模型的复杂度,可能导致过拟合;减少隐藏单元数可能导致模型欠拟合。
选择:隐藏单元数的选择通常需要在训练过程中进行调优,可以通过交叉验证或其他调参方法来确定合适的数量。
在实际应用中,选择合适的隐藏层数和隐藏单元数是一项重要的任务,需要根据具体问题的复杂度、数据集的特征以及计算资源等因素进行权衡和调整,以获得最佳的模型性能

  1. 我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])
  1. 通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
  1. 与 8.5节类似, 我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。
#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1]))) # 输出层return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
'time travellerbbabbkabyg'

很明显,这种模型根本不能输出好的结果。 接下来,我们使用 8.5节中 定义的超参数调用train_ch8,并且使用高级API训练模型。

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
perplexity 1.3, 404413.8 tokens/sec on cuda:0
time travellerit would be remarkably convenient for the historia
travellery of il the hise fupt might and st was it loflers

在这里插入图片描述

#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型(定义见第8章)"""loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 训练和预测for epoch in range(num_epochs):ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)if (epoch + 1) % 10 == 0:print(predict('time traveller'))animator.add(epoch + 1, [ppl])print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')print(predict('time traveller'))print(predict('traveller'))
#@save
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练网络一个迭代周期(定义见第8章)"""state, timer = None, d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失之和,词元数量for X, Y in train_iter:if state is None or use_random_iter:# 在第一次迭代或使用随机抽样时初始化statestate = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(net, nn.Module) and not isinstance(state, tuple):# state对于nn.GRU是个张量state.detach_()else:# state对于nn.LSTM或对于我们从零开始实现的模型是个张量for s in state:s.detach_()y = Y.T.reshape(-1)X, y = X.to(device), y.to(device)y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)# 因为已经调用了mean函数updater(batch_size=1)metric.add(l * y.numel(), y.numel())return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

小结

  • 深度学习框架的高级API提供了循环神经网络层的实现。

  • 高级API的循环神经网络层返回一个输出和一个更新后的隐状态,我们还需要计算整个模型的输出层。

  • 相比从零开始实现的循环神经网络,使用高级API实现可以加速训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/785524.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

#include<初见C语言之指针(5)>

目录 一、sizeof和strlen的对比 1. sizeof 2.strlen 二、数组和指针题解析 1. ⼀维数组 1.1数组名理解 2.字符数组 3. ⼆维数组 三、指针运算题解析 总结 一、sizeof和strlen的对比 1. sizeof 我们前面介绍过sizeof是单目操作符 sizeof括号中有表达式,不…

解决PATH变量污染的问题

文章目录 解决PATH变量污染的问题概述笔记清空PATH变量之后的系统设置在命令行查看清空后的PATH变量以 gitea-1.17.1-gogit-windows-4.0-amd64.exe 为例以系统命令 where为例run_vs2019.bat备注 - 批处理的后缀最好是batEND 解决PATH变量污染的问题 概述 随着不断安装新软件,…

一文彻底搞懂如何创建线程

文章目录 1. java创建线程(Thread)方式2. 继承 Thread 类3. 实现 Runnable 接口4. 实现 Callable 接口5. 使用线程池6. 使用匿名类 1. java创建线程(Thread)方式 方式一:继承于Thread类,这是最常见的创建线程的方式之一,通过继承 Thread 类并…

BeanDefinition

这里写目录标题 BeanDefinitionBeanDefinitionReaderAbstractBeanDefinitionReaderXmlBeanDefinitionReader BeanDefinition 上述Spring的基本运行中,你所有的定义描述信息都在XML文件里面,如何读取呢? 当然通过 new ClassPathXmlApplicatio…

常见微服务的组件?

注册中心:就是一个服务注册的地方,我们可以把拆分的服务注册到注册中心,这样注册中心就能管理这些服务,服务之间的调用就会很方便,通过服务名就能相互调用。 负载均衡:被调用放的负载均衡,比如…

【智能算法】黄金正弦算法(GSA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2017年,Tanyildizi等人受到正弦函数单位圆内扫描启发,提出了黄金正弦算法(Golden Sine Algorithm, GSA)。 2.算法原理 2.1算法思想 GSA来源于正弦函…

FreeBSD下路由问题留档

碰到了一个非常神奇的事情,一台笔记本有以太网和wifi,都可以连到同一台路由器,ip地址配置在同一网段,但是如果插上网线,再拔掉网线的话,那么wifi即使连上,也无法上网。 看路由信息,发…

【Spring源码分析】透过源码看透Spring事务

阅读此需阅读下面这些博客先【Spring源码分析】Bean的元数据和一些Spring的工具【Spring源码分析】BeanFactory系列接口解读【Spring源码分析】执行流程之非懒加载单例Bean的实例化逻辑【Spring源码分析】从源码角度去熟悉依赖注入(一)【Spring源码分析】…

书生浦语笔记一

2023年6月,InternLM的第一代大模型正式发布。仅一个月后,该模型以及其全套工具链被开源。随后,在8月份,多模态语料库chat7B和lagent也被开源。而在接下来的9月份,InternLM20B的开源发布进一步加强了全线工具链的更新。…

所有企业都在用的微服务框架,需要多强的服务集成能力?

在数字化时代,随着业务规模的扩大和系统复杂性的增加,传统的单体应用架构由于其固有的局限性,已无法高效支撑企业日益增长的业务需求。 为了突破这一瓶颈,微服务架构以其独特的优势崭露头角,逐渐成为企业数字化转型的…

腾讯云容器与Serverless的融合:探索《2023技术实践精选集》中的创新实践

腾讯云容器与Serverless的融合:探索《2023技术实践精选集》中的创新实践 文章目录 腾讯云容器与Serverless的融合:探索《2023技术实践精选集》中的创新实践引言《2023腾讯云容器和函数计算技术实践精选集》整体评价特色亮点分析Serverless与Kubernetes的…

【刷题】 二分查找入门

送给大家一句话: 总有一天,你会站在最亮的地方,活成自己曾经渴望的模样—— 苑子文 & 苑子豪《我们都一样 年轻又彷徨》 二分查找入门 1 前言2 Leetcode 704. 二分查找2.1 题目描述2.2 算法思路 3 Leetcode 34. 在排序数组中查找元素的第一个和最后…

学习笔记——C语言基本概念指针(下)——(8)

1.指针和数组 数组指针 -- 指向数组的指针。 指针数组 -- 数组的元素都是指针。 换句话理解就是:数组指针就是个指针,指针数组就是个数组。 1.1数组指针 数组指针:指向数组的指针; 先回顾一下数组的特点: 1.相…

【C语言】联合体、枚举: 联合体与结构体区别,枚举的优点

目录 1、联合体 1.1、什么是联合体 1.2、联合体的声明 1.3、联合体的特点 1.4、联合体与结构体区别 1.5、联合体的大小 2、枚举 2.1、枚举类型的声明 2.2、枚举类型的优点 3、三种自定义类型:结构体、联合体、枚举 正文 1、联合体 1.1、什么是联合体 联…

如何使用route-detect在Web应用程序路由中扫描身份认证和授权漏洞

关于route-detect route-detect是一款功能强大的Web应用程序路由安全扫描工具,该工具可以帮助广大研究人员在Web应用程序路由中轻松识别和检测身份认证漏洞和授权漏洞。 Web应用程序HTTP路由中的身份认证(authn)和授权(authz&…

题目:小明的背包5(蓝桥OJ 1178)

问题描述&#xff1a; 解题思路&#xff1a; 分组背包模板题&#xff0c;与优化01背包的不同之处在于第一维不可省略&#xff0c;要写s循环。注意要初始化 #include <bits/stdc.h> using namespace std; const int N 1e3 9; int dp[N][N]; // 分组背包模板&#xff0c;…

代码随想录阅读笔记-二叉树【平衡二叉树】

题目 给定一个二叉树&#xff0c;判断它是否是高度平衡的二叉树。 本题中&#xff0c;一棵高度平衡二叉树定义为&#xff1a;一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过1。 示例 1: 给定二叉树 [3,9,20,null,null,15,7] 返回 true 。 示例 2: 给定二叉树 [1,2,…

ZKFair 创新之旅,新阶段如何塑造财富前景

在当前区块链技术的发展中&#xff0c;Layer 2&#xff08;L2&#xff09;解决方案已成为提高区块链扩容性、降低交易成本和提升交易速度的关键技术&#xff0c;但它仍面临一些关键问题和挑战&#xff0c;例如用户体验的改进、跨链互操作性、安全性以及去中心化程度。在这些背景…

Unity中UI系统1——GUI

介绍 工作原理和主要作用 基本控件 a.文本和按钮控件 练习&#xff1a; b.多选框和单选框 练习&#xff1a; 用的是第三种方法 c.输入框和拖动框 练习&#xff1a; 练习二&#xff1a; e.图片绘制和框 练习&#xff1a; 复合控件 a.工具栏和选择网格 练习&#xff1a; b.滚动视…

纷享销客如何向生态型CRM进化 创始人罗旭给出了答案

自己挣1块钱时&#xff0c;渠道合作伙伴能够挣1块甚至更多。这是纷享销客与生态共建之道。 2024年纷享销客北方战区渠道生态伙伴发展共建会于日前在北京举行。在这场主题为“聚力纷享共赢巅峰”的大会上&#xff0c;各方探讨了企业高质量增长之源与SaaS行业渠道发展之路&#…