手写一个RNN前向传播以及反向传播

前向传播

根据公式

st = tanh (Uxt + Wst-1 + ba)

ot = softmax(Vst + by )

m = 3 词的个数   n = 5

import numpy as np
import tensorflow as tf
# 单个cell 的前向传播过程
# 两个输入,x_t,s_prev,parameters
def rnn_cell_forward(x_t,s_prev,parameters):"""单个cell 的前向传播过程:param x_t: 当前T时刻的序列输入:param s_prev: 上一个cell的隐藏层状态输入:param parameters: cell中参数,字典:return: 隐层输出 s_next,out_pred,cache"""# 取出参数U = parameters["U"]W = parameters["W"]V = parameters["V"]ba = parameters["ba"]by = parameters["by"]# 根据公式计算# 隐层输出计算s_next = np.tanh(np.dot(U,x_t) + np.dot(W,s_prev) + ba)# 计算cell的输出out_pred = tf.nn.softmax(np.dot(V,s_next) + by)# 记录每层的值,用于反向传播计算使用cache = (s_next,s_prev,x_t,parameters)return s_next,out_pred,cache
if __name__ == '__main__':# forwardnp.random.seed(1)# 定义该cell的输入x_t = np.random.randn(3, 1,)s_prev = np.random.randn(5, 1)# 定义参数W = np.random.randn(5, 5)U = np.random.randn(5, 3)V = np.random.randn(3, 5)ba = np.random.randn(5, 1)by = np.random.randn(3, 1)parameters = {"U": U, "W": W, "V": V, "ba": ba, "by": by}s_next, out_pred, caches = rnn_cell_forward(x_t, s_prev, parameters)print("s_next = ", s_next)print("s_next.shapr = ", s_next.shape)print("out_pred =", out_pred)print("out_pred.shape = ",out_pred.shape)

单个cell反向传播

根据图我们能够知道需要计算的梯度变量有哪些

ds_next:表示当前cell的损失对输出s的导数

dtanh:表示当前cel的损失对激活函数的导数

dx_t:表示当前cell的损失对输入xt的导数。

dU:表示当前cell的损失对U的导数

ds_prev:表示当前cell的损失对上一个cell的输入的导数

dW:表示当前cell的损失对W的导数

dba:表示当前cell的损失对dba的导数

表示公式:

def rnn_cell_forward(x_t,s_prev,parameters):"""单个cell 的前向传播过程:param x_t: 当前T时刻的序列输入:param s_prev: 上一个cell的隐藏层状态输入:param parameters: cell中参数,字典:return: 隐层输出 s_next,out_pred,cache"""# 取出参数U = parameters["U"]W = parameters["W"]V = parameters["V"]ba = parameters["ba"]by = parameters["by"]# 根据公式计算# 隐层输出计算s_next = np.tanh(np.dot(U,x_t) + np.dot(W,s_prev) + ba)# 计算cell的输出out_pred = tf.nn.softmax(np.dot(V,s_next) + by)# 记录每层的值,用于反向传播计算使用cache = (s_next,s_prev,x_t,parameters)return s_next,out_pred,cache
def rnn_cell_backward(ds_next, cache):"""对单个cell进行反向传播:param ds_next: 当前隐层输出结果相对于损失的导数:param cache: 每个cell的缓存:return:gradients"""# 获取缓存值(s_next, s_prev, x_t, parameters) = cacheprint(type(parameters))# 获取参数U = parameters["U"]W = parameters["W"]# V = parameters["V"]# ba = parameters["ba"]# by = parameters["by"]# 计算tanh的梯度通过对s_nextdtanh = (1 - s_next ** 2) * ds_next# 计算U的梯度值dx_t = np.dot(U.T, dtanh)dU = np.dot(dtanh, x_t.T)# 计算W的梯度值ds_prev = np.dot(W.T, dtanh)dW = np.dot(dtanh, s_prev.T)# 计算b的梯度dba = np.sum(dtanh,axis=1,keepdims= 1)# 梯度字典gradients = {"dtanh" : dtanh,"dx_t": dx_t, "ds_prev": ds_prev, "dU": dU, "dW": dW, "dba": dba}return gradients

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/4262.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运算符重载(1)

1.加号运算符重载&#xff0c;这里用编译器统一的名称operator代替函数名 #include<iostream> using namespace std; //1.成员函数的加号重载 //2.全局函数的加号重载 class Person { public:Person() {};//1.成员函数的加号重载//Person operator(Person& p)//{// P…

前端HTML5学习2(新增多媒体标签,H5的兼容性处理)

前端HTML5学习2新增多媒体标签&#xff0c;H5的兼容性处理&#xff09; 分清标签和属性新增多媒体标签新增视频标签新增音频标签新增全局属性 H5的兼容性处理 分清标签和属性 标签&#xff08;HTML元素&#xff09;和属性&#xff0c;标签定义了内容的类型或结构&#xff0c;而…

k8s学习(三十七)centos下离线部署kubernetes1.30(高可用)

文章目录 准备工作1、升级操作系统内核1.1、查看操作系统和内核版本1.2、下载内核离线升级包1.3、升级内核1.4、确认内核版本 2、修改主机名/hosts文件2.1、修改主机名2.2、修改hosts文件 3、关闭防火墙4、关闭SELINUX配置5、时间同步5.1、下载NTP5.2、卸载5.3、安装5.4、配置5…

BPE、Wordpiece、Unigram、SpanBERT等Tokenizer细节总结

BPE(Byte Pair Encoding) GPT-2和Roberta用的是这种&#xff0c;不会产生[UNK]这个unknown字符 这部分部分摘录自https://martinlwx.github.io/zh-cn/the-bpe-tokenizer/ 看以下code例子就足够理解了&#xff0c;核心是维护self.merges&#xff08;维护一个pair->str的字…

[蓝桥杯2024]-Reverse:rc4解析(对称密码rc4)

无壳 查看ida 这里应该运行就可以得flag&#xff0c;但是这个程序不能直接点击运行 按照伪代码写exp 完整exp&#xff1a; keylist(gamelab) content[0xB6,0x42,0xB7,0xFC,0xF0,0xA2,0x5E,0xA9,0x3D,0x29,0x36,0x1F,0x54,0x29,0x72,0xA8, 0x63,0x32,0xF2,0x44,0x8B,0x85,0x…

如何在 Visual Studio 中通过 NuGet 添加包

在安装之前要先确定Nuget的包源是否有问题。 Visual Studio中怎样更改Nuget程序包源-CSDN博客 1.图形界面安装 打开您的项目&#xff0c;并在解决方案资源管理器中选择您的项目。单击“项目”菜单&#xff0c;然后选择“管理 NuGet 程序包”选项。在“NuGet 包管理器”窗口中…

详解如何品味品深茶的精髓

在众多的茶品牌中&#xff0c;品深茶以其独特的韵味和深厚的文化底蕴&#xff0c;赢得了众多茶友的喜爱。今天&#xff0c;让我们一同探寻品深茶的精髓&#xff0c;品味其独特的魅力。 品深茶&#xff0c;源自中国传统茶文化的精髓&#xff0c;承载着世代茶人的智慧与匠心。这…

03-MVC执行流程-参数解析与Model

重要组件 准备Model&#xff0c;Controller Configuration public class WebConfig {ControllerAdvicestatic class MyControllerAdvice {ModelAttribute("b")public String bar() {return "bar";}}Controllerstatic class Controller1 {ResponseStatus(H…

windows环境下安装Apache

首先apache官网下载地址&#xff1a;http://www.apachelounge.com/download/按照自己的电脑操作系统来安装 这里我安装的是win64 主版本是2.4的apache。 然后解压压缩包到一个全英文的路径下&#xff01;&#xff01;&#xff01;一定一定不要有中文 中文符号也不要有&#xff…

ansible-copy用法

目录 概述实践不带目录拷贝带目录拷贝 概述 ansible copy 常用用法举例 不带目录拷贝&#xff0c;拷贝的地址要写全 带目录拷贝&#xff0c;拷贝路径不要写在 dest 路径中 实践 不带目录拷贝 # with_fileglob 是 Ansible 中的一个循环关键字&#xff0c;用于处理文件通配符匹…

【Vue3+Tres 三维开发】02-Debug

预览 介绍 Debug 这里主要是讲在三维中的调试,同以前threejs中使用的lil-gui类似,TRESJS也提供了一套可视化参数调试的插件。使用方式和之前的组件相似。 使用 通过导入useTweakPane 即可 import { useTweakPane, OrbitControls } from "@tresjs/cientos"const {…

数字文旅重塑旅游发展新格局:以数字化转型为突破口,提升旅游服务的智能化水平,为游客带来全新的旅游体验

随着信息技术的迅猛发展&#xff0c;数字化已成为推动各行各业创新发展的重要力量。在旅游业领域&#xff0c;数字文旅的兴起正以其强大的驱动力&#xff0c;重塑旅游发展的新格局。数字文旅以数字化转型为突破口&#xff0c;通过提升旅游服务的智能化水平&#xff0c;为游客带…

HarmonyOS Next从入门到精通实战精品课

第一阶段&#xff1a;HarmonyOS Next星河版从入门到精通该阶段由HarmonyOS Next星河版本出发&#xff0c;介绍HarmonyOS Next版本应用开发基础概念&#xff0c;辅助学员快速上手新版本开发范式&#xff0c;共计42课时 第一天鸿蒙NEXT Mac版、Windows版【编辑器】和【模拟器】&a…

BootStrap详解

Bootstrap简介 什么是BootStrap&#xff1f; BootStrap来自Twitter&#xff0c;是目前最受欢迎的响应式前端框Bootstrap是基于HTML、CSS、JavaScript的&#xff0c;它简洁灵活&#xff0c;使得Web开发更加快捷 为什么使用Bootstrap&#xff1f; 移动设备优先&#xff1a;自…

Kafka 3.x.x 入门到精通(07)——Java应用场景——SpringBoot集成

Kafka 3.x.x 入门到精通&#xff08;07&#xff09;——Java应用场景——SpringBoot集成 4. Java应用场景——SpringBoot集成4.1 创建SpringBoot项目4.1.1 创建SpringBoot项目4.1.2 修改pom.xml文件4.1.3 在resources中增加application.yml文件 4.2 编写功能代码4.2.1 创建配置…

机器人-轨迹规划

旋转矩阵 旋转矩阵--R--一个3*3的矩阵&#xff0c;其每列的值时B坐标系在A坐标系上的投影值。 代表B坐标系相对于A坐标系的姿态。 旋转矩阵的转置矩阵 其实A相对于B的旋转矩阵就相当于把B的列放到行上就行。 视频 &#xff08;将矩阵的行列互换得到的新矩阵称为转置矩阵。&…

SQLite尽如此轻量

众所周知&#xff0c;SQLite是个轻量级数据库&#xff0c;适用于中小型服务应用等&#xff0c;在我真正使用的时候才发现&#xff0c;它虽然轻量&#xff0c;但不知道它却如此轻量。 下载 官网&#xff1a; SQLite Download Page 安装 1、将下载好的两个压缩包同时解压到一个…

【PG-2】PostgreSQL存储管理器

2. PostgreSQL存储管理器 src/backend/storage (base) torrestorresの机革:~/codes/postgresql-16.2/src/backend/storage$ ls Makefile buffer file freespace ipc large_object lmgr meson.build objfiles.txt page smgr sync存储管理器—smgr 通用存储管理器 …

航拍图像拼接 | 使用C++实现的无人机航拍图像拼接

项目应用场景 面向无人机航拍图像拼接场景&#xff0c;项目使用 C 实现&#xff0c;使用 harris 角点查找特征点 非极大值抑制&#xff0c;由于航拍图像没有严重的尺度旋转变化&#xff0c;使用了 berief 描述子&#xff0c;然后使用 RANSAC 求 H&#xff0c;最后进行图像拼接…

linux 中 make 和 gmake的关系

1. 关系 gmake特指GNU make。 make是指系统默认的make实现; 在大多数Linux发行版中&#xff0c;make就是GNU make&#xff0c;但是在其他unix中&#xff0c;gmake可以指代make的某些其他实现&#xff0c;例如BSD make或各种商业unix的make实现。 gmake是GNU Make的缩写。 Linux…