从0开始深度学习(19)——参数管理

在选择了模型架构,并设置了超参数之后,就进入了训练阶段,此时,我们的目标是找到使损失函数最小化的模型参数值。 经过训练后,我们将需要使用这些参数来做出未来的预测。
此外,有时我们希望提取参数,以便在其他环境中复用它们, 将模型保存下来,以便它可以在其他软件中执行, 或者为了获得科学的理解而进行检查。

本章将介绍:

  • 访问参数、用于调试、诊断和可视化
  • 参数初始化
  • 在不同模型组件间共享参数

以单隐藏层的多层感知机为例:

import torch
from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

运行结果:
在这里插入图片描述

1 参数访问

从已有模型中访问参数。 当通过Sequential类定义模型时, 我们可以通过索引来访问模型的任意层。 这就像模型是一个列表一样,每层的参数都在其属性中。 如下所示,我们可以检查第二个全连接层的参数

print(net[2].state_dict())

运行结果:
在这里插入图片描述
输出了权重矩阵偏置

2 目标参数

也可以选择性访问,如下代码:

print(type(net[2].bias))
print(net[2].weight)
print(net[2].bias.data)

运行结果:
在这里插入图片描述
或者使用下面的代码:

net.state_dict()['2.bias'].data

运行结果:
在这里插入图片描述

3 一次性访问所有参数

当我们需要对所有参数执行操作时,逐个访问它们可能会很麻烦,因为需要递归整个树来提取每个子块的参数,如下代码:

print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])
  • .named_parameters():调用这个方法会返回一个迭代器,其中每个元素都是一个元组,包含两个值:参数的名字(字符串类型)和参数本身(一个张量)。
  • *操作符:在print函数前使用*是为了将列表中的元素解包,这样print函数就可以直接打印出列表中的每个元素,而不是整个列表对象。

运行结果
在这里插入图片描述

4 从嵌套快收集参数

如果我们有多个块互相嵌套,那如何获取呢?我们先假设一个嵌套网络:

def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)

设计好网络之后,我们可以通过print查看网络结果:

print(rgnet)

在这里插入图片描述
假设我们访问(0)块中的block 0块中的(0)层的偏置项

rgnet[0][1][0].bias.data

在这里插入图片描述

5 参数初始化

默认情况下,PyTorch会根据一个范围均匀地初始化权重和偏置矩阵, 这个范围是根据输入和输出维度计算出的。 PyTorch的nn.init模块提供了多种预置初始化方法。

5.1 内置初始化

让我们首先调用内置的初始化器。 下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量, 且将偏置参数设置为0。

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

我们还可以将所有参数初始化为给定的常数,比如初始化为1。

def init_constant(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 1)nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

我们还可以对某些块应用不同的初始化方法。 例如,下面我们使用Xavier初始化方法初始化第一个神经网络层, 然后将第三个神经网络层初始化为常量值42。

def init_xavier(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)
def init_42(m):if type(m) == nn.Linear:nn.init.constant_(m.weight, 42)net[0].apply(init_xavier)# 对第一个使用Xavier初始化
net[2].apply(init_42)# 对第三个使用常量初始化
print(net[0].weight.data[0])
print(net[2].weight.data)

5.2 自定义初始化

如果框架中没有我们需要的初始化方法,则需要自定义初始化,例如我们将使用下面的分布做初始化: w ∼ { U ( 5 , 10 ) 可能性  1 4 0 可能性  1 2 U ( − 10 , − 5 ) 可能性  1 4 \begin{split}\begin{aligned} w \sim \begin{cases} U(5, 10) & \text{ 可能性 } \frac{1}{4} \\ 0 & \text{ 可能性 } \frac{1}{2} \\ U(-10, -5) & \text{ 可能性 } \frac{1}{4} \end{cases} \end{aligned}\end{split} w U(5,10)0U(10,5) 可能性 41 可能性 21 可能性 41
即:
情况1:

  • w 从均匀分布 U ( 5 , 10 ) 中取值 w从均匀分布U(5,10)中取值 w从均匀分布U(5,10)中取值
  • 概率为 1 4 概率为\frac{1}{4} 概率为41

情况2:

  • w = 0 w=0 w=0
  • 概率为 1 2 概率为\frac{1}{2} 概率为21

情况3:

  • w 从均匀分布 U ( − 10 , − 5 ) 中取值 w从均匀分布U(-10,-5)中取值 w从均匀分布U(10,5)中取值
  • 概率为 1 4 概率为\frac{1}{4} 概率为41

使用下面代码来展示:

def my_init(m):if type(m) == nn.Linear:print("Init", *[(name, param.shape)for name, param in m.named_parameters()][0])nn.init.uniform_(m.weight, -10, 10)m.weight.data *= m.weight.data.abs() >= 5net.apply(my_init)
net[0].weight[:2]

1、nn.init.uniform_(m.weight, -10, 10)将权重参数m.weight初始化为从均匀分布中 U ( − 10 , 10 ) U(-10,10) U(10,10)中抽取的值

2、然后通过m.weight.data *= m.weight.data.abs() >= 5这串代码实现了概率分布的效果,让我们逐步分析:

  • m.weight.data.abs():这个操作计算 m.weight 中每个元素的绝对值,返回一个与 m.weight形状相同的张量,其中每个元素是原权重的绝对值。
  • m.weight.data.abs() >= 5:这个操作生成一个布尔张量,其中每个元素表示对应位置的权重绝对值是否大于或等于 5。结果是一个与 m.weight 形状相同的布尔张量,值为 True 或 False。
  • m.weight.data *= m.weight.data.abs() >= 5:这个操作将 m.weight 中的每个元素与其对应的布尔值相乘。在 Python 和 PyTorch 中,布尔值 True 被视为 1,False 被视为 0。
    因此,如果某个权重的绝对值大于或等于 5,布尔值为 True,乘法结果不变;如果某个权重的绝对值小于 5,布尔值为 False,乘法结果为 0。

概率分析:

  • 在 -10 到 10 之间的均匀分布中,权重落在 -10 到 -5 之间的概率是 25%(因为区间长度为 5,总区间长度为 20)。
  • 权重落在 5 到 10 之间的概率是 25%
  • 权重落在 -5 到 5 之间的概率是 50%(因为区间长度为 10,总区间长度为 20)。

经过 m.weight.data *= m.weight.data.abs() >= 5 操作后:

  • 权重在 -10 到 -5 之间的概率仍然是 25%,因为这些权重被保留。
  • 权重在 5 到 10 之间的概率仍然是 25%,因为这些权重被保留。
  • 权重在 -5 到 5 之间的概率变为 0,因为这些权重被设置为 0,因此,权重为 0 的概率是 50%。

6 参数绑定

有时我们希望在多个层间共享参数: 我们可以定义一个稠密层,然后使用它的参数来设置另一个层的参数。

# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared, nn.ReLU(),shared, nn.ReLU(),nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[2].weight.data[0] == net[4].weight.data[0])

输出结果:
在这里插入图片描述
共享参数通常可以节省内存

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/57559.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

背包九讲——完全背包问题

目录 完全背包问题 问题定义 动态规划解法 状态转移方程 初始化 遍历顺序 三种解法: 朴素版——枚举k 进阶版——dp正推(一维滚动数组) 背包问题第三讲——完全背包问题 背包问题是一类经典的组合优化问题,通常涉及在限定…

【Linux笔记】Linux命令与使用

博文将不断学习补充 学习参考博文: Linux命令大全:掌握常用命令,轻松使用Linux操作系统-CSDN博客 文件或目录操作命令 zip # zip是使用最多的文档压缩格式 # 方便跨平台使用,但是压缩率不是很高 zip指令未安装 安装zip yum ins…

python实战项目47:Selenium采集百度股市通数据

python实战项目47:Selenium采集百度股市通数据 一、思路分析二、完整代码一、思路分析 这里以获取百度股市通股评下的投票数据为例,页面中的其他数据同理。由于此页面数据是js动态加载的,所以采用Selenium获取数据。思路很简单,通过Selenium打开页面,然后定位到“股评”选…

没有B柱?极氪MIX太大胆了!

文 | AUTO芯球 作者 | 雷慢 极氪又给国产车长脸了, 极氪MIX上市,创造了多个行业先例, 估计把合资看得一愣一愣的, 哪见过这样的每月都有新技术、黑科技冒出来, 我看完整个发布会就一个感想, 家里有小…

数据结构——哈夫曼树及其应用(哈夫曼编码)

判断树:用来描述分类过程的二叉树 哈夫曼树(最优二叉树)的基本概念 路径:从树中一个结点到另一个结点之间的分支构成这两个结点间的路径。 结点的路径长度:两结点间路径上的分支数。 结点的路径长度计算&#xff1…

PDF文件为什么不能编辑是?是啥原因导致的,有何解决方法

PDF文件格式广泛应用于工作中,但有时候我们可能遇到无法编辑PDF文件的情况。这可能导致工作效率降低,特别是在需要修改文件内容时显得尤为棘手。遇到PDF不能编辑时,可以看看是否以下3个原因导致的。 一、文件受保护 有些PDF文件可能被设置了…

leetcode动态规划(十二)-最后一块石头的重量

题目 1049.最后一块石头的重量 有一堆石头&#xff0c;用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合&#xff0c;从中选出任意两块石头&#xff0c;然后将它们一起粉碎。假设石头的重量分别为 x 和 y&#xff0c;且 x < y。那么粉碎的可能结…

矩阵matrix

点积 在 NumPy 中&#xff0c;dot 是矩阵或向量的点积&#xff08;dot product&#xff09;操作。 假设有两个向量a和 b&#xff0c;它们的点积定义为对应元素相乘&#xff0c;然后求和。公式如下&#xff1a; 例子&#xff1a; 点积的计算步骤是&#xff1a; 因此&#xf…

入门 | Prometheus+Grafana 普罗米修斯

一、prometheus介绍 1、监控系统组成 一个完整的监控系统需要包括如下功能&#xff1a;数据产生、数据采集、数据存储、数据处理、数据展示、分析、告警等。 &#xff08;1&#xff09;、数据来源 数据来源&#xff0c;也就是需要监控的数据。数据常见的产生、直接或间接暴露…

服务器磁盘爆满?别慌,教你轻松清理!

服务器磁盘爆满&#xff1f;别慌&#xff0c;教你轻松清理&#xff01; 简介 服务器磁盘空间告急&#xff0c;网站访问缓慢&#xff0c;甚至无法正常运行&#xff1f;别担心&#xff0c;这篇文章将为你提供一份详细的清理指南&#xff0c;帮助你快速释放服务器磁盘空间&#x…

【算法】Bellman-Ford单源最短路径算法(附动图)

目录 一、性质 二、思路 三、有边路限制的最短路 一、性质 适用于含有负权边的图&#xff08;Dijkstra不适用&#xff09; 更简单&#xff0c;但效率慢 如果对应路径存在负权回路则没有最短路径&#xff08;可用于判断图中是否存在负权回路&#xff09; 相比于spfa&#…

[分享] SQL在线编辑工具(好用)

在线SQL编写工具&#xff08;无广告&#xff09; - 在线SQL编写工具 - Web SQL - SQL在线编辑格式化 - WGCLOUD

物联网实训项目:绿色家居套件

1、基本介绍 绿色家居通过物联网技术将家中的各种设备连接到一起&#xff0c;提供家电控制、照明控制、电话远程控制、室内外遥控、防盗报警、环境监测、暖通控制、红外转发以及可编程定时控制等多种功能和手段。绿色家居提供全方位的信息交互功能&#xff0c;甚至为各种能源费…

solana phantom NFT图片显示不出来?

solana phantom NFT图片显示不出来&#xff1f; 问题 同样是jpeg格式图片&#xff0c;一个phatom可以显示&#xff0c;一个不可以显示为什么&#xff0c;nft图片格式大小有要求吗&#xff1f; 问题分析 Phantom 官网有一些关于 NFT 集成的文档,其中可能会有关于图片大小限制…

049_python基于Python的热门微博数据可视化分析

目录 系统展示 开发背景 代码实现 项目案例 获取源码 博主介绍&#xff1a;CodeMentor毕业设计领航者、全网关注者30W群落&#xff0c;InfoQ特邀专栏作家、技术博客领航者、InfoQ新星培育计划导师、Web开发领域杰出贡献者&#xff0c;博客领航之星、开发者头条/腾讯云/AW…

15分钟学Go 第7天:控制结构 - 条件语句

第7天&#xff1a;控制结构 - 条件语句 在Go语言中&#xff0c;控制结构是程序逻辑的重要组成部分。通过条件语句&#xff0c;我们可以根据不同的条件采取不同的行动。今天我们将详细探讨Go语言中的两种主要条件结构&#xff1a;if语句和switch语句。理解这些控制结构对于编写…

CTA-GAN:基于生成对抗网络对颈动脉和主动脉的非增强CT影像进行血管增强

写在前面 目前只分析了文章的大体内容和我个人认为的比较重要的细节&#xff0c;代码实现还没仔细看&#xff0c;后续有时间会补充代码细节部分。 文章地址&#xff1a;Generative Adversarial Network-based Noncontrast CT Angiography for Aorta and Carotid Arteries 代…

JAVA基础面试题准备

一些常见的JAVA基础题&#xff0c;面试中遇到过的会加*显示。 JAVA基础 1.Java中重载和重写的区别&#xff1f;* 2.int 和Integer类型这两个区别吗&#xff1f; 为什么需要有Integer类型&#xff1a; int和Integer类型的区别&#xff1a; 3.遍历list有那些方式吗&#xff1f;…

【Linux】进程信号(下)

目录 一、信号的阻塞 1.1 信号在内核中的保存方式 1.2 sigset_t信号集 &#xff08;1&#xff09;信号集操作 &#xff08;2&#xff09;sigprocmask函数 &#xff08;3&#xff09;sigpending函数 二、信号的处理 2.1 用户态和内核态 2.2 重谈进程地址空间 三、信号…

盘点2024年4款高清稳定的Windows10录屏工具。

Windows10电脑录屏在生活当中还是挺重要的&#xff0c;无论是教育领域的制作教程&#xff0c;还是游戏玩家记录精彩瞬间&#xff0c;亦或是商务人士进行演示&#xff0c;录屏都能发挥巨大作用。如果设备自带的一些工具无法完成录屏需求的话&#xff0c;这里帮大家找了几款好用到…