PyTorch|保存及加载模型、nn.Sequential、ModuleList和ModuleDict

系列文章目录

PyTorch|Dataset与DataLoader使用、构建自定义数据集
PyTorch|搭建分类网络实例、nn.Module源码学习
pytorch|autograd使用、训练模型

文章目录

  • 系列文章目录
  • 一、保存及加载模型
    • (一)保存及加载模型的权重
    • (二)保存及加载优化器的权重
    • (三)保存及加载整个模型
    • (四)保存及加载更具一般性的checkpoint
    • (五)保存多个模型
  • 二、nn.Sequential源码分析
    • (一)init函数
    • (二)forward函数
  • 三、ModuleList和ModuleDict
    • (一)ModuleList
    • (二)ModuleDict


一、保存及加载模型

通过torch.save可以将该模型的参数、优化器状态、batch normalization、dropout、buffer变量等信息。

import torch
import torchvision.models as models

(一)保存及加载模型的权重

模型取自torchvision.models里的vgg16,权重为IMAGENET1K_V1。

model.state_dict()是模型的权重。state_dict状态字典:一般包含当前model的参数及buffer变量

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

推理时可以实现模型的加载:

  • 创建模型实例
  • 将实现保存的模型信息通过torch.load导入进来
  • 采用load_state_dict函数将模型信息载入模型实例
  • model.eval()使得模型进入推理模式
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

(二)保存及加载优化器的权重

保存优化器权重:
在这里插入图片描述

加载优化器权重:
在这里插入图片描述

(三)保存及加载整个模型

保存整个模型:

torch.save(model, 'model.pth')

加载整个模型:

model = torch.load('model.pth')

(四)保存及加载更具一般性的checkpoint

保存并加载用于推理或恢复训练的一般性checkpoint有助于从上次中断的地方重新开始。在保存一般检查点时,不仅仅是保存模型的state_dict,还包括保存优化器的state_dict、停止使用的时间,最近记录的训练损失,外部的torch.nn.Embedding层等等。

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4torch.save({'epoch': EPOCH,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH)

加载:

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

(五)保存多个模型

保存多个模型时可以将其直接合并到一个大字典中保存。

# Specify a path to save to
PATH = "model.pt"torch.save({'modelA_state_dict': netA.state_dict(),'modelB_state_dict': netB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),}, PATH)

二、nn.Sequential源码分析

nn.Sequential是有序的,当实例化nn.Sequential时,传入的模块顺序就是神经网络前向传播的顺序

在使用nn.Sequential时,可以按顺序传入模块,也可以输入一个字典。
在这里插入图片描述

(一)init函数

如果输入的是一个字典,init函数会采用遍历字典的方式,如果是一个一个的模块,init函数也会针对性的采取其他遍历方法。
在这里插入图片描述

(二)forward函数

对于一个模型的输入,nn.Sequential会依次的过其中的子模块。
在这里插入图片描述

nn.Sequential相比于ModuleList和ModuleDict来说,优势在于具有forward的功能。

三、ModuleList和ModuleDict

(一)ModuleList

pytorch允许我们把很多子模块放到一个列表中。ModuleList就是用于存放多个子模块的一个列表,在使用时可以对其进行遍历。ModuleList不单纯是一个列表,它本身就是一个module。
在这里插入图片描述

(二)ModuleDict

ModuleDict是用于存放多个子模块的一个字典,在使用时可以根据索引获得对应的子模块。ModuleDict不单纯是一个字典,它本身也是一个module。
在这里插入图片描述

除此之外,还有ParameterList、ParameterDict等,这些与ModuleList和ModuleDict的作用及使用方式类似。

参考:
8、深入剖析PyTorch的state_dict、parameters、modules源码
9、深入剖析PyTorch的nn.Sequential及ModuleList源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/796.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探究欧拉恒等式的美学与数学威力

正如老子所述,“道生一,一生二,二生三,三生万物”,数学作为人类认知自然法则的语言,其数系的不断发展象征着对世界理解的深化。从自然数经由分数、无理数至复数,复数虽看似反直觉,却…

MATLAB实现蚁群算法优化柔性车间调度(ACO-fjsp)

蚁群算法优化车间调度的步骤可以分为以下几个主要阶段: 1.初始化阶段: 设置算法参数,如信息素浓度、启发式因子等。这些参数将影响蚂蚁在选择路径时的决策过程。 确定车间调度的具体问题规模,包括工件数量、机器数量以及每个工件…

AI:162-如何使用Python进行图像识别与处理深度学习与卷积神经网络的应用

本文收录于专栏:精通AI实战千例专栏合集 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 每一个案例都附带关键代码,详细讲解供大家学习,希望可以帮到大家。正…

OpenHarmony GIF图像渲染库—ohos-gif-drawable

简介 本项目是OpenHarmony系统的一款GIF图像渲染库,基于Canvas进行绘制,主要能力如下: 支持播放GIF图片。支持控制GIF播放/暂停。支持重置GIF播放动画。支持调节GIF播放速率。支持监听GIF所有帧显示完成后的回调。支持设置显示大小。支持7种不同的展示…

面试题:Redis如何防止缓存穿透 + 布隆过滤器原理

题目来源 招银网络-技术-1面 题目描述 缓存穿透是什么?如何防止缓存穿透布隆过滤器的原理是什么? 我的回答 缓存穿透是什么? 攻击者大量请求缓存和数据库中都不存在的key。如何防止缓存穿透 可以使用布隆过滤器布隆过滤器的原理是什么&a…

AI容器化部署开发尝试 (一)(Pycharm连接docker,并部署django测试)

注意:从 Docker 19.03 开始,Docker 引入了对 NVIDIA GPU 的原生支持,因此若AI要调用GPU算力的话docker版本也是有要求的,后面博客测试。 当然本篇博客还没设计到GPU的调用,主要Pycharm加Anaconda的方案用习惯了&#…

缓存的使用及常见问题的解决方案

用户通过浏览器向我们发送请求,这个时候浏览器就会建立一个缓存,主要缓存一些静态资源(js、css、图片),这样做可以降低之后访问的网络延迟。然后我们可以在Tomcat里面添加一些应用缓存,将一些从数据库查询到…

Flask:URL与视图的映射

默认端口号80、443 blog_id 限制数据类型的话(int) 除此之外别的数据类型也可以,或者多个(用any) /book/list?page6

【笔记】ASP.NET Core Web API之Token验证

在实际开发中经常需要对外提供接口以便客户获取数据,由于数据属于私密信息,并不能随意供其他人访问,所以就需要验证客户身份。那么如何才能验证客户的身份呢?一个简单的小例子,简述ASP.NET Core Web API开发过程中&…

Git学习笔记(三)Git分支

Git分支是Git中非常重要的一个概念,无论是个人开发还是多人协作中,分支都起着至关重要的作用。几乎所有的版本控制系统都以某种形式支持分支。 使用分支意味着你可以把你的工作从开发主线上分离 开来进行重大的Bug修改、开发新的功能,以免影响…

Linux驱动开发笔记(零)驱动基础知识及准备

文章目录 前言一、Liunx、MCU和FPGA编程的区别二、Linux内核模块1. 什么是内核模块2. 内核模块的代码架构3. 头文件4. 模块参数5. makefile说明 三、 驱动程序设计思路1. 基本步骤2. 设备号3. 数据结构3.1 file_operations3.2 file3.3 inode3.4 哈希表3.5 cdev结构体3.6 kobj_m…

[Linux][进程信号][一][信号基础][如何产生信号]详细解读

目录 0.前言预备1.系统定义的信号列表2.核心转储 -- Core Dump 1.信号基础1.信号概念2.信号处理方式概览3.理解信号如何被保存4.信号发送的本质 2.如何产生信号?1.终端按键产生信号2.系统调用接口1.kill()2.raise()3.abort()4.如何理解? 3.由软件条件产生…

C# 图像旋转一定角度后,对应坐标怎么计算?

原理分析 要计算图像内坐标在旋转一定角度后的新坐标,可以使用二维空间中的点旋转公式。假设图像的中心点(即旋转中心)为 (Cx, Cy),通常对于正方形图像而言,中心点坐标为 (Width / 2, Height / 2)。给定原坐标点 (X, …

开发与产品的战争之自动播放视频

开发与产品的战争之自动播放视频 起因 产品提了个需求,对于网站上的宣传视频,进入页面就自动播放。但是基于我对chromium内核的一些浅薄了解,我当时就给拒绝了: “浏览器不允许”。(后续我们浏览器默认都是chromium内核的&#…

【深度学习】Vision Transformer

一、Vision Transformer Vision Transformer (ViT)将Transformer应用在了CV领域。在学习它之前,需要了解ResNet、LayerNorm、Multi-Head Self-Attention。 ViT的结构图如下: 如图所示,ViT主要包括Embedding、Encoder、Head三大部分。Class …

OpenHarmony鸿蒙南向开发案例:【智能燃气检测设备】

样例简介 本文档介绍了安全厨房案例中的相关智能燃气检测设备,本安全厨房案例利用轻量级软总线能力,将两块欧智通V200Z-R/BES2600开发板模拟的智能燃气检测设备和燃气告警设备组合成。当燃气数值告警时,无需其它操作,直接通知软总…

VOS3000加装登陆服务器安全防护系统有用吗

VOS3000是一款专业的软交换系统,它主要用于中小规模的VoIP运营业务,包括运营费率设定、套餐管理,账户管理、业终端管理、网关管理、数据查询、卡类管理、号码管理、系统管理等功能1。而关于加装登陆服务器安全防护系统是否有用,这…

2.4 Web容器配置:Tomcat

2.4 Web容器配置 2.4.1Tomcat配置1.常规配置2. HTTPS配置 *********** 2.4.1Tomcat配置 1.常规配置 在SpringBoot项目中,可以内置Tomcat、Jetly、Undertow、Netty等容器。 当开发者添加了spring-boot-starter-web依赖之后,默认会使用Tomcat作为Web容器…

基于Springboot+Vue的Java项目-网上点餐系统开发实战(附演示视频+源码+LW)

大家好!我是程序员一帆,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &am…

【EdgeBox-8120AI-TX2】Ubuntu18.04 + ROS_ Melodic + 星秒PAVO2单线激光 雷达评测

大家好,我是虎哥,好久不见,最近这断时间出现了一点变故,开始自己创业,很多事需要忙,所以停更了大约大半年,最近一切已经理顺,所以我还是抽空继续我之前的FLAG,CSDN突破十…