Pytorch-day07-模型保存与读取

PyTorch 模型保存&读取

  • 模型存储
  • 模型单卡存储&多卡存储
  • 模型单卡读取&多卡读取

1、模型存储

  • PyTorch存储模型主要采用pkl,pt,pth三种格式,就使用层面来说没有区别
  • PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key是层名,value是权重向量)
  • 存储也由此分为两种形式:存储整个模型(包括结构和权重)和只存储模型权重(推荐)。
import torch
from torchvision import models
model = models.resnet50(pretrained=True)
save_dir = './resnet50.pth'# 保存整个 模型结构+权重
torch.save(model, save_dir)
# 保存 模型权重
torch.save(model.state_dict, save_dir)# pt, pth和pkl三种数据格式均支持模型权重和整个模型的存储

2、模型单卡存储&多卡存储

  • PyTorch中将模型和数据放到GPU上有两种方式——.cuda()和.to(device)
  • 注:如果要使用多卡训练的话,需要对模型使用torch.nn.DataParallel

2.1、nn.DataParrallel

<CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)>
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-puyISgkD-1692613764220)(attachment:image.png)]

  • module即表示你定义的模型
  • device_ids表示你训练的device
  • output_device这个参数表示输出结果的device,而这最后一个参数output_device一般情况下是省略不写的,那么默认就是在device_ids[0]

注:因此一般情况下第一张显卡的内存使用占比会更多

import os
import torch
from torchvision import models
#单卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
model = model.cuda()  # 单卡
#print(model)
---------------------------------------------------------------------------RuntimeError                              Traceback (most recent call last)~\AppData\Local\Temp/ipykernel_7460/77570021.py in <module>1 import os2 os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
----> 3 model = model.cuda()  # 单卡D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)903             Module: self904         """
--> 905         return self._apply(lambda t: t.cuda(device))906 907     def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)795     def _apply(self, fn):796         for module in self.children():
--> 797             module._apply(fn)798 799         def compute_should_use_set_data(tensor, tensor_applied):D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)818             # `with torch.no_grad():`819             with torch.no_grad():
--> 820                 param_applied = fn(param)821             should_use_set_data = compute_should_use_set_data(param, param_applied)822             if should_use_set_data:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)903             Module: self904         """
--> 905         return self._apply(lambda t: t.cuda(device))906 907     def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\cuda\__init__.py in _lazy_init()245         if 'CUDA_MODULE_LOADING' not in os.environ:246             os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247         torch._C._cuda_init()248         # Some of the queued calls may reentrantly call _lazy_init();249         # we need to just return without initializing in that case.RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0G4NTv1z-1692613764220)(attachment:ed8eb711294e4c6e3e43690ddb2bf66.png)]

#多卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
model = torch.nn.DataParallel(model).cuda()  # 多卡
#print(model)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eHt1Dn8t-1692613764221)(attachment:image.png)]

2.3、单卡保存+单卡加载

os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
model = models.resnet50(pretrained=True)
model.cuda()save_dir = 'resnet50.pt'   #保存路径# 保存+读取整个模型
torch.save(model, save_dir)
loaded_model = torch.load(save_dir)
loaded_model.cuda()# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)
# 先加载模型结构
loaded_model = models.resnet50()   
# 在加载模型权重
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model.cuda()
---------------------------------------------------------------------------RuntimeError                              Traceback (most recent call last)~\AppData\Local\Temp/ipykernel_7460/585340704.py in <module>5 os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号6 model = models.resnet50(pretrained=True)
----> 7 model.cuda()8 9 save_dir = 'resnet50.pt'   #保存路径D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)903             Module: self904         """
--> 905         return self._apply(lambda t: t.cuda(device))906 907     def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)795     def _apply(self, fn):796         for module in self.children():
--> 797             module._apply(fn)798 799         def compute_should_use_set_data(tensor, tensor_applied):D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)818             # `with torch.no_grad():`819             with torch.no_grad():
--> 820                 param_applied = fn(param)821             should_use_set_data = compute_should_use_set_data(param, param_applied)822             if should_use_set_data:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)903             Module: self904         """
--> 905         return self._apply(lambda t: t.cuda(device))906 907     def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\cuda\__init__.py in _lazy_init()245         if 'CUDA_MODULE_LOADING' not in os.environ:246             os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247         torch._C._cuda_init()248         # Some of the queued calls may reentrantly call _lazy_init();249         # we need to just return without initializing in that case.RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

2.4、单卡保存+多卡加载


os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
model = models.resnet50(pretrained=True)
model.cuda()# 保存+读取整个模型
torch.save(model, save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir)
loaded_model = nn.DataParallel(loaded_model).cuda()# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号
loaded_model = models.resnet50()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model = nn.DataParallel(loaded_model).cuda()

2.5、多卡保存+单卡加载

核心问题:如何去掉权重字典键名中的"module",以保证模型的统一性

  • 对于加载整个模型,直接提取模型的module属性即可
  • 对于加载模型权重,保存模型时保存模型的module属性对应的权重
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存+读取整个模型
torch.save(model, save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir).module
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'   #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存权重
torch.save(model.module.state_dict(), save_dir)#加载模型权重
os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
loaded_model = models.resnet50()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model.cuda()

2.6、多卡保存+多卡加载

保存整个模型时会同时保存所使用的GPU id等信息,读取时若这些信息和当前使用的GPU信息不符则可能会报错或者程序不按预定状态运行。可能出现以下2个问题:

  • 1、读取整个模型再使用nn.DataParallel进行分布式训练设置,这种情况很可能会造成保存的整个模型中GPU id和读取环境下设置的GPU id不符,训练时数据所在device和模型所在device不一致而报错
  • 2、读取整个模型而不使用nn.DataParallel进行分布式训练设置,发现程序会自动使用设备的前n个GPU进行训练(n是保存的模型使用的GPU个数)。此时如果指定的GPU个数少于n,则会报错

建议方案:

  • 只模型权重,之后再使用nn.DataParallel进行分布式训练设置则没有问题
  • 因此多卡模式下建议使用权重的方式存储和读取模型
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'   #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存+读取模型权重,强烈建议!!
torch.save(model.state_dict(), save_dir)
#加载模型 权重
loaded_model = models.resnet50()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir)))
loaded_model = nn.DataParallel(loaded_model).cuda()

建议

  • 不管是单卡保存还是多卡保存,建议以保存模型权重为主
  • 不管是单卡还是多卡,先load模型权重,再指定是多卡加载(nn.DataParallel)或单卡(cuda)
# 使用案例(截取片段代码)My_model.eval()
test_total_loss = 0
test_total_correct = 0
test_total_num = 0past_test_loss = 0 #上一轮的loss
save_model_step = 10 # 每10步保存一次modelfor iter,(images,labels) in enumerate(test_loader):images = images.to(device)labels = labels.to(device)outputs = My_model(images)loss = criterion(outputs,labels)test_total_correct += (outputs.argmax(1) == labels).sum().item()test_total_loss += loss.item()test_total_num += labels.shape[0]test_loss = test_total_loss / test_total_numprint("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100))# model saveif test_loss<past_test_loss:#保存模型权重torch.save(model.state_dict(), save_dir)#保存 模型权重+模型结构#torch.save(model, save_dir)if iter % save_model_step == 0:#保存模型权重torch.save(model.state_dict(), save_dir)#保存 模型权重+模型结构#torch.save(model, save_dir)past_test_loss = test_loss

单卡保存&单卡读取 案例

Google Colab:https://colab.research.google.com/drive/1hEOeqXYm4BfulY6d30QCI4HrFmCmmTQu?usp=sharing



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/55034.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring cloud整合spring boot,整合nacos、gateway、open-feign等组件

补充&#xff1a; 想看具体详情的可以看我的github链接&#xff1a;codeking01/platform-parent: spring cloud整合spring boot、nacos、gateway、open feign等组件 (github.com) 由于我升级了jdk17&#xff0c;所以用上了spring boot 3.0.2了。 踩坑无数&#xff0c;一堆无用文…

通过C实现sqlite3操作,导入电子词典

#include <stdio.h> #include <string.h> #include <stdlib.h> #include <sqlite3.h> int main(int argc, const char *argv[]) {//创建并打开一个数据库sqlite3 *db NULL;if(sqlite3_open("./dict.db",&db) ! SQLITE_OK){printf("…

宠物小程序开发

在当今社会&#xff0c;宠物已成为许多人生活中不可或缺的一部分。宠物市场的持续增长为创业者提供了巨大的商机。然而&#xff0c;作为一个创业者&#xff0c;要在竞争激烈的宠物市场中脱颖而出并不容易。因此&#xff0c;开发一个专属于自己的宠物小程序成为了解决这一难题的…

jdk 03.stream

01.集合处理数据的弊端 当我们在需要对集合中的元素进行操作的时候&#xff0c;除了必需的添加&#xff0c;删除&#xff0c;获取外&#xff0c;最典型的操作就是集合遍历 package com.bobo.jdk.stream; import java.util.ArrayList; import java.util.Arrays; import java.ut…

测试框架pytest教程(10)自定义命令行-pytest_addoption

pytest_addoption pytest_addoption是pytest插件系统中的一个钩子函数&#xff0c;用于向pytest添加自定义命令行选项。 在pytest中&#xff0c;可以使用命令行选项来控制测试的行为和配置。pytest_addoption钩子函数允许您在运行pytest时添加自定义的命令行选项&#xff0c;…

中国芯,寻找新赛道迫在眉睫

北京华兴万邦管理咨询有限公司 商瑞 陈皓 近期国内半导体行业的热点可以用两个“有点多”来描述&#xff0c;一个是中国芯群体中上市公司股价闪崩的有点多&#xff0c;另一个是行业和企业的活动有点多。前者说明了许多国内芯片设计企业&#xff08;fabless商业模式&#xff09;…

Ubuntu20 安装 libreoffice

1 更新apt-get sudo apt-get update2 安装jdk 查看jdk安装情况 Command java not found, but can be installed with:sudo apt install default-jre # version 2:1.11-72, or sudo apt install openjdk-11-jre-headless # version 11.0.138-0ubuntu1~20.04 sud…

在Jupyter Notebook中添加Anaconda环境(内核)

在使用前我们先要搞清楚一些事&#xff1a; 我们在安装anaconda的时候它就内置了Jupyter Notebook&#xff0c;这个jupyter初始只有base一个内核&#xff08;显示为Python3&#xff09; 此后其实我们就不需要重复安装完整的jupyter notebook了&#xff0c;只要按需为其添加新的…

【hibernate validator】(五)分组约束

首发博客地址 https://blog.zysicyj.top/ 一、请求组 1. 人组 package org.hibernate.validator.referenceguide.chapter05;public class Person { NotNull private String name; public Person(String name) { this.name name; } // getters and sette…

基于寄生捕食算法优化的BP神经网络(预测应用) - 附代码

基于寄生捕食算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码 文章目录 基于寄生捕食算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码1.数据介绍2.寄生捕食优化BP神经网络2.1 BP神经网络参数设置2.2 寄生捕食算法应用 4.测试结果&#xff1a;5…

服务器和普通电脑有何区别?43.248.189.x

简单来讲&#xff0c;服务器和电脑的功能是一样的&#xff0c;我们也可以把服务器称之为电脑&#xff08;PC机&#xff09;&#xff0c;只是服务器对稳定性与安全性以及处理器数据能力有更高要求&#xff0c;比如我们每天浏览一个网站&#xff0c;发现这个网站每天24小时都能访…

【ASP.NET】LIS实验室信息管理系统源码

LIS系统&#xff0c;即实验室信息管理系统&#xff0c;是一种基于互联网技术的医疗行业管理软件&#xff0c;它可以帮助实验室进行样本管理、检测流程管理、结果报告等一系列工作&#xff0c; 提高实验室工作效率和质量。 一、LIS系统的功能 1. 样本管理 LIS系统可以帮助实验…

【局部活动轮廓】使用水平集方法实现局部活动轮廓方法研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

SpringBoot项目(支付宝整合)——springboot整合支付宝沙箱支付 从极简实现到IOC改进

目录 引出git代码仓库准备工作支付宝沙箱api内网穿透 [natapp.cn](https://natapp.cn/#download) springboot整合—极简实现版1.导包配置文件2.controller层代码3.进行支付流程4.支付成功回调 依赖注入的改进1.整体结构2.pom.xml文件依赖3.配置文件4.配置类&#xff0c;依赖注入…

下载的文件被Windows 11 安全中心自动删除

今天从CSDN上下载了自己曾经上传的文件&#xff0c;但是浏览器下载完之后文件被Windows安全中心自动删除&#xff0c;说是带病毒。实际是没有病毒的&#xff0c;再说了即便有病毒也不应该直接删除啊&#xff0c;至少给用户一个保留或删除的选项。 研究了一番&#xff0c;可以暂…

机器视觉之平面物体检测

平面物体检测是计算机视觉中的一个重要任务&#xff0c;它通常涉及检测和识别在图像或视频中出现的平面物体&#xff0c;如纸张、标志、屏幕、牌子等。下面是一个使用C和OpenCV进行平面物体检测的简单示例&#xff0c;使用了图像中的矩形轮廓检测方法&#xff1a; #include &l…

设计模式之八:迭代器与组合模式

有许多方法可以把对象堆起来成为一个集合&#xff08;Collection&#xff09;&#xff0c;比如放入数组、堆栈或散列表中。若用户直接从这些数据结构中取出对象&#xff0c;则需要知道具体是存在什么数据结构中&#xff08;如栈就用peek&#xff0c;数组[]&#xff09;。迭代器…

推荐前 6 名 JavaScript 和 HTML5 游戏引擎

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建3D应用场景 事实是&#xff0c;自从引入JavaScript WebGL API以来&#xff0c;现代浏览器具有直观的功能&#xff0c;使它们能够渲染更复杂和复杂的2D和3D图形&#xff0c;而无需依赖第三方插件。 你可以用纯粹的JavaScript开…

tensordataset 和dataloader取值

测试1 from torch.utils.data import TensorDataset,DataLoader import numpy as np import torch a np.array([[1,2,3],[2,3,3],[1,1,2],[10,10,10],[100,200,200],[-1,-2,-3]]) print(a)X torch.FloatTensor(a) print(X)dataset TensorDataset(X,X)测试2 from torch.uti…

字节一面:闭包是什么?闭包的用途是什么?

前言 最近博主在字节面试中遇到这样一个面试题&#xff0c;这个问题也是前端面试的高频问题&#xff0c;因为在前端开发的日常开发中我们经常会用到闭包&#xff0c;我们会借助闭包来封装一些工具函数&#xff0c;所以更深的了解闭包是很有必要的&#xff0c;博主在这给大家细细…