【深度学习实验】循环神经网络(一):循环神经网络(RNN)模型的实现与梯度裁剪

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 数据处理

2. rnn

测试

3. grad_clipping

4. 代码整合


        经验是智慧之父,记忆是智慧之母。

——谚语

一、实验介绍

        本实验介绍了一个简单的循环神经网络(RNN)模型,并探讨了梯度裁剪在模型训练中的应用。

        在前馈神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力.在生物神经网络中,神经元之间的连接关系要复杂得多.前馈神经网络可以看作一个复杂的函数,每次输入都是独立的,即网络的输出只依赖于当前的输入.但是在很多现实任务中, 网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关.比如一个有限状态自动机,其下一个时刻的状态(输出)不仅仅和当前输入相关,也和当前状态(上一个时刻的输出)相关.此外,前馈网络难以处理时序数据,比如视频、语音、文本等.时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变.因此,当处理这一类和时序数据相关 的问题时,就需要一种能力更强的模型. 循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络

        在循环神经网络中,神经元不但可以接受其他神经元的信息,也可以接受自身的信息,形成具有环路的网络结构.和前馈神经网络相比,循环神经网络更加符合生物神经网络的结构.循环神经网络已经被广泛应用在语音识别、语言模型以及自然语言生成等任务上.循环神经网络的参数学习可以通过随时间反向传播算法[Werbos, 1990]来学习.随时间反向传播算法即按照时间的逆序将错误信息一步步地往前传递.

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

import torch

1. 数据处理

        与之前的模型有所不同,循环神经网络引入了隐藏状态时间步两个新概念。当前时间步的隐藏状态由当前时间的输入与上一个时间步的隐藏状态一起计算出。

         根据隐藏状态的计算公式,需要计算两次矩阵乘法和三次加法才能得到当前时刻的隐藏状态。这里通过代码说明: 该计算公式等价于将当前时刻的输入与上一个时间步的隐藏状态做拼接,将两个权重矩阵做拼接,然后对两个拼接后的结果做矩阵乘法。此处展示省略了偏置项。

# X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

上面是按照公式计算得到的结果,下面是拼接后计算得到的结果,两个结果完全相同

torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

  • X是一个形状为(3, 1)的张量,表示输入。
  • W_xh是一个形状为(1, 4)的张量,表示输入到隐藏状态的权重。
  • H是一个形状为(3, 4)的张量,表示隐藏状态。
  • W_hh是一个形状为(4, 4)的张量,表示隐藏状态到隐藏状态的权重。

2. rnn

        定义了一个名为rnn的函数,用于执行循环神经网络的前向传播,在函数内部,通过遍历输入序列的每个时间步,逐步计算隐藏状态和输出。

def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
  • 参数:
    • inputs是一个形状为(时间步数量,批量大小,词表大小)的张量,表示输入序列。
    • state是一个形状为(批量大小,隐藏状态大小)的张量,表示初始隐藏状态。
    • params是一个包含了模型的参数的列表,包括W_xhW_hhb_hW_hqb_q
  • 对于每个时间步,
    • 使用tanh激活函数来更新隐藏状态
    • 根据更新后的隐藏状态,计算输出Y
    • 将输出添加到outputs列表中
  • 使用torch.cat函数将输出列表合并成一个张量,返回合并后的张量和最后一个隐藏状态 (H,)

测试

    inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)
  • inputs是一个形状为(10, 3, 50)的随机张量,表示模拟的输入序列
  • params是一个包含了随机参数的列表,与rnn函数中的参数对应
  • state是一个形状为(3, 50)的随机张量,表示初始隐藏状态
  • 调用rnn函数
  • 打印输出结果output

3. grad_clipping

        在循环神经网络的训练中,当时间步较大时,可能导致数值不稳定, 例如梯度爆炸或梯度消失,所以一个很重要的步骤是梯度裁剪。通过下面的函数,梯度范数永远不会超过给定的阈值, 并且更新后的梯度完全与的原始方向对齐。

def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm

        函数接受两个参数:net和theta。该函数首先根据net的类型获取需要梯度更新的参数,然后计算所有参数梯度的平方和的平方根,并将其与阈值theta进行比较。如果超过阈值,则对参数梯度进行裁剪,使其不超过阈值。

4. 代码整合

# 导入必要的工具包
import torch# # X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
# X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
# H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
# # torch.matmul(X, W_xh) + torch.matmul(H, W_hh)
# #
# # torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / normif __name__ == '__main__':inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)

        使用随机生成的输入数据和参数进行模型的测试。测试结果显示,RNN模型能够正确计算隐藏状态和输出结果,并且通过梯度裁剪可以有效控制梯度的大小,提高模型的稳定性和训练效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101867.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何优化前端图像和多媒体资源?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

电子书制作软件Vellum mac中文版特点

Vellum mac是一款专业的电子书制作软件,它可以帮助用户将文本文件转换为高质量的电子书,支持多种格式,包括EPUB、MOBI、PDF等。Vellum具有直观的用户界面和易于使用的工具,可以让用户快速地创建和发布电子书。 Vellum mac软件特点…

与艺术同频!卡萨帝在海外崭露头角

在品牌全球化步伐日益加快的当下,高端品牌如何真正实现业务全球化、品牌全球化乃至用户圈层全球化? 作为国际高端家电引领者,卡萨帝今年以来在全球范围内展开了一系列的品牌布局活动。1月,卡萨帝于巴基斯坦召开品牌发布会&#x…

分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测

分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测 目录 分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输…

手写Spring系列【一】IOC的简单实现笔记

前言: 👏作者简介:我是笑霸final,一名热爱技术的在校学生。 📝个人主页:个人主页1 || 笑霸final的主页2 📕系列专栏:项目专栏 📧如果文章知识点有错误的地方,…

大日志(大文件)查看工具

一款很不错的日志查看工具, 优势是能查看很大的日志文档。 无需安装,解压后运行即可; 有注册版,不注册也可以使用。 官方地址: LogViewer - Home page 一个下载地址: 日志查看工具UVviewsoft LogViewer(超大…

电脑如何查看是否支持虚拟化及如何开启虚拟化

什么是虚拟化? Intel Virtualization Technology就是以前众所周知的“Vanderpool”技术(简称VT,中文译为虚拟化技术),这种技术可以让一个CPU工作起来就像多个CPU并行运行,从而使得在一部电脑内同时运行多个操作系统成…

开山之作 | YOLOv1算法超详细解析(包括诞生背景+论文解析+技术原理等)

前言:Hello大家好,我是小哥谈。目标检测是计算机视觉领域的一项重要研究方向,它在许多应用领域中都得到了广泛应用,如人脸识别、物体识别、自动驾驶、视频监控等。在过去,目标检测方法主要采用基于RCNN、Fast R-CNN等深…

Python+Tkinter 图形化界面基础篇:集成数据库

PythonTkinter 图形化界面基础篇:集成数据库 引言为什么选择 SQLite 数据库?集成 SQLite 数据库的步骤示例:创建一个任务管理应用程序步骤1:导入必要的模块步骤2:创建主窗口和数据库连接步骤3:创建数据库表…

高级深入--day30

Scrapy Shell Scrapy终端是一个交互终端,我们可以在未启动spider的情况下尝试及调试代码,也可以用来测试XPath或CSS表达式,查看他们的工作方式,方便我们爬取的网页中提取的数据。 如果安装了 IPython ,Scrapy终端将使用 IPython (替代标准Python终端)。 IPython 终端与其…

从零开始:深入理解Kubernetes架构及安装过程

K8s环境搭建 文章目录 K8s环境搭建集群类型安装方式环境规划克隆三台虚拟机系统环境配置集群搭建初始化集群(仅在master节点)配置环境变量(仅在master节点)工作节点加入集群(knode1节点及knode2节点)安装ca…

1806_emacs_org-mode归档的时候修改归档文件名称

全部学习汇总:GreyZhang/g_org: my learning trip for org-mode (github.com) 前面已经基本了解了org-mode的归档的规则或者方法,但是还有一点跟我现在的工作流有点不相符。我自己的工作流中会每月做一次工作的整理总结,因此归档的文件是按照…

C++ PCL点云局部颜色变换

程序示例精选 C PCL点云局部颜色变换 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《C PCL点云局部颜色变换》编写代码,代码整洁,规则,易读。 学习与应用…

基于SpringBoot的大学城水电管理系统

目录 前言 一、技术栈 二、系统功能介绍 管理员模块的实现 领用设备管理 消耗设备管理 设备申请管理 状态汇报管理 用户模块的实现 设备申请 状态汇报 用户反馈 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着信息技术在管理上越来越深入而广泛…

深度学习简述

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据…

点击劫持:X-Frame-Options 未配置

前言 X-Frame-Options作为HTTP头的一部分,是一种用于保护网站免受点击劫持攻击的安全措施。网站可以通过设置X-Frame-Options或csp报头来控制网站本身是否可以被嵌套到iframe中。 漏洞描述 Clickjacking(点击劫持)是一种安全漏洞&#xff…

Android 项目增加 res配置

main.res.srcDirs "src/main/res_test" build->android->sourceSets

简要归纳UE5 Lumen全局光照原理

一、Jim kajiya老爷子的渲染方程: 求全局光照就是求解渲染方程,我们将两边都有未知数的渲染方程变换成离散形式: 更形象的描述这个离散的渲染方程: 要给每个三角形着色就得先判断光线有没有和它相交,以下是求光线和三…

hive数据表创建

目录 分隔符 分区表 二级分区 分桶表 外部表 分隔符 CREATE TABLE emp( userid bigint, emp_name array<string>, emp_date map<string,date>, other_info struct<deptname:string, gender:string>) ROW FORMAT DELIMITED FIELDS TERMINATED BY \t COL…

【NUMA平衡】浅入介绍NUMA平衡技术及调度方式

在云计算方案设计或项目问题处理的时候&#xff0c;经常会遇到NUMA平衡的问题&#xff0c;进行让人不清楚NUMA到底有何用&#xff0c;如何发挥作用&#xff0c;本文就NUMA技术原理和调度进行简要整理&#xff0c;方便后续需要时候查阅学习。 一.背景 一般的对称多处理器中&am…