pytorch06:权重初始化

在这里插入图片描述

目录

  • 一、梯度消失和梯度爆炸
    • 1.1相关概念
    • 1.2 代码实现
    • 1.3 实验结果
    • 1.4 方差计算
    • 1.5 标准差计算
    • 1.6 控制网络层输出标准差为1
    • 1.7 带有激活函数的权重初始化
  • 二、Xavier方法与Kaiming方法
    • 2.1 Xavier初始化
    • 2.2 Kaiming初始化
    • 2.3 常见的初始化方法
  • 三、nn.init.calculate_gain

一、梯度消失和梯度爆炸

1.1相关概念

一个简易三层全连接神经网络图和神经元计算如下:
在这里插入图片描述
观察第二个隐藏层的权值的梯度是如何求取的,根据链式法则,可以得到如下计算公式,会发现w2的梯度依赖上一层的输出值H1;
在这里插入图片描述
当H1趋近于0的时候,W2的梯度也趋近于0;—>梯度消失
当H1趋近于无穷的时候,W2的梯度也趋近于无穷;—>梯度爆炸
在这里插入图片描述
一旦出现梯度消失或者梯度爆炸就会导致模型无法训练;

1.2 代码实现

import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seedset_seed(1)  # 设置随机种子class MLP(nn.Module):def __init__(self, neural_num, layers):super(MLP, self).__init__()self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])self.neural_num = neural_numdef forward(self, x):for (i, linear) in enumerate(self.linears):x = linear(x)# x = torch.relu(x)# x = torch.tanh(x)print("layer:{}, std:{}".format(i, x.std()))  # 打印当前值的标准差if torch.isnan(x.std()):  # 判断是什么时候标准差为nanprint("output is nan in {} layers".format(i))breakreturn x# 权值初始化函数def initialize(self):for m in self.modules():if isinstance(m, nn.Linear):  # 判断当前网络层是否是线性层,如果是就进行权值初始化nn.init.normal_(m.weight.data)  # normal: mean=0, 控制标准差std在1左右# nn.init.normal_(m.weight.data, std=np.sqrt(1 / self.neural_num))# =======这段代码的目的是通过均匀分布初始化并结合tanh激活函数的特性,为神经网络的某一层(线性层)初始化合适的权重# a = np.sqrt(6 / (self.neural_num + self.neural_num))# tanh_gain = nn.init.calculate_gain('tanh')# a *= tanh_gain# nn.init.uniform_(m.weight.data, -a, a)# 将权重矩阵的值初始化为在 [-a, a] 范围内均匀分布的随机数。这个范围是通过之前的计算和调整得到的,目的是使得权重初始化在一个合适的范围内# nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)# ================凯明初始化方法================# nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法# nn.init.kaiming_normal_(m.weight.data)# flag = 0
flag = 1if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 每增加一层网络 标准差扩大根号n倍batch_size = 16net = MLP(neural_nums, layer_nums)print(net)net.initialize()inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1output = net(inputs)print(output)

1.3 实验结果

这里的初始化使用的是标准正态分布normal: mean=0, 控制标准差std在1左右的方法;
在这里插入图片描述
当输出层达到33层后就会出现梯度爆炸,超出了数据精度可以表示的范围。

1.4 方差计算

在这里插入图片描述
1.期望的计算公式
2,3.是方差的计算公式
根据1,2,3,可以得出,x,y的方差计算公式,当x,y的期望值都为0的时候,x,y的方差等于x的方差乘以y的方差。

1.5 标准差计算

在这里插入图片描述
通过计算可以得出每增加一层网络,标准差增加 n \sqrt{n} n ,n也就是神经元的个数;
代码展示:

if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 神经元个数 每增加一层网络 标准差扩大根号n倍batch_size = 16

执行结果:
可以看出第一层标准差是15.95,第二次标准差在上一层的基础上再乘以 256 \sqrt{256} 256
在这里插入图片描述

1.6 控制网络层输出标准差为1

从1.5可以看出D(H)的大小有三个因素决定,分别是n、D(X)、D(w),所以只要保证这三者乘积为1,就可以保证D(H)的值为1;
在这里插入图片描述
当我们权值的标准差为 1 / n \sqrt{1/n} 1/n ,那么就能保证网络层每一层的输出标准差都为1;

代码实现:
在这里插入图片描述

输出结果:
在这里插入图片描述
通过输出结果可以发现,几乎每一层网络输出的标准差都为1.

1.7 带有激活函数的权重初始化

在forward函数里面添加tanh激活函数
在这里插入图片描述
执行结果:
增加tanh激活函数之后,随着网络层的增加,标准差越来越小,从而会导致梯度消失的现象,下面将说明Xavier方法与Kaiming方法是如何解决该问题。
在这里插入图片描述

二、Xavier方法与Kaiming方法

2.1 Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh
Xavier初始化公式如下:
在这里插入图片描述

代码实现:
手动代码实现
在这里插入图片描述

直接使用pytorch提供的xavier_uniform_函数方法

nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

执行结果:
在这里插入图片描述
可以看到,每一层的网络输出标准差都在0.6左右

2.2 Kaiming初始化

当我们使用带有权值初始化的relu激活函数时,输出结果如下,会发现标准差随着网络层的增加逐渐减小,Kaiming初始化解决了这一问题。
在这里插入图片描述
在这里插入图片描述

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种
公式如下:
在这里插入图片描述

代码实现:

# ================凯明初始化方法================
nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
# nn.init.kaiming_normal_(m.weight.data)  # 使用pytorch自带方法

输出结果:
在这里插入图片描述

2.3 常见的初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

三、nn.init.calculate_gain

主要功能:计算激活函数的方差变化尺度(也就是输入数据的方差/经过激活函数之后的方差)
主要参数
• nonlinearity: 激活函数名称
• param: 激活函数的参数,如Leaky ReLU的negative_slop

代码实现:

flag = 1if flag:x = torch.randn(10000)out = torch.tanh(x)gain = x.std() / out.std()  # 手动计算print('gain:{}'.format(gain))tanh_gain = nn.init.calculate_gain('tanh')  # pytorch自带函数print('tanh_gain in PyTorch:', tanh_gain)

输出结果:
在这里插入图片描述
总结:任何数据在经过tanh激活函数之后,方差缩小大约1.6倍。感兴趣的话也可以使用relu进行实验,最后我的到的结果方差尺度大约是1.4左右。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/594953.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Golang Leetcode19 删除链表的倒数第N个节点 递归 双指针法+迭代

删除链表的倒数第N个节点 leetcode19 递归 由于本体是倒数第几个节点,非常适合递归 从终到始 的运行方式 func removeNthFromEnd(head *ListNode, n int) *ListNode {// 创建一个虚拟头节点,简化边界条件处理dummy : &ListNode{Next: head}//检查…

【Week-P4】CNN猴痘病识别

文章目录 一、环境配置二、准备数据三、搭建网络结构四、开始训练五、查看训练结果六、总结2.3 ⭐torch.utils.data.DataLoader()参数详解6.1 print()常用的三种输出格式6.2 修改网络结构,观察训练结果6.2.1 增加pool2、conv6、bn6,test_accuracy82.5%6.…

JavaScript中window.open()方法【详解】

✨前言✨   本篇文章主要针对JavaScript中window.open()方法的全面了解,以及详细参数说明使用 🍒欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍒博主将持续更新学习记录收获,友友们有任何问题…

canvas绘制直角梯形(向右)

查看专栏目录 canvas示例教程100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

C++ 多态向下转型详解

文章目录 1 . 前言2 . 多态3 . 向下转型3.1 子类没有改进父类的方法下,去调用该方法3.2 子类有改进父类的方法下,去调用该方法3.3 子类没有改进父类虚函数的方法下,去调用改方法3.4 子类有改进父类虚函数的方法下,去调用改方法3.5…

linux 流量监控

linux 流量监控 Linux 网络流量监控利器 iftop命令详解及实战 https://blog.csdn.net/qq_50247813/article/details/134164093 iftop命令详解 https://www.cnblogs.com/gaoyuechen/p/17300017.html 1 ubuntu如何查看流量监控 Ubuntu是一种非常流行的Linux发行版&#xff0c…

Hive10_窗口函数

窗口函数(开窗函数) 1 相关函数说明 普通的聚合函数聚合的行集是组,开窗函数聚合的行集是窗口。因此,普通的聚合函数每组(Group by)只返回一个值,而开窗函数则可为窗口中的每行都返回一个值。简单理解,就是对查询的结果多出一列…

【MySQL·8.0·源码】MySQL 表的扫描方式

前言 在进一步介绍 MySQL 优化器时,先来了解一下 MySQL 单表都有哪些扫描方式。 单表扫描方法是基表的读取基础,也是完成表连接的基础,熟悉了基表的基本扫描方式, 即可以倒推理解 MySQL 优化器层的诸多考量。 基表,即…

Windows XP SP1源代码编译方法(笔记)

NT版本 : 5.1 编译号 :2600 编译时间 : 2001年8月17日11点48分 第一步 : 搭建编译环境 使用VMWare搭建Windows XP的编译环境,注意系统要使用英文版。 第二步 : 设置编译参数 1.将原代码解压到虚拟机的XP系统中的C盘C:\NT目录或者D盘D:\NT目录下…

myysql的正则表达式

上周遇见一个需求,有这样一棵树: 点击上级,展现所有子集,点击集团,显示所有产线(例子) 这个时候有两种方式: 添加产线时,将集团、事业部、公司、车间的id存起来。 然后…

人脸关键点检测dlib安装

dlib 库是一个用来人脸关键点检测的 python 库,但因为其是 C 编写(或需要 C编译?),使得我们在安装时遇到各种各样问题。笔者在不同电脑上安装遇到的问题都不同,但最后经过搜索,都解决了&#xf…

数据结构【查找篇】

数据结构【查找篇】 文章目录 数据结构【查找篇】前言为什么突然想学算法了?为什么选择码蹄集作为刷题软件? 目录一、顺序查找二、折半查找三、 二叉排序树的查找四、红黑树 结语 前言 为什么突然想学算法了? > 用较为“官方”的语言讲&am…

Git | tag相关命令

语法命令 git tag -h usage: git tag [-a | -s | -u <key-id>] [-f] [-m <msg> | -F <file>]<tagname> [<head>]or: git tag -d <tagname>...or: git tag -l [-n[<num>]] [--contains <commit>] [--no-contains <commit&g…

普中STM32-PZ6806L开发板(HAL库函数实现-读取内部温度)

简介 主芯片STM32F103ZET6&#xff0c;读取内部温度其他知识 内部温度所在ADC通道 温度计算公式 V25跟Avg_Slope值 参考文档 stm32f103ze.pdf 电压计算公式 Vout Vref * (D / 2^n) 其中Vref代表参考电压&#xff0c; n为ADC的位数&#xff0c; D为ADC输入的数字信号。 实现…

人工智能在银行运营中的运用

机器学习在金融领域的运用&#xff1a;银行如何以最优的方式抓住 AI 机会&#xff1f; 大型企业若想获得超越竞争对手的优势&#xff0c;那么采用 AI 作为其业务战略是他们的重要任务&#xff0c;而在这方面&#xff0c;大型银行走在了前面。银行开始将 AI 和机器学习应用于前…

201.【2023年华为OD机试真题(C卷)】最长子字符串的长度(一)(滑动窗口算法-JavaPythonC++JS实现)

🚀点击这里可直接跳转到本专栏,可查阅顶置最新的华为OD机试宝典~ 本专栏所有题目均包含优质解题思路,高质量解题代码(Java&Python&C++&JS分别实现),详细代码讲解,助你深入学习,深度掌握! 文章目录 一. 题目二.解题思路三.题解代码Python题解代码JAVA题解…

了解并使用django-rest-framework-jwt

一 JWT认证 在用户注册或登录后&#xff0c;我们想记录用户的登录状态&#xff0c;或者为用户创建身份认证的凭证。我们不再使用Session认证机制&#xff0c;而使用Json Web Token&#xff08;本质就是token&#xff09;认证机制。 Json web token (JWT), 是为了在网络应用环…

什么是跨域以及怎么处理跨域问题

文章目录 什么是跨域&#xff1f;跨域问题常见场景怎么处理跨域1、配置代理2、CORS&#xff08;跨域资源共享&#xff09;3、JSONP&#xff08;仅限 GET 请求&#xff09;4、使用 WebSocket 注意事项&#xff1a; 什么是跨域&#xff1f; 跨域&#xff08;Cross-Origin&#x…

专题一_双指针(一)

文章目录 283.移动零题目解析讲解算法原理扩展编写代码 1089.复习零题目解析讲解算法原理编写代码 202.快乐数题目解析讲解算法原理证明编写代码 11.盛最多水的容器题目解析讲解算法原理暴力解法优秀的解法时间复杂度分析 编写代码 283.移动零 题目链接 题目解析 题目还是比较…

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现 漏洞名称漏洞描述影响版本 漏洞复现环境搭建安装thinkphp6漏洞信息配置 漏洞利用 修复建议 漏洞名称 漏洞描述 2020年1月10日&#xff0c;ThinkPHP团队发布一个补丁更新&#xff0c;修复了一处由不安全的SessionId导致的任意文…