pytorch学习(8)——现有网络模型的使用以及修改

1 vgg16模型

在这里插入图片描述

1.1 vgg16模型的下载

采用torchvision中的vgg16模型,能够实现1000个类型的图像分类,VGG模型在AlexNet的基础上使用3*3小卷积核,增加网络深度,具有很好的泛化能力。
首先下载vgg16模型,python代码如下:

import torchvision# 下载路径:C:\Users\win10\.cache\torch\hub\checkpoints
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("ok")

下载结果:

G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.warnings.warn(msg)
G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
ok

1.2 vgg16模型内部结构

查看预训练的模型和未预训练的模型的内部结构:

import torchvisionvgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)
print(vgg16_false)

预训练的模型和未预训练的模型在整体结构上相同,但内部节点的参数(weight和bias)有所不同。
输出结果如下:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

可以发现(classifier)中最后一层可以发现out_features=1000,表示该模型能够支持1000种类型的图像分类。

2 迁移学习

迁移学习是机器学习的一个子领域,它允许一个已经在一个任务上训练好的模型用于另一个但相关的任务。通过这种方式,模型可以借用在原任务上学到的知识,从而更快地、更准确地完成新任务。

本文采用CIFAR10数据集,内部包含10个种类的图像,修改vgg16模型对数据集进行图像分类。为了将此数据集代入vgg16模型,需要对模型进行修改。

(classifier): Sequential(... ...(6): Linear(in_features=4096, out_features=1000, bias=True)
)

2.1 添加层

使用add_module()函数添加模块。由于最后的归一化层为4096通道输出转1000通道输出,因此添加一个归一化层将1000通道输出转换为10通道输出。

import torchvision
from torch import nnvgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

输出结果(部分):

(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True)(add_linear): Linear(in_features=1000, out_features=10, bias=True))

2.2 修改层

对classifier内的第6层进行修改。由于最后的归一化层为4096通道输出转1000通道输出,因此需要修改为4096通道输出转换为10通道输出。

import torchvision
from torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

输出结果(部分):

(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=10, bias=True))

2.3 模型保存

有两种方式保存模型数据。第一种保存方式是将模型结构和模型参数保存,第二种保存方式只是保存模型参数,以字典类型保存。
python代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential, CrossEntropyLossvgg16_false = torchvision.models.vgg16(pretrained=False)        # 未经过训练的模型# 保存方式1,模型结构+模型参数
torch.save(vgg16_false, "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method1.pth")# 保存方式2,模型参数(官方推荐)
torch.save(vgg16_false.state_dict(), "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method2.pth")

保存修改过的模型或自己的编写的模型:

# 保存模型和导入模型时都需要导入MYNN这个类
class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xmynn = MYNN()
torch.save(mynn, "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\mynn_method1.pth")

以上两部分Python代码运行结果如下:
在这里插入图片描述

2.4 模型导入

有两种方式导入模型数据。第一种导入方式能够直接使用,第二种导入方法需要将字典数据导入原来的网络模型。

import torch
import torchvision
from torch import nn
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential, CrossEntropyLoss# 方式1:加载模型
vgg16_import = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method1.pth")
print(vgg16_import)# 方式2:加载模型(字典数据)
vgg16_import2 = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method2.pth")
vgg16_new = torchvision.models.vgg16(pretrained=False)      # 重新加载模型
vgg16_new.load_state_dict(vgg16_import2)                    # 将数据填入模型
print(vgg16_import2)
print(vgg16_new)

导入保存的自己的模型,python代码如下:

# 需要导入自己网络模型
class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xmodel = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\mynn_method1.pth")
print(model)

自己的网络模型导入运行结果:

MYNN((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)

完整的模型训练套路-GPU训练

模型验证套路

Github上的代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54626.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2308协程超传客户用法

原文 协超客使用文档 基本用法 如何包含协程超传客户 协程超传客户是雅库的子库,雅库是仅头的,下载雅库库之后,在自己的工程中包含目录: 包含目录(包含) 包含目录(包含/雅兰/第三方)如果是g编译器还需要启用C20协程: 如(c造c编译器标识 串等"GNU")置(c造c标志&q…

【编程题】有效三角形的个数

文章目录 一、题目二、算法讲解三、题目链接四、补充 一、题目 给定一个包含非负整数的数组 nums ,返回其中可以组成三角形三条边的三元组个数。 示例1: 输入: nums [2,2,3,4] 输出: 3 **解释:**有效的组合是: 2,3,4 (使用第一个 2) 2,3,4 (使用第二个 …

基于 kernel 4.0 初始kmalloc

kmalloc 系列函数是驱动者常用来向内核大管家申请内存的API,今天抽空扒一扒它是怎么工作的;首先看看它的原型 1. kmalloc () 函数 static __always_inline void *kmalloc(size_t size, gfp_t flags) {if (__builtin_constant_p(size)) {if (size > …

openssl 加密(encrypt)、解密(decrypt)、签名(sign)、验证(verify)

一、使用openssl rsautl 进行加密、解密、签名、验证 [kyzjjyyzc-zjjcs04 openssl]$ openssl rsautl --help Usage: rsautl [options] -in file input file -out file output file -inkey file input key -keyform arg private key format - default PEM …

【Spring Boot】SpringBoot完整实现社交网站系统

一个完整的社交网站系统需要涉及到用户登录、发布动态、关注、评论、私信等各方面。这里提供一个简单的实现示例&#xff0c;供参考。 前端代码 前端使用Vue框架&#xff0c;以下是部分代码示例&#xff1a; 登录页&#xff1a; <template><div><input type…

论文阅读:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

前言 要弄清MAML怎么做&#xff0c;为什么这么做&#xff0c;就要看懂这两张图。先说MAML**在做什么&#xff1f;**它是打着Mate-Learing的旗号干的是few-shot multi-task Learning的事情。具体而言就是想训练一个模型能够使用很少的新样本&#xff0c;快速适应新的任务。 定…

pdf转ppt软件哪个好用?推荐一个好用的pdf转ppt软件

在日常工作和学习中&#xff0c;我们经常会遇到需要将PDF文件转换为PPT格式的情况。PDF格式的文件通常用于展示和保留文档的原始格式&#xff0c;而PPT格式则更适合用于演示和展示。为了满足这一需求&#xff0c;许多软件提供了PDF转PPT的功能&#xff0c;使我们能够方便地将PD…

vscode免密连远程

一、本地操作 ssh-keygen cd ~/.ssh ls // config为安装vscode远程插件后的信息 // id_rsa为本地私钥 // id_rsa.pub为本地公钥 vim config // 在User下增加一行本地私钥地址 IdentityFile "/Users/xuerui/.ssh/id_rsa"二、远程 cd ~/.ssh vim authorized_keys //…

C语言暑假刷题冲刺篇——day5

目录 一、选择题 二、编程题 &#x1f388;个人主页&#xff1a;库库的里昂 &#x1f390;CSDN新晋作者 &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏✨收录专栏&#xff1a;C语言每日一练✨相关专栏&#xff1a;代码小游戏、C语言初阶、C语言进阶&#x1f91d;希望作者…

【CSS】CSS 特性 ( CSS 优先级 | 优先级引入 | 选择器基本权重 )

一、CSS 优先级 1、优先级引入 定义 CSS 样式时 , 可能出现 多个 类型相同的 规则 定义在 同一个元素上 , 如果 CSS 选择器 相同 , 执行 CSS 层叠性 , 根据 就近原则 选择执行的样式 , 如 : 出现两个 div 标签选择器 , 都设置 color 文本颜色 ; <style>div {color: re…

精准高效农业作业,植保无人机显身手

中国作为农业大国&#xff0c;拥有约18亿亩的农田&#xff0c;每年都需要进行种子喷洒和农药施用等农业作业&#xff0c;对于普通农户来说&#xff0c;这是一项耗时耗力的工程&#xff0c;同时&#xff0c;人工喷洒农药极易造成农药慢性中毒&#xff0c;对农民的身体健康产生极…

Redis——急速安装并设置自启(CentOS)

现状 对于开发人员来说&#xff0c;部署服务器环境并不是一个高频操作。所以就导致绝大部分开发人员不会花太多时间去学习记忆&#xff0c;而是直接百度&#xff08;有一些同学可能连链接都懒得收藏&#xff09;。所以到了部署环境的时候就头疼&#xff0c;甚至是抗拒。除了每次…

k8s 安装istio (一)

前置条件 已经完成 K8S安装过程十&#xff1a;Kubernetes CNI插件与CoreDNS服务部署 部署 istio 服务网格与 Ingress 服务用到了 helm 与 kubectl 这两个命令行工具&#xff0c;这个命令行工具依赖 ~/.kube/config 这个配置文件&#xff0c;目前只在 kubernetes master 节点中…

web前端开发中的响应式布局设计是什么意思?

响应式布局是指网页设计和开发中的一种技术方法&#xff0c;旨在使网页能够在不同大小的屏幕和设备上都能良好地显示和交互。这种方法使得网页可以自动适应不同的屏幕尺寸&#xff0c;包括桌面电脑、平板电脑和手机等。 在Web前端开发中&#xff0c;响应式布局通常使用CSS&…

随记-Kibana Dev Tools,ES 增删改查 索引,Document

索引 创建索引 创建索引 PUT index_test创建索引 并 修改分片信息 # 创建索引 并 修改分片信息 PUT index_test2 { # 必须换行, PUT XXX 必须独占一行&#xff0c;类似的 其他请求也需要独占一行 "settings": {"number_of_shards": 1, # 主分片"…

bug复刻,解决方案---在改变div层级关系时,导致传参失败

问题描述&#xff1a; 在优化页面时&#xff0c;为了实现网页顶部遮挡效果&#xff08;内容滚动&#xff0c;顶部导航栏不随着一起滚动&#xff0c;并且覆盖&#xff09;&#xff0c;做法是将内容都放在一个div里面&#xff0c;为这个新的div设置样式&#xff0c;margin-top w…

c++ qt--事件过滤(第七部分)

c qt–事件过滤&#xff08;第七部分&#xff09; 一.为什么要用事件过滤 上一篇博客中我们用到了事件来进行一些更加细致的操作&#xff0c;如监控鼠标的按下与抬起&#xff0c;但是我们发现如果有很多的组件那每个组件都要创建一个类&#xff0c;这样就显得很麻烦&#xff…

GO语言:Worker Pools线程池、Select语句、Metex互斥锁详细示例教程

目录标题 一、Buffered Channels and Worker Pools1. Goroutine and Channel Example 线程和通道示例2. Deadlock 死锁3. Closing buffered channels 关闭通道4. Length vs Capacity 长度和容量5. WaitGroup6. Worker Pool Implementation 线程池 二、Select1. Example2. Defau…

IO进程线程,文件与目录,实现linux任意目录下ls -la

注意文件的名字、路径是如何输入的。 函数opendir打开目录&#xff0c;struct dirent&#xff0c;struct stat这些结构体的含义。 readdir()函数是一个用于读取目录内容的系统调用或库函数&#xff0c;在类Unix操作系统中&#xff08;如Linux&#xff09;广泛使用。它用于遍历…

python爬虫10:selenium库

python爬虫10&#xff1a;selenium库 前言 ​ python实现网络爬虫非常简单&#xff0c;只需要掌握一定的基础知识和一定的库使用技巧即可。本系列目标旨在梳理相关知识点&#xff0c;方便以后复习。 申明 ​ 本系列所涉及的代码仅用于个人研究与讨论&#xff0c;并不会对网站产…