《动手学深度学习(PyTorch版)》笔记8.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter8 Recurrent Neural Networks

8.6 Concise Implementation of RNN

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as pltbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)state = torch.zeros((1, batch_size, num_hiddens))
print(state.shape)X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)#Y不涉及输出层的计算
print(Y.shape, state_new.shape)class RNNModel(nn.Module):#@save"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数),它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):#nn.GRU以张量作为隐状态#GRU为门控循环单元(Gated Recurrent Unit),是一种流行的循环神经网络变体。#GRU使用了一组门控机制来控制信息的流动,包括更新门(update gate)和重置门(reset gate),以更好地捕捉长期依赖关系return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:#nn.LSTM以元组作为隐状态#LSTM代表长短期记忆网络(Long Short-Term Memory),是另一种常用的循环神经网络类型。#相比于简单的循环神经网络,LSTM引入了三个门控单元:输入门(input gate)、遗忘门(forget gate)和输出门(output gate),以及一个记忆单元(cell state),可以更有效地处理长期依赖性。return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
plt.show()

训练结果:
在这里插入图片描述

与上一节相比,由于pytorch的高级API对代码进行了更多的优化,该模型在较短的时间内达到了较低的困惑度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/678853.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构(2) 线性表

线性表 线性表的定义线性表的基本操作lnitList(&L)DestroyList(&L)Listlnsert(&L,i,e)ListDelete(&L,i,&e)LocateElem(L,e)GetElem(L,i)Length(L)PrintList(L)Empty(L)Tips:引用值 小结 根据数据结构的三要素–逻辑结构、数据的运算、存储结构,…

GeoServer 2.11.1升级解决Eclipse Jetty 的一系列安全漏洞问题

Eclipse Jetty 资源管理错误漏洞(CVE-2021-28165) Eclipse Jetty HTTP请求走私漏洞(CVE-2017-7656) Eclipse Jetty HTTP请求走私漏洞(CVE-2017-7657) Eclipse Jetty HTTP请求走私漏洞(CVE-2017-7658) Jetty 信息泄露漏洞(CVE-2017-9735) Eclipse Jetty 安全漏洞(CVE-2022-20…

Javaweb之SpringBootWeb案例之事务进阶的详细解析

1.3 事务进阶 前面我们通过spring事务管理注解Transactional已经控制了业务层方法的事务。接下来我们要来详细的介绍一下Transactional事务管理注解的使用细节。我们这里主要介绍Transactional注解当中的两个常见的属性: 异常回滚的属性:rollbackFor 事…

项目02《游戏-14-开发》Unity3D

基于 项目02《游戏-13-开发》Unity3D , 任务:战斗系统之击败怪物与怪物UI血条信息 using UnityEngine; public abstract class Living : MonoBehaviour{ protected float hp; protected float attack; protected float define; …

Linux网络编程——tcp套接字

文章目录 主要代码关于构造listen监听accepttelnet测试读取信息掉线重连翻译服务器演示 本章Gitee仓库&#xff1a;tcp套接字 主要代码 客户端&#xff1a; #pragma once#include"Log.hpp"#include<iostream> #include<cstring>#include<sys/wait.h…

第73左侧菜单实现

layout下面新建menu layout index.vue导入menu import Menu from /views/layout/menu菜单实现&#xff1a; <template><el-menuactive-text-color"#ffd04b"background-color"#2d3a4b"class"el-menu-vertical-demo"default-active&quo…

灰度发布浅见

在之前的稳定性生产文章中有一项对于研发人员比较重要的措施是变更管控&#xff0c;关于变更管控其实在实际生产活动中有很多措施&#xff0c;因为对于不太的行业&#xff0c;其行业特点和稳定性生产的要求也不一样&#xff0c;例如下图&#xff0c;我们可以看到信通院调研的不…

【开源】SpringBoot框架开发APK检测管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 开放平台模块2.3 软件档案模块2.4 软件检测模块2.5 软件举报模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 开放平台表3.2.2 软件档案表3.2.3 软件检测表3.2.4 软件举报表 四、系统展示五、核心代…

【el-tree 文字过长处理方案】

文字过长处理方案 一、示例代码二、关键代码三、效果图 一、示例代码 <divstyle"height: 600px;overflow: auto"class"text item"><el-treeref"tree":data"treeData":props"defaultProps"class"filter-tree&…

计算机网络——05Internet结构和ISP

Internet结构和ISP 互连网络结构&#xff1a;网络的网络 端系统通过接入ISPs连接到互连网 住宅、公司和大学的ISPs 接入ISPs相应的必须是互联的 因此任何2个端系统可相互发送分组到对方 导致的“网络的网络”非常复杂 发展和演化是通过经济的和国家的政策来驱动的 问题&…

课堂秩序要求有哪些内容

你是否曾经疑惑&#xff0c;为什么有些课堂总是秩序井然&#xff0c;而有些则混乱不堪&#xff1f;作为一位经验丰富的老师&#xff0c;我想告诉你&#xff0c;课堂秩序不仅仅是学生安静听讲那么简单&#xff0c;它背后涉及到许多关键因素&#xff0c;直接影响着教学质量和学习…

postgresql 手动清理wal日志的101个坑

新年的第一天&#xff0c;总结下去年遇到的关于WAL日志清理的101个坑&#xff0c;以及如何相对安全地进行清理。前面是关于WAL日志堆积的原因分析&#xff0c;清理相关可以直接看第三部分。 首先说明&#xff0c;手动清理wal日志是一个高风险的操作&#xff0c;尤其对于带主从的…

CleanMyMac X 4.14.7帮您安全清理Mac系统垃圾

CleanMyMac X 4.14.7是一款强大的 Mac 清理、加速工具和健康卫士,可以让您的 Mac 再次恢复巅峰性能。 移除大型和旧文件、卸载应用,并删除浪费磁盘空间的无用数据。 5倍 更多可用磁盘空间 CleanMyMac X 4.14.7帮您安全清理Mac系统垃圾 CleanMyMac X 4.14.7一键深度扫描mac系统…

Java常用类与基础API--String的理解与不可变性

文章目录 一、字符串相关类之不可变字符序列&#xff1a;String&#xff08;1&#xff09;对String类的理解(以JDK8为例说明)1、环境2、类的声明3、内部声明的属性 &#xff08;2&#xff09;String的特性&#xff08;3&#xff09;字符串常量的存储位置1、举例2、String的存储…

「优选算法刷题」:数青蛙

一、题目 给你一个字符串 croakOfFrogs&#xff0c;它表示不同青蛙发出的蛙鸣声&#xff08;字符串 "croak" &#xff09;的组合。由于同一时间可以有多只青蛙呱呱作响&#xff0c;所以 croakOfFrogs 中会混合多个 “croak” 。 请你返回模拟字符串中所有蛙鸣所需不…

Day39- 动态规划part07

一、爬楼梯 题目一&#xff1a;57. 爬楼梯 57. 爬楼梯&#xff08;第八期模拟笔试&#xff09; 题目描述 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬至多m (1 < m < n)个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 注意&#xff1a;…

【大厂AI课学习笔记】【1.6 人工智能基础知识】(3)神经网络

深度学习是机器学习中一种基于对数据进行表征学习的算法。观测值(例如一幅草莓照片)可以使用 多种方式来表示&#xff0c;如每个像素强度值的向量&#xff0c;或者更抽象地表示成一系列边、特定形状的区域等。 深度学习的最主要特征是使用神经网络作为计算模型。神经网络模型 …

深入理解 Nginx 插件及功能优化指南

深入理解 Nginx 插件及功能优化指南 深入理解 Nginx 插件及功能优化指南1. Nginx 插件介绍1.1 HTTP 模块插件ngx_http_rewrite_modulengx_http_access_module 1.2 过滤器插件ngx_http_gzip_modulengx_http_ssl_module 1.3 负载均衡插件ngx_http_upstream_modulengx_http_upstre…

CSS Selector—选择方法,和html自动——异步社区的爬取(动态网页)——爬虫(get和post的区别)

这里先说一下GET请求和POST请求&#xff1a; post我们平时是要加data的也就是信息&#xff0c;你会发现我们平时百度之类的 搜索都是post请求 get我们带的是params&#xff0c;是发送我们指定的内容。 要注意是get和post请求&#xff01;&#xff01;&#xff01; 先说一下异…

python+django人力资源管理系统7w5x3

技术栈 后端&#xff1a;python 前端&#xff1a;vue.jselementui 框架&#xff1a;django Python版本&#xff1a;python3.7 数据库&#xff1a;mysql5.7 数据库工具&#xff1a;Navicat 开发软件&#xff1a;PyCharm .设计框架&#xff1a;Vue 1. 表现层&#xff1a;写多…