NLP之RNN的原理讲解(python示例)

目录

    • 代码示例
    • 代码解读
    • 知识点介绍

代码示例

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
print(xt)
# https://www.cnblogs.com/Renyi-Fan/p/13722276.htmlcell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
cell.build(input_shape=[None, 1])
print('variables', cell.variables)
print('config:', cell.get_config())print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))# 第t时刻运算
ht_1 = tf.ones([1, 1])
out, ht = cell(xt, ht_1)  # LSTM
print(out, ht[0])
print(id(out), id(ht[0]))# 第t+1时刻运算
cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
out2, ht2 = cell2(xt2, ht)
print(out2, ht2[0])

代码解读

这段代码包含了一些使用 TensorFlow 来创建和操作循环神经网络(RNN)的基础操作。我们将一步步地解释其含义。

  1. 导入所需的库:

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import SimpleRNNCell
    

    代码导入了NumPy库、TensorFlow库以及SimpleRNNCell,这是一个实现了简单的RNN单元操作的类。

  2. 创建训练数据:

    xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
    print(xt)
    

    这里创建了一个1x1的张量,其值是2或3之间的随机整数。这代表了在时间t的输入数据。

  3. 定义RNN单元:

    cell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
    

    使用SimpleRNNCell创建了一个RNN单元。这个单元有以下特性:

    • 只有一个神经元(units=1)。
    • 不使用激活函数(activation=None)。
    • 使用偏置,并初始化为3(bias_initializer=tf.keras.initializers.Constant(value=3))。
    • 输入权重和循环权重都初始化为1。
    • kernel_initializer='ones':
      • 这是一个初始化器,用于初始化RNN单元的权重(也称为内核权重)。
      • 'ones'表示所有的权重都被初始化为1。
      • 换句话说,当输入数据经过RNN单元时,它会与这些权重相乘,而这些权重的初始值都是1。
    • recurrent_initializer='ones':
      • 这是一个初始化器,用于初始化RNN单元的循环权重。
      • 在RNN中,当前时间步的隐藏状态是基于前一个时间步的隐藏状态计算的。这个计算涉及到的权重就是循环权重。
      • 'ones'表示所有的循环权重都被初始化为1。
    • bias_initializer=tf.keras.initializers.Constant(value=3):
      • 这是一个初始化器,用于初始化RNN单元的偏置。
      • tf.keras.initializers.Constant(value=3)表示所有的偏置被初始化为常数3。
    • 简而言之,这些参数(kernel_initializer、recurrent_initializer、bias_initializer)确定了RNN单元在开始训练之前的权重和偏置的初始状态。这些初始值在训练过程中会被更新。选择合适的初始化器对于模型的收敛速度和性能至关重要,尽管在这个特定的例子中,这些权重和偏置被赋予了特定的常数值。

    cell.build(input_shape=[None, 1])这行代码是用来告诉RNN单元输入的形状,这样它就可以创建相应的权重和偏置张量。

    • 在TensorFlow和Keras中,input_shape是用来指定输入数据的维度的参数。具体到这里的input_shape=[None, 1],我们可以解读它为:
    • [None, 1]:这是一个形状列表,其中有两个维度。
    • None
      • 第一个维度通常表示批处理的大小(即在一个批次中的样本数)。在许多情况下,为了使模型更加灵活,我们可能不想在定义模型时硬编码一个固定的批处理大小。
      • 使用None作为批处理的大小意味着模型可以接受任何大小的批次。
      • 例如,你可以选择在训练时使用64的批大小,在评估或推理时使用1的批大小,或者使用其他任何数字。
    • 1
      • 第二个维度是数据的特征维度。
      • 在这里,它指的是输入数据的每个样本有1个特征。
    • 综上所述,input_shape=[None, 1]表示模型可以接受一个二维的输入,其中第一个维度是任意大小的批处理,第二个维度是1个特征。
  4. 显示RNN单元的变量和配置:
    代码打印出RNN单元的所有变量(如权重和偏置)以及配置。

    print('variables', cell.variables)
    print('config:', cell.get_config())
    

    这两行代码是关于打印关于cell(这里的cell是一个SimpleRNNCell的实例)的相关信息。

    • print('variables', cell.variables):

      • cell.variables: 这是一个属性,它返回一个列表,该列表包含cell中的所有可训练变量(权重和偏置)。在RNN cell的上下文中,这通常包括核权重、递归权重以及偏置。
      • print(...): 打印变量列表,以便于你查看和调试。通常这可以帮助你理解RNN cell中的权重如何初始化(例如,这里你已经明确地设置了初始化器)。
    • print('config:', cell.get_config()):

      • cell.get_config(): 这是一个方法,它返回一个字典,该字典包含cell的配置。这通常包括其初始化时使用的参数(例如units的数量、激活函数、是否使用偏置等)。这允许你查看或者后续再次使用这些配置信息,例如,如果你想保存模型的结构并稍后再次创建它。
      • print(...): 打印配置字典,使你能够查看cell的配置。
    • 总之,这两行代码提供了关于SimpleRNNCell实例(cell)的详细信息,包括它的权重(和它们的初始值)以及它的配置。这是非常有用的,特别是当你在调试或了解你的模型结构时。

  5. 计算tanh的值:

    print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))
    

    这行代码计算了tanh函数在-∞、6和三个点的值。tanh是RNN和其他神经网络中常用的激活函数。

  6. 第t时刻的计算:
    这部分代码首先定义了上一个时间步的隐藏状态ht_1,然后使用cell(xt, ht_1)调用RNN单元来获取当前时间步的输出和隐藏状态。

    ht_1 = tf.ones([1, 1])
    out, ht = cell(xt, ht_1)  # LSTM
    print(out, ht[0])
    print(id(out), id(ht[0]))
    
  7. 第t+1时刻的计算:
    同样地,这部分代码定义了一个新的RNN单元cell2,然后用新的输入xt2和上一个时间步的隐藏状态ht来获取下一个时间步的输出和隐藏状态。

    cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
    xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
    out2, ht2 = cell2(xt2, ht)
    
  8. 输出与隐藏状态的关系:

    print(id(out), id(ht[0]))
    

    这部分代码展示了在简单的RNN中,输出状态out和隐藏状态ht是相同的对象。

最后,代码的主要目的是演示如何使用SimpleRNNCell在给定的输入和隐藏状态上进行计算,并展示其结果。

知识点介绍

tf.Variable 是 TensorFlow(TF)中的一个核心概念,它用于表示在 TF 计算过程中可能会发生变化的数据。在 TF 中,计算通常是通过计算图(graph)来定义的,而 tf.Variable 允许我们将可以变化的状态添加到这些计算图中。

以下是 tf.Variable 的一些关键点:

  1. 可变性:与 TensorFlow 的常量(tf.constant)不同,tf.Variable 表示的值是可变的。这意味着在训练过程中,可以更新、修改或赋予其新值。

  2. 用途tf.Variable 通常用于表示模型的参数,例如神经网络中的权重和偏置。

  3. 初始化:当创建一个 tf.Variable 时,你必须为它提供一个初始值。这个初始值可以是一个固定值,也可以是其他任何 TensorFlow 计算的结果。

  4. 赋值:使用 assignassign_add 等方法,你可以修改 tf.Variable 的值。

  5. 存储和恢复tf.Variable 的值可以被存储到磁盘并在之后恢复,这是通过 TensorFlow 的保存和恢复机制实现的,这样可以方便地保存和加载模型。

示例:

import tensorflow as tf# 创建一个初始化为1的变量
v = tf.Variable(1.0)# 使用变量
result = v * 2.0# 修改变量的值
v.assign(2.0)  # 现在 v 的值为 2.0

总之,tf.Variable 是 TensorFlow 中表示可变状态的主要方式,尤其是在模型训练中,它用于存储和更新模型的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/121188.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务技术导学

文章目录 微服务结构认识微服务技术栈 微服务结构 技术: 解决异常定位: 持续集成,解决自动化的部署: 总结如下: 认识微服务 微服务演变: 技术栈 SpringCloud与SpringBoot版本对应关系

VS2022 C# 读取 excel 2023年

今天是2023年6月26日,我有一个excel表要读数据,然后放到winform程序来处理,网上的资料太旧,很多用不起来,试了一个可以使用,记录一下: 一、excel文件后缀需要小写。 二、用VS2022建一个winform…

Java练习题2020 -1

统计1到N的整数中&#xff0c;被A除余A-1的偶数的个数 输入说明&#xff1a;整数 N(N<10000), A, (A 输出说明&#xff1a;符合条件的数的个数 输入样例&#xff1a;10 3 输出样例&#xff1a;2 (说明&#xff1a;样例中符合条件的2个数是 2、8) import java.util.Scanner;p…

Web:探索 SpreadJS强大的在线电子表格库

1、概述 SpreadJS 是葡萄城结合 40 余年专业控件技术和在电子表格应用领域的经验而推出的纯前端表格控件,基于 HTML5,兼容 450 多种 Excel 公式,具备“高性能、跨平台、与 Excel 高度兼容”的产品特性,SpreadJS 在界面和功能上与 Excel 高度类似,但又不局限于 Excel,而是…

简单线性回归模型(复习一下前向传播和反向传播)

案例1 import torch torch.__version__ xtorch.rand(3,4,requires_gradTrue) xtensor([[0.9795, 0.8240, 0.6395, 0.1617],[0.4833, 0.4409, 0.3758, 0.7234],[0.9857, 0.9663, 0.5842, 0.8751]], requires_gradTrue)btorch.rand(3,4,requires_gradTrue) txb yt.sum()y.backwa…

【电路笔记】-电路中的复数与相量(Phasor)

电路中的复数与相量(Phasor) 文章目录 电路中的复数与相量(Phasor)1、概述2、复数定义3、复数计算规则4、电子领域的复数5、总结 复数是一种重要的数学工具&#xff0c;广泛应用于包括电子学在内的许多物理领域。 这个概念可能看起来很奇怪&#xff0c;但它们的操作很简单&…

【Docker从入门到入土 6】Consul详解+Docker https安全认证(附证书申请方式)

Part 6 一、服务注册与发现的概念1.1 cmp问题1.2 服务注册与发现 二、Consul ----- 服务自动发现和注册2.1 简介2.2 为什么要用consul&#xff1f;2.3 consul的架构2.3 Consul-template 三、consul架构部署3.1 Consul服务器Step1 建立 Consul 服务Step2 查看集群信息Step3 通过…

axios封装以及详细用法

文章目录 axios用法(这里没有封装&#xff0c;下面有封装好的get&#xff0c;post方法&#xff0c;在axios封装里面)get &#xff0c;delete方法post&#xff0c;put方法 axios具体封装axios 具体参数配置 axios用法(这里没有封装&#xff0c;下面有封装好的get&#xff0c;pos…

项目中拖拽元素,可以使用html的draggable属性,当然也可以用第三方插件interact

项目中拖拽元素&#xff0c;可以使用html的draggable属性&#xff0c;当然也可以用第三方插件interact 一、安装二、引用三、使用 一、安装 npm install interactjs二、引用 import interact from interactjs三、使用 <div class"drag_box"> &…

基于android的 rk3399 同时支持多个USB摄像头

基于android的 rk3399 同时支持多个USB摄像头 一、前文二、CameraHal_Module.h三、CameraHal_Module.cpp四、编译&烧录Image五、App验证 一、前文 Android系统默认支持2个摄像头&#xff0c;一个前置摄像头&#xff0c;一个后置摄像头 需要支持数量更多的摄像头&#xff0…

selenium工作原理和反爬分析

一、 Selenium Selenium是最广泛使用的开源Web UI(用户界面)自动化测试套件之一&#xff0c;支持并行测试执行。Selenium通过使用特定于每种语言的驱动程序支持各种编程语言。Selenium支持的语言包括C#&#xff0c;Java&#xff0c;Perl&#xff0c;PHP&#xff0c;Python和Ru…

如何查看多开的逍遥模拟器的adb连接端口号

逍遥模拟器默认端口号为&#xff1a;21503。 不过&#xff0c;使用多开器多开的时候&#xff0c;端口就不一定是21503了。 如何查看&#xff1f; 进入G:\xiaoyao\Microvirt\MEmu\MemuHyperv VMs路径中 每多开一个模拟器&#xff0c;就会多出一个文件夹。 进入你要查找端口号…

2023年MathorCup高校数学建模挑战赛大数据挑战赛赛题浅析

比赛时长为期7天的妈杯大数据挑战赛如期开赛&#xff0c;为了帮助大家更好的选题&#xff0c;首先给大家带来赛题浅析&#xff0c;为了方便大家更好的选题。 赛道 A&#xff1a;基于计算机视觉的坑洼道路检测和识别 A题&#xff0c;图像处理类题目。这种题目的难度数模独一档…

SpringAOP源码解析之advice执行顺序(三)

上一章我们分析了Aspect中advice的排序为Around.class, Before.class, After.class, AfterReturning.class, AfterThrowing.class&#xff0c;然后advice真正的执行顺序是什么&#xff1f;多个Aspect之间的执行顺序又是什么&#xff1f;就是我们本章探讨的问题。 准备工作 既…

基于Python Django 的微博舆论、微博情感分析可视化系统(V2.0)

文章目录 1 简介2 意义3 技术栈Django 4 效果图微博首页情感分析关键词分析热门评论舆情预测 5 推荐阅读 1 简介 基于Python的微博舆论分析&#xff0c;微博情感分析可视化系统&#xff0c;项目后端分爬虫模块、数据分析模块、数据存储模块、业务逻辑模块组成。 Python基于微博…

第八节——Vue渲染列表+key作用

一、列表渲染 vue中使用v-for指令进行列表 <template><div><!-- item 代表 当前循环的每一项 --><!-- index 代表 当前循环的下标--><!-- 注意&#xff1a;必须要加key--><div v-for"(item, index) in arr" :key"index"…

UE5 Blueprint发送http请求

一、下载插件HttpBlueprint、Json Blueprint Utilities两个插件是互相依赖的&#xff0c;启用&#xff0c;重启项目 目前两个是Beta的状态&#xff0c;如果你使用的平台支持就可以使用&#xff0c;我们的项目因为需要取Header的值&#xff0c;所有没法使用这两个插件&#xff0…

Java集合面试题知识点总结(上篇)

大家好&#xff0c;我是栗筝i&#xff0c;从 2022 年 10 月份开始&#xff0c;我持续梳理出了全面的 Java 技术栈内容&#xff0c;一方面是对自己学习内容进行整合梳理&#xff0c;另一方面是希望对大家有所帮助&#xff0c;使我们一同进步。得到了很多读者的正面反馈。 而在 2…

DBeaver安装与使用教程(超详细安装与使用教程),好用免费的数据库管理工具

DBeaver安装步骤 资源下载&#xff1a; https://download.csdn.net/download/qq_37181642/88479235 官网地址&#xff1a; https://dbeaver.io/ 安装dbeaver 点击上图.exe安装工具&#xff0c;安装完成后不要打开 。 windows配置hosts 在hosts文件中加入&#xff1a; 127.0.0…

基于SSM民宿预订及个性化服务系统-计算机毕设 附源码 04846

SSM民宿预订及个性化服务系统 摘 要 伴随着国内旅游经济的迅猛发展民宿住宿行在国内也迎来了前所未有的发展机遇。传统的旅游模式已难以满足游客日益多元化的需求&#xff0c;随着人们外出度假的时间越来越长&#xff0c;导致人们在住宿的选择上更加追求舒适、个性化的住宿体验…