PyTorch -- LSTM 快速实践

上篇介绍了 RNN 快速实践;使用 LSTM 的话,可以解决梯度离散及短期记忆问题;代码部署方面,增加了 c 值 (即 RNN 中的 h 变成了 LSTM 中的 (h,c)), 可对照 RNN 快速实践 来快速掌握。


  • LSTM Layer torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first)
    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size] # 或者用符号 c0
      • h0:[batch, num_layers, hidden_size]
  • LSTM 的输入
    • x:[seq_len, batch, input_size] # 或者用符号 c0
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • (h0, c0):[num_layers, batch, hidden_size]
  • LSTM 的输出
    • y: [seq_len, batch, hidden_size]
    • (ht, ct):[num_layers, batch, hidden_size]

..........

三个门 ( σ \sigma σ处:遗忘f、输入i、输出o) 都是基于 x t \mathbf{x}_t xt h t − 1 \mathbf{h}_{t-1} ht1 产生,但是分别对应要学习的权重参数 W W W 不同,或可参照下简化图直观理解 LSTM 模块内部的处理流程


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署

    • 下述示例代码已注明区别行 ########################### (共3处)

    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net (RNN->LSTM)

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.LSTM(  ##1.###################################### RNN->LSTMinput_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)#hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hhidden_prev = (torch.zeros(num_layers, batch, hidden_size), torch.zeros(num_layers, batch, hidden_size))  ##2.###################################### --> 原先的 h 变为现在的 h cl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)# hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachhidden_prev = (hidden_prev[0].detach(), hidden_prev[1].detach()) ######################################## --> 原先的 h 变为现在的 h closs = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())##3.###########################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/31107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】关于在华为云中开放了端口后仍然无法访问的问题

已在安全组中添加规则: 通过指令: netstat -nltp | head -2 && netstat -nltp | grep 8080 运行结果: 可以看到服务器确实处于监听状态了. 通过指令 telnet 公网ip port 也提示: "正在连接xxx.xx.xx.xxx...无法打开到主机的连接。 在端口 8080: 连接失败"…

【漏洞复现】世邦通信 SPON IP网络对讲广播系统 addscenedata.php 任意文件上传漏洞

免责声明: 本文内容旨在提供有关特定漏洞或安全漏洞的信息,以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步,并非出于任何恶意目的。阅读者应该明白,在利用本文提到的漏洞信息或进行相关测…

C语言 | Leetcode C语言题解之第171题Excel表列序号

题目: 题解: int titleToNumber(char* columnTitle) {int number 0;long multiple 1;for (int i strlen(columnTitle) - 1; i > 0; i--) {int k columnTitle[i] - A 1;number k * multiple;multiple * 26;}return number; }

华为---静态路由-浮动静态路由及负载均衡(二)

7.2 浮动静态路由及负载均衡 7.2.1 原理概述 浮动静态路由(Floating Static Route)是一种特殊的静态路由,通过配置去往相同的目的网段,但优先级不同的静态路由,以保证在网络中优先级较高的路由,即主路由失效的情况下&#xff0c…

数据结构之“算法的时间复杂度和空间复杂度”

🌹个人主页🌹:喜欢草莓熊的bear 🌹专栏🌹:数据结构 目录 前言 一、算法效率 1.1算法的复杂度概念 1.2复杂度的重要性 二、时间复杂度 2.1时间复杂度的概念 2.2大O的渐进表示法 2.3常见的时间复杂度…

云计算【第一阶段(17)】账号和权限管理

目录 一、用户账号和组账号概述 1.1、用户账号的三种角色 1.2、组账号的两个角色 二、用户账号文件 2.1、/etc/passwd 2.2、/etc/shadow 2.3、chage 命令 三、组账号文件 3.1、/etc/group 3.2、/etc/gshadow 四、添加组账户 4.1、添加删除组成员 4.2、删除组账号 …

go 1.22 增强 http.ServerMux 路由能力

之前 server func main() {http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {fmt.Println("Received request:", r.URL.Path)fmt.Fprintf(w, "Hello, client! You requested: %s\n", r.URL.Path)})log.Println("Serv…

Web3 学习

之前学习 web3,走了不少弯路,最近看到了 hackquest,重新刷了一遍以太坊基础,感觉非常nice,而且完全免费,有需要的可以试试,链接hackquest.io。

【Proteus仿真】【51单片机】基于物联网新能源电动车检测系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器,使用LCD1602液晶显示模块、WIFI模块、蜂鸣器、LED按键、ADC PCF8591、DS18B20温度传感器等。 主要功能: 系统运行后,LCD1602显…

视频集市新增支持多格式流媒体拉流预览

流媒体除了常用实时流外还有大部分是以文件的形式存在,做融合预览必须要考虑多种兼容性能力,借用现有的ffmpeg生态可以迅速实现多种格式的支持,现在我们将按需拉流预览功能进行了拓展,正式支持了ffmpeg的功能,可快捷方…

初学51单片机之PWM实例呼吸灯以及遇到的问题(已解答)

PWM全名Pulse Width Modulation中文称呼脉冲宽度调制 如图 这是一个周期10ms、频率是100HZ的波形,但是每个周期内,高低电平宽度各不相同,这就是PWM的本质。 占空比是指高电平占整个周期的比列,上图第一个波形的占空比是40%,第二个…

Linux:多线程中的互斥与同步

多线程 线程互斥互斥锁互斥锁实现的原理封装原生线程库封装互斥锁 死锁避免死锁的四种方法 线程同步条件变量 线程互斥 在多线程中,如果存在有一个全局变量,那么这个全局变量会被所有执行流所共享。但是,资源共享就会存在一种问题&#xff1…

天才简史——Diederik P. Kingma与他的Adam优化器

一、了解Diederik P. Kingma 发生日期:2024年6月18日 前几日,与实验室同门一同前往七食堂吃饭。饭间,一位做随机优化的同门说他看过一篇被引18w的文章。随后,我表示不信,说你不会数错了吧,能有1.8w次被引都…

【人机交互 复习】第7章 可视化设计

一、窗口界面类型 1.多文档界面 (1)优点 a.节省系统资源 b.最小的可视集 c.协同工作区 d.多文档同时可视化 (2)缺点 a.菜单随活动文档窗口状态变化,导致不一致性 b.文档窗口必须在主窗口内部,减弱多文档显…

台积电(TSMC)正在探索采用新型先进芯片封装技术

台积电(TSMC)正在探索采用新型先进芯片封装技术,使用类似面板的矩形基板,以应对日益增长的先进多芯片组处理器需求。据日经亚洲报道,这项开发仍处于早期阶段,可能需要数年时间才能商业化,但如果…

Minecraft服务端配置教程

一、下载服务端核心文件 下载 | FastMirror 无极镜像 | 我的世界核心下载 Downloads for Minecraft Forge for MinecraftForge服务端下载 MCVersions.net - Minecraft Versions Download List原版 注意,这个网站可以下载Forge水桶等插件和模组端,如果…

STM32HAL库--定时器篇

STM32F429 有14个定时器,其中包括 2 个基本定时器(TIM6 和 TIM7)、 10 个通用定时器(TIM2~TIM5,TIM9~TIM14)、 2 个高级控制定时器(TIM1 和 TIM8)。 由上表知道:除了 TIM…

视频服务网关的特点

一、视频服务网关的介绍 视频服务网关采用Linux操作系统,可支持国内外不同品牌、不同协议、不同设备类型监控产品的统一接入管理,同时提供标准的H5播放接口供其他应用平台快速对接,让您快速拥有视频集成能力。不受开发环境、跨系统跨平台等条…

数据分析思考

数据分析工作流程 在我的数据分析职业发展过程中,我从基础的数据提取工作开始,逐步深入到更为复杂和具有战略意义的领域。这包括构建和完善指标体系、设计风险预警模型,以及与多部门协作完成公司整体经营分析等工作。 在这个过程中&#xf…

Rust中的数据抓取:代理和scraper的协同工作

一、数据抓取的基本概念 数据抓取,又称网络爬虫或网页爬虫,是一种自动从互联网上提取信息的程序。这些信息可以是文本、图片、音频、视频等,用于数据分析、市场研究或内容聚合。 为什么选择Rust进行数据抓取? 性能:…