DL Homework 10

习题6-1P 推导RNN反向传播算法BPTT.
习题6-2 推导公式(6.40)和公式(6.41)中的梯度

习题6-3 当使用公式(6.50)作为循环神经网络的状态更新公式时, 分析其可能存在梯度爆炸的原因并给出解决方法.

         当然,因为我数学比较菜,我看了好半天还是没看懂怎么库库推过来的,等我慢慢研究,这个博客会定时更新的,哭死

        但是可能存在梯度爆炸的原因还是比较明确的:RNN发生梯度消失和梯度爆炸的原因如图所示,将公式改为上式后当γ<1时,t-k趋近于无穷时,γ不会趋近于零,解决了梯度消失问题,但是梯度爆炸仍然存在。当γ>1时,随着传播路径的增加,γ趋近于无穷,产生梯度爆炸。

        如果时刻t的输出yt依赖于时刻k的输入xk,当间隔t-k比较大时,简单神经网络很难建模这种长距离的依赖关系, 称为长程依赖问题(Long-Term ependencies Problem)

        由于梯度爆炸或消失问题,实际上只能学习到短周期的依赖关系。

        至于改进方案老师给出了两种,一种是较为直接的修改选取合适参数,同时使用非饱和激活函数,尽量使得diag({f}'(z_{\tau }))U^T\approx 1 需要足够的人工调参经验,限制了模型的广泛应用.

        另一种则是比较有效的改进模型

  • 权重衰减 通过给参数增加L1或L2范数的正则化项来限制参数的取值范围,从而使得 γ ≤ 1.
  • 梯度截断 当梯度的模大于一定阈值时,就将它截断成为一个较小的数
习题6-2P 设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.

1.反向求导的函数

import numpy as np
import torch.nn# GRADED FUNCTION: rnn_cell_forward
def softmax(a):exp_a = np.exp(a)sum_exp_a = np.sum(exp_a)y = exp_a / sum_exp_areturn ydef rnn_cell_forward(xt, a_prev, parameters):Wax = parameters["Wax"]Waa = parameters["Waa"]Wya = parameters["Wya"]ba = parameters["ba"]by = parameters["by"]a_next = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)yt_pred = softmax(np.dot(Wya, a_next) + by)cache = (a_next, a_prev, xt, parameters)return a_next, yt_pred, cachedef rnn_cell_backward(da_next, cache):(a_next, a_prev, xt, parameters) = cacheWax = parameters["Wax"]Waa = parameters["Waa"]Wya = parameters["Wya"]ba = parameters["ba"]by = parameters["by"]dtanh = (1 - a_next * a_next) * da_next dxt = np.dot(Wax.T, dtanh)dWax = np.dot(dtanh, xt.T)da_prev = np.dot(Waa.T, dtanh)dWaa = np.dot(dtanh, a_prev.T)dba = np.sum(dtanh, keepdims=True, axis=-1)  gradients = {"dxt": dxt, "da_prev": da_prev, "dWax": dWax, "dWaa": dWaa, "dba": dba}return gradients# GRADED FUNCTION: rnn_forward
np.random.seed(1)
xt = np.random.randn(3, 10)
a_prev = np.random.randn(5, 10)
Wax = np.random.randn(5, 3)
Waa = np.random.randn(5, 5)
Wya = np.random.randn(2, 5)
ba = np.random.randn(5, 1)
by = np.random.randn(2, 1)
parameters = {"Wax": Wax, "Waa": Waa, "Wya": Wya, "ba": ba, "by": by}a_next, yt, cache = rnn_cell_forward(xt, a_prev, parameters)da_next = np.random.randn(5, 10)
gradients = rnn_cell_backward(da_next, cache)
print("gradients[\"dxt\"][1][2] =", gradients["dxt"][1][2])
print("gradients[\"dxt\"].shape =", gradients["dxt"].shape)
print("gradients[\"da_prev\"][2][3] =", gradients["da_prev"][2][3])
print("gradients[\"da_prev\"].shape =", gradients["da_prev"].shape)
print("gradients[\"dWax\"][3][1] =", gradients["dWax"][3][1])
print("gradients[\"dWax\"].shape =", gradients["dWax"].shape)
print("gradients[\"dWaa\"][1][2] =", gradients["dWaa"][1][2])
print("gradients[\"dWaa\"].shape =", gradients["dWaa"].shape)
print("gradients[\"dba\"][4] =", gradients["dba"][4])
print("gradients[\"dba\"].shape =", gradients["dba"].shape)
gradients["dxt"][1][2] = -0.4605641030588796
gradients["dxt"].shape = (3, 10)
gradients["da_prev"][2][3] = 0.08429686538067724
gradients["da_prev"].shape = (5, 10)
gradients["dWax"][3][1] = 0.39308187392193034
gradients["dWax"].shape = (5, 3)
gradients["dWaa"][1][2] = -0.28483955786960663
gradients["dWaa"].shape = (5, 5)
gradients["dba"][4] = [0.80517166]
gradients["dba"].shape = (5, 1)

# GRADED FUNCTION: rnn_forward
def rnn_forward(x, a0, parameters):caches = []n_x, m, T_x = x.shapen_y, n_a = parameters["Wya"].shapea = np.zeros((n_a, m, T_x))y_pred = np.zeros((n_y, m, T_x))a_next = a0for t in range(T_x):a_next, yt_pred, cache = rnn_cell_forward(x[:, :, t], a_next, parameters)a[:, :, t] = a_nexty_pred[:, :, t] = yt_predcaches.append(cache)caches = (caches, x)return a, y_pred, cachesnp.random.seed(1)
x = np.random.randn(3, 10, 4)
a0 = np.random.randn(5, 10)
Waa = np.random.randn(5, 5)
Wax = np.random.randn(5, 3)
Wya = np.random.randn(2, 5)
ba = np.random.randn(5, 1)
by = np.random.randn(2, 1)
parameters = {"Waa": Waa, "Wax": Wax, "Wya": Wya, "ba": ba, "by": by}a, y_pred, caches = rnn_forward(x, a0, parameters)
print("a[4][1] = ", a[4][1])
print("a.shape = ", a.shape)
print("y_pred[1][3] =", y_pred[1][3])
print("y_pred.shape = ", y_pred.shape)
print("caches[1][1][3] =", caches[1][1][3])
print("len(caches) = ", len(caches))

 用numpy和pytorh去实现反向传播算子,并且二者对比

class RNNCell:def __init__(self, weight_ih, weight_hh,bias_ih, bias_hh):self.weight_ih = weight_ihself.weight_hh = weight_hhself.bias_ih = bias_ihself.bias_hh = bias_hhself.x_stack = []self.dx_list = []self.dw_ih_stack = []self.dw_hh_stack = []self.db_ih_stack = []self.db_hh_stack = []self.prev_hidden_stack = []self.next_hidden_stack = []# temporary cacheself.prev_dh = Nonedef __call__(self, x, prev_hidden):self.x_stack.append(x)next_h = np.tanh(np.dot(x, self.weight_ih.T)+ np.dot(prev_hidden, self.weight_hh.T)+ self.bias_ih + self.bias_hh)self.prev_hidden_stack.append(prev_hidden)self.next_hidden_stack.append(next_h)# clean cacheself.prev_dh = np.zeros(next_h.shape)return next_hdef backward(self, dh):x = self.x_stack.pop()prev_hidden = self.prev_hidden_stack.pop()next_hidden = self.next_hidden_stack.pop()d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)self.prev_dh = np.dot(d_tanh, self.weight_hh)dx = np.dot(d_tanh, self.weight_ih)self.dx_list.insert(0, dx)dw_ih = np.dot(d_tanh.T, x)self.dw_ih_stack.append(dw_ih)dw_hh = np.dot(d_tanh.T, prev_hidden)self.dw_hh_stack.append(dw_hh)self.db_ih_stack.append(d_tanh)self.db_hh_stack.append(d_tanh)return self.dx_listif __name__ == '__main__':np.random.seed(123)torch.random.manual_seed(123)np.set_printoptions(precision=6, suppress=True)rnn_PyTorch = torch.nn.RNN(4, 5).double()rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),rnn_PyTorch.all_weights[0][1].data.numpy(),rnn_PyTorch.all_weights[0][2].data.numpy(),rnn_PyTorch.all_weights[0][3].data.numpy())nums = 3x3_numpy = np.random.random((nums, 3, 4))x3_tensor = torch.tensor(x3_numpy, requires_grad=True)h3_numpy = np.random.random((1, 3, 5))h3_tensor = torch.tensor(h3_numpy, requires_grad=True)dh_numpy = np.random.random((nums, 3, 5))dh_tensor = torch.tensor(dh_numpy, requires_grad=True)h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)h_numpy_list = []h_numpy = h3_numpy[0]for i in range(nums):h_numpy = rnn_numpy(x3_numpy[i], h_numpy)h_numpy_list.append(h_numpy)h3_tensor[0].backward(dh_tensor)for i in reversed(range(nums)):rnn_numpy.backward(dh_numpy[i])print("numpy_hidden :\n", np.array(h_numpy_list))print("tensor_hidden :\n", h3_tensor[0].data.numpy())print("------")print("dx_numpy :\n", np.array(rnn_numpy.dx_list))print("dx_tensor :\n", x3_tensor.grad.data.numpy())print("------")print("dw_ih_numpy :\n",np.sum(rnn_numpy.dw_ih_stack, axis=0))print("dw_ih_tensor :\n",rnn_PyTorch.all_weights[0][0].grad.data.numpy())print("------")print("dw_hh_numpy :\n",np.sum(rnn_numpy.dw_hh_stack, axis=0))print("dw_hh_tensor :\n",rnn_PyTorch.all_weights[0][1].grad.data.numpy())print("------")print("db_ih_numpy :\n",np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))print("db_ih_tensor :\n",rnn_PyTorch.all_weights[0][2].grad.data.numpy())print("------")print("db_hh_numpy :\n",np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))print("db_hh_tensor :\n",rnn_PyTorch.all_weights[0][3].grad.data.numpy())

实验结果:numpy实现和torch实现结果基本一样 

总结

本次实验主要是围绕BPTT的手推和代码(举例子推我推的很明白,但是理论硬推的时候,数学的基础是真跟不上阿,有心无力害,但是课下多努力吧,这篇博客本人写的感觉不是很好,因为数学知识不太跟得上感觉很多东西力不从心,也不算真正写完了吧,博客之后会持续更新de)

首先对于RTRL和BPTT,对于两种的学习算法要明确推导的过程(虽然我还没特别明确,半知半解)

关于梯度爆炸,梯度消失,对我们来说不陌生了,怎么能尽可能减少他们两者对我们的危害,比如梯度爆炸可以采取权重衰减和梯度截断等等,要明确梯度消失可以增加非线性等等,对于增加非线性后的容量问题,引入门控机制,LSTM等等,都应该对这块的知识有一个完整的体系。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/213183.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue之数据绑定

在我们Vue当中有两种数据绑定的方法 1.单向绑定 2.双向绑定 让我为大家介绍一下吧&#xff01; 1、单向绑定(v-bind) 数据只能从data流向页面 举个例子&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"…

BASH中export使用:命令行中传入变量

可以看到通过export address/project这句话 定义了一个变量address,数值为/project。

运维06:监控

监控生命周期 1.服务器上架到机柜2.基础设施监控 服务器温度&#xff0c;风扇转速 ipmitool命令&#xff0c;只能用在物理机上 存储的监控&#xff08;df, fdisk, iotop&#xff09; cpu&#xff08;lscpu, uptime, top, htop, glances&#xff09; 内存情况&#xff08;free&…

MVC Gantt Wrapper:RadiantQ jQuery

The RadiantQ jQuery Gantt Package includes fully functional native MVC Wrappers that let you declaratively and seamlessly configure the Gantt component within your aspx or cshtm pages just like any other MVC extensions. 如果您还没有准备好转向完全基于客户端…

(C++)只出现一次的数字I--异或

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://le…

OpenAI承认ChatGPT变懒惰,正在修复该问题

OpenAI旗下的官方ChatGPT账号在社交平台表示&#xff0c;已经收到了大量用户关于GPT-4变懒惰的反馈。 这是因为自11月11日以来&#xff0c;OpenAI就没有更新过该模型。当然这不是故意的&#xff0c;大模型的行为是不可预测的&#xff0c;正在研究修复该问题。 外界猜测&#x…

企业欠税信息API:实现税务管理的智能化与高效化

前言 随着经济的发展和社会的进步&#xff0c;企业欠税问题逐渐凸显&#xff0c;成为制约经济发展的重要因素。为了解决这一问题&#xff0c;企业欠税信息API应运而生。它通过先进的技术手段&#xff0c;提供了一种全新的欠税信息查询方式&#xff0c;帮助企业实现税务管理的智…

nginx多ip部署

1.修改网卡信息自定义多个IP 进入/etc/sysconfig/network-scripts&#xff0c;编辑ifcfg-ens33网卡文件。将dhcp动态分配修改成static&#xff0c;同时添加ip地址子网掩码、网关和DNS。 修改完成后重启网卡&#xff0c;systemctl restart network 2.修改nginx配置文件 有几个…

Vue3无废话,快速上手

Vue3无废话&#xff0c;快速上手 认识Vue3 1. Vue2 选项式 API vs Vue3 组合式API <script> export default {data(){return {count:0}},methods:{addCount(){this.count}} } </script><script setup> import { ref } from vue const count ref(0) const…

【c++随笔16】reserve之后,使用std::copy会崩溃?

【c随笔16】reserve之后&#xff0c;使用std::copy会崩溃? 一、reserve之后&#xff0c;使用std::copy会崩溃?二、函数std::reserve、std::resize、std::copy1、std::resize&#xff1a;2、std::reserve&#xff1a;3、std::copy&#xff1a; 三、崩溃原因分析方案1、你可以使…

机器学习 | Python贝叶斯超参数优化模型答疑

机器学习 | Python贝叶斯超参数优化模型答疑 目录 机器学习 | Python贝叶斯超参数优化模型答疑问题汇总问题1答疑问题2答疑问题3答疑问题汇总 问题1:想问一下贝叶斯优化是什么? 问题2:为什么使用贝叶斯优化? 问题3:如何实现? 问题1答疑 超参数优化在大多数机器学习流水线…

浅析不同NAND架构的差异与影响

SSD的存储介质是什么&#xff0c;它就是NAND闪存。那你知道NAND闪存是怎么工作的吗&#xff1f;其实&#xff0c;它就是由很多个晶体管组成的。这些晶体管里面存储着电荷&#xff0c;代表着我们的二进制数据&#xff0c;要么是“0”&#xff0c;要么是“1”。NAND闪存原理上是一…

Spring日志完结篇,MyBatis操作数据库(入门)

目录 Spring可以对日志进行分目录打印 日志持久化&#xff08;让日志进行长期的保存&#xff09; MyBatis操作数据库(优秀的持久层框架) MyBatis的写法 开发规范&#xff1a; 单元测试的写法 传递参数 Spring可以对日志进行分目录打印 他的意思是说spring相关只打印INFO…

mysql中的DQL查询

表格为&#xff1a; DQL 基础查询 语法&#xff1a;select 查询列表 from 表名&#xff1a;&#xff08;查询的结果是一个虚拟表格&#xff09; -- 查询指定的列 SELECT NAME,birthday,phone FROM student -- 查询所有的列 * 所有的列&#xff0c; 查询结果是虚拟的表格&am…

中国各省、市乡村振兴水平数据(附stata计算代码,2000-2022)

数据简介&#xff1a;乡村振兴是当下经济学研究的热点之一&#xff0c;对乡村振兴进行测度&#xff0c;是研究基础。测度乡村振兴水平的学术论文广泛发表在《数量经济技术经济研究》等顶刊上。数据来源&#xff1a;主要来源于《中国农村统计年鉴》、《中国人口和就业统计年鉴》…

CRM系统选择技巧,什么样的CRM系统好用?

SaaS行业发展迅速&#xff0c;更多的企业逐渐选择CRM管理系统。打开搜索引擎&#xff0c;有非常多的结果。怎样在数十万个搜索结果中选择适合您的CRM系统&#xff1f;下面我们将聊聊&#xff0c;怎样选择CRM系统。 第一步&#xff1a;明确自身需求 重要性&#xff1a;每家企业…

仿照MyBatis手写一个持久层框架学习

首先数据准备&#xff0c;创建MySQL数据库mybatis&#xff0c;创建表并插入数据。 DROP TABLE IF EXISTS user_t; CREATE TABLE user_t ( id INT PRIMARY KEY, username VARCHAR ( 128 ) ); INSERT INTO user_t VALUES(1,Tom); INSERT INTO user_t VALUES(2,Jerry);JDBC API允…

nginx中Include使用

1.include介绍 自己的理解&#xff1a;如果学过C语言的话&#xff0c;感觉和C语言中的Include引入是一样的&#xff0c;引入的文件中可以写任何东西&#xff0c;比如server相关信息&#xff0c;相当于替换的作用&#xff0c;一般情况下server是写在nginx.conf配置文件中的&…

VR串流线方案:实现同时充电传输视频信号

VR&#xff08;Virtual Reality&#xff09;&#xff0c;俗称虚拟现实技术&#xff0c;是一项具有巨大潜力的技术创新&#xff0c;正在以惊人的速度改变我们的生活方式和体验&#xff0c;利用专门设计的设备&#xff0c;如头戴式显示器&#xff08;VR头盔&#xff09;、手柄、定…

idea 本身快捷键ctrl+d复制 无法像eclipse快捷键ctrl+alt+上下键,自动换行格式问题解决

问题 例如我使用ctrld 想复制如下内容 复制效果如下&#xff0c;没有自动换行&#xff0c;还需要自己在进行调整 解决 让如下快捷键第一个删除 修改成如下&#xff0c;将第二个添加ctrld 提示&#xff1a;对应想要修改的item&#xff0c;直接右键&#xff0c;remove是删…