李沐56_门控循环单元——自学笔记

关注每一个序列

1.不是每个观察值都是同等重要

2.想只记住的观察需要:能关注的机制(更新门 update gate)、能遗忘的机制(重置门 reset gate)

!pip install --upgrade d2l==0.17.5  #d2l需要更新
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

将定义隐状态的初始化函数init_gru_state。此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31831.9 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

简洁实现

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 255484.2 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
traveller with a slight accession ofcheerfulness really thi

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之五 简单进行车牌检测和识别 一、简单介绍 二、简单进行车牌检测和识别实现原理 …

深入理解Java消息中间件-消息的可靠性

基石:在分布式系统中实现消息的可靠性及其原理 在构建现代的分布式系统时,能否可靠地传递消息是衡量系统成功与否的一个关键。本文章旨在讨论如何实现消息的可靠性及其背后的原理,帮助Java技术架构师和开发者构筑一个稳固的系统。 消息可靠…

关于上传自己本地项目到GitHub的相关命令

https://www.cnblogs.com/nature161/p/15014265.html 根据教程里的来,主要注意这个命令: $ git pull --rebase origin master # 对GitHub的仓库包含了readme.md文件的情况先要执行这个命令再pull 如果你的GitHub是main分支想上传到main分支&#xff0…

Opencv_10_自带颜色表操作

void color_style(Mat& image); Opencv_10_自带颜色表操作: void ColorInvert::color_style(Mat& image) { int colormap[] { COLORMAP_AUTUMN, COLORMAP_BONE , COLORMAP_JET , COLORMAP_WINTER, COLORMAP_RAINBOW , COLOR…

编译器的学习

常用的编译器: GCCVisual CClang(LLVM): Clang 可以被看作是建立在 LLVM 之上的一个项目, 实际上LLVM是clang的后端,clang作为前端前端生成LLVM IR,https://zhuanlan.zhihu.com/p/656699711MSVC &#xff…

Docker常用命令(镜像、容器、网络)

一、镜像 1.1 存出镜像 将镜像保存成为本地文件 格式&#xff1a;docker save -o 存储文件名 存储的镜像docker save -o nginx nginx:latest 1.2 载入镜像 将镜像文件导入到镜像库中 格式&#xff1a;docker load < 存出的文件或docker load -i 存出的文件…

程序猿成长之路之数据挖掘篇——朴素贝叶斯

朴素贝叶斯是数据挖掘分类的基础&#xff0c;本篇文章将介绍一下朴素贝叶斯算法 情景再现 以挑选西瓜为例&#xff0c;西瓜的色泽、瓜蒂、敲响声音、触感、脐部等特征都会影响到西瓜的好坏。那么我们怎么样可以挑选出一个好的西瓜呢&#xff1f; 分析过程 既然挑选西瓜有多个…

Android Studio 报错:AVD Pixel_3a_API_30_x86 is already running

在我的Android Studio和虚拟机运行时&#xff0c;我的电脑不小心关机了&#xff0c;在启动后再次打开Android Studio并运行虚拟机时发现报错。 Error while waiting for device: AVD Pixel_3a_API_30_x86 is already running. If that is not the case, delete the files at C…

c++设计模式之桥接模式(拼接组合)

桥接模式&#xff1a;就是进行拼接组装 应用举例&#xff1a; 1.定义了形状&#xff0c;抽象形状接口&#xff0c;圆&#xff0c;矩形 2.定义了颜色&#xff0c;抽象颜色接口&#xff0c;红色&#xff0c;蓝色 3&#xff0c;怎么桥接&#xff0c;抽象具体形状和具体颜色的组合…

应用部署tomcat的三种方式

由于一直在用springboot框架&#xff0c;集成了tomcat&#xff0c;快忘记如何单独部署tomcat了&#xff0c;以下&#xff0c;记录一下&#xff1a; 部署tomcat有三种方式&#xff1a; 一、方式一&#xff1a;将war包丢进webapps 这是最简单粗暴的方式&#xff1a;将web工程打…

用现成的容器来创建一个镜像,或者说再克隆一个一模一样的容器

前言&#xff1a;我在centos系统中使用docker拉取了一个centos镜像&#xff0c;并用这个镜像创建了一个hadoop容器&#xff0c;但是后面我又需要一个相同版本的hadoop镜像来创建其他容器&#xff08;比如hive容器&#xff09;&#xff0c;但是这个时候docker官网并没有对应版本…

个人投资者如何开通快速通道?

其实我们一直都在说快速通道&#xff0c;那么我们个人投资者如何才能开通快速通道呢&#xff1f; 怎样才能做到打板排序前列&#xff0c;成交速度快&#xff0c;适合高频交易呢&#xff1f; 我们一起来了解下&#xff01; 第一&#xff1a;什么是快速通道&#xff1f; 其实就…

C#基础|对象属性Property基础使用,业务特性

哈喽&#xff0c;你好&#xff0c;我是雷工。 探究OOP中属性的奥秘 认识类的属性&#xff08;Property&#xff09; 01 属性的使用 作用&#xff1a;在面向对象&#xff08;OOP&#xff09;中主要用来封装数据。 要求&#xff1a;一般采用Pascal命名法&#xff08;首字母要…

UDS的3字节故障码

在UDS的规范下面&#xff0c;使用19服务去读取故障码&#xff0c;会发现读到市面上各种车企的各种ECU中的所有的故障码读出来都是3个字节。这与前面的五位故障码占2个字节不符&#xff0c;其实读出来是3个字节就是UDS中制定的规范。 如今车企中主要采用的是三个字节的故障码。…

【北京迅为】《iTOP-3588开发板系统编程手册》-第19章 V4L2摄像头应用编程

RK3588是一款低功耗、高性能的处理器&#xff0c;适用于基于arm的PC和Edge计算设备、个人移动互联网设备等数字多媒体应用&#xff0c;RK3588支持8K视频编解码&#xff0c;内置GPU可以完全兼容OpenGLES 1.1、2.0和3.2。RK3588引入了新一代完全基于硬件的最大4800万像素ISP&…

c++新特性 智能指针实验

1.概要 c新特性 智能指针实验 2.实验 2.1 共同代码 #pragma once class A { public:A();~A();void fun();int a 5; }; #include "A.h" #include <iostream>A::A() {std::cout << "structure\n"; } A::~A() {std::cout << "de…

vben admin Table 实现表格列宽自由拖拽

更改BasicTable.vue文件 Table添加 resize-column“resizeColumn” 添加并 return resizeColumn const resizeColumn (w, col) > { setCacheColumnsByField(col.dataIndex, { width: w }); }; 在column中添加 resizable: true,

linux下编译c++程序报错“undefined reference to `std::allocator<char>::allocator()‘”

问题 linux下编译c程序报错“undefined reference to std::allocator::allocator()”。 原因 找不到c标准库文件。 解决办法 开始尝试给gcc指令添加-L和-l选项指定库路径和库文件名&#xff0c;但是一直不成功&#xff0c;后来把gcc改为g就可以了。

【后端】node.js安装与配置教程

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、安装node.js二、验证安装三、配置 npm四、开发环境配置五、总结 前言 随着开发语言及人工智能工具的普及&#xff0c;使得越来越多的人会主动学习使用一些…

Rust - 引用和借用

上一篇章末尾提到&#xff0c;如果仅仅支持通过转移所有权的方式获取一个值&#xff0c;那会让程序变得复杂。 Rust 能否像其它编程语言一样&#xff0c;使用某个变量的指针或者引用呢&#xff1f;答案是可以。 Rust 通过 借用(Borrowing) 这个行为来达成上述的目的&#xff0…