对pytorch optimizer中state_dict、state、param_groups的简要理解

先说结论:

  • state_dict():一个dict,里面有两个key(stateparam_groups),

    • state这个key对应的value是各个权重对应的优化器状态。具体来说,一个model有很多权重,model.parameters()会打印出该模型的各层的权重,比如使用Adam,每层权重都有一个momentum和variance,形状与权重相同,还有该层当前更新到的步数。state_dict()['state']是一个dict,每个key-value item结构如下:
      该权重在model.parameters()中的位置 : {'step': tensor, 'exp_avg': tensor, # exp_avg: exponential moving average of gradient values'exp_avg_sq: tensor # exp_avg_sq: exponential moving average of squared gradient values
      
    • param_groups这个key对应的value是一个list,其中每个元素都是超参数组成的一个dict,因为不同的权重可以使用不同的超参数,所以需要使用list来表示,而且dict中params表示该超参数配置作用于哪些权重。state_dict()['param_groups']是一个list,每个元素结构如下
      {'lr': 0.01, 'weight_decay': 0,  ...  , 'params', [该超参数配置作用于的权重的位置]}
      
  • state:是一个defaultdict,包含的信息类似于state_dict()['state']+model.parameters(),具体来说,每个key-value item结构如下:

    param_tensor :{'step': tensor, 'exp_avg': tensor, 'exp_avg_sq': tensor,	
    }
    
  • param_groups:是一个list,包含的信息类似于state_dict()['param_groups']+model.parameters(),具体来说,每个元素结构如下:

    {'params': [param1, param2, ...]'lr': 0.01, 'weight_decay': 0, ...# 注意相较于state_dict()['param_groups'],原来'params'这个key对应的是param的索引位置,现在直接就是tensor了
    }
    

示例代码:

import torch
from torch.nn import Module
from torch.optim import Adamclass MyModel(Module):def __init__(self, in_dim, hidden_dim):super(MyModel, self).__init__()self.linear = torch.nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=True)self.linear2 = torch.nn.Linear(in_features=hidden_dim, out_features=in_dim, bias=False)def forward(self, x):y = self.linear(x)out = self.linear2(y)return outin_dim = 5
hidden_dim = 2
model = MyModel(in_dim=in_dim, hidden_dim=hidden_dim)optimier = Adam([{'params': model.linear.parameters(), 'lr': 0.05},{'params': model.linear2.parameters()}
], lr=0.01)x = torch.randn((in_dim))
out = model(x)
loss = torch.sum(out, dim=-1)
optimier.zero_grad()
loss.backward()
optimier.step()print('#' * 100)
print(optimier.state_dict())print('#' * 100)
print(optimier.state)print('#' * 100)
print(optimier.param_groups)

输出:

####################################################################################################
# state_dict()
{'state': {0: {'step': tensor(1.), 'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])}, 1: {'step': tensor(1.), 'exp_avg': tensor([0.0406, 0.0112]), 'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])}, 2: {'step': tensor(1.), 'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])}}, 'param_groups': [{'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 'params': [0, 1]}, {'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 'params': [2]}]}####################################################################################################
# state
defaultdict(<class 'dict'>, {Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])}, Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([0.0406, 0.0112]), 'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])}, Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])}}
)
####################################################################################################
# param_groups
[{'params': [Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True), Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True)], 'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False}, {'params': [Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False}
]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/49806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis相关问题汇总

sql预编译原理 https://www.cnblogs.com/Createsequence/p/16963891.html MyBatis一级缓存&二级缓存 mybatis的缓存机制&#xff08;一级缓存二级缓存和刷新缓存&#xff09;和mybatis整合ehcache_mybatis缓存机制-CSDN博客

每日一题~961div2A+B+C(阅读题,思维,数学log)

A 题意&#xff1a;给你 n*n 的表格和k 个筹码。每个格子上至多放一个 问至少占据多少对角线。 显然&#xff0c;要先 格数的多的格子去放。 n n-1 n-2 …1 只有n 的是一个&#xff08;主对角线&#xff09;&#xff0c;其他的是两个。 #include <bits/stdc.h> using na…

基于Java和MySQL的数据库优化技术

基于Java和MySQL的数据库优化技术 大家好&#xff0c;我是微赚淘客系统3.0的小编&#xff0c;是个冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将探讨如何基于Java和MySQL进行数据库优化&#xff0c;提升系统的性能和稳定性。我们将从查询优化、索引使用…

管理和迁移Conda环境两种方法:conda env export 和 Conda-Pack

在管理和迁移Conda环境时&#xff0c;通常有两种常用的方法&#xff1a;conda env export 和 Conda-Pack。这两种方法各有优缺点&#xff0c;根据具体需求可以选择合适的方法。 方法一&#xff1a;Conda env export conda env export 是Conda自带的命令&#xff0c;用于导出当…

基于微信小程序图书馆座位预约管理系统设计与实现

1.1选题动因 当前的网络技术&#xff0c;软件技术等都具备成熟的理论基础&#xff0c;市场上也出现各种技术开发的软件&#xff0c;这些软件都被用于各个领域&#xff0c;包括生活和工作的领域。随着电脑和笔记本的广泛运用&#xff0c;以及各种计算机硬件的完善和升级&#x…

JS 事件循环(Event Loop)机制

事件循环机制的作用 事件循环机制是 JS 的一种执行机制&#xff0c;一种可以实现异步编程的机制。 因为 JS 是单线程的&#xff0c;单线程意味着所有任务需要排队执行。但是有一些 API&#xff08;比如&#xff1a;定时器和 Ajax 等&#xff09;是需要等待一定的时间才能得到…

【Python】一文向您详细介绍 K-means 算法

【Python】一文向您详细介绍 K-means 算法 下滑即可查看博客内容 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地&#xff01;&#x1f387; &#x1f393; 博主简介&#xff1a;985高校的普通本硕&#xff…

Visual Studio 2022新建 cmake 工程测试 tensorRT 自带样例 sampleOnnxMNIST

1. 新建 cmake 工程 vs2022_cmake_sampleOnnxMNIST_test( 如何新建 cmake 工程&#xff0c;请参考博客&#xff1a;Visual Studio 2022新建 cmake 工程测试 opencv helloworld ) 2. 删除默认生成的 vs2022_cmake_sampleOnnxMNIST_test.h 头文件 3. 修改默认生成的 vs2022_cma…

【 C语言 】 C语言设计模式

一 、C语言和设计模式&#xff08;继承、封装、多态&#xff09; C有三个最重要的特点&#xff0c;即继承、封装、多态。我发现其实C语言也是可以面向对象的&#xff0c;也是可以应用设计模式的&#xff0c;关键就在于如何实现面向对象语言的三个重要属性。 &#xff08;1&…

BSV区块链在人工智能时代的数字化转型中的角色

​​发表时间&#xff1a;2024年6月13日 企业数字化转型已有约30年的历史&#xff0c;而人工智能&#xff08;以下简称AI&#xff09;将这种转型提升到了一个全新的高度。这并不难理解&#xff0c;因为AI终于使企业能够发挥其潜力&#xff0c;实现更宏大的目标。然而&#xff0…

MySQL中实现动态表单中JSON元素精准匹配的方法

目录 前言 一、动态表单技术 1、包含的主要信息 2、元素属性设置 3、表单内容 二、表单数据存储和查询 1、数据存储 2、数据的查询 3、在5.7版本中进行JSON检索 4、8.0后的优化查询 三、总结 前言 在很多有工作流设置的地方、比如需要在不同的流程中&#xff0c;需要…

什么是跨域问题及其解决方案

什么是跨域问题及其解决方案 在现代Web开发中&#xff0c;跨域问题是一个常见的挑战。了解什么是跨域问题以及如何解决它&#xff0c;对于开发者来说至关重要。在这篇博客中&#xff0c;我们将详细介绍什么是跨域问题&#xff0c;并探讨几种常用的解决方案。 什么是跨域问题&…

Docker 搭建GitLab

# 拉取镜像 docker pull gitlab/gitlab-ce # GitLab 需要持久存储来保存数据&#xff0c;如仓库数据、配置 mkdir -p /opt/gitlab/config /opt/gitlab/logs /opt/gitlab/data # 使用 docker run 命令来启动 GitLab 容器 docker run -itd \--hostname 192.168.111.128 \--p…

服务器数据恢复—V7000存储硬盘故障脱机的数据恢复案例

服务器存储数据恢复环境&#xff1a; 某品牌P740小型机AIXSybaseV7000磁盘阵列柜&#xff0c;磁盘阵列柜中有12块SAS机械硬盘&#xff08;其中包括一块热备盘&#xff09;。 服务器存储故障&#xff1a; 磁盘阵列柜中有一块磁盘出现故障&#xff0c;运维人员用新硬盘替换掉故障…

网络安全等级保护解决方案的主打产品

网络安全等级保护解决方案的主打产品&#xff1a; HiSec Insight安全态势感知系统、 FireHunter6000沙箱、 SecoManager安全控制器、 HiSecEngine USG系列防火墙和HiSecEngine AntiDDoS防御系统。 华为HiSec Insight安全态势感知系统是基于商用大数据平台FusionInsight的A…

【LeetCode】201. 数字范围按位与

1. 题目 2. 分析 这题挺难想的&#xff0c;我到现在还没想明白&#xff0c;为啥只用左区间和右区间就能找到目标值了&#xff0c;而不用挨个做与操作&#xff1f; 3. 代码 class Solution:def rangeBitwiseAnd(self, left: int, right: int) -> int:left_bin bin(left).…

lightningcss介绍及使用

lightningcss介绍及使用 一款使用 rust 编写的 css 解析器&#xff0c;转换器、及压缩器。 特性 特别快&#xff1a;可以在毫秒级别解析、压缩大量的 css 文件&#xff0c;而且比其他工具的打包结果更小给值添加类型&#xff1a;许多其他css解析器会将值解析成一个无类型的 …

k8s集群可视化工具安装(dashboard)

可视化安装 2.1、下载相关的yaml文件 wget https://raw.githubusercontent.com/kubernetes/dashboard/v2.0.0/aio/deploy/recommended.yaml Vim recommended.yaml 2.2、部署 kubectl apply -f recommended.yaml 查看那kubernetes-dashboard命令空间下的资源 kubectl get …

ZLMRTCClient配置说明与用法(含示例)

webRTC播放视频 后面在项目中会用到通过推拉播放视频流的技术&#xff0c;所以最近预研了一下webRTC 首先需要引入封装好的webRTC客户端的js文件ZLMRTCClient.js 下面是地址需要的自行下载 http://my.zsyou.top/2024/ZLMRTCClient.js 配置说明 new ZLMRTCClient.Endpoint…

技术分享!国产ARM + FPGA的SDIO通信开发介绍!

SDIO总线介绍 SDIO(Secure Digital lnput and Output),即安全数字输入输出接口。SDIO总线协议是由SD协议演化而来,它主要是对SD协议进行了一些扩展。 SDIO总线主要是为SDIO卡提供一个高速的I/O能力,并伴随着较低的功耗。SDIO总线不但支持SDIO卡,而且还兼容SD内存卡。支持…