使用torch实现RNN

在实验室的项目遇到了困难,弄不明白LSTM的原理。到网上搜索,发现LSTM是RNN的变种,那就从RNN开始学吧。

带隐藏状态的RNN可以用下面两个公式来表示:

可以看出,一个RNN的参数有W_xh,W_hh,b_h,W_hq,b_q和H(t)。其中H(t)是步数的函数。

参考的文章考虑了这样一个问题,对于x轴上的一列点,有一列sin值,我们想知道它对应的cos值,但是即使sin值相同,cos值也不同,因为输出结果不仅依赖于当前的输入值sinx,还依赖于之前的sin值。这时候可以用RNN来解决问题

用到的核心函数:torch.nn.RNN() 参数如下:

  • input_size – 输入x的特征数量。
  • hidden_size – 隐藏层的特征数量。
  • num_layers – RNN的层数。
  • nonlinearity – 指定非线性函数使用tanh还是relu。默认是tanh
  • bias – 如果是False,那么RNN层就不会使用偏置权重 bihbih和bhhbhh,默认是True
  • batch_first – 如果True的话,那么输入Tensor的shape应该是[batch_size, time_step, feature],输出也是这样。
  • dropout – 如果值非零,那么除了最后一层外,其它层的输出都会套上一个dropout层。
  • bidirectional – 如果True,将会变成一个双向RNN,默认为False

下面是代码:

 1 # encoding:utf-82 import torch3 import numpy as np4 import matplotlib.pyplot as plt  # 导入作图相关的包5 from torch import nn6 7 8 # 定义RNN模型9 class Rnn(nn.Module):
10     def __init__(self, INPUT_SIZE):
11         super(Rnn, self).__init__()
12 
13         # 定义RNN网络,输入单个数字.隐藏层size为[feature, hidden_size]
14         self.rnn = nn.RNN(
15                 input_size=INPUT_SIZE,
16                 hidden_size=32,
17                 num_layers=1,
18                 batch_first=True  # 注意这里用了batch_first=True 所以输入形状为[batch_size, time_step, feature]
19                 )
20         # 定义一个全连接层,本质上是令RNN网络得以输出
21         self.out = nn.Linear(32, 1)
22 
23     # 定义前向传播函数
24     def forward(self, x, h_state):
25         # 给定一个序列x,每个x.size=[batch_size, feature].同时给定一个h_state初始状态,RNN网络输出结果并同时给出隐藏层输出
26         r_out, h_state = self.rnn(x, h_state)
27         outs = []
28         for time in range(r_out.size(1)):  # r_out.size=[1,10,32]即将一个长度为10的序列的每个元素都映射到隐藏层上.
29             outs.append(self.out(r_out[:, time, :]))  # 依次抽取序列中每个单词,将之通过全连接层并输出.r_out[:, 0, :].size()=[1,32] -> [1,1]
30         return torch.stack(outs, dim=1), h_state  # stack函数在dim=1上叠加:10*[1,1] -> [1,10,1] 同时h_state已经被更新
31 
32 
33 TIME_STEP = 10
34 INPUT_SIZE = 1
35 LR = 0.02
36 
37 model = Rnn(INPUT_SIZE)
38 print(model)
39 
40 loss_func = nn.MSELoss()  # 使用均方误差函数
41 optimizer = torch.optim.Adam(model.parameters(), lr=LR)  # 使用Adam算法来优化Rnn的参数,包括一个nn.RNN层和nn.Linear层
42 
43 h_state = None  # 初始化h_state为None
44 
45 for step in range(300):
46     # 人工生成输入和输出,输入x.size=[1,10,1],输出y.size=[1,10,1]
47     start, end = step * np.pi, (step + 1)*np.pi
48 
49     steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
50     x_np = np.sin(steps)
51     y_np = np.cos(steps)
52 
53     x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])
54     y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
55 
56     # 将x通过网络,长度为10的序列通过网络得到最终隐藏层状态h_state和长度为10的输出prediction:[1,10,1]
57     prediction, h_state = model(x, h_state)
58     h_state = h_state.data  # 这一步只取了h_state.data.因为h_state包含.data和.grad 舍弃了梯度
59     # 反向传播
60     loss = loss_func(prediction, y)
61     optimizer.zero_grad()
62     loss.backward()
63 
64     # 优化网络参数具体应指W_xh, W_hh, b_h.以及W_hq, b_q
65     optimizer.step()
66 
67 # 对最后一次的结果作图查看网络的预测效果
68 plt.plot(steps, y_np.flatten(), 'r-')
69 plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
70 plt.show()

最后一步预测和实际y的结果作图如下:

可看出,训练RNN网络之后,对网络输入一个序列sinx,能正确输出对应的序列cosx

在线教程

  • 麻省理工学院人工智能视频教程 – 麻省理工人工智能课程
  • 人工智能入门 – 人工智能基础学习。Peter Norvig举办的课程
  • EdX 人工智能 – 此课程讲授人工智能计算机系统设计的基本概念和技术。
  • 人工智能中的计划 – 计划是人工智能系统的基础部分之一。在这个课程中,你将会学习到让机器人执行一系列动作所需要的基本算法。
  • 机器人人工智能 – 这个课程将会教授你实现人工智能的基本方法,包括:概率推算,计划和搜索,本地化,跟踪和控制,全部都是围绕有关机器人设计。
  • 机器学习 – 有指导和无指导情况下的基本机器学习算法
  • 机器学习中的神经网络 – 智能神经网络上的算法和实践经验
  • 斯坦福统计学习

请添加图片描述

人工智能书籍

  • OpenCV(中文版).(布拉德斯基等)
  • OpenCV+3计算机视觉++Python语言实现+第二版
  • OpenCV3编程入门 毛星云编著
  • 数字图像处理_第三版
  • 人工智能:一种现代的方法
  • 深度学习面试宝典
  • 深度学习之PyTorch物体检测实战
  • 吴恩达DeepLearning.ai中文版笔记
  • 计算机视觉中的多视图几何
  • PyTorch-官方推荐教程-英文版
  • 《神经网络与深度学习》(邱锡鹏-20191121)

  • 在这里插入图片描述

第一阶段:零基础入门(3-6个月)

新手应首先通过少而精的学习,看到全景图,建立大局观。 通过完成小实验,建立信心,才能避免“从入门到放弃”的尴尬。因此,第一阶段只推荐4本最必要的书(而且这些书到了第二、三阶段也能继续用),入门以后,在后续学习中再“哪里不会补哪里”即可。

第二阶段:基础进阶(3-6个月)

熟读《机器学习算法的数学解析与Python实现》并动手实践后,你已经对机器学习有了基本的了解,不再是小白了。这时可以开始触类旁通,学习热门技术,加强实践水平。在深入学习的同时,也可以探索自己感兴趣的方向,为求职面试打好基础。

第三阶段:工作应用

这一阶段你已经不再需要引导,只需要一些推荐书目。如果你从入门时就确认了未来的工作方向,可以在第二阶段就提前阅读相关入门书籍(对应“商业落地五大方向”中的前两本),然后再“哪里不会补哪里”。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/641440.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[AutoSar]BSW_OS 06 Autosar OS_Alarms

一、 目录 一、关键词平台说明一、Timer1.1 配置1.2Periodical Interrupt Timer (PIT)和High Resolution Timer (HRT) 二、Alarm 工作机制三、Code3.1创建一个15ms的runnable3.2mapping到basic task3.3生成代码 关键词 嵌入式、C语言、autosar、OS、BSW 平台说明 项目ValueO…

k8s的helm

1、在没有helm之前,部署deployment、service、ingress等等 2、helm的作用:通过打包的方式,deployment、service、ingress这些打包在一块,一键部署服务、类似于yum功能 3、helm:官方提供的一种类似于仓库的功能&#…

时间轮设计

目录 基本概念 函数定义 函数实现与测试 测试1结果如下 测试2结果如下 基本概念 时间轮 是一种 实现延迟功能(定时器) 的 巧妙算法。如果一个系统存在大量的任务调度,时间轮可以高效的利用线程资源来进行批量化调度。把大批量的调度任务…

React16源码: React中的resetChildExpirationTime的源码实现

resetChildExpirationTime 1 )概述 在 completeUnitOfWork 当中,有一步比较重要的一个操作,就是重置 childExpirationTimechildExpirationTime 是非常重要的一个时间节点,它用来记录某一个节点的子树当中,目前优先级最…

C++提高编程——STL:string容器、vector容器

本专栏记录C学习过程包括C基础以及数据结构和算法,其中第一部分计划时间一个月,主要跟着黑马视频教程,学习路线如下,不定时更新,欢迎关注。 当前章节处于: ---------第1阶段-C基础入门 ---------第2阶段实战…

数据结构:堆与堆排序

目录 堆的定义: 堆的实现: 堆的元素插入: 堆元素删除: 堆初始化与销毁: 堆排序: 堆的定义: 堆是一种完全二叉树,完全二叉树定义如下: 一棵深度为k的有n个结点的二…

ffmpeg使用及java操作

1.文档 官网: FFmpeg 官方使用文档: ffmpeg Documentation 中文简介: https://www.cnblogs.com/leisure_chn/p/10297002.html 函数及时间: ffmpeg日记1011-过滤器-语法高阶,逻辑,函数使用_ffmpeg gte(t,2)-CSDN博客 java集成ffmpeg: SpringBoot集成f…

科技云报道:金融大模型落地,还需跨越几重山?

科技云报道原创。 时至今日,大模型的狂欢盛宴仍在持续,而金融行业得益于数据密集且有强劲的数字化基础,从一众场景中脱颖而出。 越来越多的公司开始布局金融行业大模型,无论是乐信、奇富科技、度小满、蚂蚁这样的金融科技公司&a…

深度学习如何弄懂那些难懂的数学公式?是否需要学习数学?

经过1~2年的学习,我觉得还是需要数学有一定认识,重新捡起高等数学、概率与数理、线代等这几本,起码基本微分方程、求导、对数、最小损失等等还是会用到。 下面给出几个链接,可以用于平时充电学习。 知乎上的: 机器学…

计算机毕业设计 基于SpringBoot的律师事务所案件管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

git merge和git rebase区别

具体详情 具体常见如下,假设有master和change分支,从同一个节点分裂,随后各自进行了两次提交commit以及修改。随后即为change想合并到master分支中,但是直接git commit和git push是不成功的,因为分支冲突了【master以…

上位机图像处理和嵌入式模块部署(流程)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们说过,传统图像处理的方法,一般就是pccamera的处理方式。camera本身只是提供基本的raw data数据,所有的…

基于ADAS的车道线检测算法matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 图像预处理 4.2 车道线特征提取 4.3 车道线跟踪 5.完整工程文件 1.课题概述 基于ADAS的车道线检测算法,通过hough变换和边缘检测方法提取视频样板中的车道线,然后根据车道线的弯曲情况…

Linux/Mac 命令行工具 tree 开发项目结构可以不用截图了 更方便 更清晰 更全

tree 是一个命令行工具,用于以树形结构显示文件系统目录的内容。它可用于列出指定目录下的所有文件和子目录,以及它们的层次关系。tree 命令在许多操作系统中都可用,包括Unix、Linux和macOS。 效果如下: 一、安装 linux # De…

Prometheus+Grafana监控Mysql数据库

Promethues Prometheus https://prometheus.io Prometheus是一个开源的服务监控系统,它负责采集和存储应用的监控指标数据,并以可视化的方式进行展示,以便于用户实时掌握系统的运行情况,并对异常进行检测。因此,如何…

Spring Boot3整合knife4j(swagger3)

目录 1.前置条件 2.导依赖 3.配置 1.前置条件 已经初始化好一个spring boot项目且版本为3X,项目可正常启动。 作者版本为3.2.2最新版 2.导依赖 knife4j官网: Knife4j 集Swagger2及OpenAPI3为一体的增强解决方案. | Knife4j (xiaominfo.com)http…

R语言简介

1.R语言 R语言是一种数学编程语言,主要用于统计分析、绘图和数据挖掘。 2.R语言特点 免费、开源,兼容性好(Windows、MacOS或Linux)。具有多种数据类型,如向量、矩阵、因子、数据集等常用数据结构。多用于交互式数据分析&#x…

股权众筹模式介绍(下)

3、线上线下两段式投资 对于已经成成立并运营的企业来说,由于《证券法》明确规定,向“不特定对象发行证券”以及“向特定对象发行证券累计超过200人”的行为属于公开发行证券,必须通过证监会核准,由证券公司承销。这些规定限定了…

RTDETR 引入 UniRepLKNet:用于音频、视频、点云、时间序列和图像识别的通用感知大卷积神经网络 | DRepConv

大卷积神经网络(ConvNets)近来受到了广泛研究关注,但存在两个未解决且需要进一步研究的关键问题。1)现有大卷积神经网络的架构主要遵循传统ConvNets或变压器的设计原则,而针对大卷积神经网络的架构设计仍未得到解决。2)随着变压器在多个领域的主导地位,有待研究ConvNets…

小程序商城 免 费 搭 建之java商城 电子商务Spring Cloud+Spring Boot+二次开发+mybatis+MQ+VR全景+b2b2c

java SpringCloud版本b2b2c鸿鹄云商平台全套解决方案 使用技术: Spring CloudSpring BootMybatis微服务服务监控可视化运营 B2B2C平台: 平台管理端(包含自营) 商家平台端(多商户入驻) PC买家端、手机wap/公众号买家端 微服务(30个通用…