人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shapetorch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/171560.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

itop4412移植lrzsz工具踩坑笔记

4412开发板在传输文件一直用的都是tftp文件传输,但这样效率有点慢,平常在linux上习惯用lrzsz工具来传输文件,特此记录下,因为不熟悉linux编译 踩坑了很多地方 在操作前 我们的虚拟机要线安装好编译环境 下载lrzsz源码&#xff0…

一起学docker系列之十docker安装tomcat实践

目录 前言1 安装tomcat的步骤步骤 1: 查找并拉取 Tomcat 镜像步骤 2: 运行 Tomcat 容器步骤 3: 管理 Tomcat 容器步骤 4: 访问 Tomcat 首页 2 解决访问首页的404访问不到的问题2.1 Tomcat 10 的默认设置问题2.2 端口映射或防火墙问题 3 推荐使用 Tomcat 8.5 版本总结 前言 当安…

最轻量级最完整的屏幕适配完全适配各个手机方案

当你看到这篇博客的时候,说明你已经迈出了惊人的一步,已经慢慢进入高级资深开发工程师行列了,这是开发之路必备技能。 当你接到一个任务时,每天按照需求原型、设计师UI图立刻积极的开发完成后,满满的兴高采烈去打包提测,板凳还没做安稳,测试人员就提了一个又一个的BUG,…

【华为OD题库-037】跳房子2-java

题目 跳房子,也叫跳飞机,是一种世界性的儿童游戏游戏。参与者需要分多个回合按顺序跳到第1格直到房子的最后一格,然后获得一次选房子的机会,直到所有房子被选完,房子最多的人获胜。 跳房子的过程中,如果有踩…

【Docker】从零开始:11.Harbor搭建企业镜像仓库

【Docker】从零开始:11.Harbor搭建企业镜像仓库 1. Harbor介绍2. 软硬件要求(1). 硬件要求(2). 软件要求 3.Harbor优势4.Harbor的误区5.Harbor的几种安装方式6.在线安装(1).安装composer(2).配置内核参数,开启路由转发(3).下载安装包并解压(4).创建并修改配置文件(5…

element-ui DatePicker 日期选择器-控制选择精确到时分秒-禁止选择今天之前-或者今天之后日期### 前言

前言 最近在使用芋道框架时候,后端使用生成代码,时间因为类型问题,只能是时间戳,否则为空(1970-) 前端其实很简单只要在日期选择器把类型改成时间错即可,但根据业务需求需要精确到时分秒 把时…

python+pytest接口自动化(1)-接口测试基础

一般我们所说的接口即API,那什么又是API呢,百度给的定义如下: API(Application Programming Interface,应用程序接口)是一些预先定义的接口(如函数、HTTP接口),或指软件系…

3款免费的语音视频转文本AI神器

最近有很多粉丝让我出一期关于语音转文本的免费AI神器,毕竟这类工具在学习和工作中经常会用到,那今天就给大家安排。 我亲测了好几款软件之后,最终评选留下了三款 剪映hugging face飞书妙记 接下来一一给大家讲解 1.剪映 剪映其实是一款视…

什么是proxy代理?

1. 什么是proxy代理 代理(Proxy)是 JavaScript 中一种非常强大而灵活的功能。代理允许你拦截并覆盖对象的默认行为,提供了一种拦截、定制和扩展对象操作的机制。 简单说,就是在访问对象属性或者赋值时,可以做一些额外…

引用、动态内存分配、函数、结构体

引用 定义和初始化 **数据类型 &引用名 目标名;**引用和目标共用同一片空间(相当于对一片空间取别名)。 引用的底层实现:数据类型 * const p; ------> 常指针 int const *p; -----> 修饰 *p const int *p; ----->…

第六题-红和蓝【第六届传智杯程序设计挑战赛解题分析详解复盘】(JavaPythonC++实现)

🚀 欢迎来到 ACM 算法题库专栏 🚀 在ACM算法题库专栏,热情推崇算法之美,精心整理了各类比赛题目的详细解法,包括但不限于ICPC、CCPC、蓝桥杯、LeetCode周赛、传智杯等等。无论您是刚刚踏入算法领域,还是经验丰富的竞赛选手,这里都是提升技能和知识的理想之地。 ✨ 经典…

stm32实现0.96oled图片显示,菜单功能

stm32实现0.96oled图片显示,菜单功能 功能展示简介代码介绍oled.coled.holedfont.h(字库文件)main函数 代码思路讲解 本期内容,我们将学习0.96寸oled的进阶使用,展示图片,实现菜单切换等功能,关…

QT visual stdio加载动态库报错126问题

报错126是找不到指定的模块 QT 查看构建目录,将依赖的动态库放到该目录下即可成功 visual stdio将依赖的动态库放到运行目录 在vs中使用导出类的动态库时,不但需要将对应的.dll放到对应的目录下,还需要将该动态库对应的.lib添加到如下配置才…

【SAS Planet 下载地图瓦片】

SAS Planet是一位俄罗斯爱好者创建的的开源应用,该应用可以浏览与下载主流网络地图,包括Google地图、Google地球、Bing地图、Esri 地图、Yandex地图等,支持100多图源。 安装包下载地址:https://www.sasgis.org/download/ github…

337. 打家劫舍III (二叉树)

题目 题解 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Solution:def rob(self, root: Optional[TreeNode]) ->…

指针的进阶

重中之重: 目录 1.字符指针: 2.指针数组 3.数组指针 4.数组参数、指针参数 5.函数指针 1.字符指针: 一般实现: int main() {char ch w;char *pc &ch;*pc w;return 0; } 二班实现: int main() {const c…

第一题-字符串拼接【第六届传智杯程序设计挑战赛解题分析详解复盘】(C/C++实现)

🚀 欢迎来到 ACM 算法题库专栏 🚀 在ACM算法题库专栏,热情推崇算法之美,精心整理了各类比赛题目的详细解法,包括但不限于ICPC、CCPC、蓝桥杯、LeetCode周赛、传智杯等等。无论您是刚刚踏入算法领域,还是经验丰富的竞赛选手,这里都是提升技能和知识的理想之地。 ✨ 经典…

CSS3媒体查询实现不同宽度的下不同内容的展示

文章目录 前言CSS3 多媒体查询实例520 到 699px 宽度 - 添加邮箱图标700 到 1000px - 添加文本前缀信息大于 1001px 宽度 - 添加邮件地址大于 1151px 宽度 - 添加图标代码后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:CSS &#x1f43…

什么是闭包和作用域链?

1. 什么是闭包 闭包指的是那些引用了另一个函数作用域中变量的函数,通常是在嵌套函数中实现的。 举个栗子,createCounter 接受一个参数 n,然后返回一个匿名函数,这个匿名函数是闭包,它可以访问外部函数 createCounte…

智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.鲸鱼算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…