PyTorch常用参数初始化方法详解

Python微信订餐小程序课程视频

https://edu.csdn.net/course/detail/36074

Python实战量化交易理财系统

https://edu.csdn.net/course/detail/35475

1、均匀分布初始化

torch.nn.init.uniform_(tensor, a=0, b=1)

从均匀分布U(a, b)中采样,初始化张量。  参数:

    • tensor - 需要填充的张量
      • a - 均匀分布的下界
      • b - 均匀分布的上界

例子

w = torch.empty(3, 5)
nn.init.uniform\_(w)
"""
tensor([[0.2116, 0.3085, 0.5448, 0.6113, 0.7697],[0.8300, 0.2938, 0.4597, 0.4698, 0.0624],[0.5034, 0.1166, 0.3133, 0.3615, 0.3757]])
"""

均匀分布详解

若 xxx 服从均匀分布,即 x U(a,b)x U(a,b)x~U(a,b),其概率密度函数(表征随机变量每个取值有多大的可能性)为,

f(x)={1b−a,a<x<b0,elsef(x)={1b−a,a<x<b0,elsef(x)=\left{\begin{array}{l}\frac{1}{b-a}, \quad a<x<b \ 0, \quad else \end{array}\right.

则有期望和方差,

E(x)=∫∞−∞xf(x)dx=12(a+b)D(x)=E(x2)−[E(x)]2=(b−a)212E(x)=∫∞−∞xf(x)dx=12(a+b)D(x)=E(x2)−[E(x)]2=(b−a)212\begin{array}{c}E(x)=\int_{-\infty}^{\infty} x f(x) d x=\frac{1}{2}(a+b) \D(x)=E\left(x{2}\right)-[E(x)]{2}=\frac{(b-a)^{2}}{12}\end{array}

2、正态(高斯)分布初始化

torch.nn.init.normal_(tensor, mean=0.0, std=1.0)

从给定的均值和标准差的正态分布 N(mean,std2)N(mean,std2)N\left(\right. mean, \left.s t d^{2}\right) 中生成值,初始化张量。

参数:

    • tensor - 需要填充的张量
      • mean - 正态分布的均值
      • std - 正态分布的标准偏差

例子

w = torch.Tensor(3, 5)
torch.nn.init.normal\_(w, mean=0, std=1)
"""
tensor([[-1.3903, 0.4045, 0.3048, 0.7537, -0.5189],[-0.7672, 0.1891, -0.2226, 0.2913, 0.1295],[ 1.4719, -0.3049, 0.3144, -1.0047, -0.5424]])
"""

正态分布详解:

若随机变量 xxx 服从正态分布,即 x∼N(μ,σ2)x∼N(μ,σ2)x \sim N\left(\mu, \sigma^{2}\right) , 其概率密度函数为,

f(x)=1σ√2πexp(−(x−μ2)2σ2)f(x)=\frac{1}{\sigma \sqrt{2 \pi}} \exp \left(-\frac{\left(x-\mu^{2}\right)}{2 \sigma^{2}}\right)

正态分布概率密度函数中一些特殊的概率值:

    • 68.268949% 的面积在平均值左右的一个标准差 σ\sigma 范围内 (μ±σ\mu \pm \sigma)
      • 95.449974% 的面积在平均值左右两个标准差 2σ2 \sigma 的范围内 (μ±2σ\mu \pm 2 \sigma)
      • 99.730020% 的面积在平均值左右三个标准差 3σ3 \sigma 的范围内 (μ±3σ\mu \pm 3 \sigma)
      • 99.993666% 的面积在平均值左右四个标准差 4σ4 \sigma 的范围内 (μ±4σ\mu \pm 4 \sigma)

μ=0\mu=0, σ=1\sigma=1 时的正态分布是标准正态分布。

3. Xavier初始化

3.1 Xavier均匀分布初始化

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

又称 Glorot 初始化,按照 Glorot, X. & Bengio, Y.(2010)在论文Understanding the difficulty of training deep feedforward neural networks 中描述的方法,从均匀分布 U(−a,a)U(−a, a) 中采样,初始化输入张量 tensortensor,其中 aa 值由下式确定:

a= gain ×√6 fan_in + fan_out a=\text { gain } \times \sqrt{\frac{6}{\text { fan_in }+\text { fan_out }}}

例子

w = torch.Tensor(3, 5)
nn.init.xavier\_uniform\_(w, gain=torch.nn.init.calculate\_gain('relu'))
"""
tensor([[ 0.7695, -0.7687, -0.2561, -0.5307, 0.5195],[-0.6187, 0.4913, 0.3037, -0.6374, 0.9725],[-0.2658, -0.4051, -1.1006, -1.1264, -0.1310]])
"""

3.2 Xavier正态分布初始化

torch.nn.init.xavier_normal_(tensor, gain=1.0)

又称 Glorot 初始化,按照 Glorot, X. & Bengio, Y.(2010)在论文Understanding the difficulty of training deep feedforward neural networks 中描述的方法,从均匀分布 N(0,std2)N\left(0, s t d^{2}\right) 中采样,初始化输入张量 tensortensor,其中 stdstd 值由下式确定:

std= gain ×√2 fan_in + fan_out \operatorname{std}=\text { gain } \times \sqrt{\frac{2}{\text { fan_in }+\text { fan_out }}}

参数:

    • tensor - 需要初始化的张量
      • gain - 可选的放缩因子

例子

w = torch.arange(10).view(2,-1).type(torch.float32)
torch.nn.init.xavier\_normal\_(w)
"""
tensor([[-0.3139, -0.3557, 0.1285, -0.9556, 0.3255],[-0.6212, 0.3405, -0.4150, -1.3227, -0.0069]])
"""

4. kaiming初始化

4.1 kaiming均匀分布初始化

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

又称 He 初始化,按照He, K. et al. (2015)在论文Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification中描述的方法,从均匀分布U(−bound,bound)U(−bound, bound) 中采样,初始化输入张量 tensor,其中 bound 值由下式确定:

bound = gain ×√3 fan_mode \text { bound }=\text { gain } \times \sqrt{\frac{3}{\text { fan_mode }}}

参数:

    • tensor - 需要初始化的张量;
      • a\mathrm{a}- 这层之后使用的 rectifier的斜率系数,用来计算gain =\sqrt{\frac{2}{1+\mathrm{a}^{2}}} (此参数仅在参数nonlinea rity为’leaky_relu’时生效);
      • mode - 可以为“fan_in”(默认)或“fan_out”。“fan_in”维持前向传播时权值方差,“fan_out”维持反向传播时的方差;
      • nonlinearity - 非线性函数(nn.functional中的函数名),pytorch建议仅与“relu”或“leaky_relu”(默认)一起使用;

例子

w = torch.Tensor(3, 5)
torch.nn.init.kaiming\_uniform\_(w, mode='fan\_in', nonlinearity='relu')
"""
tensor([[-0.4362, -0.8177, -0.7034, 0.7306, -0.6457],[-0.5749, -0.6480, -0.8016, -0.1434, 0.0785],[ 1.0369, -0.0676, 0.7430, -0.2484, -0.0895]])
"""

4.2 kaiming正态分布初始化

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

又称He初始化,按照He, K. et al. (2015)在论文Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification中描述的方法,从正态分布 N(0,std2)N\left(0, s t d^{2}\right) 中采样,初始化输入张量tensor,其中std值由下式确定:

参数:

    • tensor - 需要初始化的张量;
      • a\mathrm{a} - 这层之后使用的 rectifier 的斜率系数,用来计算 gain=√21+a2gain =\sqrt{\frac{2}{1+\mathrm{a}^{2}}} (此参数仅在参数nonlinea rity为’leaky_relu’时生效);
      • mode - 可以为"fan_in" (默认) 或“fan_out"。"fan_in"维持前向传播时权值方差,"fan_out"维持反 向传播时的方差;
      • nonlinearity - 非线性函数 (nn.functional中的函数名),pytorch建议仅与“relu”或"leaky_relu”(默 认)一起使用;

5、正交矩阵初始化

torch.nn.init.orthogonal_(tensor, gain=1)

用一个(半)正交矩阵初始化输入张量,参考Saxe, A. et al. (2013) - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks。输入张量必须至少有 2 维,对于大于 2 维的张量,超出的维度将被flatten化。

正交初始化可以使得卷积核更加紧凑,可以去除相关性,使模型更容易学到有效的参数。

参数:

    • tensor - 需要初始化的张量
      • gain - 可选的放缩因子

例子:

w = torch.Tensor(3, 5)
torch.nn.init.orthogonal\_(w)
"""
tensor([[ 0.7395, -0.1503, 0.4474, 0.4321, -0.2090],[-0.2625, 0.0112, 0.6515, -0.4770, -0.5282],[ 0.4554, 0.6548, 0.0970, -0.4851, 0.3453]])
"""

6、稀疏矩阵初始化

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

将2维的输入张量作为稀疏矩阵填充,其中非零元素由正态分布 N(0,0.012)N\left(0,0.01^{2}\right) 生成。 参考Martens, J.(2010)的 Deep learning via Hessian-free optimization。

参数:

    • tensor - 需要填充的张量
      • sparsity - 每列中需要被设置成零的元素比例
      • std - 用于生成非零元素的正态分布的标准偏差

例子:

w = torch.Tensor(3, 5)
torch.nn.init.sparse\_(w, sparsity=0.1)
"""
tensor([[-0.0026, 0.0000, 0.0100, 0.0046, 0.0048],[ 0.0106, -0.0046, 0.0000, 0.0000, 0.0000],[ 0.0000, -0.0005, 0.0150, -0.0097, -0.0100]])
"""

7、常数初始化

torch.nn.init.constant_(tensor, val)

使值为常数 val 。

例子:

w=torch.Tensor(3,5)
nn.init.constant\_(w,1.2)
"""
tensor([[1.2000, 1.2000, 1.2000, 1.2000, 1.2000],[1.2000, 1.2000, 1.2000, 1.2000, 1.2000],[1.2000, 1.2000, 1.2000, 1.2000, 1.2000]])
"""

8、单位矩阵初始化

torch.nn.init.eye_(tensor)

将二维 tensor 初始化为单位矩阵(the identity matrix)

例子:

w=torch.Tensor(3,5)
nn.init.eye\_(w)
"""
tensor([[1., 0., 0., 0., 0.],[0., 1., 0., 0., 0.],[0., 0., 1., 0., 0.]])
"""

9、零填充初始化

torch.nn.init.zeros_(tensor)

例子:

w = torch.empty(3, 5)
nn.init.zeros\_(w)
"""
tensor([[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]])
"""

10、应用

例子:

print('module-----------')
print(model)
print('setup-----------')
for m in model.modules():if isinstance(m,nn.Linear):nn.init.xavier\_uniform\_(m.weight, gain=nn.init.calculate\_gain('relu'))
"""
module-----------
Sequential((flatten): FlattenLayer()(linear1): Linear(in\_features=784, out\_features=512, bias=True)(activation): ReLU()(linear2): Linear(in\_features=512, out\_features=256, bias=True)(linear3): Linear(in\_features=256, out\_features=10, bias=True)
)
setup-----------
"""

例子:

for param in model.parameters():nn.init.uniform\_(param)

例子:

def weights\_init(m):classname = m.\_\_class\_\_.\_\_name\_\_if classname.find('Conv2d') != -1:nn.init.xavier\_normal\_(m.weight.data)nn.init.constant\_(m.bias.data, 0.0)elif classname.find('Linear') != -1:nn.init.xavier\_normal\_(m.weight)nn.init.constant\_(m.bias, 0.0)
model.apply(weights\_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。

  • 1、均匀分布初始化

  • 2、正态(高斯)分布初始化

  • 3. Xavier初始化

  • 3.1 Xavier均匀分布初始化

  • 3.2 Xavier正态分布初始化

  • 4. kaiming初始化

  • 4.1 kaiming均匀分布初始化

  • 4.2 kaiming正态分布初始化

  • 5、正交矩阵初始化

  • 6、稀疏矩阵初始化

  • 7、常数初始化

  • 8、单位矩阵初始化

  • 9、零填充初始化

  • 10、应用

    __EOF__

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xU3OTp9P-1646759404745)(https://blog.csdn.net/BlairGrowing)]Blair - 本文链接: https://blog.csdn.net/BlairGrowing/p/15981694.html

  • 关于博主: 评论和私信会在第一时间回复。或者直接私信我。
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
  • 声援博主: 如果您觉得文章对您有帮助,可以点击文章右下角**【[推荐](javascript:void(0)😉】**一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/401228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sql语句中的删除操作

drop: drop table tb; 删除内容和定义&#xff0c;释放空间。简单来说就是把整个表去掉。以后不能再新增数据&#xff0c;除非新增一个表。 truncate&#xff1a; truncate table tb; 删除内容、释放空间但不删除定义&#xff0c;即只是清空数&#xff0c;不会删除表的数据结构…

[面试题]事件循环经典面试题解析

Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/course/detail/35475 基础概念 进程是计算机已经运行的程序,线程是操作系统能够进行运算调度的最小单位,它被包含在进程中.浏览器中每开一个Tab…

CC254x--OSAL

OSAL运行原理 蓝牙协议栈PROFILE、所有的应用程序、驱动等都是围绕着OSAL组织运行的。OSAL&#xff08;Operating System Abstraction Layer&#xff09;操作系统抽象层&#xff0c;它不是一个真正的操作系统&#xff08;它没有 Context Switch 上下文切换功能&#xff09;&am…

CLR 与 C++的常用类型转换笔记

1. System::String 转换到 const wchar_t* const wchar_t* ToUnmanagedUnicode( System::String^ str ){ pin_ptr<const WCHAR> nativeString1 PtrToStringChars( str ); return (const wchar_t*)nativeString1;} 2. const wchar_t* / const char* 转换到 System::Strin…

mysql跨节点join——federated引擎

一、 什么是federated引擎 mysql中的federated类似于oracle中的dblink。 federated是一个专门针对远程数据库的实现&#xff0c;一般情况下在本地数据库中建表会在数据库目录中生成相对应的表定义文件&#xff0c;并同时生成相对应的数据文件。 [图] 但是通过federated引擎创建…

【阅读SpringMVC源码】手把手带你debug验证SpringMVC执行流程

Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/course/detail/35475 ✿ 阅读源码思路&#xff1a; 先跳过非重点&#xff0c;深入每个方法&#xff0c;进入的时候可以把整个可以理一下方法的执…

Zabbix监控(十六):分布式监控-Zabbix Proxy

说明&#xff1a;Zabbix支持分布式监控&#xff0c;利用Proxy代理功能&#xff0c;在其他网络环境中部署代理服务器&#xff0c;将监控数据汇总到Zabbix主服务器&#xff0c;实现多网络的分布式监控&#xff0c;集中监控。1、分布式监控原理Zabbix proxy和Zabbix server一样&am…

CC254x--BLE

BLE协议栈 BLE体系结构&#xff0c;着重了解GAP和GATT。 PHY物理层在2.4GHz的ISM频段中跳频识别。LL连接层&#xff1a;控制设备的状态。设备可能有5中状态&#xff1a;就绪standby&#xff0c;广播advertising&#xff0c;搜索scanning&#xff0c;初始化initiating和连接con…

DW 在onload运行recordset find.html时 发生了以下javascript错误

这两天打开Dreamweaver CS5&#xff0c;总是弹出一个错误&#xff0c;写着&#xff1a; 在onLoad运行RecordsetFind.htm时&#xff0c;发生了以下JavaScript错误&#xff1a; 在文件“RecordsetFind”中&#xff1a; findRsisnotdefined 在关闭Dreamweaver的时候也会弹出一个类…

Azure Container App(一)应用介绍

Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/course/detail/35475 一&#xff0c;引言 容器技术正日益成为打包、部署应用程序的第一选择。Azure 提供了许多使用容器的选项。例如&#xff0…

CC254x--API

CC2541常用API 连接 定义广播数据 GAPRole_SetParameter(GAPROLE_ADVERT_DATA,…); 自定义扫描响应数据 GAPRole_SetParameter(GAPROLE_SCAN_RSP_DATA,…); 密码管理回调 ProcessPasscodeCB(); 状态管理回调 peripheralStateNotificationCB(); 通信控制 添加GATT服务 GATTServ…

mysql 分区信息查看

select partition_name part,partition_expression expr,partition_description descr,table_rows from INFORMATION_SCHEMA.PARTITIONS where TABLE_SCHEMASCHEMA() AND TABLE_NAMEmx_domain//查看分区信息 CREATE TABLE mx_domain (id int(10) NOT NULL AUTO_INCREMENT,name…

怎样配置键盘最方便,以及一些设计的思考

使用Emacs的人&#xff0c;如果肯折腾&#xff0c;肯定有重新映射键盘的经历。我原来经常看到的是把Ctrl和Capslock交换&#xff0c;但是我感觉没什么道理&#xff0c;因为Ctrl已经用的很熟练了&#xff0c;换了反而不方便&#xff0c;而且对其他程序影响太大。那么我们就要使用…

profile、服务、特征、属性之间的关系

一个profile有很多的服务&#xff0c;一个服务又有很多的特性&#xff0c;一个特性中又有几种属性条目组成。 profile&#xff08;数据配置文件&#xff09; 一个profile文件可以包含一个或者多个服务&#xff0c;一个profile文件包含需要的服务的信息或者为对等设备如何交互的…

面试突击32:为什么创建线程池一定要用ThreadPoolExecutor?

Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/course/detail/35475 在 Java 语言中&#xff0c;并发编程都是依靠线程池完成的&#xff0c;而线程池的创建方式又有很多&#xff0c;但从大的分类…

Bootstrap datepicker 在弹出窗体modal中不工作

解决办法 在 show 方法后面 添加 下面一段代码 $(#modalCard).modal(show);—例子 打开 弹出窗体 //$(#modalCard).modal(hide); $(#modalCard).on(shown.bs.modal, function () { //$(.input-group.date).datetimepicker({ $(#dpReceiveDate).datetimepicker({ format: "…

学习Samba基础命令详解之大话西游01

服务名:smb配置目录:/etc/sabma/主配置文件:/etc/sabma/smb.conf# Global Settings 17行workgroup语法 workgtoup <工作组群>; 预设 workgroup MYGROUP 说明 设定 Samba Server 的工作组 例 workgroup workgroup 和WIN2000S设为一个组&#xff0c;可在网上邻居可中看到…

实例讲解getopt()函数的使用

[cpp] view plaincopy #include <stdio.h> #include <unistd.h> int main(int argc, char *argv[]) { extern char *optarg;//保存选项的参数 extern int optind, opterr, optopt; int ch; printf("\n\n"); pri…

机器学习实战 | SKLearn最全应用指南

Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/course/detail/35475 作者&#xff1a;韩信子ShowMeAI 教程地址&#xff1a;http://www.showmeai.tech/tutorials/41 本文地址&#xff1a;http…