人工智能-循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/159801.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#获取枚举Enum的描述

简单封装个扩展方法,方便下次使用。 using System.ComponentModel;namespace Order.Core.API.Enums {public enum AccountStatus{[Description("账号不存在!")]AccountNotExist 1,[Description("密码错误!")]PasswordE…

最常用的5款报表系统

在这个信息化飞速发展的时代,报表系统已经成为了企业管理和决策的重要工具。随着市场的需求不断增长,报表系统也在不断地更新和完善。如今,市面上有数不尽的报表系统,但是哪款才是最常用的呢?接下来,我们将…

开源之夏 2023 | Databend 社区项目总结与分享

开源之夏是由中科院软件所“开源软件供应链点亮计划”发起并长期支持的一项暑期开源活动,旨在鼓励在校学生积极参与开源软件的开发维护,培养和发掘更多优秀的开发者,促进优秀开源软件社区的蓬勃发展,助力开源软件供应链建设。 官…

node.js获取unsplash图片

1. 在Unsplash的开发者页面注册并创建一个应用程序,以便获取一个API访问密钥(即Access Key)。 2. 安装unsplash-js库和axios: npm install axios3. 使用获取到的API密钥进行请求。 示例代码如下: const axios req…

3DEXPERIENCE许可管理工具:掌控设计软件许可,提高企业竞争力

在当今竞争激烈的设计领域,合理的软件许可管理对于企业的成功至关重要。3DEXPERIENCE作为一款功能强大的设计软件,提供了全面且易于使用的许可管理工具。本文将详细介绍3DEXPERIENCE的许可管理工具及其对企业管理的价值,帮助您更好地了解和利…

【正点原子STM32连载】第五十六章 DSP BasicMath实验 摘自【正点原子】APM32F407最小系统板使用指南

1)实验平台:正点原子stm32f103战舰开发板V4 2)平台购买地址:https://detail.tmall.com/item.htm?id609294757420 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html## 第五…

【管理运筹学】背诵手册(四)| 整数规划

四、整数规划 整数规划根据变量取值的限制形式,可分为三种: 纯整数规划(IP)。混合整数规划(MIP)。0-1 整数规划(BIP)。 不能用凑整的方法去求整数规划问题的最优解,用…

ChinaSoft 论坛巡礼 | 新兴系统软件论坛

2023年CCF中国软件大会(CCF ChinaSoft 2023)由CCF主办,CCF系统软件专委会、形式化方法专委会、软件工程专委会以及复旦大学联合承办,将于2023年12月1-3日在上海国际会议中心举行。 本次大会主题是“智能化软件创新推动数字经济与社…

算法 全排列的应用

#include <iostream> #include <string>using namespace std;// 交换字符串中两个字符的位置 void swap(char& a, char& b) {char temp a;a b;b temp; }void fun(string str) {string a str.substr(0,4); int aa;sscanf(a.c_str(), "%d",…

Harmony 应用开发之size 脚本

作者&#xff1a;麦客奥德彪 在应用开发中&#xff0c;最终呈现在用户面前的UI&#xff0c;是用户能否继续使用应用的强力依据之一&#xff0c;在之前的开发中&#xff0c;Android 屏幕碎片化严重&#xff0c;所以出现了很多尺寸适配方案。 最小宽适配、百分比适配等等。 还有一…

知虾shopee收费,多少钱一个月

在当今电商行业的竞争激烈的环境下&#xff0c;许多商家都在寻求更好的方式来推广和销售他们的产品。这就是为什么越来越多的商家选择使用知虾shopee这样的平台来展示和销售他们的商品。但是&#xff0c;对于许多商家来说&#xff0c;他们可能会对知虾shopee的收费情况感到好奇…

MySql 计算同比、环比

一、理论 国家统计局同比、环比计算公式 增长速度是反映经济社会某一领域发展变化情况的重要数据&#xff0c;而同比和环比是反映增长速度最基础、最核心的数据指标&#xff0c;也是国际上通用的指标。在统计中&#xff0c; 同比和环比通常是同比变化率和环比变化率的简称&…

关于2023年11月25日PMI认证考试有关事项的通知

PMP项目管理学习专栏https://blog.csdn.net/xmws_it/category_10954848.html?spm1001.2014.3001.54822023年8月PMP考试成绩出炉|微思通过率95%以上-CSDN博客文章浏览阅读135次。国际注册项目管理师(PMP) 证书是项目管理领域含金量最高的职业资格证书&#xff0c;获得该资质是项…

CentOS简介、ISO类型、CentOS7安装与配置以及远程连接。

目录 1.CentOS简介 2.CentOS ISO类型 3.CentOS7安装与配置 4.远程连接 1.CentOS简介 CentOS&#xff08;Community Enterprise Operating System&#xff0c;中文意思是社区企业操作系统&#xff09;是Linux发行版之一&#xff0c;它是来自于Red Hat Enterprise Linux依照…

代码随想录算法训练营第23期60天完结总结

这六十天在自己准备中期答辩、完成中期答辩中度过了 其实挺惭愧的&#xff0c;60天里&#xff0c;每天能塌下心来的时间每日递减&#xff0c;从打卡的情况上也能看出来。最后十天都是补打卡&#xff0c;每道题都是想了几分种&#xff0c;不会就开始看解析&#xff0c;然后凭着…

微软Copilot即将对大陆开放,一起来看看都有什么好用的功能

微软发布了Copilot&#xff0c;12月1日起对大陆用户开放&#xff0c;以下是Copilot的11个新功能&#xff0c;你一定不想错过&#xff1a;1. PowerPoint&#xff1a; 将Word文档转换为演示文稿。从文件中快速创建演示文稿。通过关键幻灯片总结冗长的演示文稿。使用提示添加新的…

基于MS16F3211芯片的触摸控制灯的状态变化和亮度控制总结版(11.22)

1.任务需求 基于MS16F3211芯片实现功能一个按键通过长按可以控制当前处于亮状态的灯的亮度&#xff0c;当灯从最亮达到最暗时&#xff0c;所用时为3s。现有三盏颜色分别为红绿蓝的灯&#xff0c;在处于关机状态时红灯亮&#xff0c;处于开机状态时红灯灭。点按第一次仅绿灯亮&…

postgresql新增非空默认值字段是否需要重写表

简介&#xff1a; PostgreSQL 10 版本前表新增不带默认值的DDL不需要重写表&#xff0c;只需要更新数据字典&#xff0c;因此DDL能瞬间执行&#xff0c;如下: ALTER TABLE table_name ADD COLUMN flag text; 如果新增的字段带默认值&#xff0c;则需要重写表&#xff0c;表越大…

mysql使用--数据的插入,删除和更新

1.UNION合并多个结果集 如&#xff1a;SELECT m1, n1 FROM t1 WHERE m1 < 2 UNION SELECT m2, n2 FROM t2 WHERE m2 > 2; 默认下&#xff0c;某行在参与合并两个结果集均存在时&#xff0c;最终结果集中只包含一行。 如希望最终结果集包含两行&#xff0c;用UNION ALL。 …

慕尼黑电子展Samtec Demo | 回环测试带来Samtec产品组合优异表现

【摘要/前言】 大家好&#xff01;Electronica虎家展台Demo系列回来咯。 实践出真知&#xff0c;再好的纸面数据都不如来一场实际的测试和演示。Samtec团队始终在努力为客户带来卓越的产品和优质服务。而这其中&#xff0c;Demo演示的存在至关重要。演示过程可以为大家带来了…