state_dict使用详解

     在PyTorch中,state_dict是一个非常重要的概念,它是一个包含模型参数的字典对象。每个模型的state_dict都包含了该模型的所有参数(权重和偏置等),用于在训练和推理过程中重现模型的内部状态.

      pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如 model的每一层的weights及偏置等等) (注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等) 优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

1. 保存模型参数

        使用torch.save(model.state_dict(), PATH)可以将state_dict保存到指定路径. 常用的保存 state_dict的格式是".pt"或’.pth’的文件,即下面命令的 PATH="./***.pt". 但是文件名字不影响,只是大家大家默认这个名字有辨识度,你取***.sp照样不影响.

torch.save(model.state_dicr(),PATH)  # PATH为存储的位置例如: path/best.pth

2.初始化模型

       即初始化模型的参数, 使用model.load_state_dict(torch.load(PATH))可以重新加载模型。

modle = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH)

3.取出或更新某一层参数

       前面说了state_dict()中的参数是按字典存取,即每个层都有一个key值索引, 所以按照字典规则取出该值即可. 现在假设某层的名字为 conv1.weight.

weight_data = torch.load('./model_state_dict.pt')['conv1.weight']

        修改某一层的值

# 假设 model 是一个已经初始化的模型  
# 更改第一层的权重  
model.state_dict()['layer1.weight'] = torch.randn(10, 10)

     在训练过程中,state_dict还用于存储梯度信息。在反向传播过程中,PyTorch会通过state_dict来更新模型参数.

4.控制model的某层是否需要梯度求导

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(mode.pretrained.parameters()):param.requires_grad = True

5.手写网络层及state_dict()使用例子

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
#define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef main():# Initialize modelmodel = TheModelClass()#Initialize optimizeroptimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)#print model's state_dictprint('Model.state_dict:')for param_tensor in model.state_dict():#打印 key value字典print(param_tensor,'\t',model.state_dict()[param_tensor].size())#print optimizer's state_dictprint('Optimizer,s state_dict:')for var_name in optimizer.state_dict():print(var_name,'\t',optimizer.state_dict()[var_name])if __name__=='__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/198224.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL-含json字段表和与不含json字段表查询性能对比

含json字段表和与不含json字段表查询性能对比 说明: EP_USER_PICTURE_INFO_2:不含json字段表 20200729json_test:含有json字段表 其中20200729json_test 标准ID、MANAGER_NO、PHONE_NO 为非json字段 data为json字段 2个表中MANAGER_NO、PHONE_NO都创建了各自的索引 测试…

CUDA简介, 配置和运行第一个CUDA程序(Windows和Linux)

CUDA简介 CUDA(Compute Unified Device Architecture)是由NVIDIA开发的一种通用并行计算架构。CUDA允许程序员利用NVIDIA GPU的并行计算能力,加速各种计算密集型应用程序。 CUDA技术基于GPU的并行计算原理。传统的CPU处理器拥有少量的核心&…

js中继承的方法

前言: 本人刚写了一篇原型链的封装继承多态,用家有儿女做的demo。其实我个人感觉封装和多态都容易去理解与实现。关键在于继承,js的才是比较难的,也容易让人混乱,至少我是因为继承头大过\(^o^)/~ js中有很多方法可以实现继承,这篇文章主要对继承的方法进行学习与测试。 这里…

[STM32-1.点灯大师上线】

学习了江协科技的前4课,除了打开套件的第一秒是开心的,后面的时间都是在骂娘。因为51的基础已经几乎忘干净,c语言已经还给谭浩强,模电数电还有点底子,硬着头皮上吧。 本篇主要是讲述学习点灯的过程和疑惑解释。 1.工…

【3】密评-物理和环境安全测评

0x01 依据 GB/T 39786 -2021《信息安全技术 信息系统密码应用基本要求》针对等保三级系统要求: 物理和环境层面: a)宜采用密码技术进行物理访问身份鉴别,保证重要区域进入人员身份的真实性; b)宜采用密码技术保证电子门…

[HTML]Web前端开发技术7(HTML5、CSS3、JavaScript )CSS的定位机制——喵喵画网页

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,佬佬会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

电音制作入门软件FL Studio21.2.0最新永久免费版

FL Studio是一款出色的编曲软件,最新版本的FL Studio21新增了四款全新的插件,覆盖了音频设计、延迟、相位器等等。通过软件的不断更新,我们可以享受到更加智能的电子音乐创作工具,目前,FL Studio的正式版已经推出了超过…

内核启动时间信息打印

文章目录 一 串口打印1 借助串口助手2 dmesg自带时间3 内核显示时间信息4 借助initcall_debug二 图形花显示1 bootgraph工具使用2 Bootchart工具使用3 Grabserial工具使用一 串口打印 1 借助串口助手 2 dmesg自带时间 root@xboard:~# dmesg [ 0.000000] Booting Linux on …

操作系统概论:揭秘计算机背后的神秘力量

操作系统概论 & 功能 概述定义操作系统功能作为系统资源的管理者向上层提供方便易用的服务作为最接近硬件的层次 主页传送门:📀 传送 概述 概念: 定义 控制和管理计算机硬件和软件资源的程序一种系统软件为上层用户、应用程序提供简单易…

uniapp开发小程序经验记录

uniapp开发小程序的过程中会遇到很多问题,这里记录一下相关工具优化,便于后来者参考。 每次保存代码后,小程序都跳回首页 针对这个问题,常规的做法就是修改pages配置文件,但是这种方式不便于路由参数的设置&#xff…

某60区块链安全之JOP实战一学习记录

区块链安全 文章目录 区块链安全Jump Oriented Programming实战一实验目的实验环境实验工具实验原理实验内容Jump Oriented Programming实战一 实验步骤分析合约源代码漏洞Jump Oriented Programming实战一 实验目的 学会使用python3的web3模块 学会分析以太坊智能合约中中Ju…

CPP-SCNUOJ-Problem P24. [算法课贪心] 跳跃游戏

Problem P24. [算法课贪心] 跳跃游戏 给定一个非负整数数组 nums ,你最初位于数组的 第一个下标 。 数组中的每个元素代表你在该位置可以跳跃的最大长度 判断你是否能够到达最后一个下标。 输入 输入一行数组nums 输出 输出true/fasle 样例 标准输入 2 3 1 …

【Wireshark工具使用】Wireshark无法抓取TwinCAT的EtherCAT包(已解决)

写在前面 因项目需要,近期在在深入研究EtherCAT协议,之后会将协议做一个系统的总结,分享在这个分栏。在研究EtherCAT协议帧时,使用了一个网络数据分析工具Wireshark,本文是关于EtherCAT数据帧分析工具使用中遇到的一个…

【设计模式】策略模式设计-电影票打折功能

任务二:使用策略模式设计电影票打折功能 某电影院售标系统为不同类型的用户提供了不同的打折方式(Discount),学生凭学生证可享受8折优惠**(StudentDiscount),儿童可享受减免10元的优惠&#xf…

「Verilog学习笔记」时钟分频(偶数)

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 timescale 1ns/1nsmodule even_div(input wire rst ,input wire clk_in,output wire clk_out2,output wire clk_out4,output wire clk_out8); //********…

新华三数字大赛复赛知识点 VLAN基本技术

VLAN IEEE 802.1Q 交换机端口类型 MVRP协议 VLAN Virtual LAN虚拟局域网。LAN可以是由几台少数家用计算机构成的网络,也可以是数以百计的计算机构成的企业网络。VLAN所指的LAN特指使用路由器分割的网络–也就是广播域。将一个物理的局域网在逻辑上划分成多个广播域…

苹果IOS在Safari浏览器中将网页添加到主屏幕做伪Web App,自定义图标,启动动画,自定义名称,全屏应用pwa

在ios中我们可以使用Safari浏览自带的将网页添加到主屏幕上,让我们的web页面看起来像一个本地应用程序一样,通过桌面APP图标一打开,直接全屏展示,就像在APP中效果一样,完全体会不到你是在浏览器中。 1.网站添加样式 在…

时间复杂度为 O(n^2) 的排序算法 | 京东物流技术团队

对于小规模数据,我们可以选用时间复杂度为 O(n2) 的排序算法。因为时间复杂度并不代表实际代码的执行时间,它省去了低阶、系数和常数,仅代表的增长趋势,所以在小规模数据情况下, O(n2) 的排序算法可能会比 O(nlogn) 的…

Stable Diffusion教程:4000字说清楚图生图

原文:Stable Diffusion教程:4000字说清楚图生图 - 知乎 目录 收起 基本使用 涂鸦绘制 局部绘制 局部绘制(涂鸦蒙版) 局部绘制(上传蒙版) 批量处理 总结 资源下载 “图生图”是 Stable Diffusion…

【Android知识笔记】架构专题(三)

如何用工程手段,提高写代码的生产力?(元编程) 即如何写同样多的代码,花费更少的时间?如何自动生成代码,哪种代码可以被自动生成?哪些环节能够作为自动生成代码的切入点? 代码自动生成技术 代码自动生成,指的并不是让计算机凭自己的意愿生成代码。而是让预先实现好…