深度学习基础练习:从pytorch API出发复现LSTM与LSTMP

2024/11/5-2024/11/7:

前置知识:

[译] 理解 LSTM(Long Short-Term Memory, LSTM) 网络 - wangduo - 博客园

【官方双语】LSTM(长短期记忆神经网络)StatQuest_哔哩哔哩_bilibili

大部分思路来自于:

PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibiliicon-default.png?t=O83Ahttps://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.880.my_history.page.click&vd_source=db0d5acc929b82408b1040d67f2b1dde

部分常量设置与官方api使用:

        其实在实现RNN之后可以发现,lstm基本是同样的套路。 在看完上面的前置知识之后,理解三个门的作用即可对lstm有一个具体的认识,这里不再赘述。

         关于输入设置这方面,参考如下:

# 定义常量
bs, T, i_size, h_size = 2, 3, 4, 5
# 输入序列
input = torch.randn(bs, T, i_size)
# 初始值,不需要训练
c0 = torch.randn(bs, h_size)
h0 = torch.randn(bs, h_size)

        将定义的常量输入官方api:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)for k, v in lstm_layer.named_parameters():print(k)print(v.shape)

        输出如下:

torch.Size([2, 3, 5])
tensor([[[ 0.1134, -0.1032,  0.1496,  0.1853, -0.3758],
         [ 0.1831,  0.0223,  0.0377,  0.0867, -0.1090],
         [ 0.1233,  0.1121,  0.0574, -0.0401, -0.1576]],

        [[-0.2761,  0.3259,  0.1687, -0.0632,  0.2046],
         [ 0.1796,  0.3110,  0.0974,  0.0294,  0.0220],
         [ 0.1205,  0.1815,  0.0840, -0.1714, -0.1216]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0
torch.Size([20, 4])
weight_hh_l0
torch.Size([20, 5])
bias_ih_l0
torch.Size([20])
bias_hh_l0
torch.Size([20])

        可以看到LSTM的内置参数有 weight_ih_l0、weight_hh_l0、bias_ih_l0、bias_hh_l0,将关于三个门的知识结合在一起看差不多就明白接下来应该怎么做了: 

        h和c都是统一经过*weight+bias的操作,加在一起后经过tahn或者sigmoid激活函数,最后或点乘或加在h或者c上进行对参数的更新。只要不把维度的对应关系搞混还是比较好复现的。

        需要注意的是:三个门中的四个weight和bias(遗忘门一个,输入门两个,输出门一个)全部都按照第0维度拼在了一起方便同时进行矩阵运算,所以我们可以看到这些权重和偏置的第0维度的大小为4*h_size。一开始这一点也带给了我比较大的困惑。

代码复现与验证:

        代码较为简单,跟上次实现RNN的思路也差不多,基本是照着官方api那给的公式一步一步来的:

# 代码复现
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):h0, c0 = initial_statesbs, T, i_size = input.shapeh_size = w_ih.shape[0] // 4prev_h = h0     # [bs, h_size]prev_c = c0     # [bs, h_size]"""w_ih    # 4*h_size, i_sizew_hh    # 4*h_size, h_size"""# 输出序列output_size = h_sizeoutput = torch.zeros(bs, T, output_size)for t in range(T):x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]# 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])# 更新流prev_c = f_t * prev_c + i_t * g_tprev_h = o_t * torch.tanh(prev_c)output[:, t, :] = prev_hreturn output, (prev_h, prev_c)

        输出结果对比验证:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)for k, v in lstm_layer.named_parameters():print(k, v.shape)
output, (h_final, c_final) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)print(output)

        结果如下:

torch.Size([2, 3, 5])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]], grad_fn=<CopySlices>)

         复现成功。

appendix:

        这里放下LSTMP的参数设置:

# lstmp对h_size进行压缩
proj_size = 3
# h0的h_size也改为proj_size,而c0不变
h0 = torch.randn(bs, proj_size)# 调用官方api
lstmp_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h, c) = lstmp_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)), )print(output)
print(output.shape)
print(h.shape)
print(c.shape)
for k, v in lstmp_layer.named_parameters():print(k, v.shape)

tensor([[[-0.0492,  0.0265,  0.0883],
         [-0.1028, -0.0327, -0.0542],
         [ 0.0250, -0.0231, -0.1199]],

        [[-0.2417, -0.1737, -0.0755],
         [-0.2351, -0.0837, -0.0376],
         [-0.2527, -0.0258, -0.0236]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 3])
torch.Size([1, 2, 3])
torch.Size([1, 2, 5])
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])

         其实LSTMP就多出了个weight_hr_l0对h进行压缩,但是不对cell压缩,目的是减少lstm的参数量,在小一点的sequence上基本没啥区别。若要支持lstmp,在前面的代码上改动几行即可:

def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):h0, c0 = initial_statesbs, T, i_size = input.shapeh_size = w_ih.shape[0] // 4prev_h = h0     # [bs, h_size]prev_c = c0     # [bs, h_size]"""w_ih    # 4*h_size, i_sizew_hh    # 4*h_size, h_size"""if w_hr is not None:# 输出压缩至p_sizep_size = w_hr.shape[0]output_size = p_sizeelse:output_size = h_sizeoutput = torch.zeros(bs, T, output_size)for t in range(T):x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]# 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])# 更新流prev_c = f_t * prev_c + i_t * g_tprev_h = o_t * torch.tanh(prev_c)   # [bs, h_size]if w_hr is not None:prev_h = torch.matmul(w_hr, prev_h.unsqueeze(-1)).squeeze(-1)   # [bs, p_size]output[:, t, :] = prev_hreturn output, (prev_h, prev_c)

        经过验证,复现成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/59664.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

L1G3000 提示工程(Prompt Engineering)

什么是Prompt(提示词)? Prompt是一种灵活、多样化的输入方式&#xff0c;可以用于指导大语言模型生成各种类型的内容。什么是提示工程? 提示工程是一种通过设计和调整输入(Prompts)来改善模型性能或控制其输出结果的技术。 六大基本原则: 指令要清晰提供参考内容复杂的任务拆…

Servlet 3.0 新特性全解

文章目录 Servlet3.0新特性全解Servlet 3.0 新增特性Servlet3.0的注解Servlet3.0的Web模块支持servlet3.0提供的异步处理提供异步原因实现异步原理配置servlet类成为异步的servlet类具体实现异步监听器改进的ServletAPI(上传文件) Servlet3.0新特性全解 tomcat 7以上的版本都支…

PPT文件设置了修改权限,如何取消权?

不知道大家在使用PPT文件的时候&#xff0c;是否遇到过下面的提示框&#xff0c;这就是PPT文件设置了修改权限&#xff0c;只有输入密码才可以编辑文件。 如果我们没有输入密码&#xff0c;以只读方式进入&#xff0c;那么我们会发现功能栏中的按钮全是灰色&#xff0c;无法使用…

牛客sql题目总结(1)

1.第N高的薪水 AC: create function getnthhighestsalary(n int) returns int begindeclare m int; set m n - 1; return (select distinct salaryfrom employeeorder by salary desclimit m, 1); end 2.平均播放进度大于60%的视频类别 AC&#xff1a; select tb_video_info…

【NLP】使用 SpaCy、ollama 创建用于命名实体识别的合成数据集

命名实体识别 (NER) 是自然语言处理 (NLP) 中的一项重要任务&#xff0c;用于自动识别和分类文本中的实体&#xff0c;例如人物、位置、组织等。尽管它很重要&#xff0c;但手动注释大型数据集以进行 NER 既耗时又费钱。受本文 ( https://huggingface.co/blog/synthetic-data-s…

【51单片机】串口通信原理 + 使用

学习使用的开发板&#xff1a;STC89C52RC/LE52RC 编程软件&#xff1a;Keil5 烧录软件&#xff1a;stc-isp 开发板实图&#xff1a; 文章目录 串口硬件电路UART串口相关寄存器 编码单片机通过串口发送数据电脑通过串口发送数据控制LED灯 串口 串口是一种应用十分广泛的通讯接…

线程函数和线程启动的几种不同形式

线程函数和线程启动的几种不同形式 在C中&#xff0c;线程函数和线程启动可以通过多种形式实现。以下是几种常见的形式&#xff0c;并附有相应的示例代码。 1. 使用函数指针启动线程 最基本的方式是使用函数指针来启动线程。 示例代码&#xff1a; #include <iostream&g…

C语言网络编程 -- TCP/iP协议

一、Socket简介 1.1 什么是socket socket通常也称作"套接字"&#xff0c;⽤于描述IP地址和端⼝&#xff0c;是⼀个通信链的句柄&#xff0c;应⽤ 程序通常通过"套接字"向⽹络发出请求或者应答⽹络请求。⽹络通信就是两个进程 间的通信&#xff0c;这两个进…

Qt字符编码

目前字符编码有以下几种&#xff1a; 1、UTF-8 UTF-8编码是Unicode字符集的一种编码方式(CEF)&#xff0c;其特点是使用变长字节数(即变长码元序列、变宽码元序列)来编码。一般是1到4个字节&#xff0c;当然&#xff0c;也可以更长。 2、UTF-16 UTF-16是Unicode字符编码五层次…

Linux下的ADC

ADC ADC简介 ADC是 Analog Digital Converter 的缩写&#xff0c;翻译过来为模数转换器&#xff0c;ADC可以将模拟值转换成数字值。模拟值是什么呢?比如我们日常生活中的温度&#xff0c;速度&#xff0c;湿度等等都是模拟值。所以如果我们想测量这些模拟值的值是多少&#x…

理解Web登录机制:会话管理与跟踪技术解析(二)-JWT令牌

JWT令牌是一种用于安全地在各方之间传递信息的开放标准&#xff0c;它不仅能够验证用户的身份&#xff0c;还可以安全地传递有用的信息。由于其结构简单且基于JSON&#xff0c;JWT可以在不同的系统、平台和语言间无缝传递&#xff0c;成为现代Web开发中不可或缺的一部分。 文章…

论 ONLYOFFICE:开源办公套件的深度探索

公主请阅 引言第一部分&#xff1a;ONLYOFFICE 的历史背景1.1 开源软件的崛起1.2 ONLYOFFICE 的发展历程 第二部分&#xff1a;ONLYOFFICE 的核心功能2.1 文档处理2.2 电子表格2.3 演示文稿 第三部分&#xff1a;技术架构与兼容性3.1 技术架构3.2 兼容性 第四部分&#xff1a;部…

sql报错信息将字符串转换为 uniqueidentifier 时失败

报错信息&#xff1a; [42000] [Microsoft][SQL Server Native Client 10.0][SQL Server]将字符串转换为 uniqueidentifier 时失败 出错行如下&#xff1a; 表A.SourceCode 表B.ID 出错原因&#xff1a; SourceCode是nvarchar,但ID是uniqueidentifier 数据库查询字段和类…

[复健计划][紫书]Chapter 7 暴力求解法

7.1 简单枚举 例7-1 Division uva725 输入正整数n&#xff0c;按从小到大的顺序输出所有形如abcde/fghij n的表达式&#xff0c;其中a&#xff5e;j恰好为数字0&#xff5e;9的一个排列&#xff08;可以有前导0&#xff09;&#xff0c;2≤n≤79。枚举fghij&#xff0c;验证a…

大语言模型鼻祖Transformer的模型架构和底层原理

Transformer 模型的出现标志着自然语言处理&#xff08;NLP&#xff09;技术的一次重大进步。这个概念最初是针对机器翻译等任务而提出的&#xff0c;Transformer 后来被拓展成各种形式——每种形式都针对特定的应用&#xff0c;包括原始的编码器-解码器&#xff08;encoder-de…

基于MySQL的企业专利数据高效查询与统计实现

背景 在进行产业链/产业评估工作时&#xff0c;我们需要对企业的专利进行评估&#xff0c;其中一个重要指标是统计企业每一年的专利数量。本文基于MySQL数据库&#xff0c;通过公司名称查询该公司每年的专利数&#xff0c;实现了高效的专利数据统计。 流程 项目流程概述如下&…

canfestival主站多电机对象字典配置

不要使用数组进行命名&#xff1a;无法运行PDO 使用各自命名的方式&#xff1a;

bat批量处理脚本细节研究

文章目录 bat批处理脚本&#xff08;框架&#xff09;set变量设置基本语法显示环境变量 自定义环境变量临时环境变量和永久环境变量特殊环境变量和系统默认环境变量set命令利用选项的其他应用 !与%解析变量的区别/为什么使用setlocal enabledelayedexpansion区别%的规则!使用 %…

Java 网络编程(一)—— UDP数据报套接字编程

概念 在网络编程中主要的对象有两个&#xff1a;客户端和服务器。客户端是提供请求的&#xff0c;归用户使用&#xff0c;发送的请求会被服务器接收&#xff0c;服务器根据请求做出响应&#xff0c;然后再将响应的数据包返回给客户端。 作为程序员&#xff0c;我们主要关心应…

使用C++来编写VTK项目时,就是要写自己的算法

其实&#xff0c;使用VTK可以使用很多种语言&#xff0c;比如java&#xff0c;python&#xff0c;和C。那么为什么非要使用C 呢&#xff1f;一个原因是觉得C语言处理数据比较快&#xff0c;另一个原因是需要自己写算法。通过继承polyDataAlgorithm来写自己的算法&#xff0c;很…