人工智能-循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/159801.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

最常用的5款报表系统

在这个信息化飞速发展的时代,报表系统已经成为了企业管理和决策的重要工具。随着市场的需求不断增长,报表系统也在不断地更新和完善。如今,市面上有数不尽的报表系统,但是哪款才是最常用的呢?接下来,我们将…

开源之夏 2023 | Databend 社区项目总结与分享

开源之夏是由中科院软件所“开源软件供应链点亮计划”发起并长期支持的一项暑期开源活动,旨在鼓励在校学生积极参与开源软件的开发维护,培养和发掘更多优秀的开发者,促进优秀开源软件社区的蓬勃发展,助力开源软件供应链建设。 官…

【正点原子STM32连载】第五十六章 DSP BasicMath实验 摘自【正点原子】APM32F407最小系统板使用指南

1)实验平台:正点原子stm32f103战舰开发板V4 2)平台购买地址:https://detail.tmall.com/item.htm?id609294757420 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html## 第五…

ChinaSoft 论坛巡礼 | 新兴系统软件论坛

2023年CCF中国软件大会(CCF ChinaSoft 2023)由CCF主办,CCF系统软件专委会、形式化方法专委会、软件工程专委会以及复旦大学联合承办,将于2023年12月1-3日在上海国际会议中心举行。 本次大会主题是“智能化软件创新推动数字经济与社…

算法 全排列的应用

#include <iostream> #include <string>using namespace std;// 交换字符串中两个字符的位置 void swap(char& a, char& b) {char temp a;a b;b temp; }void fun(string str) {string a str.substr(0,4); int aa;sscanf(a.c_str(), "%d",…

Harmony 应用开发之size 脚本

作者&#xff1a;麦客奥德彪 在应用开发中&#xff0c;最终呈现在用户面前的UI&#xff0c;是用户能否继续使用应用的强力依据之一&#xff0c;在之前的开发中&#xff0c;Android 屏幕碎片化严重&#xff0c;所以出现了很多尺寸适配方案。 最小宽适配、百分比适配等等。 还有一…

知虾shopee收费,多少钱一个月

在当今电商行业的竞争激烈的环境下&#xff0c;许多商家都在寻求更好的方式来推广和销售他们的产品。这就是为什么越来越多的商家选择使用知虾shopee这样的平台来展示和销售他们的商品。但是&#xff0c;对于许多商家来说&#xff0c;他们可能会对知虾shopee的收费情况感到好奇…

MySql 计算同比、环比

一、理论 国家统计局同比、环比计算公式 增长速度是反映经济社会某一领域发展变化情况的重要数据&#xff0c;而同比和环比是反映增长速度最基础、最核心的数据指标&#xff0c;也是国际上通用的指标。在统计中&#xff0c; 同比和环比通常是同比变化率和环比变化率的简称&…

关于2023年11月25日PMI认证考试有关事项的通知

PMP项目管理学习专栏https://blog.csdn.net/xmws_it/category_10954848.html?spm1001.2014.3001.54822023年8月PMP考试成绩出炉|微思通过率95%以上-CSDN博客文章浏览阅读135次。国际注册项目管理师(PMP) 证书是项目管理领域含金量最高的职业资格证书&#xff0c;获得该资质是项…

微软Copilot即将对大陆开放,一起来看看都有什么好用的功能

微软发布了Copilot&#xff0c;12月1日起对大陆用户开放&#xff0c;以下是Copilot的11个新功能&#xff0c;你一定不想错过&#xff1a;1. PowerPoint&#xff1a; 将Word文档转换为演示文稿。从文件中快速创建演示文稿。通过关键幻灯片总结冗长的演示文稿。使用提示添加新的…

基于MS16F3211芯片的触摸控制灯的状态变化和亮度控制总结版(11.22)

1.任务需求 基于MS16F3211芯片实现功能一个按键通过长按可以控制当前处于亮状态的灯的亮度&#xff0c;当灯从最亮达到最暗时&#xff0c;所用时为3s。现有三盏颜色分别为红绿蓝的灯&#xff0c;在处于关机状态时红灯亮&#xff0c;处于开机状态时红灯灭。点按第一次仅绿灯亮&…

慕尼黑电子展Samtec Demo | 回环测试带来Samtec产品组合优异表现

【摘要/前言】 大家好&#xff01;Electronica虎家展台Demo系列回来咯。 实践出真知&#xff0c;再好的纸面数据都不如来一场实际的测试和演示。Samtec团队始终在努力为客户带来卓越的产品和优质服务。而这其中&#xff0c;Demo演示的存在至关重要。演示过程可以为大家带来了…

关于Flink的旁路缓存与异步操作

1. 旁路缓存 1. 什么是旁路缓存? 将数据库中的数据,比较经常访问的数据,保存起来,以减少和硬盘数据库的交互 比如: 我们使用mysql时 经常查询一个表 , 而这个表又一般不会变化,就可以放在内存中,查找时直接对内存进行查找,而不需要再和mysql交互 2. 旁路缓存例子使用 dim层…

Vue-报错No “exports“ main defined in xx

vue报错&#xff1a;No "exports" main defined in F:\wjh\vue#Practice\EasyQuestionnaire-web-master\EasyQuestionnaire-web-master\node_modules\babel\helper-compilation-targets\package.json 1.在文件中找到该路径的package.json文件&#xff0c; 2.按照提示…

MEMS制造的基本工艺——晶圆键合工艺

晶圆键合是一种晶圆级封装技术&#xff0c;用于制造微机电系统 (MEMS)、纳米机电系统 (NEMS)、微电子学和光电子学&#xff0c;确保机械稳定和气密密封。用于 MEMS/NEMS 的晶圆直径范围为 100 毫米至 200 毫米&#xff08;4 英寸至 8 英寸&#xff09;&#xff0c;用于生产微电…

java序列化与反序列化

java中序列化与反序列化 概念 在Java中&#xff0c;序列化是指将对象转换为字节流的过程&#xff0c;而反序列化则是将字节流转换回对象的过程。序列化和反序列化通常用于在网络上传输对象或将对象持久化到磁盘上。 要对一个对象进行序列化&#xff0c;可以使用ObjectOutput…

github访问失败

1. 问题场景 今天了解到notepad可以安装许多插件&#xff0c;但是自动下载插件时总是失败&#xff0c;这些插件的下载源都是github&#xff0c;将地址复制到浏览器也打不开&#xff0c;所以查了下github的访问问题&#xff0c;目前插件已正常下载。 2. 解决方法 gitee上搜索…

BUUCTF [SWPU2019]神奇的二维码 1

BUUCTF:https://buuoj.cn/challenges 题目描述&#xff1a; 得到的 flag 请包上 flag{} 提交。 密文&#xff1a; 下载附件&#xff0c;得到一个.png图片。 解题思路&#xff1a; 1、使用QR research扫一下&#xff0c;得到“swpuctf{flag_is_not_here}”的提示。 2、放到0…

orvibo的Mini网关VS20ZW玩法

概述 闲鱼淘来一个2016年生产的网关,此网关的型号:VS20ZW。 已经不能用APP入网了,没事拆来玩玩。 此设备已经被淘汰,很多新的zigbee产品不再支持入网。 官网设备的简介: ZigBee Mini网关,智能家居网关,智能家居主机|ORVIBO欧瑞博智能网关 设备概貌: 主要器件: …

Uptime Kuma 企业微信群机器人告警

curl https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key693axxx6-7aoc-4bc4-97a0-0ec2sifa5aaa \-H Content-Type: application/json \-d {"msgtype": "text","text": {"content": "hello world"}}企业微信群机器人ke…