经典网络 循环神经网络(一) | RNN结构解析,代码实现

文章目录

  • 1 提出背景
  • 2 RNN
    • 2.1 RNN结构
    • 2.2 RNN代码实现
    • 2.3 代码简洁实现

1 提出背景

为什么要引入RNN呢?

非常简单,之前我们的卷积神经网络CNN,全连接神经网络等都是单个神经元计算

但在序列模型中,前一个神经元往往对后面一个神经元有影响

比如

两句话

I like eating apples.

I want to have a apple watch

第一个苹果和第二个苹果的概念是不一样的,第一个苹果是红彤彤的苹果,第二个苹果是苹果公司的意思

如何知道

是因为apple的翻译参考了上下文,第一句话看到了eating这个单词,第二句话看到了watch这个单词

因而可见,对于语言这种时序信息,利用需要参考上下文进行

还有其他原因

  • 拿人类的某句话来说,也就是人类的自然语言,是不是符合某个逻辑或规则的字词拼凑排列起来的,这就是符合序列特性。
  • 语音,我们发出的声音,每一帧每一帧的衔接起来,才凑成了我们听到的话,这也具有序列特性、
  • 股票,随着时间的推移,会产生具有顺序的一系列数字,这些数字也是具有序列特性。

2 RNN

具有时序功能,从某种意义来说,RNN也就具有了记忆功能,好比我们人类自己,为什么会受到过去影响,因为我们具有记忆能力。

同时只有记忆能力是不够的,处理后的信息得储存起来,形成“新的记忆”

对于RNN,可以分为单向RNN,和双向RNN,其中单向的是只利用前面的信息,而双向的RNN既可以利用前面的信息,也可以利用后面的信息。

2.1 RNN结构

RNN的基本单元包含以下关键组件:

  • 输入 ( x t x_t xt ): 表示在时间步 (t) 的输入序列。
  • 隐藏状态 ( h t h_t ht ): 在时间步 (t) 的隐藏状态,是网络在处理序列过程中保留的信息相当于ht里面藏着上下文信息
  • 每一步的输出(Oi):每一个时间步有一个输出Oi,Oi综合了当前时间步和之前的很多信息,那么对于某些特定任务,如分类什么的,就可以直接用Oi去做判断。很多时候直接把隐藏状态拿去做了输出

如下图,图片来自《动手学深度学习》

在这里插入图片描述

那么每一个隐状态是通过怎样的方式得到的呢?

RNN的隐藏状态 (ht ) 的计算通过以下数学公式完成:

$h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t−1}+b_{hh}) $

这个公式展示了RNN如何根据当前输入 (xt ) 和前一个时间步的隐藏状态 (ht−1 ) 来计算当前时间步的隐藏状态 (ht )。其中 (tanh) 是双曲正切激活函数,用于引入非线性。

实际中我们可以看到

  • 权重矩阵 ($W_{ih} , W_{hh} $): 分别是输入到隐藏状态和隐藏状态到隐藏状态的权重矩阵。
  • 偏差 ($b_{ih} , b_{hh} $): 对应的偏差。

在这里插入图片描述

第一个问题 是每一个句子的长度不一致,你怎么用统一的矩阵呢?

​ 只实现了一个单层神经元,可以通过获得句子长度知道时间步数t,进一步做相关的调整

2.2 RNN代码实现

代码实现首先实现上图的一个神经元

def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)

然后利用循环,根据语句长度做预测判断,损失函数计算优化

def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]:  # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds):  # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])

2.3 代码简洁实现

往往通过一个nn.RNN来实现

nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)

参数说明

input_size输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就等于一个词向量的维度
hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
num_layers网络的层数,一般可以默认为1
nonlinearity激活函数
bias是否使用偏置
batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
birdirectional是否使用双向的 rnn,默认是 False
注意某些参数的默认值在标题中已注明

rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, )

定义模型, 其中vocab_size = 1027, hidden_size = 256

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/626759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么使用 atan2(sin(z), cos(z)) 进行角度归一化?

文章目录 为什么使用 atan2(sin(z), cos(z)) 进行归一化?为什么归一化后的角度等于原始角度? atan2 方法返回 -π 到 π 之间的值,代表点 (x, y) 相对于正X轴的偏移角度。这个角度是逆时针测量的,以弧度为单位。关于 atan2 函数为…

YOLOv5姿态估计:HRnet实时检测人体关键点

前言: Hello大家好,我是Dream。 今天来学习一下利用YOLOv5进行姿态估计,HRnet与SimDR检测图片、视频以及摄像头中的人体关键点,欢迎大家一起前来探讨学习~ 本文目录: 一、项目准备1Pycharm中克隆github上的项目2.具体步…

【Linux实用篇】Linux软件安装 JDK Tomcat MySQL lrzsz

1. 软件安装 1.1 软件安装方式 在Linux系统中,安装软件的方式主要有四种,这四种安装方式的特点如下: 安装方式特点二进制发布包安装软件已经针对具体平台编译打包发布,只要解压,修改配置即可rpm安装软件已经按照red…

微信好友批量自动添加:快捷方式解密

对于一些希望扩大社交圈子或者推广业务的人来说,手动添加好友可能是一个耗时且繁琐的任务。 不过,别担心,今天给大家种草一个能够批量自动添加好友的微信管理工具,让你轻松地扩展好友列表。 首先,当微信在个微管理系…

Python数据分析案例31——中国A股的月份效应研究(方差分析,虚拟变量回归)

案例背景 本次案例是博主本科在行为金融学课程上做的一个小项目,最近看很多经管类的学生作业都很需要,我就用python来重新做了一遍。不弄那些复杂的机器学习模型了,经管类同学就用简单的统计学方法来做模型就好。 研究目的 有效市场假说是现…

VUE项目快速打包发布

VUE项目快速打包发布 首先在你的VS Code中新建一个终端 输入 npm run build 回车等运行结束之后会在你的项目中生成一个dist目录 此时再iis部署的时候把你添加的网站指定的目录指向dist即可

STM32CubeMX配置STM32G071UART+DMA收发数据(HAL库开发)

时钟配置HSI主频配置64M 配置好串口&#xff0c;选择异步模式 配置DMA TX,RX,选择循环模式。 NVIC中勾选使能中断 勾选生成独立的.c和h文件 配置好需要的开发环境并获取代码 串口重定向勾选Use Micro LIB main.c文件修改 增加头文件和串口重定向 #include <string.h&g…

spring常见漏洞(3)

CVE-2017-8046 Spring-Data-REST-RCE(CVE-2017-8046)&#xff0c;Spring Data REST对PATCH方法处理不当&#xff0c;导致攻击者能够利用JSON数据造成RCE。本质还是因为spring的SPEL解析导致的RCE 影响版本 Spring Data REST versions < 2.5.12, 2.6.7, 3.0 RC3 Spring Bo…

光学雨量监测站比传统雨量站有哪些优势

光学雨量监测站相比传统雨量站具有许多优势。首先&#xff0c;光学雨量监测站采用光学原理进行雨量监测&#xff0c;而传统雨量站则依靠传感器和机械部件进行测量。光学雨量监测站的结构相对简单&#xff0c;不需要频繁维护和校准&#xff0c;减少了运维成本和工作量。 其次&am…

【Emgu CV教程】5.1、几何变换之平移

图像的几何变换对于图像处理来说&#xff0c;也是最基础的那一档次&#xff0c;包括平移、旋转、缩放、透视变换等等&#xff0c;也就是对图像整理形状的改变&#xff0c;用到的函数都比较简单&#xff0c;理解起来也很容易。但是为了凑字数&#xff0c;还是一个函数一个函数的…

如何通过企业司法协助信息API识别潜在的不良合作伙伴

引言 在商业合作中&#xff0c;合作伙伴的信誉和合规性是至关重要的。然而&#xff0c;在选择合作伙伴时&#xff0c;我们往往面临信息不对称的问题。如何有效地识别潜在的不良合作伙伴&#xff0c;避免潜在的风险呢&#xff1f;通过企业司法协助信息API&#xff0c;我们可以快…

mysql简单操作集成数据模型使用方法

查看表信息&#xff0c;其中包括字段信息以及创表信息 DESCRIBE asset; show COLUMNS FROM asset; SHOW CREATE TABLE asset; 常规操作表 --查询 select * FROM device_template --插入 INSERT into asset_package (protocol,project_code,lease_id,station_name,device_id,…

Qt点击按钮在附近弹出下拉框

效果 MainWindow.h #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include"toollayout.h" QT_BEGIN_NAMESPACE namespace Ui { class MainWindow; } QT_END_NAMESPACEclass MainWindow : public QMainWindow {Q_OBJECTpublic:MainWindow…

vue3+vite项目构建时报错npm ERR! code EPERMnpm ERR! syscall mkdir...

vscode终端中输入npm create vitelatest vueviteproject1 -- --vue命令后报错 具体报错如下&#xff1a; PS D:\project> npm create vitelatest vueviteproject1 -- --vue >> npm ERR! code EPERM npm ERR! syscall mkdir npm ERR! path D:\node\node_cache\_cac…

C语言:自定义类型——联合和枚举

一、联合体 1.1 联合体类型的声明 像结构体⼀样&#xff0c;联合体也是由⼀个或者多个成员构成&#xff0c;这些成员可以是不同的类型。 声明方式如下图&#xff1a; 那联合体和结构体究竟有什么区别呢&#xff1f;&#xff1f; 下面将重点讲解联合体的特点&#xff01;&am…

判断交叉编译工具是否支持C++20的标准

写个任意的测试程序hello_world 执行 arm-linux-gnueabihf-g -stdc14 main.cpp arm-linux-gnueabihf-g -stdc17 main.cpp arm-linux-gnueabihf-g -stdc20 main.cpp没报错则代表支持&#xff0c;报错则不支持.

数字图像处理常用算法的原理和代码实现详解

本专栏详细地分析了常用图像处理算法的数学原理、实现步骤。配有matlab或C实现代码&#xff0c;并对代码进行了详细的注释。最后&#xff0c;对算法的效果进行了测试。相信通过这个专栏&#xff0c;你可以对这些算法的原理及实现有深入的理解&#xff01;   如有疑问&#xf…

公司想做一套数字化管理系统,该怎么做?

引言 一个老板的心声&#xff1a;随着科技的迅猛发展&#xff0c;公司数字化已经成为提升业务竞争力不可或缺的关键因素。在这个数字时代&#xff0c;我们公司旨在顺应潮流&#xff0c;迎接挑战&#xff0c;以构建一套强大而高效的数字化管理系统为目标。通过本系统&#xff0…

8路DI高速计数器,8路DO支持PWM输出,Modbus TCP模块 YL93 开关量输入输出

特点&#xff1a; ● 8路开关量输入&#xff0c;8路开关量输出 ● DI每一路都可用作计数器或者频率测量 ● DO每一路都可独立输出PWM信号 ● DI和DO都支持PNP&#xff0c;NPN切换功能 ● 支持Modbus TCP 通讯协议 ● 内置网页功能&#xff0c;可以通过网页查询电平状…

安达发|APS计划排产排程排单软件功能解析

APS计划排产排程排单软件是一种用于生产计划和排程的工具&#xff0c;它可以帮助制造企业实现高效、准确的生产计划和排程。该软件具有多种功能&#xff0c;包括可视化人机互动排产、一键式全自动排产、设备产能约束排产、模具约束排产、人员约束排产、半成品约束排产、物料齐套…