带你学会深度学习之循环神经网络[RNN] - 2

前言

笔者写下此系列文章是希望在复习人工智能相关知识同时为想学此技术的人提供一定帮助。

图源网络,所有者可随时联系笔者删除。

代码不代表全部实现,只是为展示模型的关键结构。

与CNN不同,RNN被设计用来处理序列数据。它通过在网络的隐藏层中引入循环,使网络能够保留前一个状态的信息,并将这些信息用于当前状态的计算。这种设计使RNN特别适合处理语言翻译、自然语言处理、语音识别等需要理解数据序列中时间相关性的任务。

正文

RNN的多层结构

我们知道CNN模型的多层结构大致是这样的

代码上看,对于如下的两层卷积一层全连接的简单CNN模型定义如下: 

def __init__(self):super(SimpleCNN, self).__init__()# 第一个卷积层, 输入通道数1, 输出通道数16, 卷积核大小3x3self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)# 第二个卷积层, 输入通道数16, 输出通道数32, 卷积核大小3x3self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)# 最大池化层, 使用2x2窗口进行下采样self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层, 32 * 7 * 7 输入特征数, 10输出特征数 (10类分类,28 X 28)self.fc = nn.Linear(32 * 7 * 7, 10)

前向传播时,我们只需要如下操作(Pytorch,TF会完成后向传播)

def forward(self, x):# 通过第一个卷积层后激活x = F.relu(self.conv1(x))# 通过池化层x = self.pool(x)# 通过第二个卷积层后激活x = F.relu(self.conv2(x))# 通过池化层x = self.pool(x)# 展平特征图以供全连接层使用x = x.view(-1, 32 * 7 * 7)# 通过全连接层x = self.fc(x)return x

符合直觉,再来看看RNN我们会怎么做。

def __init__(self, input_size, hidden_size, num_layers, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# RNN 层self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)# 输出层self.fc = nn.Linear(hidden_size, output_size)

nn.RNN 是 PyTorch 中的 RNN 层,我们设置它的输入尺寸、隐藏状态尺寸、层数,以及是否批量处理输入数据。

def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)# 前向传播RNNout, _ = self.rnn(x, h0)# 取最后一时间步out = out[:, -1, :]out = self.fc(out)return out

全连接层 (self.fc) 用于 RNN 的最后一个时间步的隐藏状态转换为输出。

最后一个时间步的隐藏状态是序列中的最后一个元素或时刻,每一个时间步都有其对应的数据输入。在循环神经网络中,网络会逐个时间步地处理整个序列,为每一个时间步生成一个输出。

对于很多任务来说,通常只需要序列的最终状态,因为在循环神经网络中,每个时间步的隐藏状态包含了到目前为止的序列信息,因此最后一个时间步的隐藏状态上包含了整个序列的信息。

从架构图上看,它是这样来堆叠多层的。

RNN 变种之GRU门控循环单元

GRU的提出主要是为了解决RNN在长距离依赖问题上的挑战,这个问题也被称为梯度消失或梯度爆炸问题,具体的说,如下为RNN隐藏状态梯度的通项公式,当时间步数T较大或者时间步t较小,将产生梯度衰减和梯度爆炸问题,不利于训练模型。

为了解决这些问题,GRU引入了两个关键的门控机制:更新门(update gate)和重置门(reset gate)。

  • 更新门用于控制前一个状态到当前状态的信息量。它决定了有多少之前的信息需要保留,以及有多少新的信息需要加入。通过这种方式,GRU能够在必要时保留长期信息,从而有效地缓解梯度消失问题。

  • zt​ 是时刻 t 的更新门,用于决定保留多少旧状态信息。
  • Wz​ 是更新门的权重参数。
  • σ 表示sigmoid函数,将输入压缩到0和1之间,以便作为门控信号。
  • [ht−1​,xt​] 表示ht−1​和xt​ 的连接。 
  • 重置门则控制了多少之前的信息需要忘记,这使得模型能够根据新的输入丢弃不相关的状态信息。这有助于模型更好地处理输入序列中的变化,使其对于序列中的重要事件更加敏感。

  • rt​ 是时刻 t 的重置门,用于决定有多少过去的信息需要被忘记。
  • Wr​ 是重置门的权重参数。
  • [ht−1​,xt​] 表示ht−1​和xt​ 的连接。

候选隐藏状态(Candidate Hidden State)

候选隐藏状态提供了一种可能的新状态,其基于当前的输入 xt​ 和通过重置门调整过的前一时刻的隐藏状态rt​∗ht−1​得出,候选隐藏状态的计算考虑了当前输入和经重置门修改后的上一时刻隐藏状态。如果重置门接近0,那么旧的信息会被忽略,候选状态几乎完全基于当前的输入,允许模型在需要的时候快速忘记无关的过去信息。

  • h~t​ 是时刻 t 的候选隐藏状态,它包含了在当前时刻可能需要加入到真正隐藏状态中的新信息。
  • W 是候选隐藏状态的权重参数。
  • rt​∗ht−1​ 表示重置门控制后的隐藏状态,其中的 ∗ 表示逐元素乘法。
  • tanh⁡tanh 函数帮助将数据压缩到 −1 和 1之间,帮助控制梯度流。

最终隐藏状态(Final Hidden State)

最终隐藏状态是根据更新门的输出决定的,它决定了从上一时刻隐藏状态 ht−1​ 到当前时刻 t 的隐藏状态 ht​ 应该保留多少信息,以及候选隐藏状态 h~t​ 应该贡献多少新信息,通过更新门,GRU在候选隐藏状态(代表了新信息)和前一隐藏状态(代表了旧信息)之间做权衡。如果更新门值接近1,意味着保留更多旧信息;如果接近0,意味着采纳更多新信息。

  • ht​ 是时刻 t 的最终隐藏状态,它通过更新门 zt​ 来融合之前的隐藏状态 ht−1​ 和当前的候选隐藏状态 h~t​。
  • 这一步骤使GRU能够在保留长期依赖信息的同时,也加入新的信息,解决了传统RNN中的梯度消失问题。

由于现在框架发展成熟,调用只需如下,分别提供了两种主流框架代码:

import tensorflow as tfmodel = tf.keras.Sequential([# 输入数据维度是 (None, 10, 64)# None代表批次大小,10代表序列长度,64代表每个时间步的特征维度tf.keras.layers.GRU(256, return_sequences=True, input_shape=(10, 64)),tf.keras.layers.GRU(128),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

两个GRU层来处理序列数据,其中第一个GRU层返回整个序列的输出供下一个GRU层使用,Dense层分类。

import torch
import torch.nn as nnclass GRUNet(nn.Module):def __init__(self):super(GRUNet, self).__init__()self.gru1 = nn.GRU(input_size=64, hidden_size=256, num_layers=1, batch_first=True)self.gru2 = nn.GRU(input_size=256, hidden_size=128, num_layers=1, batch_first=True)self.fc = nn.Linear(128, 10)def forward(self, x):# x shape [batch_size, sequence_length, feature_size]out, _ = self.gru1(x)out, _ = self.gru2(out)# 取序列的最后一个时间步out = out[:, -1, :]out = self.fc(out)return outmodel = GRUNet()
print(model)

两个GRU层和一个全连接层,全连接层10分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/769166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3GPP 协议资料学习和文档下载

一、登录3GPP官网 3GPP – The Mobile Broadband Standard 二、选择Specifications Per TSG Round 三、选择ftp下载路径 四、选择不同阶段的3GPP协议 包含了从1999年到R18,甚至更新到当前最新的协议。 五、查看对应版本的LTE或者5G NR协议 其中LTE射频相关章节为36.521系列&…

小目标检测篇 | YOLOv8改进之增加小目标检测层(针对Neck网络为AFPN)

前言:Hello大家好,我是小哥谈。小目标检测是计算机视觉领域中的一个研究方向,旨在从图像或视频中准确地检测和定位尺寸较小的目标物体。相比于常规目标检测任务,小目标检测更具挑战性,因为小目标通常具有低分辨率、低对比度和模糊等特点,容易被背景干扰或遮挡。本篇文章就…

MP4如何把视频转MOV格式? MP4视频转MOV格式的技巧

在现代的数字媒体时代,视频格式转换成为了许多用户必须掌握的技能。特别是将MP4视频转换为MOV格式,这对于需要在Apple设备上播放或编辑视频的用户来说尤为重要。本文将详细介绍如何将MP4视频转换为MOV格式,帮助读者轻松应对不同设备和平台的需…

三端可调正稳压器集成电路D317——输出电压范围是1.2V至37V,负载电流最大为1.5A

D317大电流可调稳压电路 1、 概述: D317是一款三端可调正稳压器集成电路,其输出电压范围是1.2V至37V,负载电流最大为1.5A。它的使用非常简单,仅需两个外接电阻来设置输出电压。此外,它的电压线性度和负载调整率也比标准…

使用Python制作一个批量查询搜索排名的SEO免费工具

搭建背景 最近工作中需要用上 Google SEO(搜索引擎优化),有了解过的朋友们应该都知道SEO必不可少的工作之一就是查询关键词的搜索排名。关键词少的时候可以一个一个去查没什么问题,但是到了后期,一个网站都有几百上千…

浏览器工作原理与实践--渲染流程(上):HTML、CSS和JavaScript,是如何变成页面的

在上一篇文章中我们介绍了导航相关的流程,那导航被提交后又会怎么样呢?就进入了渲染阶段。这个阶段很重要,了解其相关流程能让你“看透”页面是如何工作的,有了这些知识,你可以解决一系列相关的问题,比如能…

获取第三方小程序指定页面的path

获取第三方小程序指定页面的path wx.navigateToMiniProgramappIdpathwx.navigateToMiniProgram 在开发小程序时需要跳转到第三方小程序指定页面时,需通过wx.navigateToMiniProgram方法完成。其中有两个主要参数appId和path,文本以问卷星为例,分享两者获取方法。 appId 在…

使用Python批量实现文件夹下所有Excel文件的第二张表合并

目录 一、前言 二、准备工作 三、实现步骤 遍历文件夹获取所有Excel文件 读取每个Excel文件的第二张表 合并所有表格 主函数 四、案例实践 五、注意事项 六、扩展与改进 七、总结 在数据处理和分析中,经常需要对多个Excel文件进行批量操作,特…

代码随想录阅读笔记-栈与队列【滑动窗口最大值】

题目 给定一个数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回滑动窗口中的最大值。 进阶: 你能在线性时间复杂度内解决此题吗? 提示&am…

如何提升FFmpeg 1‰的转码性能

在8K视频编解码特别是解码部分,我做了一些优化工作,转码速度提升了50%以上。专家们评价曰:“主要围绕算法并行度的优化,属于算法性能优化的常规手段,在创新性和技术难度方面的体现较为一般”。评价过于犀利&#xff0c…

一文道破将bean注入到Spring中的几种方式

前言: 前两天有学妹问我如何将bean注入到Spring中,虽问题较简单,但还是写此文以告之。 在Java的Spring框架中,将bean注入到容器中是核心概念之一,这是实现依赖注入的基础。Spring提供了多种方式来将bean注入到容器中…

MySQL高可用解决方案――从主从复制到InnoDB Cluster架构

2024送书福利正式起航 关注「哪吒编程」,提升Java技能 文末送5本《MySQL高可用解决方案――从主从复制到InnoDB Cluster架构》 大家好,我是哪吒。 爱奇艺每天都为数以亿计的用户提供7x24小时不间断的视频服务。通过爱奇艺的平台,用户可以…

力扣:290. 单词规律

前言:剑指offer刷题系列 问题: 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律…

docker推拉时的数据交换详解

前言 docker用了这么久了, 有没有想过, 在执行docker push 和 docker pull命令的时候, 数据是如何传递的呢? 换句话说, 如果要实现一个镜像仓库, 针对推拉的服务, 如何实现接口呢? 根据OCI 分发规范文档 的描述, 已经对整个推拉过程中要调用的接口有描述了. 但是, 纸上学来…

CNN、Transformer、Uniformer之外,我们终于有了更高效的视频理解技术

ChatGPT狂飙160天,世界已经不是之前的样子。 新建了人工智能中文站https://ai.weoknow.com 每天给大家更新可用的国内可用chatGPT资源 发布在https://it.weoknow.com 更多资源欢迎关注 视频理解因大量时空冗余和复杂时空依赖,同时克服两个问题难度巨大…

力扣每日一题 2024/3/24 零钱兑换

题目描述 用例说明 思路讲解 动态规划五步法 第一步确定dp数组的含义:dp[i]为凑到金额为i所用最少的硬币数量 第二步确定动态规划方程:凑足金额为j-coins[i]所需最少的硬币个数为dp[j-coins[i]],那凑足金额为j所用的最少硬币数为dp[j-coin…

怎么将文件快速生成二维码?文件二维码的在线生成技巧

现在越来越多的人都开始通过二维码的方式来传递文件,将word、pdf、excel、pdf等格式的文件通过扫码的方式展示或者下载文件,这种方式有很多的优势,包括传播速度快成本低,只需要生成一张二维码图片,就可以让其他人能够同…

Prompt-RAG:在特定领域中应用的革新性无需向量嵌入的RAG技术

论文地址:https://arxiv.org/ftp/arxiv/papers/2401/2401.11246.pdf 原文地址:https://cobusgreyling.medium.com/prompt-rag-98288fb38190 2024 年 3 月 21 日 虽然 Prompt-RAG 确实有其局限性,但在特定情况下它可以有效地替代传统向量嵌入 …

QTableWidget删除单元格

如果单元格内有内容&#xff0c;可以使用函数selectedItems() 获取有内容行的一个链表 QList<QTableWidgetItem *> items ui->qtableWidget->selectedItems(); //选中有内容的行可选择有内容的行int count items.count();for(int i 0 ; i < count; i){ …

搭建vite项目

文章目录 Vite 是一个基于 Webpack 的开发服务器&#xff0c;用于开发 Vue 3 和 Vite 应用程序 一、创建一个vite项目二、集成Vue Router1.安装 vue-routernext插件2.在 src 目录下创建一个名为 router 的文件夹&#xff0c;并在其中创建一个名为 index.js 的文件。在这个文件中…