[pytorch、学习] - 9.2 微调

参考

9.2 微调

在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。

假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常用的椅子,为椅子拍摄1000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞大,但样本仍然不及ImageNet数据集中样本数的十分之一。这可能会导致适用于ImageNet数据集的复杂模型在这个椅子数据集上过拟合。同时,因为数据量有限,但其成本仍热不可忽略。

另一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

本节我们介绍迁移学习中的一种常用技术: 微调(fine tuning)。如图9.1所示,微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

9.2.1 热狗识别

接下来我们来实践一个具体的例子: 热狗识别。我们将基于一个小数据集在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张包含热狗和不包含热狗的图像。我们使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需要的包或模块。torchvision的models包提供了常用的预训练模型。如果希望获取更多的预训练模型,可以使用pretrained-models.pytorch仓库.

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import osimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.2.1.1 获取数据集

我们使用的热狗数据集是从网上抓取的,它包含1400张含热狗的正类图像,和同样多包含其他食品的负类图像。各类的1000张图像被用于训练,其余则用于测试。

我们首先将压缩后的数据集下载到路径data_dir之下,然后在该路径将下载好的数据集解压,得到两个文件夹hotdog/trainhotdog/test。这两个文件夹下面均有hotdognot-hotdog两个类别文件夹,每个类别文件夹里面是图像文件。

data_dir = "C:/Users/1/Datasets"
os.listdir(os.path.join(data_dir, 'hotdog'))

我们创建两个ImageFolder实例来分别读取训练数据集和测试数据集中的所有图像文件

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

下面画出前8张正类图像和最后8张负类图像。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i- 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

在这里插入图片描述
在训练时,我们先从图像中裁剪随机大小和随机宽高比的一块随机区域,然后将该区域缩放为高和宽均为224像素的输入。测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高和宽均为224像素的中心区域作为输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化:每个数值减去通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。

注: 使用pretrained-models仓库时,一定要对图像进行相应的预处理

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406 ], std = [0.229, 0.224, 0.225])
train_augs = transforms.Compose([transforms.RandomResizedCrop(size= 224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])
test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

9.2.1.2 定义和初始化模型

我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。这里指定pretrained=True来自动下载并记载预训练的模型参数。在第一次使用时需联网下载模型参数

pretrained_net  = models.resnet18(pretrained=True)

打印源模型的成员变量fc。作为一个全连接层,它将ResNet最终的全局平均池化层输出变成ImageNet数据集上1000类的输出

print(pretrained_net.fc)

在这里插入图片描述
可见此时pretrained_net最后的输出个数等于目标数据集的类别数1000。所以我们应该将最后的fc修改成我们需要输出类别数:

pretrained_net.fc = nn.Linear(512, 2)

此时,pretrained_netfc层就随机初始化了,但是其他层依然保存着预训练得到的参数。由于是在很大的ImageNet数据集上预训练的,所以参数已经足够好,因此一般只需使用较小的学习率来微调这些参数,而fc中的随机参数一般需要更大的学习率从头训练。PyTorch可以方便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设置为已经预训练过的部分的10倍

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())lr = 0.01
optimizer = optim.SGD([{'params': feature_params},{'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],lr = lr, weight_decay=0.001)

9.2.1.3 微调模型

我们先定义一个使用微调的训练函数train_fine_tuning以便多次调用。

def train_fine_tuning(net, optimizer, batch_size = 128, num_epochs = 15):train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform = train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs), batch_size)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)

根据前面的设置,我们将以10倍的学习率从头训练目标模型的输出层参数。

train_fine_tuning(pretrained_net, optimizer)

在这里插入图片描述
作为对比,我们定义一个相同的模型,但将它的所有模型参数都初始化为随机值。由于整个模型都需要从头训练,我们可以使用较大的学习率。

scratch_net = models.resnet18(pretrained=False, num_classes=2)
lr = 0.1
optimizer  = optim.SGD(scratch_net.parameters(), lr = lr, weight_decay = 0.001)
train_fine_tuning(scratch_net, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250110.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于mac机抓包的几点基础知识

1. 我使用的抓包工具为WireShark,以下操作按我当前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的区别。 2. 将mac设置为热点:打开系统偏好设置,点击共享: 然后点击WIFI选项,设置WIFI名…

SpringBoot启动如何加载application.yml配置文件

一、前言 在spring时代配置文件的加载都是通过web.xml配置加载的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多数都是通过指定路径的文件名的形式去告诉spring该加载哪个文件&#xff1b; <context-param><param-name>contextConfigLocat…

阿里云服务器端口开放对外访问权限

登陆阿里云管理控制台 点击自己的实例 点击安全组配置 点击配置规则 点击添加安全组规则 配置出入放心&#xff0c;和开放的端口号&#xff0c;以及那些网段可以访问&#xff0c;这里设置所有网段都可以访问 转自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

数据分页功能是我们软件系统中必备的功能&#xff0c;在持久层使用mybatis的情况下&#xff0c;pageHelper来实现后台分页则是我们常用的一个选择&#xff0c;所以本文专门类介绍下。 PageHelper原理 相关依赖 <dependency><groupId>org.mybatis</groupId>&…

10-多写一个@Autowired导致程序崩了

再是javaweb实验六中&#xff0c;是让我们改代码&#xff0c;让它跑起来&#xff0c;结果我少注释了一个&#xff0c;导致一直报错&#xff0c;检查许久没有找到&#xff0c;最后通过代码替换逐步查找&#xff0c;才发现问题。 转载于:https://www.cnblogs.com/zhumengdexiaoba…

springboot---整合redis

pom.xml新增 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>代码结构如下 其中redis.yml是连接redis的配置文件&#xff0c;RedisConfig.java是java配置…

[Head First Java] - Swing做一个简单的客户端

参考 - P487 1. vscode配置java的格式 点击左下角齿轮 -> 设置 -> 打开任意的setting.json输入如下代码 {code-runner.executorMap": {"java": "cd $dir && javac -encoding utf-8 $fileName && java $fileNameWithoutExt"},…

计算机网络知识总结

一 OSI与TCP/IP各层的结构与功能&#xff0c;都有哪些协议 OSI的七层体系结构概念清楚&#xff0c;理论也很完整&#xff0c;但是它比较复杂而且不实用。在这里顺带提一下之前一直被一些大公司甚至一些国家政府支持的OSI失败的原因&#xff1a; OSI的专家缺乏实际经验&#xff…

[Head First Java] - 给线程命名

参考 - P503 public class RunThreads implements Runnable {public static void main (String[] args) {RunThreads runner new RunThreads();Thread alpha new Thread(runner);Thread beta new Thread(runner);alpha.setName("Alpha thread");beta.setName(&qu…

快速排序的C++版

int Partition(int a[], int low, int high) {int x a[high];//将输入数组的最后一个数作为主元&#xff0c;用它来对数组进行划分int i low - 1;//i是最后一个小于主元的数的下标for (int j low; j < high; j)//遍历下标由low到high-1的数{if (a[j] < x)//如果数小于…

asp.net中提交表单数据时提示从客户端(。。。)中检测到有潜在危险的 Request.Form 值...

看到这个图是不是很亲切熟悉哈&#xff0c;做过。net的肯定都见过哈 已经 将近4年没碰。net了&#xff0c;今天正好朋友的程序有几个bug,让我帮忙修复下&#xff0c;于是我就抱着试试看的心情改了改&#xff0c;改到最后一个问题的时候也就是上面的这个问题&#xff0c;我一看&…

Shiro表结构设计

表设计 开发用户-角色-权限管理系统&#xff0c;首先我们需要知道用户-角色-权限管理系统的表结构设计。 在用户-角色-权限管理系统找那个一般会涉及5张表&#xff0c;分别为&#xff1a; 1.sys_users用户表 2.sys_roles角色表 3.sys_permissions权限表&#xff08;或资源表&…

[Java核心技术(卷I)] - 简易的日历

参考 - P102~P103 1. 目标 生成一个日历,格式如下图所示。 ps: 当前的天数需要标记为* 2. 核心 对日历的变量 import java.time.*; public class CalendarTest{public static void main(String[] args) {LocalDate date LocalDate.now(); // 获取当前日期int month date…

个人作业——福大微信公众号使用评测

案例分析&#xff1a;在福州大学公众号上&#xff0c;我们可以即时使用手机关注福大新闻&#xff0c;查看自身课表、成绩等。公众号可能存在一些小bug影响同学们的用户体验。本次作业中&#xff0c;作为一个用户——福大的学生&#xff0c;将切身体验该公众号的功能&#xff0c…

在winform中使用wpf窗体

在winform项目&#xff0c;通过引用dll可以添加WPF窗体&#xff0c;如下 但是如果直接在winform的项目中添加wpf窗体还是有部分问题&#xff0c;图片的显示。 直接在XAML界面中用Source属性设置图片会出现错误。必须通过后台代码的方式来实现。 image1.Source GetImageIcon(gl…

shiro---注解

RequiresAuthentication 验证用户是否登录&#xff0c;等同于方法subject.isAuthenticated() 结果为true时。 RequiresUser 验证用户是否被记忆&#xff0c;user有两种含义&#xff1a; 一种是成功登录的&#xff08;subject.isAuthenticated() 结果为true&#xff09;&…

[Java核心技术(卷I)] - Java中的参数能做什么和不能做什么

1. 参考 - P123 ~ P126 2. 你将学到 Java中对方法参数能做什么和不能做什么 方法不能修改基本数据类型的参数(数值型或布尔型)方法可以改变对象参数的状态方法不能让一个对象参数引用一个新的对象 3. 代码证明 public class ParamTest {public static void main(String[] ar…

软件构造 第五章第一节 可复用性的度量、形态和外部观察

第五章第一节 可复用性的度量、形态和外部观察 面向复用编程(programming for reuse)&#xff1a;开发出可复用的软件 基于复用编程(programming with reuse)&#xff1a;利用已有的可复用软件搭建应用系统 代码复用的类型&#xff1a; 白盒复用&#xff1a;源代码可见&#x…

[web性能优化] - 使用在线工具对html、js、css进行压缩

参考 1. 学习点 使用 在线工具对html、css、js进行压缩学会分析压缩前后的效率提高点 2. 解决方案: 2.1 HTML压缩 在线压缩nodejs提供了 html-minifier工具(在构建层对代码进行压缩)后端模板引擎渲染压缩 2.2 CSS压缩 使用html-minifier对html中的css进行压缩使用clean-cs…

SpringBoot之基础

简介 背景 J2EE笨重的开发 / 繁多的配置 / 低下的开发效率 / 复杂的部署流程 / 第三方技术集成难度大 特点 ① 快速创建独立运行的spring项目以及主流框架集成 ② 使用嵌入式的Servlet容器, 应用无需达成war包 ③ starters自动依赖和版本控制 ④ 大量自动配置, 简化开发, 也可修…