PyTorch 的 torch.nn.init
模块提供了一组用于初始化张量或模型参数的函数。这些初始化方法对深度学习模型的训练收敛速度和性能有显著影响,正确选择初始化方法可以避免梯度消失或爆炸等问题。
模块功能
torch.nn.init
提供了一系列函数,用于对张量(如权重或偏置)进行初始化。这些函数可以直接作用于张量,或者配合 nn.Module
的 apply
方法对模型参数进行批量初始化。
常用初始化方法
以下是 torch.nn.init
模块中常用的初始化方法及其适用场景:
1. 随机初始化
-
torch.nn.init.uniform_
将张量用均匀分布初始化。 -
torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
-
- 参数:
a
和b
定义分布范围[a, b]
。 - 适用场景: 初始化偏置或某些非权重张量。
- 参数:
-
torch.nn.init.normal_
用正态分布初始化张量。 -
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
-