torch.load 报错 ModuleNotFoundError 或 AttributeError

Python 3.11.3 (main, Apr  7 2023, 19:25:52) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

正常情况下,我们会使用 torch.save 保存模型的 state_dict 。但我们也可以 torch.save 保存一个自定义类型对象,例如

import torch
import torch.nn as nn
class Module(nn.Module):def __init__(self) -> None:self._one = 1
torch.save(Module(), 'module.pth')

在读取 module.pth 时,可能会遇到 AttributeError

import torch
torch.load('module.pth')
Traceback (most recent call last):File "/Users/bytedance/Developer/todd/load.py", line 3, in <module>torch.load('module.pth')File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in loadreturn _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _loadresult = unpickler.load()^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_classreturn super().find_class(mod_name, name)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'Module' on <module '__main__' from '/Users/bytedance/Developer/todd/load.py'>

这是因为 torch.save 底层通过 pickle 实现,而 pickle 在保存自定义类型对象时不会保存其类型定义。用户需要保证 torch.load 时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将 Module 引用到当前命名空间,就可以正常加载 module.pth

import torch
from save import Module
torch.load('module.pth')

但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法

import torch
class Module:def __init__(self) -> None:# in case __setstate__ is not calledself._state = Nonedef __setstate__(self, state):# whenever state is not empty, __setstate__ is calledself._state = state
module = torch.load('module.pth')
print(module._state)
{'_one': 1}

但是如果自定义类型是从其他位置 import 得到的,例如

# module.py
import torch.nn as nn
class Module(nn.Module):def __init__(self) -> None:self._one = 1# save.py
import torch
from module import Module
torch.save(Module(), 'module.pth')

torch.load 会先尝试 import 相应的模块,如果不存在就会报错

Traceback (most recent call last):File "/Users/bytedance/Developer/todd/load.py", line 13, in <module>module = torch.load('module.pth')^^^^^^^^^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in loadreturn _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _loadresult = unpickler.load()^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_classreturn super().find_class(mod_name, name)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'module'

我们可以 mock 相应模块

import sys
from unittest.mock import Mock
import torch
sys.modules['module'] = Mock()
torch.load('module.pth')
Traceback (most recent call last):File "/Users/bytedance/Developer/todd/load.py", line 14, in <module>module = torch.load('module.pth')^^^^^^^^^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in loadreturn _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _loadresult = unpickler.load()^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: NEWOBJ class argument must be a type, not Mock

出现这个问题,是因为 Mock 具有递归创建的特性。我们可以手动修改

import sys
from unittest.mock import Mock
import torch
class Module:def __init__(self) -> None:self._state = Nonedef __setstate__(self, state):self._state = state
sys.modules['module'] = Mock()
sys.modules['module'].Module = Module
module = torch.load('module.pth')
print(module._state)
{'_one': 1}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

向 Maven 中央仓库上传一个修改过的基于jeecg的autoPOI的 jar包记录

1、注册https://issues.sonatype.org/账号 下面就代表注册好了&#xff0c;同时提交的工单也通过了 2、这里主要是goupId 需要进行认证&#xff0c;需要到域名注册商近一个txt的解析&#xff0c;以便确保这个是你的 通过下面来验证你的域名信息&#xff0c;这里主要是上面的工…

git命令分类合集

配置 git config --global user.name <name>&#xff1a;设置全局用户名 git config --global user.email <email>&#xff1a;设置全局用户邮箱 git config --global core.editor <editor>&#xff1a;设置全局文本编辑器创建与克隆仓库 git init&#xf…

面试题:说一说深拷贝和浅拷贝?

JavaScript中存在两大数据类型&#xff1a; 基本类型 和 引用类型 基本类型数据保存在在栈内存中 引用类型数据保存在堆内存中&#xff0c;引用数据类型的变量是一个指向堆内存中实际对象的引用&#xff0c;存在栈中 深拷贝和浅拷贝都只针对于引用类型。 一、 浅拷贝&#xff1…

【2023年11月第四版教材】《第1章-信息化发展之<3 现代化创新发展>》

第1章-信息化发展之&#xff1c;3 现代化创新发展&#xff1e; 3 现代化创新发展3.1 农业现代化3.2 两化融合&#xff08;17下2&#xff09;&#xff08;18下2&#xff09; &#xff08;22下12)3.3 智能制造3.4 消费互联网 3 现代化创新发展 3.1 农业现代化 要素具体内容农业…

Cpp7 — 继承和多态

继承 -------- 面向对象的三大特性之一 面向对象的三大特性&#xff1a;封装、继承、多态 封装&#xff1a;把数据和方法都封装在一起&#xff0c;想给你访问的变成共有&#xff0c;不想给访问的&#xff0c;写成私有。 继承&#xff1a;继承是类设计层次的复用 多态&#…

[SQL挖掘机] - 索引

介绍: 当你在数据库中进行查询时&#xff0c;索引是一种用于提高查询性能的重要工具。索引是对表中的一列或多列进行排序的数据结构&#xff0c;它可以快速定位到满足特定条件的记录&#xff0c;从而减少了查询所需的时间和资源。 在数据库中使用索引的主要好处包括&#xff…

【AGI】Copilot AI编程辅助工具安装教程

1. 基础激活教程 GitHub和OpenAI联合为程序员们送上了编程神器——GitHub Copilot。 但是&#xff0c;Copilot目前不提供公开使用&#xff0c;需要注册账号通过审核&#xff0c;我也提交了申请&#xff1a;这里第一期记录下&#xff0c;开启教程&#xff0c;欢迎大佬们来讨论…

通向架构师的道路之apache性能调优

一、总结前一天的学习 在前两天的学习中我们知道、了解并掌握了Web Server结合App Server实现单向Https的这样的一个架构。这个架构是一个非常基础的J2ee工程上线布署时的一种架构。在前两天的教程中&#xff0c;还讲述了Http服务 器、App Server的最基本安全配置&#xff08;…

java 数组的使用

数组 基本介绍 数组可以存放多个同一类型的数据&#xff0c;数组也是一种数据类型&#xff0c;是引用类型。 即&#xff1a;数组就是一组数据。 数组的使用 1、数组的定义 方法一 -> 单独声明 数据类型[] 数组名 new 数据类型[大小] 说明&#xff1a;int[] a new int…

C/C++算法——散列表

1、散列表介绍 散列表的英文叫Hash Table&#xff0c;我们平时也叫它哈希表或者Hash 表。散列表用的是数组支持按照下标随机访问数据的特性&#xff0c;所以散列表其实就是数组的一种扩展&#xff0c;由数组演化而来。可以说&#xff0c;如果没有数组&#xff0c;就没有散列表。…

【SpringBoot】85、SpringBoot中Boolean类型数据转0/1返回序列化配置

在 SpringBoot 中,前端传参数 0,1,后端可自动解析为 boolean 类型,但后端返回前端 boolean 类型时,却无法自动转换为 0,1,所以我们需要自定义序列化配置,将 boolean 类型转化为 0,1 1、类型对应 boolean 类型有false,true对应的 int 类型0,12、序列化配置 import com.f…

iOS——锁与死锁问题

iOS中的锁 什么是锁锁的分类互斥锁1. synchronized2. NSLock3. pthread 递归锁1. NSRecursiveLock2. pthread 信号量Semaphore1. dispatch_semaphore_t2. pthread 条件锁1. NSCodition2. NSCoditionLock3. POSIX Conditions 分布式锁NSDistributedLock 读写锁1. dispatch_barri…

MySQL优化

目录 一. 优化 SQL 查询语句 1.1. 分析慢查询日志 1.2. 优化 SQL 查询语句的性能 1.2.1 优化查询中的索引 1.2.2 减少表的连接&#xff08;join&#xff09; 1.2.3 优化查询语句中的过滤条件 1.2.4 避免使用SELECT * 1.2.5 优化存储过程和函数 1.2.6 使用缓存 二. 优化表结构…

超全整理,Jmeter性能测试-常用Jmeter第三方插件详解(超细)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 Jmeter作为一个开…

React(4)

1.属性&#xff08;props&#xff09;初始 状态state都是组件内部写的&#xff0c;也就是A组件内的state就只能A组件里面用&#xff0c;其他组件复用不了。因此属性props就可以。 比如一个导航栏&#xff0c;首页有&#xff0c;购物车有&#xff0c;我的有&#xff0c;他们三个…

Vue使用QrcodeVue生成二维码并下载

生成二维码 1、安装qrcode.vue组件 npm install --save qrcode.vue<template><div id"app"><qrcode-vue :valuevalue :sizesize></qrcode-vue><br /></div> </template><script> //导入组件 import QrcodeVue fro…

《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(20)-Fiddler精选插件扩展安装让你的Fiddler开挂到你怀疑人生

1.简介 Fiddler本身的功能其实也已经很强大了&#xff0c;但是Fiddler官方还有很多其他扩展插件功能&#xff0c;可以更好地辅助Fiddler去帮助用户去开发、测试和管理项目上的任务。Fiddler已有的功能已经够我们日常工作中使用了&#xff0c;为了更好的扩展Fiddler&#xff0c…

P4780 Phi的反函数

题目 思路 φ(x)n 当指数均为1时n最小 证明&#xff1a;容斥原理 代码 #include<bits/stdc.h> using namespace std; #define int long long const int maxn1e9; int ansINT_MAX,n; bool f; map<int,bool> mp; bool is_prime(int n){if(n<1) return false;fo…

Spring事务创建与使用

目录 前言Spring中事务的实现声明式事务Transactional 作⽤范围Transactional 参数说明对于事务不回滚的解决方案 前言 在数据库中我们提到了 事务, 事务的定义为, 将一系列操作封装成一个整体去调用 , 要么一起成功, 要么一起失败 Spring中事务的实现 在Spring中事务的操作…

发npm包

重点文件 .github -> workflow -> .yml文件 发自己的包 新建dev分支&#xff0c;合并到master后自动执行 fork别人的包 fork -> base dev新建本地rebase-dev分支 -> 提交push后合并至dev -> dev合并至master后自动执行 值得注意的是&#xff0c;fork别人的…