Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

import torch
import torch.nn as nn# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量# 创建一个简单模型实例
model = SimpleModel()# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)# 将注册名指向另一个浮点型张量
model.test = float_parameter# 保存模型
torch.save(model.state_dict(), 'model.pth')# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)# 打印加载后的参数
print(model.test)# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)# 打印加载后的参数
print(model_1.test)
输出:
tensor(0.6000)
tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

import torch# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())# 将张量 b 中的值复制到张量 a 中
a.copy_(b)# 打印复制后的结果
print(a)# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],[7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

漏洞挖掘之某厂商OAuth2.0认证缺陷

0x00 前言 文章中的项目地址统一修改为: a.test.com 保护厂商也保护自己 0x01 OAuth2.0 经常出现的地方 1:网站登录处 2:社交帐号绑定处 0x02 某厂商绑定微博请求包 0x02.1 请求包1: Request: GET https://www.a.test.com/users/auth/weibo?…

SpringCloud微服务:Eureka 和 Nacos 注册中心

共同点 都支持服务注册和服务拉取都支持服务提供者心跳方式做健康检测 不同点 Nacos 支持服务端主动检测提供者状态:临时实例采用心跳模式,非临时(永久)实例采用主动检测模式Nacos 临时实例心跳不正常会被剔除,非临时实…

【C++基础】缺省参数

一&#xff0c;缺省参数概念 缺省参数是声明或定义一个函数时为函数的参数指定一个缺省值。 简单来说就是在定义函数的时候可以给形参赋一个初始化的值&#xff0c;这个值就叫做缺省值。 例&#xff1a; void Func(int a0) { cout<<a<<end1; } int main() { Fun…

深度学习中权重初始化的重要性

深度学习模型中的权重初始化经常被人忽略&#xff0c;而事实上这是非常重要的一个步骤&#xff0c;模型的初始化权重的好坏关系到模型的训练成功与否&#xff0c;以及训练速度是否快速&#xff0c;效果是否更好等等&#xff0c;这次我们专门来看看深度学习中的权重初始化问题。…

my-room-in-3d中的电脑,电视,桌面光带发光原理

1. my-room-in-3d中的电脑&#xff0c;电视&#xff0c;桌面光带发光原理 最近在github中&#xff0c;看到了这样的一个项目&#xff1b; 项目地址 我看到的时候&#xff0c;蛮好奇他这个光带时怎么做的。 最后发现&#xff0c;他是通过&#xff0c;加载一个 lightMap.jpg这个…

java的嵌套循环

在java中&#xff0c;也有嵌套循环。 下面是一个示例代码 public class Example17qiantaoxunhuan {public static void main(String[] args) {int i,j;for(i1;i<9;i){for(j1;j<i;j){System.out.println("*");}System.out.println("\n");}}}这段代码…

分割等和子集

416. 分割等和子集 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集&#xff0c;使得两个子集的元素和相等。 示例 1&#xff1a; 输入&#xff1a;nums [1,5,11,5] 输出&#xff1a;true 解释&#xff1a;数组可以分割成 [1, 5, 5] 和…

让我们一起来领悟带环问题的核心思想

一、带环的链表&#xff1a; 本质还是快慢指针来解决 关于如下一个带环链表怎么去找到他们想碰到的节点呢&#xff1f;&#xff1f;&#xff1f;&#xff1f;我们可以想到快慢指针&#xff0c;第一个快点走&#xff0c;若是有环就会进入环&#xff0c;此时快指针每次走2步&am…

2.1 上海雷卯电子PLC

PLC&#xff08;可编程逻辑控制器&#xff09;像是工厂自动化系统的“大脑”&#xff0c;负责监控和控制各种生产过程。PLC 能够精确地协调各类设备的操作&#xff0c;实现生产流程的自动化和优化。通过编程&#xff0c;它可以根据不同的生产需求灵活调整控制逻辑&#xff0c;提…

可视化大屏应用场景:智慧安防,保驾护航

hello&#xff0c;我是大千UI工场&#xff0c;本篇分享智慧安防的大屏设计&#xff0c;关注我们&#xff0c;学习N多UI干货&#xff0c;有设计需求&#xff0c;我们也可以接单。 实时监控与预警 可视化大屏可以将安防系统中的监控画面、报警信息、传感器数据等实时展示在大屏上…

快速幂笔记

快速幂即为快速求出一个数的幂&#xff0c;这样可以避免TLE&#xff08;超时&#xff09;的错误。 传送门&#xff1a;快速幂模板 前置知识&#xff1a; 1) 又 2) 代码&#xff1a; #include <bits/stdc.h> using namespace std; int quickPower(int a, int b) {int…

TiDB系列之:部署TiDB集群常见报错解决方法

TiDB系列之&#xff1a;部署TiDB集群常见报错解决方法 一、部署TiDB集群二、unsupported filesystem ext3三、soft limit of nofile四、THP is enabled五、numactl not usable六、net.ipv4.tcp_syncookies 1七、service irqbalance not found,八、登陆TiDB数据库 一、部署TiDB…

搜款网商品列表API接口:高效获取时尚潮流商品的新途径

API接口概述 搜款网商品列表API接口允许开发者根据设定的条件&#xff08;如分类、价格区间、关键词等&#xff09;查询搜款网上的商品信息&#xff0c;并返回符合条件的商品列表。通过调用该接口&#xff0c;您可以轻松获取到搜款网上最新、最热的时尚商品数据&#xff0c;为…

批量视频剪辑新选择:一键式按照指定秒数分割视频并轻松提取视频中的音频,让视频处理更高效!

是否经常为大量的视频剪辑工作感到头疼&#xff1f;还在一个个手动分割、提取音频吗&#xff1f;现在&#xff0c;我们为你带来了一款全新的视频批量剪辑神器&#xff0c;让你轻松应对各种视频处理需求&#xff01; 首先&#xff0c;进入媒体梦工厂的主页面&#xff0c;并在板…

数据结构===队列

文章目录 概要操作入队出队 顺序队列代码Python 链式队列代码Python 小结 概要 队列&#xff0c;就像现实中的排队一样&#xff0c;这样的数据结构&#xff0c;一说很多人都熟悉。 队列&#xff0c;就是像我们排队一样&#xff0c;有2个操作&#xff0c;入队&#xff0c;出队&…

TFT显示屏偶发无法点亮

一. 问题描述 最近接到一起客诉&#xff1a;设备偶发显示屏不亮。复现现象时&#xff0c;发现有如下规律&#xff1a; 上电后&#xff0c;如果显示屏正常启动&#xff0c;则在使用过程中会一直正常。反之&#xff0c;如果显示屏一上电就无法显示&#xff0c;则一直黑屏。 是…

安卓硬件访问服务

安卓硬件访问服务 硬件访问服务通过硬件抽象层模块来为应用程序提供硬件读写操作。 由于硬件抽象层模块是使用C语言开发的&#xff0c; 而应用程序框架层中的硬件访问服务是使用Java语言开发的&#xff0c; 因此&#xff0c; 硬件访问服务必须通过Java本地接口&#xff08;Jav…

【Python】数据类型

文章目录 数值列表列表的基本概念&#xff1a;列表的常用方法和操作&#xff1a;列表的迭代和遍历&#xff1a;列表的内部实现原理&#xff1a; 字典字典的基本概念&#xff1a;字典的常用方法和操作&#xff1a;词典的迭代和遍历&#xff1a;词典的内部实现原理&#xff1a; 集…

vector的使用

1.构造函数 void test_vector1() {vector<int> v; //无参的构造函数vector<int> v2(10, 0);//n个value构造&#xff0c;初始化为10个0vector<int> v3(v2.begin(), v2.end());//迭代器区间初始化,可以用其他容器的区间初始化vector<int> v4(v3); //拷贝…

Java项目:基于SSM框架实现的学院党员管理系统高校党员管理系统(ssm+B/S架构+源码+数据库+毕业论文+开题)

一、项目简介 本项目是一套基于SSM框架实现的学院党员管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单、功能齐…