Pytorch中RNN入门思想及实现


RNN循环神经网络


整体思想:

将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步,也就是说后面的结果受前面输入的影响。
在这里插入图片描述

RNN的实现公式:

在这里插入图片描述
在这里插入图片描述

个人思路:

首先设置一个初始权重h(一般情况下设置为全0矩阵),将h(0)与权重W相乘在加上输入X(0)与权重U相乘再加上偏差值b(本文设为0)的和通过tanh这个激活函数后的结果为第二层的h(1),依次类推。各层h(除了h(0)和最后一个输出的h(n)以外)组成的向量既为输出结果,h(n)为预测隐含层结果。

Pytorch模型下RNN代码实现:

import torch
import torch.nn as nn
import numpy as np"""
实现简单的神经网络
使用pytorch实现RNN
不考虑偏差值
"""
class TorchRNN(nn.Module):def __init__(self, input_size, hidden_size):super(TorchRNN, self).__init__()self.layer = nn.RNN(input_size, hidden_size, bias=False, batch_first=True)def forward(self, x):return self.layer(x)x = np.array([[1, 2, 3], [3, 4, 5], [5, 6, 7]])  #网络输入#torch实验
hidden_size = 4     #隐单元是4维 1×4矩阵
torch_model = TorchRNN(3, hidden_size)      #输入是3维,因此输出应为3×4矩阵
#print(len(torch_model.state_dict()))
w_ih = torch_model.state_dict()["layer.weight_ih_l0"]	#获得随机初始化的权重U
w_hh = torch_model.state_dict()["layer.weight_hh_l0"]	#获得随机初始化的权重Wtorch_x = torch.FloatTensor([x])
output, h = torch_model.forward(torch_x)
print(output.detach().numpy(), "torch模型预测结果")
#这里的输出第一行是第一个输入值之后得到的h,以此类推
#print(h.detach().numpy(), "torch模型预测隐含层结果")
#预测隐含层结果就是最后一个输入值得到的h

在自定义模型中的RNN的代码实现:

"""
手动实现简单的神经网络
手动实现RNN
需要将该代码放入pytorch模型中才能运行
对比理解过程
"""
#自定义RNN模型
class DiyModel:def __init__(self, w_ih, w_hh, hidden_size):self.w_ih = w_ihself.w_hh = w_hhself.hidden_size = hidden_sizedef forward(self, x):ht = np.zeros((self.hidden_size))   #初始化h为全0矩阵output = []for xt in x:ux = np.dot(self.w_ih, xt)		#函数实现wh = np.dot(self.w_hh, ht)		#函数实现ht_next = np.tanh(ux + wh)      #函数实现output.append(ht_next)      #放入output列表中ht = ht_next        #h向下更新return np.array(output), htdiy_model = DiyModel(w_ih, w_hh, hidden_size)
output, h = diy_model.forward(x)
print(output, "diy模型预测结果")
# print(h, "diy模型预测隐含层结果")

因为自定义模型中需要得到跟pytorch模型中的输入和随机初始化的W和,如果需要运行需将自定义模型代码放在pytorch模型的下面运行。同样不考虑偏差值。

结果如下:

在这里插入图片描述
可以看到两者运行结果相同,说明运算逻辑没有问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389585.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小扎不哭!FB又陷数据泄露风波,9000万用户受影响

对小扎来说,又是多灾多难的一个月。 继不久前Twitter曝出修补了一个可能造成数以百万计用户私密消息被共享给第三方开发人员的漏洞,连累Facebook股价跟着短线跳水之后,9月28日,Facebook又双叒叕曝出因安全漏洞遭到黑客攻击&#…

在衡量欧洲的政治意识形态时,调查规模的微小变化可能会很重要

(Related post: On a scale from 1 to 10, how much do the numbers used in survey scales really matter?)(相关文章: 从1到10的量表,调查量表中使用的数字到底有多重要? ) At Pew Research Center, survey questions about respondents’…

Pytorch中CNN入门思想及实现

CNN卷积神经网络 基础概念: 以卷积操作为基础的网络结构,每个卷积核可以看成一个特征提取器。 思想: 每次观察数据的一部分,如图,在整个矩阵中只观察黄色部分33的矩阵,将这【33】矩阵(点乘)权重得到特…

事件映射 消息映射_映射幻影收费站

事件映射 消息映射When I was a child, I had a voracious appetite for books. I was constantly visiting the library and picking new volumes to read, but one I always came back to was The Phantom Tollbooth, written by Norton Juster and illustrated by Jules Fei…

前端代码调试常用

转载于:https://www.cnblogs.com/tabCtrlShift/p/9076752.html

Pytorch中BN层入门思想及实现

批归一化层-BN层(Batch Normalization) 作用及影响: 直接作用:对输入BN层的张量进行数值归一化,使其成为均值为零,方差为一的张量。 带来影响: 1.使得网络更加稳定,结果不容易受到…

匿名内部类和匿名类_匿名schanonymous

匿名内部类和匿名类Everybody loves a fad. You can pinpoint someone’s generation better than carbon dating by asking them what their favorite toys and gadgets were as a kid. Tamagotchi and pogs? You were born around 1988, weren’t you? Coleco Electronic Q…

Pytorch框架中SGD&Adam优化器以及BP反向传播入门思想及实现

因为这章内容比较多,分开来叙述,前面先讲理论后面是讲代码。最重要的是代码部分,结合代码去理解思想。 SGD优化器 思想: 根据梯度,控制调整权重的幅度 公式: 权重(新) 权重(旧) - 学习率 梯度 Adam…

朱晔和你聊Spring系列S1E3:Spring咖啡罐里的豆子

标题中的咖啡罐指的是Spring容器,容器里装的当然就是被称作Bean的豆子。本文我们会以一个最基本的例子来熟悉Spring的容器管理和扩展点。阅读PDF版本 为什么要让容器来管理对象? 首先我们来聊聊这个问题,为什么我们要用Spring来管理对象&…

ab实验置信度_为什么您的Ab测试需要置信区间

ab实验置信度by Alos Bissuel, Vincent Grosbois and Benjamin HeymannAlosBissuel,Vincent Grosbois和Benjamin Heymann撰写 The recent media debate on COVID-19 drugs is a unique occasion to discuss why decision making in an uncertain environment is a …

基于Pytorch的NLP入门任务思想及代码实现:判断文本中是否出现指定字

今天学了第一个基于Pytorch框架的NLP任务: 判断文本中是否出现指定字 思路:(注意:这是基于字的算法) 任务:判断文本中是否出现“xyz”,出现其中之一即可 训练部分: 一&#xff…

支撑阻力指标_使用k表示聚类以创建支撑和阻力

支撑阻力指标Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

高版本(3.9版本)python在anaconda安装opencv库及skimage库(scikit_image库)诸多问题解决办法

今天开始CV方向的学习,然而刚拿到基础代码的时候发现 from skimage.color import rgb2gray 和 import cv2标红(这里是因为我已经配置成功了,所以没有红标),我以为是单纯两个库没有下载,去pycharm中下载ski…

单机安装ZooKeeper

2019独角兽企业重金招聘Python工程师标准>>> zookeeper下载、安装以及配置环境变量 本节介绍单机的zookeeper安装,官方下载地址如下: https://archive.apache.org/dist/zookeeper/ 我这里使用的是3.4.11版本,所以找到相应的版本点…

均线交易策略的回测 r_使用r创建交易策略并进行回测

均线交易策略的回测 rR Programming language is an open-source software developed by statisticians and it is widely used among Data Miners for developing Data Analysis. R can be best programmed and developed in RStudio which is an IDE (Integrated Development…

opencv入门课程:彩色图像灰度化和二值化(采用skimage库和opencv库两种方法)

用最简单的办法实现彩色图像灰度化和二值化: 首先采用skimage库(skimage库现在在scikit_image库中)实现: from skimage.color import rgb2gray import numpy as np import matplotlib.pyplot as plt""" skimage库…

instagram分析以预测与安的限量版运动鞋转售价格

Being a sneakerhead is a culture on its own and has its own industry. Every month Biggest brands introduce few select Limited Edition Sneakers which are sold in the markets according to Lottery System called ‘Raffle’. Which have created a new market of i…

opencv:用最邻近插值和双线性插值法实现上采样(放大图像)与下采样(缩小图像)

上采样与下采样 概念: 上采样: 放大图像(或称为上采样(upsampling)或图像插值(interpolating))的主要目的 是放大原图像,从而可以显示在更高分辨率的显示设备上。 下采样&#xff…

CSS魔法堂:那个被我们忽略的outline

前言 在CSS魔法堂:改变单选框颜色就这么吹毛求疵!中我们要模拟原生单选框通过Tab键获得焦点的效果,这里涉及到一个常常被忽略的属性——outline,由于之前对其印象确实有些模糊,于是本文打算对其进行稍微深入的研究^_^ …

初创公司怎么做销售数据分析_初创公司与Faang公司的数据科学

初创公司怎么做销售数据分析介绍 (Introduction) In an increasingly technological world, data scientist and analyst roles have emerged, with responsibilities ranging from optimizing Yelp ratings to filtering Amazon recommendations and designing Facebook featu…