深入 PyTorch:新手可探索 torch.nn 模块的强大功能

目录

引言

torch.nn使用和详解

Parameter

函数作用

使用技巧

使用方法和示例

UninitializedParameter

特点和用途

可进行的操作

使用示例

UninitializedBuffer

特点和用途

可进行的操作

使用示例

Module**(重点)

关键特性和功能

举例说明

torch.nn.Module 主要方法详解

add_module(name, module)

apply(fn)

bfloat16()

buffers(recurse=True)

children()

cpu()

cuda(device=None)

double()

eval()

extra_repr()

float()

forward(*input)

get_buffer(target)

get_parameter(target)

half()

load_state_dict(state_dict, strict=True, assign=False)

modules()

named_buffers(prefix='', recurse=True, remove_duplicate=True)

named_children()

named_modules(memo=None, prefix='', remove_duplicate=True)

named_parameters(prefix='', recurse=True, remove_duplicate=True)

parameters(recurse=True)

register_backward_hook(hook)

register_buffer(name, tensor, persistent=True)

register_forward_hook(hook)

register_forward_pre_hook(hook)

register_full_backward_hook(hook)

register_parameter(name, param)

state_dict()

to(*args, **kwargs)

train(mode=True)

type(dst_type)

zero_grad(set_to_none=True)

Sequential

主要特性

与 torch.nn.ModuleList 的区别

使用示例

append(module) 方法

ModuleList

主要特性

使用示例

ModuleList 的方法

ModuleDict

主要特性

使用示例

ModuleDict 的方法

ParameterList

主要特性

使用示例

ParameterList 的方法

ParameterDict

主要特性

使用示例

ParameterDict 的方法

总结


引言

        在当今快速发展的人工智能领域,深度学习已成为其中最引人注目的技术之一。PyTorch 作为一种流行的深度学习框架,因其灵活性和易用性而受到广泛欢迎。在 PyTorch 的众多组件中,torch.nn 模块无疑是构建复杂深度学习模型的基石。本文将深入探讨 torch.nn 模块的功能、优势和使用技巧,旨在为读者提供一个清晰的理解和应用指南。torch.nn 提供了构建神经网络所需的所有基本构建块,包括各种类型的层(如卷积层、池化层、激活函数)、损失函数和容器。这些组件不仅是模块化和可重用的,而且也支持灵活的网络架构设计。通过本文,我们将逐一解析这些组件的特性和使用场景,并分享一些实用的技巧来优化网络性能。无论是新手还是有经验的开发者,都可以从中获得宝贵的见解,以更好地利用这个强大的模块来设计和实现高效的深度学习模型。

        接下来的章节将从 torch.nn 的基础知识开始,逐步深入到更高级的主题,包括定制网络层、优化技巧和最佳实践。准备好,让我们开始这次深入浅出的 torch.nn 之旅吧!

torch.nn使用和详解

Parameter

    torch.nn.parameter.Parameter 是 PyTorch 深度学习框架中的一个重要类,用于表示神经网络中的参数。这个类是 Tensor 的子类,它在与 Module(模块)一起使用时具有特殊属性。当 Parameter 被赋值为 Module 的属性时,它自动被添加到模块的参数列表中,并且会出现在例如 parameters() 迭代器中。这与普通的 Tensor 不同,因为 Tensor 赋值给模块时不会有这样的效果。

函数作用

  • 目的: Parameter 主要用于将张量标记为模块的参数。这对于模型的训练和参数更新至关重要,因为只有被标记为 Parameter 的张量才会在模型训练时更新。
  • 使用场景: 在构建自定义神经网络层或整个模型时,需要用到 Parameter 来定义可训练的参数(如权重和偏置)。这些参数在训练过程中会通过反向传播进行优化。

使用技巧

  • 参数初始化: 在定义模型的参数时,可以直接使用 Parameter 类对其进行初始化,从而确保这些参数会被识别并在训练过程中更新。
  • 控制梯度: 通过设置 requires_grad 参数,可以控制特定参数是否需要在反向传播中计算梯度。这对于冻结模型的部分参数或进行特定的优化策略非常有用。

使用方法和示例

以下是 torch.nn.parameter.Parameter 的使用示例:

import torch
import torch.nn as nn# 定义一个自定义的线性层
class CustomLinearLayer(nn.Module):def __init__(self, in_features, out_features):super(CustomLinearLayer, self).__init__()# 定义权重为一个可训练的参数self.weight = nn.Parameter(torch.randn(out_features, in_features))# 定义偏置为一个可训练的参数self.bias = nn.Parameter(torch.randn(out_features))def forward(self, x):# 实现前向传播return torch.matmul(x, self.weight.t()) + self.bias# 创建一个自定义的线性层实例
layer = CustomLinearLayer(5, 3)
print(list(layer.parameters()))

        在上述代码中,CustomLinearLayer 类中定义了两个 Parameter 对象:weightbias。这些参数在模块被实例化时自动注册,并在训练过程中会被优化。通过打印 layer.parameters(),可以看到这些被注册的参数。

UninitializedParameter

torch.nn.parameter.UninitializedParameter 是 PyTorch 中的一个特殊类,用于表示尚未初始化的参数。这个类是 torch.nn.Parameter 的一个特殊情况,其主要特点是在创建时数据的形状(shape)还未知。

特点和用途

  • 尚未初始化: 与常规的 torch.nn.Parameter 不同,UninitializedParameter 不持有任何数据。这意味着在初始化之前,试图访问某些属性(如它们的形状)会引发运行时错误。
  • 灵活的初始化: UninitializedParameter 允许在模型定义阶段创建参数,而不必立即指定它们的大小或形状。这在某些情况下非常有用,例如,当参数的大小依赖于运行时才可知的因素时。

可进行的操作

  • 更改数据类型: 可以更改 UninitializedParameter 的数据类型。
  • 移动到不同设备: 可以将 UninitializedParameter 移动到不同的设备(例如从 CPU 移到 GPU)。
  • 转换为常规参数: 可以将 UninitializedParameter 转换为常规的 torch.nn.Parameter,此时需要指定其形状和数据。

使用示例

在下面的示例中,将展示如何创建一个未初始化的参数,并在稍后将其转换为常规参数:

import torch
import torch.nn as nnclass CustomLayer(nn.Module):def __init__(self):super(CustomLayer, self).__init__()# 创建一个未初始化的参数self.uninitialized_param = nn.parameter.UninitializedParameter()def forward(self, x):# 在前向传播中使用参数前必须先初始化if isinstance(self.uninitialized_param, nn.parameter.UninitializedParameter):# 初始化参数self.uninitialized_param = nn.Parameter(torch.randn(x.size(1), x.size(1)))return torch.matmul(x, self.uninitialized_param.t())# 创建自定义层的实例
layer = CustomLayer()# 假设输入x
x = torch.randn(10, 5)# 使用自定义层
output = layer(x)
print(output)

        在这个例子中,CustomLayer 在初始化时创建了一个 UninitializedParameter。在进行前向传播时,检查这个参数是否已初始化,如果没有,则对其进行初始化,并将其转换为常规的 Parameter。这种方式在处理动态大小的输入时特别有用。

UninitializedBuffer

        torch.nn.parameter.UninitializedBuffer 是 PyTorch 中的一个特殊类,它代表一个尚未初始化的缓冲区。这个类是 torch.Tensor 的一个特殊情形,其主要特点是在创建时数据的形状(shape)还未知。

特点和用途

  • 尚未初始化: 与常规的 torch.Tensor 不同,UninitializedBuffer 不持有任何数据。这意味着在初始化之前,尝试访问某些属性(如它们的形状)会引发运行时错误。
  • 适用场景: UninitializedBuffer 适用于那些在模型定义阶段需要创建缓冲区,但其大小或形状取决于后来才可知的数据或配置的情况。

可进行的操作

  • 更改数据类型: 可以更改 UninitializedBuffer 的数据类型。
  • 移动到不同设备: 可以将 UninitializedBuffer 移动到不同的设备(例如从 CPU 移到 GPU)。
  • 转换为常规张量: 可以将 UninitializedBuffer 转换为常规的 torch.Tensor,此时需要指定其形状和数据。

使用示例

在下面的示例中,将展示如何创建一个未初始化的缓冲区,并在稍后将其转换为常规张量:

import torch
import torch.nn as nnclass CustomLayer(nn.Module):def __init__(self):super(CustomLayer, self).__init__()# 创建一个未初始化的缓冲区self.uninitialized_buffer = nn.parameter.UninitializedBuffer()def forward(self, x):# 在前向传播中使用缓冲区前必须先初始化if isinstance(self.uninitialized_buffer, nn.parameter.UninitializedBuffer):# 初始化缓冲区self.uninitialized_buffer = torch.Tensor(x.size(0), x.size(1))# 在这里可以使用缓冲区进行计算或其他操作return x + self.uninitialized_buffer# 创建自定义层的实例
layer = CustomLayer()# 假设输入x
x = torch.randn(10, 5)# 使用自定义层
output = layer(x)
print(output)

        在这个例子中,CustomLayer 在初始化时创建了一个 UninitializedBuffer。在进行前向传播时,检查这个缓冲区是否已初始化,如果没有,则对其进行初始化,并将其转换为常规的 Tensor。这种方法在动态处理数据大小时非常有用,特别是在需要临时存储数据但在模型定义阶段无法确定其大小的情况下。  

Module**(重点)

         torch.nn.Module 是 PyTorch 中用于构建所有神经网络模型的基类。几乎所有的 PyTorch 神经网络模型都是通过继承 torch.nn.Module 来构建的。这个类提供了模型需要的基本功能,如参数管理、模型保存和加载、设备转移(例如,从 CPU 到 GPU)等。

关键特性和功能

  • 模块树结构: Module 可以包含其他 Module,形成一个嵌套的树状结构。这允许用户以模块化的方式构建复杂的神经网络。
  • 参数和缓冲区的管理: Module 自动管理其属性中的所有 ParameterBuffer 对象。这包括注册参数、转移到不同设备、保存和加载模型状态等。
  • 前向传播定义: 所有子类都应该覆盖 forward 方法,以定义其在接收输入时的计算过程。

举例说明

以下是一个基本的 torch.nn.Module 子类的示例:

import torch.nn as nn
import torch.nn.functional as Fclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))model = SimpleModel()

        在这个例子中,SimpleModel 继承了 torch.nn.Module。在其构造函数中,定义了两个卷积层 conv1conv2,并在 forward 方法中定义了模型的前向传播逻辑。

torch.nn.Module 主要方法详解

add_module(name, module)
  • 功能:向当前模块添加子模块。
  • 参数:
    • name: 子模块的名称。
    • module: 要添加的子模块对象。
# 定义一个自定义模块
class CustomModule(nn.Module):def __init__(self):super(CustomModule, self).__init__()# 创建一个线性层linear = nn.Linear(10, 5)# 使用 add_module 添加线性层作为子模块self.add_module('linear', linear)
apply(fn)
  • 功能:递归地将函数 fn 应用于每个子模块及其自身。
  • 参数:
    • fn: 要应用的函数,通常用于初始化参数。
# 定义一个初始化权重的函数
def init_weights(m):if type(m) == nn.Linear:nn.init.uniform_(m.weight)# 应用 init_weights 函数初始化模型的权重
model = CustomModule()
model.apply(init_weights)
bfloat16()
  • 功能:将所有浮点参数和缓冲区转换为 bfloat16 数据类型。
  • 注意:此方法就地修改模块。
# 将模型的参数和缓冲区转换为 bfloat16 数据类型
model.bfloat16()
buffers(recurse=True)
  • 功能:返回一个迭代器,遍历模块的所有缓冲区。
  • 参数:
    • recurse: 如果为 True,则遍历此模块及所有子模块的缓冲区。
# 遍历模型的所有缓冲区
for buf in model.buffers():print(buf.size())
children()
  • 功能:返回一个迭代器,遍历模块的直接子模块。
# 遍历模型的直接子模块
for child in model.children():print(child)
cpu()
  • 功能:将所有模型参数和缓冲区移动到 CPU。
# 将模型移动到 CPU
model.cpu()
cuda(device=None)
  • 功能:将所有模型参数和缓冲区移动到 GPU。
  • 参数:
    • device: 指定 GPU 设备。
# 将模型移动到 GPU
model.cuda()
double()
  • 功能:将所有浮点参数和缓冲区转换为 double 数据类型。
# 将模型的参数和缓冲区转换为 double 数据类型
model.double()
eval()
  • 功能:将模块设置为评估模式。
# 将模型设置为评估模式
model.eval()
extra_repr()
  • 功能:设置模块的额外表示,用于自定义信息打印。
# 自定义模型的额外表示
class CustomModule(nn.Module):def __init__(self):super(CustomModule, self).__init__()def extra_repr(self):return '自定义信息'model = CustomModule()
print(model)
float()
  • 功能:将所有浮点参数和缓冲区转换为 float 数据类型。
# 将模型的参数和缓冲区转换为 float 数据类型
model.float()
forward(*input)
  • 功能:定义每次调用时的计算,所有子类必须覆盖此方法。
# 定义模型的前向传播
class CustomModule(nn.Module):def __init__(self):super(CustomModule, self).__init__()self.linear = nn.Linear(10, 5)def forward(self, x):return self.linear(x)model = CustomModule()
input = torch.randn(1, 10)
output = model(input)
get_buffer(target)
  • 功能:根据目标名称返回对应的缓冲区。
# 获取特定名称的缓冲区
buffer = model.get_buffer('buffer_name')
get_parameter(target)
  • 功能:根据目标名称返回对应的参数。
# 获取特定名称的参数
parameter = model.get_parameter('param_name')
half()
  • 功能:将所有浮点参数和缓冲区转换为半精度 (half) 数据类型。
# 将模型的参数和缓冲区转换为半精度 (half) 数据类型
model.half()
load_state_dict(state_dict, strict=True, assign=False)
  • 功能:从 state_dict 中复制参数和缓冲区到此模块及其后代。
  • 参数:
    • state_dict: 包含参数和持久缓冲区的字典。
    • strict: 是否严格匹配 state_dict 和模块的键。
# 从 state_dict 加载模型状态
state_dict = {'linear.weight': torch.randn(5, 10), 'linear.bias': torch.randn(5)}
model.load_state_dict(state_dict, strict=False)
modules()
  • 功能:返回一个迭代器,遍历网络中的所有模块。
# 遍历网络中的所有模块
for module in model.modules():print(module)
named_buffers(prefix='', recurse=True, remove_duplicate=True)
  • 功能:返回一个迭代器,遍历模块的所有缓冲区,同时提供缓冲区的名称。
# 遍历模型的所有缓冲区,同时提供缓冲区的名称
for name, buf in model.named_buffers():print(f"Buffer name: {name}, Buffer: {buf}")
named_children()
  • 功能:返回一个迭代器,遍历模块的直接子模块,同时提供子模块的名称。
# 遍历模型的直接子模块,同时提供子模块的名称
for name, child in model.named_children():print(f"Child name: {name}, Child module: {child}")
named_modules(memo=None, prefix='', remove_duplicate=True)
  • 功能:返回一个迭代器,遍历网络中的所有模块,同时提供模块的名称。
# 遍历网络中的所有模块,同时提供模块的名称
for name, module in model.named_modules():print(f"Module name: {name}, Module: {module}")
named_parameters(prefix='', recurse=True, remove_duplicate=True)
  • 功能:返回一个迭代器,遍历模块的所有参数,同时提供参数的名称。
# 遍历模型的所有参数,同时提供参数的名称
for name, param in model.named_parameters():print(f"Parameter name: {name}, Parameter: {param}")
parameters(recurse=True)
  • 功能:返回一个迭代器,遍历模块的所有参数。
# 遍历模型的所有参数
for param in model.parameters():print(param)
register_backward_hook(hook)
  • 功能:注册一个反向传播钩子。
# 注册一个反向传播钩子
def backward_hook(module, grad_input, grad_output):print(f"Backward hook in {module}")model.register_backward_hook(backward_hook)
register_buffer(name, tensor, persistent=True)
  • 功能:向模块添加一个缓冲区。
# 向模块添加一个缓冲区
model.register_buffer('new_buffer', torch.randn(5))
register_forward_hook(hook)
  • 功能:注册一个前向传播钩子。
# 注册一个前向传播钩子
def forward_hook(module, input, output):print(f"Forward hook in {module}")model.register_forward_hook(forward_hook)
register_forward_pre_hook(hook)
  • 功能:注册一个前向传播预处理钩子。
# 注册一个前向传播钩子
def forward_hook(module, input, output):print(f"Forward hook in {module}")model.register_forward_hook(forward_hook)
register_full_backward_hook(hook)
  • 功能:注册一个完整的反向传播钩子。
# 注册一个完整的反向传播钩子
def full_backward_hook(module, grad_input, grad_output):print(f"Full backward hook in {module}")model.register_full_backward_hook(full_backward_hook)
register_parameter(name, param)
  • 功能:向模块添加一个参数。
# 向模块添加一个参数
param = nn.Parameter(torch.randn(5))
model.register_parameter('new_param', param)
state_dict()
  • 功能:返回包含模块所有状态信息的字典。
# 获取模块所有状态信息的字典
state_dict = model.state_dict()
to(*args, **kwargs)
  • 功能:移动和/或转换参数和缓冲区。
# 移动和/或转换参数和缓冲区
# 移动模型到 GPU 并转换为 double 类型
model.to('cuda', dtype=torch.double)
train(mode=True)
  • 功能:将模块设置为训练模式。
# 将模块设置为训练模式
model.train()
type(dst_type)
  • 功能:将所有参数和缓冲区转换为指定类型。
# 将所有参数和缓冲区转换为指定类型
model.type(torch.float32)
zero_grad(set_to_none=True)
  • 功能:重置所有模型参数的梯度。
# 重置所有模型参数的梯度
model.zero_grad()

         这些示例涵盖了 torch.nn.Module 类中的大多数主要方法,展示了如何在实际情况中使用它们。

Sequential

   torch.nn.Sequential 是 PyTorch 中的一个容器模块,用于按顺序封装一系列子模块。它简化了模型的构建过程,使得将多个模块组合成一个单独的序列变得容易和直观。

主要特性

  • 顺序处理: Sequential 按照它们在构造函数中传递的顺序,依次处理每个子模块。输入数据首先被传递到第一个模块,然后依次传递到每个后续模块。
  • 容器作为单一模块: Sequential 允许将整个容器视为单一模块,对其进行的任何转换都适用于它存储的每个模块(每个模块都是 Sequential 的一个注册子模块)。

torch.nn.ModuleList 的区别

torch.nn.ModuleList 仅仅是一个存储子模块的列表,而 Sequential 中的层是级联连接的。在 ModuleList 中,层之间没有直接的数据流动关联,而在 Sequential 中,一个层的输出直接成为下一个层的输入。

使用示例

  1. 使用 Sequential 创建一个简单的模型:

model = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU())

在这个例子中,输入数据首先通过一个 Conv2d 层,然后是 ReLU 层,接着是第二个 Conv2d 层,最后是另一个 ReLU 层。

使用带有 OrderedDict 的 Sequential:

from collections import OrderedDictmodel = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1, 20, 5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20, 64, 5)),('relu2', nn.ReLU())]))

         使用 OrderedDict 允许为每个模块指定一个唯一的名称。这在需要引用特定层或在打印模型结构时提高了可读性。

append(module) 方法

  • 功能: 将给定的模块添加到序列的末尾。
  • 参数:
    • module (nn.Module): 要附加的模块。
  • 返回类型: Sequential

这种方式构建的模型可以简化前向传播的实现,使得模型的构建和理解更加直观。

ModuleList

torch.nn.ModuleList 是 PyTorch 中用于存储子模块的列表容器。它类似于 Python 的常规列表,但具有额外的功能,使其能够适当地注册其中包含的模块,并使它们对所有 Module 方法可见。

主要特性

  • 列表式结构: ModuleList 提供了一个列表式的结构来保存模块,允许通过索引或迭代器访问这些模块。
  • 模块注册: 它包含的模块会被正确注册,这意味着当调用诸如 .parameters().to(device)Module 方法时,这些子模块也会被考虑在内。

使用示例

class MyModule(nn.Module):def __init__(self):super().__init__()# 使用 ModuleList 创建一个线性层的列表self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])def forward(self, x):# ModuleList 可以作为迭代器,也可以使用索引访问for i, l in enumerate(self.linears):x = self.linears[i // 2](x) + l(x)return x

 在这个例子中,MyModule 创建了一个 ModuleList,其中包含了 10 个 nn.Linear(10, 10) 层。在 forward 方法中,使用了两种不同的方式来访问和应用这些层。

ModuleList 的方法

  1. append(module)

    • 功能:在列表末尾添加一个给定的模块。
    • 参数:
      • module (nn.Module):要添加的模块。
  2. extend(modules)

    • 功能:将来自 Python 可迭代对象的模块添加到列表的末尾。
    • 参数:
      • modules (iterable):可迭代的模块对象。
  3. insert(index, module)

    • 功能:在列表中给定索引之前插入一个模块。
    • 参数:
      • index (int):插入的索引。
      • module (nn.Module):要插入的模块。

ModuleList 提供了灵活的方式来管理子模块的集合,特别是当模型的某些部分是动态的或者模型结构中的层的数量在初始化时未知时非常有用。

ModuleDict

torch.nn.ModuleDict 是 PyTorch 中的一个容器模块,用于以字典形式保存子模块。它类似于 Python 的常规字典,但其包含的模块会被正确注册,并且对所有 Module 方法可见。

主要特性

  • 字典式结构: ModuleDict 提供了一个字典式的结构来保存模块,允许通过键值对访问这些模块。
  • 有序字典: 自 Python 3.6 起,ModuleDict 是一个有序字典,它会保留插入顺序和合并顺序。

使用示例

class MyModule(nn.Module):def __init__(self):super().__init__()# 使用 ModuleDict 创建一个由不同层组成的字典self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})# 可以使用列表初始化 ModuleDictself.activations = nn.ModuleDict([['lrelu', nn.LeakyReLU()],['prelu', nn.PReLU()]])def forward(self, x, choice, act):# 通过键值访问 ModuleDict 中的模块x = self.choices[choice](x)x = self.activations[act](x)return x

在这个例子中,MyModule 创建了两个 ModuleDict,一个用于保存卷积层和池化层,另一个用于保存激活层。

ModuleDict 的方法

  1. clear()

    • 功能:清除 ModuleDict 中的所有项。
  2. items()

    • 功能:返回 ModuleDict 中的键/值对的迭代器。
  3. keys()

    • 功能:返回 ModuleDict 键的迭代器。
  4. pop(key)

    • 功能:从 ModuleDict 中移除键并返回其模块。
    • 参数:
      • key (str):要从 ModuleDict 中弹出的键。
  5. update(modules)

    • 功能:用来自映射或迭代器的键值对更新 ModuleDict,覆盖现有的键。
    • 参数:
      • modules (iterable):从字符串到模块的映射(字典),或键值对的迭代器。
  6. values()

    • 功能:返回 ModuleDict 中模块值的迭代器。

ModuleDict 提供了一个灵活的方式来管理具有特定键的子模块的集合。这在模型设计中特别有用,尤其是当模型的不同部分需要根据键动态选择时。

ParameterList

torch.nn.ParameterList 是 PyTorch 中的一个容器模块,用于按列表形式保存参数(Parameter 对象)。它类似于 Python 的常规列表,但其特殊之处在于其中包含的 Tensor 对象会被转换为 Parameter 对象,并正确注册,使得这些参数对所有 Module 方法可见。

主要特性

  • 列表式结构: ParameterList 提供了一个列表式的结构来保存参数,允许通过索引或迭代器访问这些参数。
  • 参数注册: 其中包含的 Tensor 对象会被自动转换为 Parameter 对象,确保它们可以被 PyTorch 的优化器等模块正确处理。

使用示例

class MyModule(nn.Module):def __init__(self):super().__init__()# 使用 ParameterList 创建一个包含多个参数的列表self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])def forward(self, x):# ParameterList 可以作为迭代器,也可以使用索引访问for i, p in enumerate(self.params):x = self.params[i // 2].mm(x) + p.mm(x)return x

在这个例子中,MyModule 创建了一个 ParameterList,其中包含了 10 个形状为 10x10 的随机参数。在 forward 方法中,这些参数被用于矩阵乘法操作。

ParameterList 的方法

  1. append(value)

    • 功能:在列表末尾添加一个给定的值(会被转换为 Parameter)。
    • 参数:
      • value (Any):要添加的值。
  2. extend(values)

    • 功能:将来自 Python 可迭代对象的值添加到列表的末尾(每个值都会被转换为 Parameter)。
    • 参数:
      • values (iterable):要添加的值的可迭代对象。

ParameterList 提供了一种灵活的方式来管理模型中的参数集合,特别是当模型的某些部分参数数量动态变化时非常有用。通过使用 ParameterList,您可以确保模型的所有参数都正确注册,并且可以通过标准的 PyTorch 方法进行访问和优化。

ParameterDict

torch.nn.ParameterDict 是 PyTorch 中用于以字典形式保存参数(Parameter 对象)的容器模块。它类似于 Python 的常规字典,但其特殊之处在于其中包含的参数被正确注册,并对所有 Module 方法可见。

主要特性

  • 字典式结构: ParameterDict 提供了一个字典式的结构来保存参数,允许通过键值对访问这些参数。
  • 有序字典: ParameterDict 是一个有序字典,它保留插入顺序和合并顺序(对于 OrderedDict 或另一个 ParameterDict)。

使用示例

class MyModule(nn.Module):def __init__(self):super().__init__()# 使用 ParameterDict 创建一个由不同参数组成的字典self.params = nn.ParameterDict({'left': nn.Parameter(torch.randn(5, 10)),'right': nn.Parameter(torch.randn(5, 10))})def forward(self, x, choice):# 通过键值访问 ParameterDict 中的参数x = self.params[choice].mm(x)return x

在这个例子中,MyModule 创建了一个 ParameterDict,其中包含了两个名为 'left' 和 'right' 的参数。在 forward 方法中,根据传入的 choice 键来选择相应的参数进行矩阵乘法操作。

ParameterDict 的方法

  1. clear()

    • 功能:清除 ParameterDict 中的所有项。
  2. copy()

    • 功能:返回这个 ParameterDict 实例的副本。
  3. fromkeys(keys, default=None)

    • 功能:根据提供的键返回一个新的 ParameterDict
    • 参数:
      • keys (iterable, string):用于创建新 ParameterDict 的键。
      • default (Parameter, 可选):为所有键设置的默认值。
  4. get(key, default=None)

    • 功能:如果存在,返回与 key 相关联的参数。否则,如果提供了 default,则返回 default;如果没有提供,则返回 None
  5. items()

    • 功能:返回 ParameterDict 键/值对的迭代器。
  6. keys()

    • 功能:返回 ParameterDict 键的迭代器。
  7. pop(key)

    • 功能:从 ParameterDict 中移除键并返回其参数。
    • 参数:
      • key (str):要从 ParameterDict 中弹出的键。
  8. popitem()

    • 功能:从 ParameterDict 中移除并返回最后插入的 (键, 参数) 对。
  9. setdefault(key, default=None)

    • 功能:如果 keyParameterDict 中,则返回其值。如果不是,插入 key 与参数 default 并返回 defaultdefault 默认为 None
  10. update(parameters)

    • 功能:用来自映射或迭代器的键值对更新 ParameterDict,覆盖现有的键。
  11. values()

    • 功能:返回 ParameterDict 中参数值的迭代器。

ParameterDict 提供了一种灵活的方式来管理模型中具有特定键的参数集合。这在模型设计中特别有用,尤其是当模型的不同部分需要根据键动态选择参数时。

总结

        本文深入探索了 PyTorch 框架中的 torch.nn 模块,这是构建和实现高效深度学习模型的核心组件。我们详细介绍了 torch.nn 的关键类别和功能,包括 Parameter, Module, Sequential, ModuleList, ModuleDict, ParameterListParameterDict,为读者提供了一个全面的理解和应用指南。这篇博客仅仅是torch.nn的一部分功能,后续我这边会继续更新这个模块的其他相关功能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/592587.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SSRF靶场攻略记录

目录 1 获取并显示指定文件内容的应用程序代码 靶场路径 漏洞点

ubuntu22.04装机记录

1.配网 文件地址: greeW19a406  ⚙ /etc/netplan  pwd  ✔  8  15:40:50 /etc/netplangreeW19a406  ⚙ /etc/netplan  #文件内容,注意缩进与网关可能报错需要添加初始化命令 # Let NetworkManager manage all…

后端开发——jdbc的学习(一)

上篇结束了Mysql数据库的基本使用,本篇开始对JDBC进行学习总结,开始先简单介绍jdbc的基本使用,以及简单的练习;后续会继续更新!以下代码可以直接复制到idea中运行,便于理解和练习。 JDBC的概念 JDBC&#…

【深度学习-基础学习】Transformer 笔记

本篇文章学习总结 李宏毅 2021 Spring 课程中关于 Transformer 相关的内容。课程链接以及PPT:李宏毅Spring2021ML这篇Blog需要Self-Attention为前置知识。 Transfomer 简介 Transfomer 架构主要是用来解决 Seq2Seq 问题的,也就是 Sequence to Sequence…

ntp校时服务器、ntp授时服务器、ntp时钟服务器

ntp校时服务器、ntp授时服务器、ntp时钟服务器 ntp校时服务器、ntp授时服务器、ntp时钟服务器 三者都是利用NTP技术来实现时间同步服务的一种电子科技产品,名称不同功能一样而已、设备采用冗余架构设计,高精度时钟直接来源于北斗、GPS系统中各个卫星的原…

web前端——clear可以清除浮动产生的影响

clear可以解决高度塌陷的问题&#xff0c;产生的副作用要小 未使用clear之前 <!DOCTYPE html> <head><meta charset"UTF-8"><title>高度塌陷相关学习</title><style>div{font-size:50px;}.box1{width:200px;height:200px;backg…

JavaScript:html获取url参数

使用場景&#xff1a;常用在分享页面 1、采用正则表达式获取地址栏参数 function getQueryString(name) {var reg new RegExp("(^|&)" name "([^&]*)(&|$)", "i" );var r window.location.search.substr(1).match(reg);if (r…

【已解决】打印PDF文件,如何跳过不需要的页面?

打印PDF文件的时候&#xff0c;有时候我们只需要打印其中的几页&#xff0c;并不需要全部打印&#xff0c;那如何在打印时跳过那些不需要的页面呢&#xff1f;不清楚的小伙伴一起来看看吧&#xff01; 如果你是通过网页打开PDF文件&#xff0c;那么可以在页面中找到并点击“打…

[每周一更]-(第53期):Python3和Django环境安装并搭建Django

Python和Django 的安装 Python和Django 兼容情况 django 1.11.x python 2.7 3.4 3.5 3.6 LTS python 目前在用版本 Python 3.6.5 2018-03-28 更新Python 2.7.15 2018-05-01 更新Python 2.7.5 2013-05-12 更新 python和python3安装pip 同时安装上 python2.7.18、python3.11…

如何用 GPT 去分析Excel数据

背景 需要尝试分析 Excel 的内容&#xff0c;每月都需要进行相关的分析&#xff0c;固定化流程&#xff0c;因此尝试制作固化的脚本&#xff0c;方便后续的分析。 执行步骤 帮我写一段 python 代码&#xff0c;我需要区分一个.xlsx的数据。格式示例如下&#xff1a; ”这块自…

mysql创建数据库和表

要在MySQL中创建数据库和表&#xff0c;可以按照以下步骤进行操作&#xff1a; 连接到MySQL服务器&#xff1a; mysql -u username -p其中&#xff0c;username 是你的MySQL用户名。执行上述命令后&#xff0c;系统会提示你输入密码&#xff0c;输入密码后即可登录到MySQL服务…

c语言结构体学习上篇

文章目录 前言一、结构体的声明1&#xff0c;什么叫结构体?2&#xff0c;结构体的类型3,结构体变量的创建和初始化4&#xff0c;结构体的类型5&#xff0c;结构体的初始化 二、结构体的访问1&#xff0c;结构体成员的点操作符访问2&#xff0c;结构体体成员的指针访问 前言 昨…

用户管理第2节课--idea 2023.2 后端--实现基本数据库操作(操作user表) -- 自动生成 --【本人】

一、插件安装 1.1 搜索插件 mybatis 安装 1.2 接受安装 1.3 再次进入&#xff0c;说明安装好了 1.4 与鱼皮不同点 1&#xff09;mybatis 版本不一致 鱼皮&#xff1a; 本人&#xff1a; 2&#xff09;鱼皮需重启安装 本人不需要 1.5 【需完成 三、步骤&#xff0c;再来看】 …

AI发展将来对人力市场有什么影响

#AI发展将来对人力市场有什么影响 #合同智能审查、合同要素智能提取、合同版本对比、合同智能起草、文本一致性对比、广告审查、合同范本库 人工智能的发展对人们的日常生活起到至关重要的作用&#xff0c;智能发展也是涉猎众多领域&#xff0c;人工智能技术对于企业法务管理…

Git - 强制替换覆盖 master 分支解决方案

问题描述 在版本迭代中&#xff0c;通常会保持一个主分支 master&#xff0c;及多个 dev 分支&#xff0c;但是因为 dev 分支的开发周期过长&#xff0c;迭代太多而没有及时维护 master &#xff0c;导致后来发版上线的大部分代码都在 dev 分支上&#xff0c;如果将代码在 mas…

LiveGBS流媒体平台GB/T28181功能-用户管理通道权限管理关联通道支持只看已选只看未选添加用户备注角色

LiveGBS功能用户管理通道权限管理关联通道支持只看已选只看未选添加用户备注角色 1、用户管理2、添加用户3、关联通道3.1、只看已选3.2、只看未选 4、自定义角色5、搭建GB28181视频直播平台 1、用户管理 LiveGBS支持用户管理&#xff0c;添加用户&#xff0c;及配置相关用户权…

promise.prototype.finally重写和兼容火狐低版本浏览器

一、finally()方法用于指定不管 Promise 对象最后状态如何&#xff0c;都会执行的操作。该方法是 ES2018 引入标准的 let promise new Promise() promise .then(result > {}) .catch(error > {}) .finally(() > {})finally方法的回调函数不接受任何参数;finally方法…

element-ui Tree 树形控件 过滤保留子级并获取过滤后的数据

本示例基于vue2 element-ui element-ui 的官网demo是只保留到过滤值一级的&#xff0c;并不会保留其子级 目标 1、Tree 树形控件 保留过滤值的子级 2、在第一次过滤数据的基础上进行第二次过滤 先看效果 Tree 树形控件 保留过滤值的子级 <el-treeclass"filter-t…

直观从零理解 梯度下降(Gradient descent) VS 随机梯度下降 (Stochastic gradient descent) 函数优化

首发于Data Science 单变量微分(Differentiation) 常用基本微分有&#xff1a; 四则运算法则&#xff1a; 链式法则(Chain-rule) 极大值(maxima)与极小值(minima) 向量微分 梯度下降(Gradient descent):几何直觉 学习率&#xff08;Learning Rate&#xff09;的直观理解…

笔记中所得(已删减)

1.交流电的一个周期内电压/电流的平均值都为0 2.电动势:电池将单位正电荷由负极搬到正极所做的功 5.额定能量:电池的额定容量乘以标称电压,以Wh为单位 6.500mAh意义是可以以500mA的电流放电1小时 7.电池容量的单位是mAh 13.实际电流源不能串联 14. 15. 16. 17. 18. 19.电…