分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.Module
递归地将函数fn
应用于每个子模块及self
,子模块由.children()
返回。典型的用法包括初始化模型的参数(可以参考torc.nn.init)。
语法
torch.nn.Module.apply(fn)
参数
fn
:应用于每个子模块的函数
返回值
self
,即torch.nn.Module
实例
@torch.no_grad()
def init_weights(m):print(m)if type(m) == nn.Linear:m.weight.fill_(1.0)print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
函数实现
def apply(self: T, fn: Callable[['Module'], None]) -> T:r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)as well as self. Typical use includes initializing the parameters of a model(see also :ref:`nn-init-doc`).Args:fn (:class:`Module` -> None): function to be applied to each submoduleReturns:Module: selfExample::>>> @torch.no_grad()>>> def init_weights(m):>>> print(m)>>> if type(m) == nn.Linear:>>> m.weight.fill_(1.0)>>> print(m.weight)>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))>>> net.apply(init_weights)Linear(in_features=2, out_features=2, bias=True)Parameter containing:tensor([[1., 1.],[1., 1.]], requires_grad=True)Linear(in_features=2, out_features=2, bias=True)Parameter containing:tensor([[1., 1.],[1., 1.]], requires_grad=True)Sequential((0): Linear(in_features=2, out_features=2, bias=True)(1): Linear(in_features=2, out_features=2, bias=True))"""for module in self.children():module.apply(fn)fn(self)return self