微调Fine tune

网络架构
一个神经网络一般可以分为两块

  • 特征抽取将原始像素变成容易线性分割的特征
  • 线性分类器来做分类
    在这里插入图片描述

微调:使用之前已经训练好的特征抽取模块来直接使用到现有模型上,而对于线性分类器由于标号可能发生改变而不能直接使用

训练
是一个目标数据集上的正常训练任务,但使用更强的正则化

  • 使用更小的学习率
  • 使用更小的数据迭代
    源数据集远复杂于目标数据,通常微调效果会更好

重用分类器权重

  • 源数据集可能也有目标数据中的部分标号
  • 可以使用预训练好模型分类器中对应标号对应的向量来做初始化

固定一些层
神经网络通常学习有层次的特征表示

  • 低层次的特征更加通用
  • 高层次的特征则更和数据集相关
    可以固定底部一些层的参数,不参与更新来减小模型的复杂度
  • 更强的正则

微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来完成提升精度
预训练模型质量很重要
微调通常速度更快,精度更好

就是重用在大数据集上训练好的模型的特征提取模块,用来做自己模型的特征提取的初始化,用来使得相比于随机初始化有更好的效果

1. 实现

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
# 热狗数据集来源于网络d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
# 图像的大小和纵横比各有不同hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
# 数据增广normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], # 因为要使用imageNet上的特征提取模块,所以要对数据先进行归一化【方差,均值】[0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(), normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(), normalize])
# 定义和初始化模型pretrained_net = torchvision.models.resnet18(pretrained=True)pretrained_net.fc # 输出Linear(in_features=512, out_features=1000, bias=True)finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) # 只对最后一层的类别改变
nn.init.xavier_uniform_(finetune_net.fc.weight)
# 微调模型def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), # 最后一层分类器的学习率提高十倍,为了使其能够更快的学习'lr': learning_rate * 10}], lr=learning_rate,weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
# 使用较小的学习率train_fine_tuning(finetune_net, 5e-5)# 为了进行比较, 所有模型参数初始化为随机值 -》结果没有之前微调的效果好
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/187719.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

道可云元宇宙每日资讯|智慧旅游发展大会暨智慧旅游示范展示活动在南京举办

道可云元宇宙每日简报(2023年11月28日)讯,今日元宇宙新鲜事有: 智慧旅游发展大会暨智慧旅游示范展示活动在南京举办 2023年11月23日至25日,由文化和旅游部资源开发司、江苏省文化和旅游厅共同主办的“智慧旅游发展大会…

Linux驱动开发——网络设备驱动(实战篇)

目录 四、 网络设备驱动实例 五、DM9000 网络设备驱动代码分析 六、NAPI 七、习题 书接上回: Linux驱动开发——网络设备驱动(理论篇)-CSDN博客 (没看过上面博客的同学,skb是linux对于网络套接字缓冲区的一个虚拟…

Leetcode 136. 只出现一次的数字

class Solution {//任何数与0异或结果都是原来的数//任何数和自身异或结果都是0//异或满足交换律和结合律//a ^ b ^ a (a ^ a) ^ b 0 ^ b bpublic int singleNumber(int[] nums) {int res nums[0];for(int i 1; i < nums.length; i){res ^ nums[i];}return res;} }

OpenCvSharp从入门到实践-(04)色彩空间

目录 1、GRAY色彩空间 2、从BGR色彩空间转换到GRAY色彩空间 2.1色彩空间转换码 2.2实例 BGR色彩空间转换到GRAY色彩空间 3、HSV色彩空间 4、从BGR色彩空间转换到HSV色彩空间 4.1色彩空间转换码 4.2实例 BGR色彩空间转换到HSV色彩空间 1、GRAY色彩空间 GRAY色彩空间通常…

26、Spring是如何解决Bean的循环依赖?

Spring是如何解决Bean的循环依赖&#xff1f; 采用三级缓存解决的 就是三个Map &#xff1b; 关键&#xff1a; 一定要有一个缓存保存它的早期对象作为死循环的出口 一级缓存&#xff1a;存储完整的Bean二级缓存&#xff1a; 避免多重循环依赖的情况 重复创建动态代理。三级缓…

Spring简单的存储和读取

前言 前面讲了spring的创建&#xff0c;现在说说关于Bean和五大类注解 一、Bean是什么&#xff1f; 在 Java 语⾔中对象也叫做 Bean&#xff0c;所以后⾯咱们再遇到对象就以 Bean 著称。这篇文章还是以spring创建为主。 二、存储对象 2.1 俩种存储方式 需要在 spring-conf…

FlinkSql-Temporal Joins-Lookup Join

说明 在 Flink SQL 中&#xff0c;Temporal Joins 是一种常见的数据关联操作&#xff0c;特别适用于处理包含时间维度的数据。Lookup Join 是 Temporal Joins 的一种类型&#xff0c;它允许将流数据与维表数据进行关联。使用场景如下&#xff1a; 实时维度关联&#xff1a; 当…

Python---文件备份案例

需求&#xff1a;用户输入当前目录下任意文件名&#xff0c;完成对该文件的备份功能(备份文件名为xx[备份]后缀&#xff0c;例如&#xff1a;test[备份].txt)。 思考&#xff1a; ① 接收用户输入的文件名 ② 规划备份文件名 ③ 备份文件写入数据 代码 # 1、接收用户输入的…

paddle detection整体结构

核心思想就是通过Yaml文件将主体模块和可拔插的模块组成一个完整的pipline. train.py流程解析&#xff1a; 初始化训练参数 1 parserArgsParser() #读取命令行传递参数&#xff0c;加载yaml文件参数 2 整合参数&#xff0c;检查参数配置是否正确 3 检查是否使用GPU加速 4 检查…

Ubuntu 18.04 ARM离线安装cifs-utils

1、环境说明 由于本地都是x86&#xff0c;不支持arm架构&#xff0c;所以用Docker容器下载离线包本地环境&#xff1a;Docker、Ubuntu 22.04.1 LTS x86&#xff08;可上网&#xff09;安装环境&#xff1a;Ubuntu 18.04.4 LTS arm&#xff08;内网&#xff09; 2、启动qemu-a…

使用Jmeter进行http接口性能测试

在进行网页或应用程序后台接口开发时&#xff0c;一般要及时测试开发的接口能否正确接收和返回数据&#xff0c;对于单次测试&#xff0c;Postman插件是个不错的Http请求模拟工具。 但是Postman只能模拟单客户端的单次请求&#xff0c;而对于模拟多用户并发等性能测试&#xf…

[Verilog语法]:===和!==运算符使用注意事项

[Verilog语法]&#xff1a;和!运算符使用注意事项 1&#xff0c; 和 !运算符使用注意事项2&#xff0c;3&#xff0c; 1&#xff0c; 和 !运算符使用注意事项 参考文献&#xff1a; 1&#xff0c;[SystemVerilog语法拾遗] 和!运算符使用注意事项 2&#xff0c; 3&#xff0c;

机器学习入门(第五天)——决策树(每次选一边)

Decision tree 知识树 Knowledge tree 一个小故事 A story 挑苹果&#xff1a; 根据这些特征&#xff0c;如颜色是否是红色、硬度是否是硬、香味是否是香&#xff0c;如果全部满足绝对是好苹果&#xff0c;或者红色硬但是无味也是好苹果&#xff0c;从上图可以看出来&#…

数据可视化:用图表和图形展示数据

写在开头 在当今信息爆炸的时代,海量的数据如同一座沉默的宝库,等待着我们挖掘和理解。然而,这些庞大的数据集本身可能令人望而生畏。在这个时候,数据可视化成为了解数据、发现模式和传达信息的强大工具。本篇博客将带领你探索数据可视化的奇妙世界,学习如何在python中使…

91基于matlab的以GUI实现指纹的识别和匹配百分比

基于matlab的以GUI实现指纹的识别和匹配百分比,中间有对指纹的二值化&#xff0c;M连接&#xff0c;特征提取等处理功能。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 91M连接 特征提取 (xiaohongshu.com)

Windows 安装redis,设置开机自启动

Windows 安装redis,设置开机自启动 文章目录 Windows 安装redis,设置开机自启动下载, 解压到指定目录设置redis密码启动redis服务端停止redis服务端设置自启动 下载, 解压到指定目录 官网地址: https://redis.io/ 安装包下载地址: https://github.com/tporadowski/redis/relea…

NB-IoT BC260Y Open CPU SDK⑥ADC的应用

NB-IoT BC260Y Open CPU SDK⑥ADC的应用 1、BC260Y_CN_AA模块 ADC的介绍2、ADC相关API的介绍3、软件设计4、实例分析5、以下是调试的结果:1、BC260Y_CN_AA模块 ADC的介绍 BC260Y-CN QuecOpen 模块提供 2 个专用于 ADC(ADC0、ADC1)功能的 I/O 引脚。通过相应的 API函数可以直…

掌握Vue侦听器(watch)的应用

文章目录 &#x1f341;watch 的优缺点&#x1f342;Watch 优点&#x1f342;Watch 缺点 &#x1f341;watch 的用法&#x1f342;对象式 watch&#x1f342;函数式 watch &#x1f341;代码示例&#x1f342;监听基本数据类型&#x1f342;监听复杂数据类型&#xff08;Object…

GPLT(有空就写)

L2 - 047 锦标赛 思路&#xff1a; 将其放入一颗满二叉树上去考虑&#xff1a;从二叉树的最底层开始&#xff0c;每一轮比赛&#xff0c;为同一个祖先的左右两个儿子进行比较&#xff0c;而你需要将败者的能力值填到左右两个儿子其中一个上面&#xff0c;另一个就向上传递表示胜…

Day51:503.下一个更大元素II、42. 接雨水

文章目录 503.下一个更大元素II思路代码实现 42. 接雨水思路代码实现 503.下一个更大元素II 题目链接 思路 这道题和下一个更大元素 I的不同之处在于这个查找是循环的。 循环直接可以用查找两次来解决&#xff0c;所以题目步骤唯一不同的就是循环的终止位置。 for(int i1;i…