《动手学深度学习 Pytorch版》 8.6 循环神经网络的简洁实现

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

8.6.1 定义模型

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)
state = torch.zeros((1, batch_size, num_hiddens))
state.shape  # (隐藏层数,批量大小,隐藏单元数)
torch.Size([1, 32, 256])

通过一个隐状态和一个输入可以用更新后的隐状态计算输出。

需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算:它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

8.6.2 训练与预测

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
'time travellerffffffffff'
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)  # 比自己写的跑得快
perplexity 1.3, 213489.4 tokens/sec on cuda:0
time traveller held in his han so withtre scon the thin one mige
travellericho for the prof read haly and hes it nople hat d

在这里插入图片描述

练习

(1)尝试使用高级API,能使循环神经网络模型过拟合吗?

略。


(2)如果在循环神经网络模型中增加隐藏层的数量会发生什么?能使模型正常工作吗?

num_hiddens1 = 1024
rnn_layer1 = nn.RNN(len(vocab), num_hiddens1)net1 = RNNModel(rnn_layer1, vocab_size=len(vocab))
net1 = net1.to(device)num_epochs, lr = 500, 1
d2l.train_ch8(net1, train_iter, vocab, lr, num_epochs, device)  # 效果更好了,但是曲线没那么平滑了
perplexity 1.0, 97329.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

在这里插入图片描述


(3)尝试使用循环神经网络实现 8.1 节的自回归模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【python高级】设计模式、类工厂、对象工厂

一、说明 最近试着读Design pattern, 不过有些概念实在太抽象了, 整理一下自己所学抽象工厂的精神,就是要有abstract class(not implement),而所有不同种类的对象,都是继承这个abstract class&a…

Linux命令(94)之history

linux命令之history 1.history介绍 linux命令history会记录并显示用户所执行过的所有命令,也可以对其命令进行修改和删除操作。 2.history用法 history [参数] history参数 参数说明-a将当前会话的历史信息追加到历史文件(.bash_history)中-c删除所有条目从而清…

【云计算网络安全】DDoS 攻击类型:什么是 ACK 洪水 DDoS 攻击

文章目录 一、什么是 ACK 洪水 DDoS 攻击?二、什么是数据包?三、什么是 ACK 数据包?四、ACK 洪水攻击如何工作?五、SYN ACK 洪水攻击如何工作?六、文末送书《AWD特训营》内容简介读者对象 一、什么是 ACK 洪水 DDoS 攻…

猫头虎博客带您使用Markdown编辑器

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

BAT026:删除当前目录及子目录下的空文件夹

引言:编写批处理程序,实现批量删除当前目录及子目录下的空文件夹。 一、新建Windows批处理文件 参考博客: CSDNhttps://mp.csdn.net/mp_blog/creation/editor/132137544 二、写入批处理代码 1.右键新建的批处理文件,点击【编辑…

代码随想录训练营二刷第五十八天 | 583. 两个字符串的删除操作 72. 编辑距离

代码随想录训练营二刷第五十八天 | 583. 两个字符串的删除操作 72. 编辑距离 一、583. 两个字符串的删除操作 题目链接:https://leetcode.cn/problems/delete-operation-for-two-strings/ 思路:定义dp[i][j]为要是得区间[0,i-1]和区间[0,j-1]所需要删除…

leetcode做题笔记179. 最大数

给定一组非负整数 nums,重新排列每个数的顺序(每个数不可拆分)使之组成一个最大的整数。 注意:输出结果可能非常大,所以你需要返回一个字符串而不是整数。 示例 1: 输入:nums [10,2] 输出&am…

Linux实现无需手动输入密码的自动化SSH身份验证

SSH密钥身份验证是一种安全的方式,使您能够在无需手动输入密码的情况下连接到远程服务器。以下是如何设置SSH密钥身份验证,以便您的脚本能够自动运行: 步骤 生成SSH密钥对: 在您的本地系统上生成SSH密钥对。如果您尚未生成,请使用…

JimuReport 积木报表 v1.6.4 稳定版本正式发布 — 开源免费的低代码报表

项目介绍 一款免费的数据可视化报表,含报表和大屏设计,像搭建积木一样在线设计报表!功能涵盖,数据报表、打印设计、图表报表、大屏设计等! Web 版报表设计器,类似于excel操作风格,通过拖拽完成报…

【前端设计模式】之备忘录模式

备忘录模式是一种行为设计模式,它允许在不破坏封装性的前提下捕获和恢复对象的内部状态。在前端开发中,备忘录模式可以用于保存和恢复用户界面的状态,以及实现撤销和重做功能。 备忘录模式特性: 封装了对象的状态:备…

【微服务 SpringCloud】实用篇 · Eureka注册中心

微服务(3) 文章目录 微服务(3)1. Eureka的结构和作用2. 搭建eureka-server2.1 创建eureka-server服务2.2 引入eureka依赖2.3 编写启动类2.4 编写配置文件2.5 启动服务 3. 服务注册1)引入依赖2)配置文件3&am…

android 13.0 SystemUI导航栏添加虚拟按键功能(二)

1.概述 在13.0的系统产品开发中,对于在SystemUI的原生系统中默认只有三键导航,想添加其他虚拟按键就需要先在构建导航栏的相关布局 中分析结构,然后添加相关的图标xml就可以了,然后添加对应的点击事件,就可以了,接下来先分析第二步关于导航栏的相关布局情况 然后实现功能…

尚硅谷Flink(三)时间、窗口

1 🎰🎲🕹️ 🎰时间、窗口 🎲窗口 🕹️是啥 Flink 是一种流式计算引擎,主要是来处理无界数据流的,数据源源不断、无穷无尽。想要更加方便高效地处理无界流,一种方式就…

react简单写一个transition动画组件然后在modal组件中应用

1. 动画组件 原理: 利用transition和class类的切换执行transition的动画效果 import React, { useEffect } from "react"; import _ from "lodash";interface IProps {name: string;children: React.ReactNode;transitionShow: boolean; } sessionStorag…

期中考Web复现

第一题 1z_php <?php //Yeedo told you to study hard! echo !(!(!(!(include "flag.php")||(!error_reporting(0))||!isset($_GET[OoO])||!isset($_GET[0o0])||($_GET[OoO]2023)||!(intval($_GET[OoO][0])2023)||$_GET[0o0]$_GET[OoO]||!(md5($_GET[0o0])md5($_…

从头开始机器学习:线性回归

从头开始机器学习&#xff1a;线性回归 跟随 16 分钟阅读 28月 <> 1 一、说明 本篇实现线性回归的先决知识是&#xff1a;基本线性代数&#xff0c;微积分&#xff08;偏导数&#xff09;、梯度和、Python &#xff08;NumPy&#xff09;&#xff1b;从线性方程入手。 代…

基于vue实现滑块动画效果

主要实现&#xff1a;通过鼠标移移动、触摸元素、鼠标释放、离开元素事件来进行触发 创建了一个滑动盒子&#xff0c;其中包含一个滑块图片。通过鼠标按下或触摸开始事件&#xff0c;开始跟踪滑块的位置和鼠标/触摸位置之间的偏移量。然后&#xff0c;通过计算偏移量和起始时的…

【京东开源项目】微前端框架MicroApp 1.0正式发布

介绍 MicroApp是由京东前端团队推出的一款微前端框架&#xff0c;它从组件化的思维&#xff0c;基于类WebComponent进行微前端的渲染&#xff0c;旨在降低上手难度、提升工作效率。MicroApp无关技术栈&#xff0c;也不和业务绑定&#xff0c;可以用于任何前端框架。 源码地址…

vueday01——文本渲染与挂载

1.定义html样式字符串 const rawHtml "<span stylecolor:red>htmlTest</span>" 2.创建标签&#xff0c;分别渲染普通文本和html文本 <p> 你好<span v-html"rawHtml"></span></p> 3.代码展示 4.结果展示

day01——禁用按钮和输入框等组件

1.代码展示 <button :disabled"true" click"printId">Print ID {{ resultId }}</button> 2.非禁用情况 <button :disabled"false" click"printId">Print ID {{ resultId }}</button> 3.禁用情况 <butt…