PyTorch -- LSTM 快速实践

上篇介绍了 RNN 快速实践;使用 LSTM 的话,可以解决梯度离散及短期记忆问题;代码部署方面,增加了 c 值 (即 RNN 中的 h 变成了 LSTM 中的 (h,c)), 可对照 RNN 快速实践 来快速掌握。


  • LSTM Layer torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first)
    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size] # 或者用符号 c0
      • h0:[batch, num_layers, hidden_size]
  • LSTM 的输入
    • x:[seq_len, batch, input_size] # 或者用符号 c0
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • (h0, c0):[num_layers, batch, hidden_size]
  • LSTM 的输出
    • y: [seq_len, batch, hidden_size]
    • (ht, ct):[num_layers, batch, hidden_size]

..........

三个门 ( σ \sigma σ处:遗忘f、输入i、输出o) 都是基于 x t \mathbf{x}_t xt h t − 1 \mathbf{h}_{t-1} ht1 产生,但是分别对应要学习的权重参数 W W W 不同,或可参照下简化图直观理解 LSTM 模块内部的处理流程


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署

    • 下述示例代码已注明区别行 ########################### (共3处)

    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net (RNN->LSTM)

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.LSTM(  ##1.###################################### RNN->LSTMinput_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)#hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hhidden_prev = (torch.zeros(num_layers, batch, hidden_size), torch.zeros(num_layers, batch, hidden_size))  ##2.###################################### --> 原先的 h 变为现在的 h cl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)# hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachhidden_prev = (hidden_prev[0].detach(), hidden_prev[1].detach()) ######################################## --> 原先的 h 变为现在的 h closs = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())##3.###########################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/31107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】使用Linux find 命令根据时间过滤文件并输出文件名

那年夏天我和你躲在 这一大片宁静的海 直到后来我们都还在 对这个世界充满期待 今年冬天你已经不在 我的心空出了一块 很高兴遇见你 让我终究明白 回忆比真实精彩 🎵 王心凌《那年夏天宁静的海》 在Linux系统中,find 命令是一个强大…

【Linux】关于在华为云中开放了端口后仍然无法访问的问题

已在安全组中添加规则: 通过指令: netstat -nltp | head -2 && netstat -nltp | grep 8080 运行结果: 可以看到服务器确实处于监听状态了. 通过指令 telnet 公网ip port 也提示: "正在连接xxx.xx.xx.xxx...无法打开到主机的连接。 在端口 8080: 连接失败"…

【漏洞复现】世邦通信 SPON IP网络对讲广播系统 addscenedata.php 任意文件上传漏洞

免责声明: 本文内容旨在提供有关特定漏洞或安全漏洞的信息,以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步,并非出于任何恶意目的。阅读者应该明白,在利用本文提到的漏洞信息或进行相关测…

C语言 | Leetcode C语言题解之第171题Excel表列序号

题目: 题解: int titleToNumber(char* columnTitle) {int number 0;long multiple 1;for (int i strlen(columnTitle) - 1; i > 0; i--) {int k columnTitle[i] - A 1;number k * multiple;multiple * 26;}return number; }

Linux系统下的Swift与Ceph分布式存储解决方案

Linux系统下,Swift和Ceph是两种广泛使用的分布式存储解决方案,它们各有特色,适用于不同的存储需求场景。 Swift Swift是OpenStack项目的一部分,专为大规模对象存储设计。它提供了高度可扩展、持久且分布式的数据存储&#xff0c…

华为---静态路由-浮动静态路由及负载均衡(二)

7.2 浮动静态路由及负载均衡 7.2.1 原理概述 浮动静态路由(Floating Static Route)是一种特殊的静态路由,通过配置去往相同的目的网段,但优先级不同的静态路由,以保证在网络中优先级较高的路由,即主路由失效的情况下&#xff0c…

数据结构之“算法的时间复杂度和空间复杂度”

🌹个人主页🌹:喜欢草莓熊的bear 🌹专栏🌹:数据结构 目录 前言 一、算法效率 1.1算法的复杂度概念 1.2复杂度的重要性 二、时间复杂度 2.1时间复杂度的概念 2.2大O的渐进表示法 2.3常见的时间复杂度…

云计算【第一阶段(17)】账号和权限管理

目录 一、用户账号和组账号概述 1.1、用户账号的三种角色 1.2、组账号的两个角色 二、用户账号文件 2.1、/etc/passwd 2.2、/etc/shadow 2.3、chage 命令 三、组账号文件 3.1、/etc/group 3.2、/etc/gshadow 四、添加组账户 4.1、添加删除组成员 4.2、删除组账号 …

go 1.22 增强 http.ServerMux 路由能力

之前 server func main() {http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {fmt.Println("Received request:", r.URL.Path)fmt.Fprintf(w, "Hello, client! You requested: %s\n", r.URL.Path)})log.Println("Serv…

Web3 学习

之前学习 web3,走了不少弯路,最近看到了 hackquest,重新刷了一遍以太坊基础,感觉非常nice,而且完全免费,有需要的可以试试,链接hackquest.io。

2024广东省职业技能大赛云计算赛项实战——Ceph集群部署

Ceph部署 前言 今年的比赛也是考了Ceph,题目是这道,往年国赛里翻到的: 使用提供的 ceph.tar.gz 软件包,安装 ceph 服务并完成初始化操作。使用提供的 ceph-14.2.22.tar.gz 软件包,在 OpenStack 平台上创建三台CentOS…

【Proteus仿真】【51单片机】基于物联网新能源电动车检测系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器,使用LCD1602液晶显示模块、WIFI模块、蜂鸣器、LED按键、ADC PCF8591、DS18B20温度传感器等。 主要功能: 系统运行后,LCD1602显…

视频集市新增支持多格式流媒体拉流预览

流媒体除了常用实时流外还有大部分是以文件的形式存在,做融合预览必须要考虑多种兼容性能力,借用现有的ffmpeg生态可以迅速实现多种格式的支持,现在我们将按需拉流预览功能进行了拓展,正式支持了ffmpeg的功能,可快捷方…

数据标注概念

数据标注的步骤 数据清洗:处理数据中的噪声、缺失值和异常值,确保数据的质量和完整性。 数据转换:将数据从原始格式转换为适合机器学习模型处理的格式。 数据标注:根据应用需求,为数据添加标签或注释,标识…

初学51单片机之PWM实例呼吸灯以及遇到的问题(已解答)

PWM全名Pulse Width Modulation中文称呼脉冲宽度调制 如图 这是一个周期10ms、频率是100HZ的波形,但是每个周期内,高低电平宽度各不相同,这就是PWM的本质。 占空比是指高电平占整个周期的比列,上图第一个波形的占空比是40%,第二个…

Linux:多线程中的互斥与同步

多线程 线程互斥互斥锁互斥锁实现的原理封装原生线程库封装互斥锁 死锁避免死锁的四种方法 线程同步条件变量 线程互斥 在多线程中,如果存在有一个全局变量,那么这个全局变量会被所有执行流所共享。但是,资源共享就会存在一种问题&#xff1…

QT中常用控件的样式美化,已上传相应的qss样式和图片资源

1、QComboBox /*仅仅输入框*/ QComboBox {background-color: transparent;border-image: url(:/images/systemSetImage/common/comboBoxBk.png);border: 1px solid #7285CA

天才简史——Diederik P. Kingma与他的Adam优化器

一、了解Diederik P. Kingma 发生日期:2024年6月18日 前几日,与实验室同门一同前往七食堂吃饭。饭间,一位做随机优化的同门说他看过一篇被引18w的文章。随后,我表示不信,说你不会数错了吧,能有1.8w次被引都…

如何编译和运行您的第一个Java程序

​ 如何编译和运行您的第一个Java程序 让我们从一个简单的java程序开始。 简单的Java程序 这是一个非常基本的java程序,它会打印一条消息“这是我在java中的第一个程序”。 ​ public class FirstJavaProgram {public static void main(String[] args){System.…

【人机交互 复习】第7章 可视化设计

一、窗口界面类型 1.多文档界面 (1)优点 a.节省系统资源 b.最小的可视集 c.协同工作区 d.多文档同时可视化 (2)缺点 a.菜单随活动文档窗口状态变化,导致不一致性 b.文档窗口必须在主窗口内部,减弱多文档显…