预训练--微调

预训练–微调

一个很简单的道理,如果我们的模型是再ImageNet下训练的,那么这个模型一定是会比较复杂的,意思就是这个模型可以识别到很多种类别的即泛化能力很强,但是如果要它精确的识别是否某种类别,它的表现可能就不佳了,因此,我们需要在原来的基础上再对特定的我们需要识别的类别进行重新训练,微调原来网络结构中的参数,此时模型还是可以抽取较通用的图像特征。
在这里插入图片描述
参考自https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter09_computer-vision/9.2_fine-tuning
当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

热狗识别

源数据集是ImageNet,超过1000万个图像和1000类物体,热狗数据集包含1400个正类图像和其他多种负类图像
最开始还是导入所需要的库以及设置cuda

import torch
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os
import d2lzh_pytorch as d2l
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下载数据集https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/hotdog.zip
我直接放在了我的默认路径下,读数据如下

train_imgs = ImageFolder("hotdog/train")
test_imgs = ImageFolder("hotdog/test")

然后我们观察一下数据集,可以看到大小,宽高比各不同

# 前八张正类图像和最后八张负类图像,可以看到宽高比、大小各不同
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [test_imgs[-1-i][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs,2, 8, scale=2)

在这里插入图片描述
接下来就是训练时,我们先从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入,测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高、宽均为224的中心区域作为输入,此外对RGB三通道作标准化,每个数值减去通道的平均值,再除以标准差需要注意的是,在使用预训练模型时,一定要和预训练时作同样的预处理。 如果你使用的是torchvision的models,
那就要求: All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
如果你使用的是pretrained-models.pytorch仓库,请务必阅读其README,其中说明了如何预处理。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([#transforms.Resize(size=256),  # 是将最小边调整到256#transforms.CenterCrop(size=224),transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

需要注意的是,首先我有最开始有两点疑惑

  1. 为什么不能需要从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入。然后我测试了一下,如果不这样做的话,那么泛化能力会比较差
  2. 如果非要这么做,那么可不可以直接transforms.Resize(size=224)?不可以的,transforms.Resize(size=224)是把最短的边变为224,宽高比没变,那么这样就会导致图像的尺寸不一样,后面自然会报错,所以需要先transforms.Resize(size=256),然后transforms.CenterCrop(size=224)

之后我们使用在ImageNet上预训练的ResNet18,pretrained=True,自动下载预训练参数
不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home目录下.torch文件夹。
你可以通过修改环境变量$TORCH_MODEL_ZOO来更改下载目录

pretrained_net = models.resnet18(pretrained=True)

修改最后一层

pretrained_net.fc = nn.Linear(512, 2)

接下来设置训练的参数,由于除了最后一层,之前的参数都经过预训练,所以我们学习率调小一点,最后的fc层是初始化过的,于是我们学习率调大一点

output_params = list(map(id, pretrained_net.fc.parameters()))  # fc层
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())  # 除了fc层
lr = 0.01 # 用来更新特征层
# fc层是lr * 10
optimizer = optim.SGD([{"params":feature_params},{"params":pretrained_net.fc.parameters(), "lr":lr*10}
] ,lr = lr, weight_decay=0.001)

在之后就是训练了

def train_fine_tuning(net, optimizer, batch_size=64, num_epochs=5):train_iter = DataLoader(ImageFolder("hotdog/train", transform=train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder("hotdog/test", transform=test_augs), batch_size, shuffle=False)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
train_fine_tuning(pretrained_net, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/207704.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

07-2 Python模块和命名空间

1. 模块 概念:其实就是一个Python文件,正常文件有的变量,函数,类,模块都有 功能:模块可以被其它程序引入,以使用该模块中的函数等功能。 示例:test-module.py调用mymodule.py模块中的now_time…

充电桩IC

充电桩IC 电子元器件百科 文章目录 充电桩IC前言一、充电桩IC是什么二、充电桩IC的类别三、充电桩IC的应用实例四、充电桩IC的工作原理总结前言 充电桩IC的设计和功能会根据不同的充电协议和市场需求进行调整和定制。目前市场上有许多不同型号和厂家的充电桩IC可供选择,以满足…

一篇文章带你快速入门 Vue 核心语法

一篇文章带你快速入门 Vue 核心语法 一、为什么要学习Vue 1.前端必备技能 2.岗位多,绝大互联网公司都在使用Vue 3.提高开发效率 4.高薪必备技能(Vue2Vue3) 二、什么是Vue 概念:Vue (读音 /vjuː/,类似于 view) …

Mysql 日期函数大全

一、时间函数 (一)、获取当前时间 1、NOW() 获取当前日期和时间,在程序一开始执行便拿到时间 返回格式 YYYY-MM-DD hh:mm:ss eg: NOW() 得到 2023-12-03 12:20:02 NOW(),SLEEP(2),NOW() 得到 2023-12-03 12:20:02 | 0 | 2023-…

目标检测——OverFeat算法解读

论文:OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks 作者:Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus, Yann LeCun 链接:https://arxiv.org/abs/1312.6229 文章…

Go语言-让我印象深刻的13个特性

我们正在加速进入云原生时代,Go语言作为云原生的一块基石,确有它的独到之处。本文介绍Go语言的几个让我印象深刻的特性。 1、兼顾开发效率和性能 Go语言兼顾开发效率和性能。可以像Python那样有很快的开发速度,也可以像C那样有很快的执行速…

SpringAOP专栏二《原理篇》

上一篇SpringAOP专栏一《使用教程篇》-CSDN博客介绍了SpringAop如何使用,这一篇文章就会介绍Spring AOP 的底层实现原理,并通过源代码解析来详细阐述其实现过程。 前言 Spring AOP 的实现原理是基于动态代理和字节码操作的。不了解动态代理和字节码操作…

【C语言】函数递归详解(一)

目录 1.什么是递归: 1.1递归的思想: 1.2递归的限制条件: 2.递归举例: 2.1举例1:求n的阶乘: 2.1.1 分析和代码实现: 2.1.2图示递归过程: 2.2举例2:顺序打印一个整数的…

机器学习---集成学习的初步理解

1. 集成学习 集成学习(ensemble learning)是现在非常火爆的机器学习方法。它本身不是一个单独的机器学 习算法,而是通过构建并结合多个机器学习器来完成学习任务。也就是我们常说的“博采众长”。集 成学习可以用于分类问题集成,回归问题集成&#xff…

多线程并发Ping脚本

1. 前言 最近需要ping地址,还是挺多的,就使用python搞一个ping脚本,记录一下,以免丢失了。 2. 脚本介绍 首先检查是否存在True.txt或False.txt文件,并在用户确认后进行删除,然后从IP.txt的文件中读取IP地…

CSS——sticky定位

1. 大白话解释sticky定位 粘性定位通俗来说,它就是相对定位relative和固定定位fixed的结合体,它的触发过程分为三个阶段 在最近可滚动容器没有触发滑动之前,sticky盒子的表现为相对定位relative【第一阶段】, 但当最近可滚动容…

【MATLAB】tvfEMD信号分解+FFT+HHT组合算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 TVFEMDFFTHHT组合算法是一种结合了总体变分模态分解(TVFEMD)、傅里叶变换(FFT)和希尔伯特-黄变换(HHT)的信号分解方…

vivado时序方法检查8

TIMING-30 &#xff1a; 生成时钟所选主源管脚欠佳 生成时钟 <clock_name> 所选的主源管脚欠佳 &#xff0c; 时序可能处于消极状态。 描述 虽然 create_generated_clock 命令允许您指定任意参考时钟 &#xff0c; 但是生成时钟应引用在其直接扇入中传输的时钟。此…

电子学会C/C++编程等级考试2021年06月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数字变换 给定一个包含5个数字(0-9)的字符串,例如 “02943”,请将“12345”变换到它。 你可以采取3种操作进行变换 1. 交换相邻的两个数字 2. 将一个数字加1。如果加1后大于9,则变为0 3. 将一个数字加倍。如果加倍后大于…

JS--异步的日常用法

目录 JS 异步编程并发&#xff08;concurrency&#xff09;和并行&#xff08;parallelism&#xff09;区别回调函数&#xff08;Callback&#xff09;GeneratorPromiseasync 及 await常用定时器函数 JS 异步编程 并发&#xff08;concurrency&#xff09;和并行&#xff08;p…

Python中一些有趣的例题

下面会写一些基础的例题&#xff0c;有兴趣的自己也可以练练手&#xff01; 1.假设手机短信收到的数字验证码为“278902”&#xff0c;编写一个程序&#xff0c;让用户输入数字验证码&#xff0c;如果数字验证码输入正确&#xff0c;提示“支付成功”&#xff1b;否则提示“数…

Python configparser 模块:优雅处理配置文件的得力工具

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 配置文件在软件开发中扮演着重要的角色&#xff0c;而Python中的 configparser 模块提供了一种优雅而灵活的方式来处理各种配置需求。本文将深入介绍 configparser 模块的各个方面&#xff0c;通过丰富的示例代码…

嵌入式杂记 - MDK的Code, RO-data , RW-data, ZI-data意思

嵌入式杂记 - Keil的Code, RO-data , RW-data, ZI-data意思 MDK中的数据分类MCU中的内部存储分布MDK中数据类型存储Code代码段例子 RO-data 只读数据段例子 RW-data 可读写数据段例子 ZI-data 清零数据段例子 在嵌入式开发中&#xff0c;我们经常都会使用一些IDE&#xff0c;例…

Hadoop学习笔记(HDP)-Part.17 安装Spark2

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …

Web前端 ---- 【Vue】Vuex的使用(辅助函数、模块化开发)

目录 前言 Vuex是什么 Vuex的配置 安装vuex 配置vuex文件 Vuex核心对象 actions mutations getters state Vuex在vue中的使用 辅助函数 Vuex模块化开发 前言 本文介绍一种新的用于组件传值的插件 —— vuex Vuex是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态…