【深度学习实验】循环神经网络(一):循环神经网络(RNN)模型的实现与梯度裁剪

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 数据处理

2. rnn

测试

3. grad_clipping

4. 代码整合


        经验是智慧之父,记忆是智慧之母。

——谚语

一、实验介绍

        本实验介绍了一个简单的循环神经网络(RNN)模型,并探讨了梯度裁剪在模型训练中的应用。

        在前馈神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力.在生物神经网络中,神经元之间的连接关系要复杂得多.前馈神经网络可以看作一个复杂的函数,每次输入都是独立的,即网络的输出只依赖于当前的输入.但是在很多现实任务中, 网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关.比如一个有限状态自动机,其下一个时刻的状态(输出)不仅仅和当前输入相关,也和当前状态(上一个时刻的输出)相关.此外,前馈网络难以处理时序数据,比如视频、语音、文本等.时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变.因此,当处理这一类和时序数据相关 的问题时,就需要一种能力更强的模型. 循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络

        在循环神经网络中,神经元不但可以接受其他神经元的信息,也可以接受自身的信息,形成具有环路的网络结构.和前馈神经网络相比,循环神经网络更加符合生物神经网络的结构.循环神经网络已经被广泛应用在语音识别、语言模型以及自然语言生成等任务上.循环神经网络的参数学习可以通过随时间反向传播算法[Werbos, 1990]来学习.随时间反向传播算法即按照时间的逆序将错误信息一步步地往前传递.

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

import torch

1. 数据处理

        与之前的模型有所不同,循环神经网络引入了隐藏状态时间步两个新概念。当前时间步的隐藏状态由当前时间的输入与上一个时间步的隐藏状态一起计算出。

         根据隐藏状态的计算公式,需要计算两次矩阵乘法和三次加法才能得到当前时刻的隐藏状态。这里通过代码说明: 该计算公式等价于将当前时刻的输入与上一个时间步的隐藏状态做拼接,将两个权重矩阵做拼接,然后对两个拼接后的结果做矩阵乘法。此处展示省略了偏置项。

# X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

上面是按照公式计算得到的结果,下面是拼接后计算得到的结果,两个结果完全相同

torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

  • X是一个形状为(3, 1)的张量,表示输入。
  • W_xh是一个形状为(1, 4)的张量,表示输入到隐藏状态的权重。
  • H是一个形状为(3, 4)的张量,表示隐藏状态。
  • W_hh是一个形状为(4, 4)的张量,表示隐藏状态到隐藏状态的权重。

2. rnn

        定义了一个名为rnn的函数,用于执行循环神经网络的前向传播,在函数内部,通过遍历输入序列的每个时间步,逐步计算隐藏状态和输出。

def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
  • 参数:
    • inputs是一个形状为(时间步数量,批量大小,词表大小)的张量,表示输入序列。
    • state是一个形状为(批量大小,隐藏状态大小)的张量,表示初始隐藏状态。
    • params是一个包含了模型的参数的列表,包括W_xhW_hhb_hW_hqb_q
  • 对于每个时间步,
    • 使用tanh激活函数来更新隐藏状态
    • 根据更新后的隐藏状态,计算输出Y
    • 将输出添加到outputs列表中
  • 使用torch.cat函数将输出列表合并成一个张量,返回合并后的张量和最后一个隐藏状态 (H,)

测试

    inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)
  • inputs是一个形状为(10, 3, 50)的随机张量,表示模拟的输入序列
  • params是一个包含了随机参数的列表,与rnn函数中的参数对应
  • state是一个形状为(3, 50)的随机张量,表示初始隐藏状态
  • 调用rnn函数
  • 打印输出结果output

3. grad_clipping

        在循环神经网络的训练中,当时间步较大时,可能导致数值不稳定, 例如梯度爆炸或梯度消失,所以一个很重要的步骤是梯度裁剪。通过下面的函数,梯度范数永远不会超过给定的阈值, 并且更新后的梯度完全与的原始方向对齐。

def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm

        函数接受两个参数:net和theta。该函数首先根据net的类型获取需要梯度更新的参数,然后计算所有参数梯度的平方和的平方根,并将其与阈值theta进行比较。如果超过阈值,则对参数梯度进行裁剪,使其不超过阈值。

4. 代码整合

# 导入必要的工具包
import torch# # X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
# X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
# H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
# # torch.matmul(X, W_xh) + torch.matmul(H, W_hh)
# #
# # torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / normif __name__ == '__main__':inputs=torch.rand(10,3,50)params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]state=torch.rand((3,50))output=rnn(inputs,state,params)print(output)

        使用随机生成的输入数据和参数进行模型的测试。测试结果显示,RNN模型能够正确计算隐藏状态和输出结果,并且通过梯度裁剪可以有效控制梯度的大小,提高模型的稳定性和训练效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101867.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何优化前端图像和多媒体资源?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

电子书制作软件Vellum mac中文版特点

Vellum mac是一款专业的电子书制作软件,它可以帮助用户将文本文件转换为高质量的电子书,支持多种格式,包括EPUB、MOBI、PDF等。Vellum具有直观的用户界面和易于使用的工具,可以让用户快速地创建和发布电子书。 Vellum mac软件特点…

追求极致性能!Qwik 1.0版本发布

前言 偶然发现 Qwik 这个 Node.js 后端框架,感觉非常新奇,它构建的网站,能够在谷歌的网站评测工具中跑出100分满分的成绩,而且还是移动端(一般情况下,移动端分值要低于PC端)!不得不…

IDEA XML文件里写SQL比较大小条件

背景 最近开发的时候&#xff0c;有一个需求的查询需要支持范围查询[a,b)&#xff0c;并且查询的结果要求查询的范围含头端点不含尾端点。因为between…and…查询的范围是含头含尾的&#xff0c;因而不能使用。 因此打算直接使用>和<来比较实现&#xff0c;使用>的时…

【Redis】Set集合内部编码方式

内部编码 集合类型的内部编码有两种&#xff1a; intset&#xff08;整数集合&#xff09;&#xff1a;当集合中的元素都是整数并且元素的个数⼩于set-max-intset-entries配置&#xff08;默认512个&#xff09;时&#xff0c;Redis会选⽤intset来作为集合的内部实现&#xf…

与艺术同频!卡萨帝在海外崭露头角

在品牌全球化步伐日益加快的当下&#xff0c;高端品牌如何真正实现业务全球化、品牌全球化乃至用户圈层全球化&#xff1f; 作为国际高端家电引领者&#xff0c;卡萨帝今年以来在全球范围内展开了一系列的品牌布局活动。1月&#xff0c;卡萨帝于巴基斯坦召开品牌发布会&#x…

hyperf框架WebSocket 服务

1&#xff1a;安装 composer require hyperf/websocket-server2&#xff1a;配置 Server 修改 config/autoload/server.php&#xff0c;增加以下配置。 return [servers > [[name > ws,type > Server::SERVER_WEBSOCKET,host > 0.0.0.0,port > 9502,sock_typ…

分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测

分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测 目录 分类预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入分类预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输…

[JAVA版本] Websocket获取B站直播弹幕——基于直播开放平台

教程 B站直播间弹幕Websocket获取 — 哔哩哔哩直播开放平台 基于B站直播开放平台开放且未上架时&#xff0c;只能个人使用。 代码实现 1、相关依赖 fastjson2用于解析JSON字符串&#xff0c;可自行替换成别的框架。 hutool-core用于解压zip数据&#xff0c;可自行替换成别的…

手写Spring系列【一】IOC的简单实现笔记

前言&#xff1a; &#x1f44f;作者简介&#xff1a;我是笑霸final&#xff0c;一名热爱技术的在校学生。 &#x1f4dd;个人主页&#xff1a;个人主页1 || 笑霸final的主页2 &#x1f4d5;系列专栏&#xff1a;项目专栏 &#x1f4e7;如果文章知识点有错误的地方&#xff0c;…

大日志(大文件)查看工具

一款很不错的日志查看工具&#xff0c; 优势是能查看很大的日志文档。 无需安装&#xff0c;解压后运行即可&#xff1b; 有注册版&#xff0c;不注册也可以使用。 官方地址&#xff1a; LogViewer - Home page 一个下载地址&#xff1a; 日志查看工具UVviewsoft LogViewer(超大…

makefile编译举例

makefile编译举例 # 定义编译器和编译选项 CC gcc CFLAGS -Wall -Werror # 定义目标文件名 TARGET myprogram # 定义需要编译的源文件目录和文件名 SRC_DIR1 src1 SRC_DIR2 src2 OBJ_DIR1 obj1 OBJ_DIR2 obj2 SRC_FILES1 file1.c file2.c SRC_FILES2…

电脑如何查看是否支持虚拟化及如何开启虚拟化

什么是虚拟化? Intel Virtualization Technology就是以前众所周知的“Vanderpool”技术&#xff08;简称VT&#xff0c;中文译为虚拟化技术&#xff09;&#xff0c;这种技术可以让一个CPU工作起来就像多个CPU并行运行&#xff0c;从而使得在一部电脑内同时运行多个操作系统成…

MyBatis的xml里#{}的参数为null报错、将null作为参数传递报错问题

今天在调试的过程中发现一个bug&#xff0c;把传入的参数写到查询分析器中执行没有问题&#xff0c;但是在程序中执行就报错&#xff1a;org.springframework.jdbc.UncategorizedSQLException : Error setting null parameter. Most JDBC drivers require that the JdbcType m…

开山之作 | YOLOv1算法超详细解析(包括诞生背景+论文解析+技术原理等)

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。目标检测是计算机视觉领域的一项重要研究方向&#xff0c;它在许多应用领域中都得到了广泛应用&#xff0c;如人脸识别、物体识别、自动驾驶、视频监控等。在过去&#xff0c;目标检测方法主要采用基于RCNN、Fast R-CNN等深…

Python+Tkinter 图形化界面基础篇:集成数据库

PythonTkinter 图形化界面基础篇&#xff1a;集成数据库 引言为什么选择 SQLite 数据库&#xff1f;集成 SQLite 数据库的步骤示例&#xff1a;创建一个任务管理应用程序步骤1&#xff1a;导入必要的模块步骤2&#xff1a;创建主窗口和数据库连接步骤3&#xff1a;创建数据库表…

高级深入--day30

Scrapy Shell Scrapy终端是一个交互终端,我们可以在未启动spider的情况下尝试及调试代码,也可以用来测试XPath或CSS表达式,查看他们的工作方式,方便我们爬取的网页中提取的数据。 如果安装了 IPython ,Scrapy终端将使用 IPython (替代标准Python终端)。 IPython 终端与其…

零基础学python之数据类型

文章目录 1、数据类型1.1 编程规范注释标识符命名规则命名规则python命名规则关于代码规范编程习惯的重要性 输入输出与变量输出输入变量 1.2 数值类型int(整型)浮点型&#xff08;float&#xff09;类型转化 1.3 字符串字符串创建字符串格式化**format**%s**f** 案例&#xff…

从零开始:深入理解Kubernetes架构及安装过程

K8s环境搭建 文章目录 K8s环境搭建集群类型安装方式环境规划克隆三台虚拟机系统环境配置集群搭建初始化集群&#xff08;仅在master节点&#xff09;配置环境变量&#xff08;仅在master节点&#xff09;工作节点加入集群&#xff08;knode1节点及knode2节点&#xff09;安装ca…

1806_emacs_org-mode归档的时候修改归档文件名称

全部学习汇总&#xff1a;GreyZhang/g_org: my learning trip for org-mode (github.com) 前面已经基本了解了org-mode的归档的规则或者方法&#xff0c;但是还有一点跟我现在的工作流有点不相符。我自己的工作流中会每月做一次工作的整理总结&#xff0c;因此归档的文件是按照…