预训练--微调

预训练–微调

一个很简单的道理,如果我们的模型是再ImageNet下训练的,那么这个模型一定是会比较复杂的,意思就是这个模型可以识别到很多种类别的即泛化能力很强,但是如果要它精确的识别是否某种类别,它的表现可能就不佳了,因此,我们需要在原来的基础上再对特定的我们需要识别的类别进行重新训练,微调原来网络结构中的参数,此时模型还是可以抽取较通用的图像特征。
在这里插入图片描述
参考自https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter09_computer-vision/9.2_fine-tuning
当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

热狗识别

源数据集是ImageNet,超过1000万个图像和1000类物体,热狗数据集包含1400个正类图像和其他多种负类图像
最开始还是导入所需要的库以及设置cuda

import torch
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os
import d2lzh_pytorch as d2l
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下载数据集https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/hotdog.zip
我直接放在了我的默认路径下,读数据如下

train_imgs = ImageFolder("hotdog/train")
test_imgs = ImageFolder("hotdog/test")

然后我们观察一下数据集,可以看到大小,宽高比各不同

# 前八张正类图像和最后八张负类图像,可以看到宽高比、大小各不同
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [test_imgs[-1-i][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs,2, 8, scale=2)

在这里插入图片描述
接下来就是训练时,我们先从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入,测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高、宽均为224的中心区域作为输入,此外对RGB三通道作标准化,每个数值减去通道的平均值,再除以标准差需要注意的是,在使用预训练模型时,一定要和预训练时作同样的预处理。 如果你使用的是torchvision的models,
那就要求: All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
如果你使用的是pretrained-models.pytorch仓库,请务必阅读其README,其中说明了如何预处理。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([#transforms.Resize(size=256),  # 是将最小边调整到256#transforms.CenterCrop(size=224),transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

需要注意的是,首先我有最开始有两点疑惑

  1. 为什么不能需要从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入。然后我测试了一下,如果不这样做的话,那么泛化能力会比较差
  2. 如果非要这么做,那么可不可以直接transforms.Resize(size=224)?不可以的,transforms.Resize(size=224)是把最短的边变为224,宽高比没变,那么这样就会导致图像的尺寸不一样,后面自然会报错,所以需要先transforms.Resize(size=256),然后transforms.CenterCrop(size=224)

之后我们使用在ImageNet上预训练的ResNet18,pretrained=True,自动下载预训练参数
不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home目录下.torch文件夹。
你可以通过修改环境变量$TORCH_MODEL_ZOO来更改下载目录

pretrained_net = models.resnet18(pretrained=True)

修改最后一层

pretrained_net.fc = nn.Linear(512, 2)

接下来设置训练的参数,由于除了最后一层,之前的参数都经过预训练,所以我们学习率调小一点,最后的fc层是初始化过的,于是我们学习率调大一点

output_params = list(map(id, pretrained_net.fc.parameters()))  # fc层
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())  # 除了fc层
lr = 0.01 # 用来更新特征层
# fc层是lr * 10
optimizer = optim.SGD([{"params":feature_params},{"params":pretrained_net.fc.parameters(), "lr":lr*10}
] ,lr = lr, weight_decay=0.001)

在之后就是训练了

def train_fine_tuning(net, optimizer, batch_size=64, num_epochs=5):train_iter = DataLoader(ImageFolder("hotdog/train", transform=train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder("hotdog/test", transform=test_augs), batch_size, shuffle=False)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
train_fine_tuning(pretrained_net, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/207704.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

07-2 Python模块和命名空间

1. 模块 概念:其实就是一个Python文件,正常文件有的变量,函数,类,模块都有 功能:模块可以被其它程序引入,以使用该模块中的函数等功能。 示例:test-module.py调用mymodule.py模块中的now_time…

一篇文章带你快速入门 Vue 核心语法

一篇文章带你快速入门 Vue 核心语法 一、为什么要学习Vue 1.前端必备技能 2.岗位多,绝大互联网公司都在使用Vue 3.提高开发效率 4.高薪必备技能(Vue2Vue3) 二、什么是Vue 概念:Vue (读音 /vjuː/,类似于 view) …

Mysql 日期函数大全

一、时间函数 (一)、获取当前时间 1、NOW() 获取当前日期和时间,在程序一开始执行便拿到时间 返回格式 YYYY-MM-DD hh:mm:ss eg: NOW() 得到 2023-12-03 12:20:02 NOW(),SLEEP(2),NOW() 得到 2023-12-03 12:20:02 | 0 | 2023-…

目标检测——OverFeat算法解读

论文:OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks 作者:Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus, Yann LeCun 链接:https://arxiv.org/abs/1312.6229 文章…

SpringAOP专栏二《原理篇》

上一篇SpringAOP专栏一《使用教程篇》-CSDN博客介绍了SpringAop如何使用,这一篇文章就会介绍Spring AOP 的底层实现原理,并通过源代码解析来详细阐述其实现过程。 前言 Spring AOP 的实现原理是基于动态代理和字节码操作的。不了解动态代理和字节码操作…

【C语言】函数递归详解(一)

目录 1.什么是递归: 1.1递归的思想: 1.2递归的限制条件: 2.递归举例: 2.1举例1:求n的阶乘: 2.1.1 分析和代码实现: 2.1.2图示递归过程: 2.2举例2:顺序打印一个整数的…

机器学习---集成学习的初步理解

1. 集成学习 集成学习(ensemble learning)是现在非常火爆的机器学习方法。它本身不是一个单独的机器学 习算法,而是通过构建并结合多个机器学习器来完成学习任务。也就是我们常说的“博采众长”。集 成学习可以用于分类问题集成,回归问题集成&#xff…

多线程并发Ping脚本

1. 前言 最近需要ping地址,还是挺多的,就使用python搞一个ping脚本,记录一下,以免丢失了。 2. 脚本介绍 首先检查是否存在True.txt或False.txt文件,并在用户确认后进行删除,然后从IP.txt的文件中读取IP地…

CSS——sticky定位

1. 大白话解释sticky定位 粘性定位通俗来说,它就是相对定位relative和固定定位fixed的结合体,它的触发过程分为三个阶段 在最近可滚动容器没有触发滑动之前,sticky盒子的表现为相对定位relative【第一阶段】, 但当最近可滚动容…

【MATLAB】tvfEMD信号分解+FFT+HHT组合算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 TVFEMDFFTHHT组合算法是一种结合了总体变分模态分解(TVFEMD)、傅里叶变换(FFT)和希尔伯特-黄变换(HHT)的信号分解方…

电子学会C/C++编程等级考试2021年06月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数字变换 给定一个包含5个数字(0-9)的字符串,例如 “02943”,请将“12345”变换到它。 你可以采取3种操作进行变换 1. 交换相邻的两个数字 2. 将一个数字加1。如果加1后大于9,则变为0 3. 将一个数字加倍。如果加倍后大于…

Python configparser 模块:优雅处理配置文件的得力工具

更多资料获取 📚 个人网站:ipengtao.com 配置文件在软件开发中扮演着重要的角色,而Python中的 configparser 模块提供了一种优雅而灵活的方式来处理各种配置需求。本文将深入介绍 configparser 模块的各个方面,通过丰富的示例代码…

嵌入式杂记 - MDK的Code, RO-data , RW-data, ZI-data意思

嵌入式杂记 - Keil的Code, RO-data , RW-data, ZI-data意思 MDK中的数据分类MCU中的内部存储分布MDK中数据类型存储Code代码段例子 RO-data 只读数据段例子 RW-data 可读写数据段例子 ZI-data 清零数据段例子 在嵌入式开发中,我们经常都会使用一些IDE,例…

Hadoop学习笔记(HDP)-Part.17 安装Spark2

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …

Web前端 ---- 【Vue】Vuex的使用(辅助函数、模块化开发)

目录 前言 Vuex是什么 Vuex的配置 安装vuex 配置vuex文件 Vuex核心对象 actions mutations getters state Vuex在vue中的使用 辅助函数 Vuex模块化开发 前言 本文介绍一种新的用于组件传值的插件 —— vuex Vuex是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态…

【ArcGIS Pro微课1000例】0053:基于SQL Server创建与启用地理数据库

之前的文章有讲述基于SQL Server创建企业级地理数据库,本文讲述在SQL Server中创建常规的关心数据库,然后在ArcGIS Pro中将其启用,转换为企业级地理数据库。 1. 在SQL Server中创建数据库** 打开SQL Server 2019,连接到数据库服务器。 展开数据库连接,在数据库上右键→新…

(四)Tiki-taka算法(TTA)求解无人机三维路径规划研究(MATLAB)

一、无人机模型简介: 单个无人机三维路径规划问题及其建模_IT猿手的博客-CSDN博客 参考文献: [1]胡观凯,钟建华,李永正,黎万洪.基于IPSO-GA算法的无人机三维路径规划[J].现代电子技术,2023,46(07):115-120 二、Tiki-taka算法(TTA&#xf…

基于SSH的java记账管理系统

基于SSH的java记账管理系统 一、系统介绍二、功能展示四、其他系统实现五、获取源码 一、系统介绍 项目类型:Java EE项目 项目名称:基于SSH的记账管理系统 项目架构:B/S架构 开发语言:Java语言 前端技术:HTML、CS…

初识优先级队列与堆

1.优先级队列 由前文队列queue可知,队列是一种先进先出(FIFO)的数据结构,但有些情况下,操作的数据可能带有优先级,一般出队列时,可能需要优先级高的元素先出队列,在此情况下,使用队列queue显然不…

git常用命令指南

目录 一、基本命令 1、创建分支 2、切换分支 3、合并分支 4、初始化空git仓库 二、文件操作 1、创建文件 2、添加多个文件 3、查看项目的当前状态 4、修改文件 5、删除文件 6、提交项目 三、实际操作 1、创建目录 2、进入新目录 3、初始化空git仓库 4、创建文…