【深度学习笔记】4_2-3 模型参数的访问、初始化和共享

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

4.2 模型参数的访问、初始化和共享

在3.3节(线性回归的简洁实现)中,我们通过init模块来初始化模型的参数。我们也介绍了访问模型参数的简单方法。本节将深入讲解如何访问和初始化模型参数,以及如何在多个层之间共享同一份模型参数。

我们先定义一个与上一节中相同的含单隐藏层的多层感知机。我们依然使用默认方式初始化它的参数,并做一次前向计算。与之前不同的是,在这里我们从nn中导入了init模块,它包含了多种模型初始化方法。

import torch
from torch import nn
from torch.nn import initnet = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))  # pytorch已进行默认初始化print(net)
X = torch.rand(2, 4)
Y = net(X).sum()

输出:

Sequential((0): Linear(in_features=4, out_features=3, bias=True)(1): ReLU()(2): Linear(in_features=3, out_features=1, bias=True)
)

4.2.1 访问模型参数

回忆一下上一节中提到的Sequential类与Module类的继承关系。对于Sequential实例中含模型参数的层,我们可以通过Module类的parameters()或者named_parameters方法来访问所有参数(以迭代器的形式返回),后者除了返回参数Tensor外还会返回其名字。下面,访问多层感知机net的所有参数:

print(type(net.named_parameters()))
for name, param in net.named_parameters():print(name, param.size())

输出:

<class 'generator'>
0.weight torch.Size([3, 4])
0.bias torch.Size([3])
2.weight torch.Size([1, 3])
2.bias torch.Size([1])

可见返回的名字自动加上了层数的索引作为前缀。
我们再来访问net中单层的参数。对于使用Sequential类构造的神经网络,我们可以通过方括号[]来访问网络的任一层。索引0表示隐藏层为Sequential实例最先添加的层。

for name, param in net[0].named_parameters():print(name, param.size(), type(param))

输出:

weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>
bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>

因为这里是单层的所以没有了层数索引的前缀。另外返回的param的类型为torch.nn.parameter.Parameter,其实这是Tensor的子类,和Tensor不同的是如果一个TensorParameter,那么它会自动被添加到模型的参数列表里,来看下面这个例子。

class MyModel(nn.Module):def __init__(self, **kwargs):super(MyModel, self).__init__(**kwargs)self.weight1 = nn.Parameter(torch.rand(20, 20))self.weight2 = torch.rand(20, 20)def forward(self, x):passn = MyModel()
for name, param in n.named_parameters():print(name)

输出:

weight1

上面的代码中weight1在参数列表中但是weight2却没在参数列表中。

因为ParameterTensor,即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度。

weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad) # 反向传播前梯度为None
Y.backward()
print(weight_0.grad)

输出:

tensor([[ 0.2719, -0.0898, -0.2462,  0.0655],[-0.4669, -0.2703,  0.3230,  0.2067],[-0.2708,  0.1171, -0.0995,  0.3913]])
None
tensor([[-0.2281, -0.0653, -0.1646, -0.2569],[-0.1916, -0.0549, -0.1382, -0.2158],[ 0.0000,  0.0000,  0.0000,  0.0000]])

4.2.2 初始化模型参数

我们在3.15节(数值稳定性和模型初始化)中提到了PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略(不同类型的layer具体采样的哪一种初始化方法的可参考源代码)。但我们经常需要使用其他方法来初始化权重。PyTorch的init模块里提供了多种预设的初始化方法。在下面的例子中,我们将权重参数初始化成均值为0、标准差为0.01的正态分布随机数,并依然将偏差参数清零。

for name, param in net.named_parameters():if 'weight' in name:init.normal_(param, mean=0, std=0.01)print(name, param.data)

输出:

0.weight tensor([[ 0.0030,  0.0094,  0.0070, -0.0010],[ 0.0001,  0.0039,  0.0105, -0.0126],[ 0.0105, -0.0135, -0.0047, -0.0006]])
2.weight tensor([[-0.0074,  0.0051,  0.0066]])

下面使用常数来初始化权重参数。

for name, param in net.named_parameters():if 'bias' in name:init.constant_(param, val=0)print(name, param.data)

输出:

0.bias tensor([0., 0., 0.])
2.bias tensor([0.])

4.2.3 自定义初始化方法

有时候我们需要的初始化方法并没有在init模块中提供。这时,可以实现一个初始化方法,从而能够像使用其他初始化方法那样使用它。在这之前我们先来看看PyTorch是怎么实现这些初始化方法的,例如torch.nn.init.normal_

def normal_(tensor, mean=0, std=1):with torch.no_grad():return tensor.normal_(mean, std)

可以看到这就是一个inplace改变Tensor值的函数,而且这个过程是不记录梯度的。
类似的我们来实现一个自定义的初始化方法。在下面的例子里,我们令权重有一半概率初始化为0,有另一半概率初始化为 [ − 10 , − 5 ] [-10,-5] [10,5] [ 5 , 10 ] [5,10] [5,10]两个区间里均匀分布的随机数。

def init_weight_(tensor):with torch.no_grad():tensor.uniform_(-10, 10)tensor *= (tensor.abs() >= 5).float()for name, param in net.named_parameters():if 'weight' in name:init_weight_(param)print(name, param.data)

输出:

0.weight tensor([[ 7.0403,  0.0000, -9.4569,  7.0111],[-0.0000, -0.0000,  0.0000,  0.0000],[ 9.8063, -0.0000,  0.0000, -9.7993]])
2.weight tensor([[-5.8198,  7.7558, -5.0293]])

此外,参考2.3.2节,我们还可以通过改变这些参数的data来改写模型参数值同时不会影响梯度:

for name, param in net.named_parameters():if 'bias' in name:param.data += 1print(name, param.data)

输出:

0.bias tensor([1., 1., 1.])
2.bias tensor([1.])

4.2.4 共享模型参数

在有些情况下,我们希望在多个层之间共享模型参数。4.1.3节提到了如何共享模型参数: Module类的forward函数里多次调用同一个层。此外,如果我们传入Sequential的模块是同一个Module实例的话参数也是共享的,下面来看一个例子:

linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear) 
print(net)
for name, param in net.named_parameters():init.constant_(param, val=3)print(name, param.data)

输出:

Sequential((0): Linear(in_features=1, out_features=1, bias=False)(1): Linear(in_features=1, out_features=1, bias=False)
)
0.weight tensor([[3.]])

在内存中,这两个线性层其实一个对象:

print(id(net[0]) == id(net[1]))
print(id(net[0].weight) == id(net[1].weight))

输出:

True
True

因为模型参数里包含了梯度,所以在反向传播计算时,这些共享的参数的梯度是累加的:

x = torch.ones(1, 1)
y = net(x).sum()
print(y)
y.backward()
print(net[0].weight.grad) # 单次梯度是3,两次所以就是6

输出:

tensor(9., grad_fn=<SumBackward0>)
tensor([[6.]])

小结

  • 有多种方法来访问、初始化和共享模型参数。
  • 可以自定义初始化方法。

注:本节与原书此节有一些不同,原书传送门

4.3 模型参数的延后初始化

由于使用Gluon创建的全连接层的时候不需要指定输入个数。所以当调用initialize函数时,由于隐藏层输入个数依然未知,系统也无法得知该层权重参数的形状。只有在当形状已知的输入X传进网络做前向计算net(X)时,系统才推断出该层的权重参数形状为多少,此时才进行真正的初始化操作。但是使用PyTorch在定义模型的时候就要指定输入的形状,所以也就不存在这个问题了,所以本节略。有兴趣的可以去看看原文,传送门。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704061.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC专栏9——Scalable Diffusion Models with Transformers (DiT)结构解析

AIGC专栏9——Scalable Diffusion Models with Transformers &#xff08;DiT&#xff09;结构解析 学习前言源码下载地址网络构建一、什么是Diffusion Transformer (DiT)二、DiT的组成三、生成流程1、采样流程a、生成初始噪声b、对噪声进行N次采样c、单次采样解析I、预测噪声I…

kitti数据显示

画出track_id publish_utils.py中 def publish_3dbox(box3d_pub, corners_3d_velos, types, track_ids):marker_array MarkerArray()for i, corners_3d_velo in enumerate(corners_3d_velos):marker Marker()marker.header.frame_id FRAME_IDmarker.header.stamp rospy.T…

Pytorch训练RCAN QAT超分模型

Pytorch训练RCAN QAT超分模型 版本信息测试步骤准备数据集创建容器生成文件列表创建文件列表的代码执行脚本,生成文件列表训练RCAN模型准备工作修改开源代码编写训练代码执行训练脚本可视化本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下: 先训练FP32模型再加载FP32训练…

量子计算学习经验

推荐B站冉仕举老师视频&#xff08;老师讲的详细又耐心&#xff0c;张量网络做量子计算&#xff0c;不过有些基础概念都是通用的&#xff09; StringCNU的个人空间-StringCNU个人主页-哔哩哔哩视频 2《量子计算与量子信息》是经典的教材书的&#xff0c;但是大部分同学第一次看…

【随笔】固态硬盘数据删除无法恢复(开启TRIM),注意数据备份

文章目录 一、序二、机械硬盘和固态硬盘的物理结构与工作原理2.1 机械硬盘2.11 基本结构2.12 工作原理 2.2 固态硬盘2.21 基本结构2.22 工作原理 三、机械硬盘和固态硬盘的垃圾回收机制3.1 机械硬盘GC3.2 固态硬盘GC3.3 TRIM指令开启和关闭 四、做好数据备份 一、序 周末电脑突…

数据库设计过程中的各种模式

在数据库设计过程中&#xff0c;有几种常见的模式&#xff0c;它们有助于组织和管理数据。以下是这几种模式的简介&#xff1a; 主扩展模式&#xff08;也称为主从模式&#xff09;&#xff1a;这种模式适用于多个表具有相似结构的情况。这些表共享某些基本属性&#xff08;也…

备战蓝桥之二分

二分题目&#xff1a; B3880 [信息与未来 2015] 买木头 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.security.PublicKey; impor…

【Qt学习】QLineEdit 控件 属性与实例(登录界面,验证密码,正则表达式)

文章目录 1. 介绍2. 实例使用2.1 登录界面2.2 对比两次密码是否相同2.3 通过按钮显示当前输入的密码&#xff08;并对2.2进行优化&#xff09;2.4 结语 3. 正则表达式3.1 QRegExp3.2 验证输入内容 4. 资源代码 1. 介绍 关于 QLineEdit 的详细介绍&#xff0c;可以去查阅官方文…

[计算机网络]--IP协议

前言 作者&#xff1a;小蜗牛向前冲 名言&#xff1a;我可以接受失败&#xff0c;但我不能接受放弃 如果觉的博主的文章还不错的话&#xff0c;还请点赞&#xff0c;收藏&#xff0c;关注&#x1f440;支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、IP协议…

202432读书笔记|《泰戈尔的诗》——什么事让你大笑,我生命的小蓓蕾

202432读书笔记|《泰戈尔的诗》——什么事让你大笑&#xff0c;我生命的小蓓蕾 《泰戈尔写给孩子的诗&#xff08;中英双语版&#xff09;》作者拉宾德拉纳特泰戈尔文 张王哲图&#xff0c;图文并茂的一本书&#xff0c;文字与图画都很美&#xff0c;相得益彰&#xff01;很值得…

【Memory协议栈】EEPROM Abstraction模块详细介绍

目录 前言 正文 1.功能简介 2.关键概念 3.功能详解 3.1 Addressing scheme and segmentation 3.2 Address calculation 3.3 Limitation of erase / write cycles 3.4 Handling of “immediate” data 3.5 Managing block consistency information 4.关键API定义 4.…

学习磁盘管理

文章目录 一、磁盘接口类型二、磁盘设备的命名三、fdisk分区四、自动挂载五、扩容swap六、GPT分区七、逻辑卷管理八、磁盘配额九、RAID十、软硬链接 一、磁盘接口类型 IDE、SATA、SCSI、SAS、FC&#xff08;光纤通道&#xff09; IDE, 该接口是并口。SATA, 该接口是串口。SCS…

Linux笔记--文件内容的查阅与统计

一、文件内容的查阅 1.cat指令 concatenate&#xff0c;连接文件并打印到标准输出设备上(查看文件) &#xff08;1&#xff09; #cat文件的路径 常用选项: -n列出行号 &#xff08;2&#xff09;#tac 含义:倒序显示&#xff08;应用:查看日志) 2. head指令 查看一个文件的前n行…

golang学习2,golang开发配置国内镜像

go env -w GO111MODULEon go env -w GOPROXYhttps://goproxy.cn,direct

npm已经配置淘宝源仍然无法使用

使用npm命令安装Taro框架的时候&#xff0c;尽管已经设置淘宝源但是仍然无法下载&#xff0c;提示错误 >npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED npm ERR! request to https://registry.npm.taobao.org/cnpm failed, reason: certificate h…

K8S部署Java项目(Gitlab CI/CD自动化部署终极版)

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

websocket入门及应用

websocket When to use a HTTP call instead of a WebSocket (or HTTP 2.0) WebSocket 是基于TCP/IP协议&#xff0c;独立于HTTP协议的通信协议。WebSocket 是双向通讯&#xff0c;有状态&#xff0c;客户端一&#xff08;多&#xff09;个与服务端一&#xff08;多&#xff09…

代码随想录刷题第43天

第一题是最后一块石头的重量IIhttps://leetcode.cn/problems/last-stone-weight-ii/&#xff0c;没啥思路&#xff0c;直接上题解了。本题可以看作将一堆石头尽可能分成两份重量相似的石头&#xff0c;于是问题转化为如何合理取石头&#xff0c;使其装满容量为石头总重量一半的…

【AI Agent系列】【MetaGPT多智能体学习】0. 环境准备 - 升级MetaGPT 0.7.2版本及遇到的坑

之前跟着《MetaGPT智能体开发入门课程》学了一些MetaGPT的知识和实践&#xff0c;主要关注在MetaGPT入门和单智能体部分&#xff08;系列文章附在文末&#xff0c;感兴趣的可以看下&#xff09;。现在新的教程来了&#xff0c;新教程主要关注多智能体部分。 本系列文章跟随《M…

五种主流数据库:常用字符函数

SQL 字符函数用于字符数据的处理&#xff0c;例如字符串的拼接、大小写转换、子串的查找和替换等。 本文比较五种主流数据库常用数值函数的实现和差异&#xff0c;包括 MySQL、Oracle、SQL Server、PostgreSQL 以及 SQLite。 字符函数函数功能MySQLOracleSQL ServerPostgreSQ…