人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shapetorch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/171560.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

itop4412移植lrzsz工具踩坑笔记

4412开发板在传输文件一直用的都是tftp文件传输,但这样效率有点慢,平常在linux上习惯用lrzsz工具来传输文件,特此记录下,因为不熟悉linux编译 踩坑了很多地方 在操作前 我们的虚拟机要线安装好编译环境 下载lrzsz源码&#xff0…

一起学docker系列之十docker安装tomcat实践

目录 前言1 安装tomcat的步骤步骤 1: 查找并拉取 Tomcat 镜像步骤 2: 运行 Tomcat 容器步骤 3: 管理 Tomcat 容器步骤 4: 访问 Tomcat 首页 2 解决访问首页的404访问不到的问题2.1 Tomcat 10 的默认设置问题2.2 端口映射或防火墙问题 3 推荐使用 Tomcat 8.5 版本总结 前言 当安…

【华为OD题库-037】跳房子2-java

题目 跳房子,也叫跳飞机,是一种世界性的儿童游戏游戏。参与者需要分多个回合按顺序跳到第1格直到房子的最后一格,然后获得一次选房子的机会,直到所有房子被选完,房子最多的人获胜。 跳房子的过程中,如果有踩…

【Docker】从零开始:11.Harbor搭建企业镜像仓库

【Docker】从零开始:11.Harbor搭建企业镜像仓库 1. Harbor介绍2. 软硬件要求(1). 硬件要求(2). 软件要求 3.Harbor优势4.Harbor的误区5.Harbor的几种安装方式6.在线安装(1).安装composer(2).配置内核参数,开启路由转发(3).下载安装包并解压(4).创建并修改配置文件(5…

python+pytest接口自动化(1)-接口测试基础

一般我们所说的接口即API,那什么又是API呢,百度给的定义如下: API(Application Programming Interface,应用程序接口)是一些预先定义的接口(如函数、HTTP接口),或指软件系…

3款免费的语音视频转文本AI神器

最近有很多粉丝让我出一期关于语音转文本的免费AI神器,毕竟这类工具在学习和工作中经常会用到,那今天就给大家安排。 我亲测了好几款软件之后,最终评选留下了三款 剪映hugging face飞书妙记 接下来一一给大家讲解 1.剪映 剪映其实是一款视…

引用、动态内存分配、函数、结构体

引用 定义和初始化 **数据类型 &引用名 目标名;**引用和目标共用同一片空间(相当于对一片空间取别名)。 引用的底层实现:数据类型 * const p; ------> 常指针 int const *p; -----> 修饰 *p const int *p; ----->…

stm32实现0.96oled图片显示,菜单功能

stm32实现0.96oled图片显示,菜单功能 功能展示简介代码介绍oled.coled.holedfont.h(字库文件)main函数 代码思路讲解 本期内容,我们将学习0.96寸oled的进阶使用,展示图片,实现菜单切换等功能,关…

QT visual stdio加载动态库报错126问题

报错126是找不到指定的模块 QT 查看构建目录,将依赖的动态库放到该目录下即可成功 visual stdio将依赖的动态库放到运行目录 在vs中使用导出类的动态库时,不但需要将对应的.dll放到对应的目录下,还需要将该动态库对应的.lib添加到如下配置才…

【SAS Planet 下载地图瓦片】

SAS Planet是一位俄罗斯爱好者创建的的开源应用,该应用可以浏览与下载主流网络地图,包括Google地图、Google地球、Bing地图、Esri 地图、Yandex地图等,支持100多图源。 安装包下载地址:https://www.sasgis.org/download/ github…

337. 打家劫舍III (二叉树)

题目 题解 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Solution:def rob(self, root: Optional[TreeNode]) ->…

指针的进阶

重中之重: 目录 1.字符指针: 2.指针数组 3.数组指针 4.数组参数、指针参数 5.函数指针 1.字符指针: 一般实现: int main() {char ch w;char *pc &ch;*pc w;return 0; } 二班实现: int main() {const c…

什么是闭包和作用域链?

1. 什么是闭包 闭包指的是那些引用了另一个函数作用域中变量的函数,通常是在嵌套函数中实现的。 举个栗子,createCounter 接受一个参数 n,然后返回一个匿名函数,这个匿名函数是闭包,它可以访问外部函数 createCounte…

智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于鲸鱼算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.鲸鱼算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

Nginx常见的中间件漏洞

目录 1、Nginx文件名逻辑漏洞 2、Nginx解析漏洞 3、Nginx越权读取缓存漏洞 这里需要的漏洞环境可以看:Nginx 配置错误导致的漏洞-CSDN博客 1、Nginx文件名逻辑漏洞 该漏洞利用条件有两个: Nginx 0.8.41 ~ 1.4.3 / 1.5.0 ~ 1.5.7 php-fpm.conf中的s…

百度 Comate 终于支持 IntelliJ IDEA 了

大家好,我是伍六七。 对于一直关注 AI 编程的阿七来说,编程助手绝对是必不可少的,除了 GitHub Copilot 之外,国内百度的 Comate 一直是我关注的重点。 但是之前,Comate 还支持 VS code,并不支持 IntelliJ…

mybatis的使用,mybatis的实现原理,mybatis的优缺点,MyBatis缓存,MyBatis运行的原理,MyBatis的编写方式

文章目录 MyBatis简介结构图Mybatis缓存(一级缓存、二级缓存)MyBatis是什么?mybatis的实现原理JDBC编程有哪些不足之处,MyBatis是如何解决这些问题的?Mybatis优缺点优点缺点映射关系 MyBatis的解析和运行原理MyBatis的…

【单片机学习笔记】STC8H1K08参考手册学习笔记

STC8H1K08参考手册学习笔记 STC8H系列芯片STC8H1K08开发环境串口烧录 STC8H系列芯片 STC8H 系列单片机是不需要外部晶振和外部复位的单片机,是以超强抗干扰/超低价/高速/低功耗为目标的 8051 单片机,在相同的工作频率下,STC8H 系列单片机比传统的 8051约快12 倍速度…

数组题目:645. 错误的集合、 697. 数组的度、 448. 找到所有数组中消失的数字、442. 数组中重复的数据 、41. 缺失的第一个正数

645. 错误的集合 思路: 我们定义一个数组cnt,记录每个数出现的次数。然后我们遍历数组,从1开始,如果cnt[i] 0 那就说明这个是错误的数,如果 cnt[i] 2,那就说明是重复的数。 代码: class So…

RabbitMQ之消费者可靠性

文章目录 前言一、消费者确认机制二、失败重试机制三、失败处理策略四、业务幂等性唯一消息ID业务判断 五、兜底方案总结 前言 当RabbitMQ向消费者投递消息以后,需要知道消费者的处理状态如何。因为消息投递给消费者并不代表就一定被正确消费了,可能出现…