Pytorch深度学习-----神经网络模型的保存与加载(VGG16模型)

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)


文章目录

  • 系列文章目录
  • 一、网络模型的保存
    • 1.方法一
    • 2.方法二
  • 二、网络模型的加载
    • 1.方法一
    • 2.方法二
  • 三、总结


一、网络模型的保存

1.方法一

保存整个模型,包括其相关的所有参数

torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)

参数说明:

obj: 要保存的对象,可以是模型、张量、字典等。
f: 要保存到的文件路径或文件对象。
pickle_protocol: 序列化协议的版本,默认为DEFAULT_PROTOCOL。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true, "vgg16_model_true.pth")

其中.pth是后缀标志。

在这里插入图片描述

2.方法二

只保存模型参数,在原有vgg16对象中使用.state_dict()方法即可。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true.state_dict(), "vgg16_model_true_2.pth")

在这里插入图片描述

二、网络模型的加载

1.方法一

对应于上述中保存模型的方法1进行加载。

相关函数如下:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

f: 要加载的文件路径或文件对象。
map_location: 可选参数,用于指定在哪个设备上加载模型。如果不提供该参数,默认会加载到当前设备。
pickle_module: 可选参数,用于指定用于反序列化的模块。默认为pickle。
pickle_load_args: 其他可选的用于反序列化的参数。

代码如下:

import torch
import torchvision.models as models
from torch import nnmodel1 = torch.load("vgg16_model_true.pth")  # 因为vgg16_model_true.pth是使用方法一保存的,故输出后是整个模型网络结构
print(model1)
model2 = torch.load("vgg16_model_true_2.pth")  # 因为vgg16_model_true_2.pth是使用方法二保存的,只保留模型参数,故输出后是整个字典类型
print(model2)

vgg16_model_true.pth加载结果

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

vgg16_model_true_2.pth加载结果

OrderedDict([('features.0.weight', tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],[-5.8312e-01,  3.5655e-01,  7.6566e-01],[-6.9022e-01, -4.8019e-02,  4.8409e-01]],[[ 1.7548e-01,  9.8630e-03, -8.1413e-02],[ 4.4089e-02, -7.0323e-02, -2.6035e-01],[ 1.3239e-01, -1.7279e-01, -1.3226e-01]],[[ 3.1303e-01, -1.6591e-01, -4.2752e-01],[ 4.7519e-01, -8.2677e-02, -4.8700e-01],[ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],[[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],[-4.2805e-01, -2.4349e-01,  2.4628e-01],[-2.5066e-01,  1.4177e-01, -5.4864e-03]],[[-1.4076e-01, -2.1903e-01,  1.5041e-01],[-8.4127e-01, -3.5176e-01,  5.6398e-01],[-2.4194e-01,  5.1928e-01,  5.3915e-01]],[[-3.1432e-01, -3.7048e-01, -1.3094e-01],[-4.7144e-01, -1.5503e-01,  3.4589e-01],[ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],[[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],[-2.7185e-01, -7.1709e-01,  3.1292e-01],[-7.5753e-02, -2.2079e-01,  3.3455e-01]],[[ 3.0924e-01,  6.7071e-01,  2.0546e-02],[-4.6607e-01, -1.0697e+00,  3.3501e-01],[-8.0284e-02, -3.0522e-01,  5.4460e-01]],[[ 3.1572e-01,  4.2335e-01, -3.4976e-01],[ 8.6354e-02, -4.6457e-01,  1.1803e-02],[ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],...,

2.方法二

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)vgg16_true.load_state_dict(torch.load("vgg16_model_true_2.pth"))  # 针对第二种加载参数的情况,使其显示完整的网络结构
print(vgg16_true)
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

注意: 加载模型时,要确保当前代码中使用的模型类与之前保存的模型类相同。

三、总结

torch.load()是PyTorch中用于加载保存的对象的函数,可以加载之前使用torch.save()保存的模型、张量、字典等。可以指定要加载的文件路径或文件对象,并可选地指定加载到的设备、反序列化模块等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Git介绍及常用命令详解

一、Git的概述 Git是一个分布式版本控制工具,通常用来对软件开发过程中的源代码文件进行管理。 Git 会跟踪我们对文件所做的更改,因此我们可以记录已完成的工作,并且可以在需要时恢复到特定或以前的版本。Git 还使多人协作变得更加容易&…

个人对前后端分离的一些看法

内容简介:前端开发过程中能完全不依赖后端的才是真正的前后端分离指的是工作过程中,前端的的代码中往往会掺杂一些后端的逻辑。后端返回了一个json对象 前端开发过程中能完全不依赖后端的才是真正的前后端分离 指的是工作过程中,前端的的代码…

涉及JS时实用的简洁方法

当涉及到JavaScript编程时,有许多简洁和实用的方法可以帮助你更有效地编写代码。以下是一些常用的简洁方法: 箭头函数: 箭头函数是一种简洁的语法形式,适用于单行函数表达式。它可以让你更紧凑地定义匿名函数。 // 传统函数 fun…

Linux系统中的自旋锁(两幅图清晰说明)

总结: 多CPU下的自旋锁采取的是忙等待(原地打转)机制,虽然忙等待的线程占用了它所在的cpu,但其他线程仍可放到其他CPU上执行。所以自旋锁上锁和解锁之间的临界区代码要尽量的短,最好不要超过5行&#xff0c…

jenkins流水线

1.拉取代码 https://gitee.com/Wjc_project/yygh-parent.git2、项目编译 mvn clean package -Dmaven.test.skiptrue ls hospital-manage/target3、构建镜像 ls hospital-manage/target docker build -t hospital-manage:latest -f hospital-manage/Dockerfile ./hospital-ma…

AWD攻防学习总结(草稿状态,待陆续补充)

AWD攻防学习总结 防守端1、修改密码2、备份网站3、备份数据库4、部署WAF5、部署文件监控脚本6、部署流量监控脚本/工具7、D盾扫描,删除预留webshell8、代码审计,seay/fortify扫描,漏洞修复及利用9、时刻关注流量和积分信息,掉分时…

业绩难言乐观,皓泽电子撤回上市申请,小米等为其关联方

撰稿|行星 来源|贝多财经 8月8日,深圳证券交易所披露的信息显示,由于河南皓泽电子股份有限公司(下称“皓泽电子”)及其保荐人主动要求撤回申请文件,深交所终止了皓泽电子的发行注册程序。 据此前招股书披露&#xff…

python爬虫实战(1)--爬取新闻数据

想要每天看到新闻数据又不想占用太多时间去整理,萌生自己抓取新闻网站的想法。 1. 准备工作 使用python语言可以快速实现,调用BeautifulSoup包里面的方法 安装BeautifulSoup pip install BeautifulSoup完成以后引入项目 2. 开发 定义请求头&#xf…

Fast Tone Mapping for High Dynamic Range Images

Abstract 我们提出了一种快速、有效、灵活的色调再现方法,在低动态范围再现设备中保留了高动态范围场景的可视性和对比度印象。 一个单一的参数控制能见度和对比度在一个简单和优雅的方式和互动速度。 新方法使用简单,计算效率高。 实验表明&#xff0c…

Spring Boot Actuator未授权访问漏洞

1.问题 Spring Boot Actuator 端点的未授权访问漏洞是一个安全性问题,可能会导致未经授权的用户访问敏感的应用程序信息。 可是并不用太过担心,Spring Boot Actuator 默认暴漏的信息有限,一般情况下并不会暴露敏感数据。 注册中心有些功能集…

Jenkins+Docker+SpringCloud微服务持续集成

JenkinsDockerSpringCloud微服务持续集成 JenkinsDockerSpringCloud持续集成流程说明SpringCloud微服务源码概述本地运行微服务本地部署微服务 Docker安装和Dockerfile制作微服务镜像Harbor镜像仓库安装及使用在Harbor创建用户和项目上传镜像到Harbor从Harbor下载镜像 微服务持…

RK3568蓝牙程序开发过程

1、搭建蓝牙开发环境 蓝牙开发可以使用C语言开发或python语言开发,使用的是蓝牙开发库为bluez库。 本文开发使用python语言开发,安装bluez库,可以使用pip install PyBluez来安装。 如果安装不上的话,可以使用sudo apt install pyt…

Kafka与Zookeeper版本对应关系

文章目录 了解版本对应Kafka安装包Kafka源码包 了解 比如: kafka_2.11-1.1.1.jar包 其中2.11表示的是Scala的版本,因为Kafka服务器端代码完全由Scala语音编写。”-“后面的1.1.1表示的kafka的版本信息。遵循一个基本原则,Kafka客户端版本和服…

无涯教程-Perl - getnetbyname函数

描述 此函数返回由NAME指定的网络信息(在列表context中)($name,$aliases,$addrtype,$net) 语法 以下是此函数的简单语法- getnetbyname NAME返回值 此函数在错误时返回undef,否则在标量context中返回网络地址,在错误时返回空列表,否则在列表context中返回网络记录(名称,别…

pandas 笔记 date_range

返回固定频率下的datetime 1 使用方法 pandas.date_range(startNone, endNone, periodsNone, freqNone, tzNone, normalizeFalse, nameNone, inclusiveboth, *, unitNone, **kwargs) 2 基本参数 start、end、periods至少需要两个 start生成日期的左边界end生成日期的右边界…

错误: XXXAdapter不是抽象的, 并且未覆盖Adapter中的抽象方法onBindViewHolder(ViewHolder,int)

一、问题描述 在学习Android可侧滑删除的RecyclerView的时候,遇到了下面的报错 错误: SwipeDelAdapter不是抽象的, 并且未覆盖Adapter中的抽象方法onBindViewHolder(ViewHolder,int) public class SwipeDelAdapter extends RecyclerView.Adapter { ^ 在上面的…

java springboot word文档转pdf

java springboot word文档转pdf 1、环境2、依赖3、代码 1、环境 1、java、springboot 2、maven或者gradle 3、办公软件(自己电脑上的wps或者office等,如果部署到服务器上也要安装,linux、Mac 都有,自己安装) 可能会遇…

用zabbix实现web监控

上篇我们说到了用最简单的web页面监控,如果你的页面只有ip和port就可以访问的话,那么简单的监测没有问题了。如果……开发给你的网站在后边加了个目录呢?那么就绕不开了web场景监控了。 一、添加模板 在【模板】上新建一个模板,…

操作系统—调度算法

进程调度算法 进程调度算法也称CPU调度算法 调度发生时期 当进程从运行状态转到等待状态;当进程从运行状态转到就绪状态;当进程从等待状态转到就绪状态;当进程从运行状态转到终止状态; 其中发生在 1 和 4 两种情况下的调度称为…

[QCM6125][Android13] 关闭救援模式

文章目录 开发平台基本信息问题描述解决方法 开发平台基本信息 芯片: QCM6125 版本: Android 13 kernel: msm-4.14 问题描述 安装系统在未响应5分钟的时候,系统会自动进入救援模式,这时候需要通过音量键和电源键进行操作才能再次进入系统。对于无人值…