pytorch06:权重初始化

在这里插入图片描述

目录

  • 一、梯度消失和梯度爆炸
    • 1.1相关概念
    • 1.2 代码实现
    • 1.3 实验结果
    • 1.4 方差计算
    • 1.5 标准差计算
    • 1.6 控制网络层输出标准差为1
    • 1.7 带有激活函数的权重初始化
  • 二、Xavier方法与Kaiming方法
    • 2.1 Xavier初始化
    • 2.2 Kaiming初始化
    • 2.3 常见的初始化方法
  • 三、nn.init.calculate_gain

一、梯度消失和梯度爆炸

1.1相关概念

一个简易三层全连接神经网络图和神经元计算如下:
在这里插入图片描述
观察第二个隐藏层的权值的梯度是如何求取的,根据链式法则,可以得到如下计算公式,会发现w2的梯度依赖上一层的输出值H1;
在这里插入图片描述
当H1趋近于0的时候,W2的梯度也趋近于0;—>梯度消失
当H1趋近于无穷的时候,W2的梯度也趋近于无穷;—>梯度爆炸
在这里插入图片描述
一旦出现梯度消失或者梯度爆炸就会导致模型无法训练;

1.2 代码实现

import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seedset_seed(1)  # 设置随机种子class MLP(nn.Module):def __init__(self, neural_num, layers):super(MLP, self).__init__()self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])self.neural_num = neural_numdef forward(self, x):for (i, linear) in enumerate(self.linears):x = linear(x)# x = torch.relu(x)# x = torch.tanh(x)print("layer:{}, std:{}".format(i, x.std()))  # 打印当前值的标准差if torch.isnan(x.std()):  # 判断是什么时候标准差为nanprint("output is nan in {} layers".format(i))breakreturn x# 权值初始化函数def initialize(self):for m in self.modules():if isinstance(m, nn.Linear):  # 判断当前网络层是否是线性层,如果是就进行权值初始化nn.init.normal_(m.weight.data)  # normal: mean=0, 控制标准差std在1左右# nn.init.normal_(m.weight.data, std=np.sqrt(1 / self.neural_num))# =======这段代码的目的是通过均匀分布初始化并结合tanh激活函数的特性,为神经网络的某一层(线性层)初始化合适的权重# a = np.sqrt(6 / (self.neural_num + self.neural_num))# tanh_gain = nn.init.calculate_gain('tanh')# a *= tanh_gain# nn.init.uniform_(m.weight.data, -a, a)# 将权重矩阵的值初始化为在 [-a, a] 范围内均匀分布的随机数。这个范围是通过之前的计算和调整得到的,目的是使得权重初始化在一个合适的范围内# nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)# ================凯明初始化方法================# nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法# nn.init.kaiming_normal_(m.weight.data)# flag = 0
flag = 1if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 每增加一层网络 标准差扩大根号n倍batch_size = 16net = MLP(neural_nums, layer_nums)print(net)net.initialize()inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1output = net(inputs)print(output)

1.3 实验结果

这里的初始化使用的是标准正态分布normal: mean=0, 控制标准差std在1左右的方法;
在这里插入图片描述
当输出层达到33层后就会出现梯度爆炸,超出了数据精度可以表示的范围。

1.4 方差计算

在这里插入图片描述
1.期望的计算公式
2,3.是方差的计算公式
根据1,2,3,可以得出,x,y的方差计算公式,当x,y的期望值都为0的时候,x,y的方差等于x的方差乘以y的方差。

1.5 标准差计算

在这里插入图片描述
通过计算可以得出每增加一层网络,标准差增加 n \sqrt{n} n ,n也就是神经元的个数;
代码展示:

if flag:layer_nums = 100  # 100层线性层neural_nums = 256  # 神经元个数 每增加一层网络 标准差扩大根号n倍batch_size = 16

执行结果:
可以看出第一层标准差是15.95,第二次标准差在上一层的基础上再乘以 256 \sqrt{256} 256
在这里插入图片描述

1.6 控制网络层输出标准差为1

从1.5可以看出D(H)的大小有三个因素决定,分别是n、D(X)、D(w),所以只要保证这三者乘积为1,就可以保证D(H)的值为1;
在这里插入图片描述
当我们权值的标准差为 1 / n \sqrt{1/n} 1/n ,那么就能保证网络层每一层的输出标准差都为1;

代码实现:
在这里插入图片描述

输出结果:
在这里插入图片描述
通过输出结果可以发现,几乎每一层网络输出的标准差都为1.

1.7 带有激活函数的权重初始化

在forward函数里面添加tanh激活函数
在这里插入图片描述
执行结果:
增加tanh激活函数之后,随着网络层的增加,标准差越来越小,从而会导致梯度消失的现象,下面将说明Xavier方法与Kaiming方法是如何解决该问题。
在这里插入图片描述

二、Xavier方法与Kaiming方法

2.1 Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh
Xavier初始化公式如下:
在这里插入图片描述

代码实现:
手动代码实现
在这里插入图片描述

直接使用pytorch提供的xavier_uniform_函数方法

nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

执行结果:
在这里插入图片描述
可以看到,每一层的网络输出标准差都在0.6左右

2.2 Kaiming初始化

当我们使用带有权值初始化的relu激活函数时,输出结果如下,会发现标准差随着网络层的增加逐渐减小,Kaiming初始化解决了这一问题。
在这里插入图片描述
在这里插入图片描述

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种
公式如下:
在这里插入图片描述

代码实现:

# ================凯明初始化方法================
nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
# nn.init.kaiming_normal_(m.weight.data)  # 使用pytorch自带方法

输出结果:
在这里插入图片描述

2.3 常见的初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

三、nn.init.calculate_gain

主要功能:计算激活函数的方差变化尺度(也就是输入数据的方差/经过激活函数之后的方差)
主要参数
• nonlinearity: 激活函数名称
• param: 激活函数的参数,如Leaky ReLU的negative_slop

代码实现:

flag = 1if flag:x = torch.randn(10000)out = torch.tanh(x)gain = x.std() / out.std()  # 手动计算print('gain:{}'.format(gain))tanh_gain = nn.init.calculate_gain('tanh')  # pytorch自带函数print('tanh_gain in PyTorch:', tanh_gain)

输出结果:
在这里插入图片描述
总结:任何数据在经过tanh激活函数之后,方差缩小大约1.6倍。感兴趣的话也可以使用relu进行实验,最后我的到的结果方差尺度大约是1.4左右。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/594953.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Golang Leetcode19 删除链表的倒数第N个节点 递归 双指针法+迭代

删除链表的倒数第N个节点 leetcode19 递归 由于本体是倒数第几个节点,非常适合递归 从终到始 的运行方式 func removeNthFromEnd(head *ListNode, n int) *ListNode {// 创建一个虚拟头节点,简化边界条件处理dummy : &ListNode{Next: head}//检查…

【Week-P4】CNN猴痘病识别

文章目录 一、环境配置二、准备数据三、搭建网络结构四、开始训练五、查看训练结果六、总结2.3 ⭐torch.utils.data.DataLoader()参数详解6.1 print()常用的三种输出格式6.2 修改网络结构,观察训练结果6.2.1 增加pool2、conv6、bn6,test_accuracy82.5%6.…

canvas绘制直角梯形(向右)

查看专栏目录 canvas示例教程100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

C++ 多态向下转型详解

文章目录 1 . 前言2 . 多态3 . 向下转型3.1 子类没有改进父类的方法下,去调用该方法3.2 子类有改进父类的方法下,去调用该方法3.3 子类没有改进父类虚函数的方法下,去调用改方法3.4 子类有改进父类虚函数的方法下,去调用改方法3.5…

Hive10_窗口函数

窗口函数(开窗函数) 1 相关函数说明 普通的聚合函数聚合的行集是组,开窗函数聚合的行集是窗口。因此,普通的聚合函数每组(Group by)只返回一个值,而开窗函数则可为窗口中的每行都返回一个值。简单理解,就是对查询的结果多出一列…

【MySQL·8.0·源码】MySQL 表的扫描方式

前言 在进一步介绍 MySQL 优化器时,先来了解一下 MySQL 单表都有哪些扫描方式。 单表扫描方法是基表的读取基础,也是完成表连接的基础,熟悉了基表的基本扫描方式, 即可以倒推理解 MySQL 优化器层的诸多考量。 基表,即…

myysql的正则表达式

上周遇见一个需求,有这样一棵树: 点击上级,展现所有子集,点击集团,显示所有产线(例子) 这个时候有两种方式: 添加产线时,将集团、事业部、公司、车间的id存起来。 然后…

数据结构【查找篇】

数据结构【查找篇】 文章目录 数据结构【查找篇】前言为什么突然想学算法了?为什么选择码蹄集作为刷题软件? 目录一、顺序查找二、折半查找三、 二叉排序树的查找四、红黑树 结语 前言 为什么突然想学算法了? > 用较为“官方”的语言讲&am…

普中STM32-PZ6806L开发板(HAL库函数实现-读取内部温度)

简介 主芯片STM32F103ZET6,读取内部温度其他知识 内部温度所在ADC通道 温度计算公式 V25跟Avg_Slope值 参考文档 stm32f103ze.pdf 电压计算公式 Vout Vref * (D / 2^n) 其中Vref代表参考电压, n为ADC的位数, D为ADC输入的数字信号。 实现…

人工智能在银行运营中的运用

机器学习在金融领域的运用:银行如何以最优的方式抓住 AI 机会? 大型企业若想获得超越竞争对手的优势,那么采用 AI 作为其业务战略是他们的重要任务,而在这方面,大型银行走在了前面。银行开始将 AI 和机器学习应用于前…

了解并使用django-rest-framework-jwt

一 JWT认证 在用户注册或登录后,我们想记录用户的登录状态,或者为用户创建身份认证的凭证。我们不再使用Session认证机制,而使用Json Web Token(本质就是token)认证机制。 Json web token (JWT), 是为了在网络应用环…

专题一_双指针(一)

文章目录 283.移动零题目解析讲解算法原理扩展编写代码 1089.复习零题目解析讲解算法原理编写代码 202.快乐数题目解析讲解算法原理证明编写代码 11.盛最多水的容器题目解析讲解算法原理暴力解法优秀的解法时间复杂度分析 编写代码 283.移动零 题目链接 题目解析 题目还是比较…

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现 漏洞名称漏洞描述影响版本 漏洞复现环境搭建安装thinkphp6漏洞信息配置 漏洞利用 修复建议 漏洞名称 漏洞描述 2020年1月10日,ThinkPHP团队发布一个补丁更新,修复了一处由不安全的SessionId导致的任意文…

【GlobalMapper精品教程】069:中文属性表乱码问题及解决方法

参考阅读:【ArcGIS Pro微课1000例】0012:ArcGIS Pro属性表中文乱码完美解决办法汇总 文章目录 一、Globalmapper默认字符集设置二、shp属性表乱码三、转出的kmz乱码一、Globalmapper默认字符集设置 中文字体乱码通常是由字符编码不匹配造成的。 打开Globalmapper软件,点击工…

【动态规划】【字符串】扰乱字符串

作者推荐 视频算法专题 涉及知识点 动态规划 字符串 LeetCode87扰乱字符串 使用下面描述的算法可以扰乱字符串 s 得到字符串 t : 如果字符串的长度为 1 ,算法停止 如果字符串的长度 > 1 ,执行下述步骤: 在一个随机下标处将…

java struts2教务管理系统Myeclipse开发mysql数据库struts2结构java编程计算机网页项目

一、源码特点 java struts2 教务管理系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助 struts2 框架开发,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境 为TOMCAT7.0,Myeclipse8.5开发,数据库…

【ikbp】数据可视化DataV

天天查询一些数据,希望来一个托拉拽的展示,部署体验一下可视化大屏 快速搭建快速查询实时更新简单易用 启动服务 数据可视化 静态查询 配置数据 过滤数据 分享

系列七、Ribbon

一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如&#xff1a…

《低功耗方法学》翻译——卷首语

就目前半导体的发展现状来说,我们国家还处在奋力追赶阶段。在我国半导体行业历经多轮技术制裁的今天,我们不得不承认的是,半导体技术最先进的就是美国。我国早在上世纪六七十年代就有涉足半导体技术,大量华裔留美的爱国人士回国为…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《基于碳捕集与封存-电转气-电解熔融盐协同的虚拟电厂优化调度》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主的专栏栏目《论文与完整程序》 这个标题涉及到多个关键概念,让我们逐一解读: 碳捕集与封存(Carbon Capture and Storage,CCS)&a…