3.2.微调

微调

​ 对于一些样本数量有限的数据集,如果使用较大的模型,可能很快过拟合,较小的模型可能效果不好。这个问题的一个解决方案是收集更多数据,但其实在很多情况下这是很难做到的。

​ 另一种方法就是迁移学习(transfer learning),将源数据集学到地知识迁移到目标数据集,例如,我们只想识别椅子,只有100把椅子,每把椅子的1000张不同角度的图像,尽管ImageNet数据集中大多数图像与椅子无关,但在次数据集上训练的模型可能会提取更通用的图像特征(可以理解为越底层的layer提取的特征越通用),这有助于识别边缘、纹理、形状和对象组合,也可能有效地识别椅子。

在这里插入图片描述

1. 步骤

​ 微调是迁移学习中的常见技巧,步骤如下:

  1. 在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型
  2. 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关;因此不在目标模型中使用该层。
  3. 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

在这里插入图片描述

1.1 目标模型的训练:

​ 是一个正在目标数据集上的正常训练任务,但使用更强的正则化(参数变化不大):

  • 更小的学习率
  • 更少的数据迭代

​ 如果源数据集远复杂于目标数据,通常微调效果更好

1.2 重用分类器权重

​ 有些时候源数据集可能也有目标数据中的部分标号,比如ImageNet里可能有椅子这一标签,那么可以使用预训练好的模型分类器中对应标号中对应的向量来做初始化(就直接copy)

1.3 固定一些层

​ 通常而言,神经网络中低层次的特征更加通用,高层次的特征则更跟数据集相关。

​ 那么可以固定底部一些层的参数,不参与更新,这样也能有更强的正则。

2.热狗识别

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import torch_directmldevice = torch_directml.device()
# @save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
d2l.plt.show()# 使用RGB通道的均值和标准差,以标准化每个通道 ,因为预训练的模型做了这个
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])# 先随机裁剪,并变为224 * 224 的图形,因为预训练模型输入是这个
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])
# 将图像的高度和宽度都缩放到256像素,然后裁剪中央 224 * 224的区域来作为输入
test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])# 下载模型,pretrained参数已被弃用,使用weights来获取与训练模型
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)  # 预训练最后一层为输出层finetune_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).to(device)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2).to(device)
nn.init.xavier_uniform_(finetune_net.fc.weight)def train_batch_ch13(net, X, y, loss, trainer, devices):"""用多GPU进行小批量训练"""if isinstance(X, list):# 微调BERT中所需X = [x.to(devices) for x in X]else:X = X.to(devices)y = y.to(devices)net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum# @save 多GPU的,把参数devices改成device了,本来是个列表
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device):"""用多GPU进行模型训练"""timer, num_batches = d2l.Timer(), len(train_iter)animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],legend=['train loss', 'train acc', 'test acc'])# net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):# 4个维度:储存训练损失,训练准确度,实例数,特点数metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch_ch13(net, features, labels, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3],None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc 'f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f'{str(device)}')# 微调
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)# devices = d2l.try_all_gpus()device = torch_directml.device()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]# 最后一层使用10倍学习率trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device)train_fine_tuning(finetune_net, 5e-5)
d2l.plt.show()

loss 0.270, train acc 0.899, test acc 0.948
232.6 examples/sec on privateuseone:0

一次训练效果就很好了,而且后续训练很平滑,没有过拟合。

在这里插入图片描述

​ 如果初始化为随机值:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/51723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go语言编程 学习笔记整理 第2章 顺序编程 前半部分

前言:《Go语言编程》编著 许式伟 吕桂华 等 1.1 变量 var v1 int var v2 string var v3 [10]int // 数组 var v4 []int // 数组切片 var v5 struct { f int } var v6 *int // 指针 var v7 map[string]int // map,key为string类型,value为in…

【QT】qt 文件操作

qt 文件 qt 文件1. Qt 文件概述2. 输入输出设备类3. 文件读写类4. 文件和目录信息类 qt 文件 1. Qt 文件概述 文件操作是应用程序必不可少的部分。Qt 作为⼀个通用开发库,提供了跨平台的文件操作能力。 Qt 提供了很多关于文件的类,通过这些类能够对文件…

微服务--配置管理

现在依然还有几个问题需要解决: 网关路由在配置文件中写死了,如果变更必须重启微服务 某些业务配置在配置文件中写死了,每次修改都要重启服务 每个微服务都有很多重复的配置,维护成本高 这些问题都可以通过统一的配置管理器服…

DP的优化途径---单调队列

1.前缀和单调队列&#xff1a;https://www.acwing.com/problem/content/137/ 我们先预处理下前缀和&#xff0c;以下标为i的点为有边界&#xff1a; 也就是求()的min&#xff0c;考虑到j的范围是定值&#xff0c;用单调队列维护即可。 AC代码&#xff1a; #include<bits/…

OpenGL3.3_C++_Windows(32)

demo SSAO SSAO 环境光照(Ambient Lighting)&#xff1a;光的散射&#xff0c;我们通过一个固定的常量作为环境光的模拟&#xff0c;但是这种固定的环境光并不能很好模拟散射&#xff0c;因为环境光不是一成不变的&#xff0c;环境光遮蔽&#xff1a;让&#xff08;褶皱、孔洞…

更新至2023年上市公司ESG数据合集(十份数据:华证年度、华证季度、商道融绿、wind、秩鼎、润灵环球、盟浪、富时罗素、上市银行华证ESG)

更新至2023年上市公司ESG数据合集&#xff08;十份数据&#xff1a;华证年度、华证季度、商道融绿、wind、秩鼎、润灵环球、盟浪、富时罗素、上市银行华证ESG&#xff09; 数据名称&#xff1a; 一、2018-2023年上市公司富时罗素ESG评分数据 二、2018-2023年上市公司Wind ES…

深度学习实战笔记3循环神经网络实现

我们要训练一个基于循环神经网络的字符级语言模型&#xff0c;根据用户提供的文本的前缀生成后续文本。 import math import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l batch_size, num_steps 32, 35 train_iter, voc…

C#插件 调用存储过程(输出参数类型)

存储过程 CREATE PROCEDURE [dbo].[GetSum]num1 INT,num2 INT,result INT OUTPUT AS BEGINselect result num1 num2 END C#代码 using Kingdee.BOS; using Kingdee.BOS.App.Data; using Kingdee.BOS.Core.Bill.PlugIn; using Kingdee.BOS.Util; using System; using System.…

MySQL死锁问题案例

MySQL死锁问题 问题描述&#xff1a;在一张流水生成的记录表中&#xff0c;当没有当前条件的数据时候&#xff0c;并发情况下会导致有线程因为死锁问题生成流水号失败。 场景 有一张生成流水的表&#xff1a; 场景复现&#xff1a; 简单来说&#xff0c;在根据流水类型、年、月…

Python如何快速定位最慢的代码?优雅了~

编写Python代码时&#xff0c;我们常常会遇到性能瓶颈&#xff0c;这不仅影响程序的执行效率&#xff0c;还可能导致用户体验下降。那么&#xff0c;如何快速定位代码中最慢的部分&#xff0c;成为每个开发者必须掌握的技能。 如何快速定位 Python 代码中的性能瓶颈&#xff1…

Url图标实现

Url图标实现 效果如下&#xff1a; 1.引入样式 <link rel"icon" href"favicon.ico"> favicon.ico和对应的html一般需要在同一个目录下&#xff08;同级别&#xff09;。 2.title是用来设置在url页签中显示的名称。 可能存在的问题&#xff1a; …

前端实现文本超出指定行数显示”展开”和”收起”效果

目录 效果演示步骤一&#xff1a;实现整体框架步骤二&#xff1a;实现样式步骤三&#xff1a;js实现元素控制完整代码 效果演示 本文方法是利用js原生进行实现的&#xff0c;可根据相关vue或react语法进行相关的改写&#xff0c;并实现效果 步骤一&#xff1a;实现整体框架 <…

c-periphery RS485串口库文档serial.md(serial.h)(非阻塞读)(VMIN、VTIME)

c-peripheryhttps://github.com/vsergeev/c-periphery 文章目录 NAMESYNOPSISENUMERATIONS关于奇偶校验枚举类型 DESCRIPTIONserial_new()serial_open()关于流控制软件流控制&#xff08;XON/XOFF&#xff09;硬件流控制&#xff08;RTS/CTS&#xff09;选择流控制方法 serial_…

独立3D网络游戏《战域重甲》开发与上架经验分享

“ 小编阿麟&#xff1a;心之所向便是光&#xff0c;我们都是追光者!这位独立游戏开发者的产品能力已经不输给许多小团队&#xff0c;希望他的故事和经验分享&#xff0c;可以给走在同样道路上的朋友一些信心和帮助。 背景介绍 2023年年底的时候&#xff0c;我突然有一个很强的…

硬件工程师笔面试真题汇总

目录 1、电阻 1&#xff09;上拉电阻的作用 2&#xff09;PTC热敏电阻作为电源电路保险丝的工作原理 2、电容 1&#xff09;电容的特性 2) 电容的特性曲线 3) 1uf的电容通常来滤除什么频率的信号 3、电感 4、二极管 1&#xff09;二极管特性 2&#xff09;二极管伏安…

HVV | .NET 攻防工具库,值得您拥有!

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等&#xff08;包括但不限于&#xff09;进行检测或维护参考&#xff0c;未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

《破解验证码:用Requests和Selenium实现模拟登录的终极指南》

两种模拟登录方式(图形验证码) 超级鹰 打码平台&#xff0c;用于识别验证码 requests模拟登录 from chaojiying import Chaojiying_Client import requests from requests import Session from lxml import etree #获取图片信息 def get_pic_info(img_name):chaojiying Ch…

10个append()函数在Python程序开发中的创新应用

文末赠免费精品编程资料~~ 在Python编程的世界里&#xff0c;append()函数是列表操作中最常见的方法之一。它允许我们在列表的末尾添加一个元素&#xff0c;这一简单的功能却能激发无限的创造力。今天&#xff0c;我们将探讨append()函数在Python程序开发中的10种创新应用&…

代码随想录第23天|回溯

39.组合总和 题目链接/文章讲解&#xff1a; 代码随想录 视频讲解&#xff1a;带你学透回溯算法-组合总和&#xff08;对应「leetcode」力扣题目&#xff1a;39.组合总和&#xff09;| 回溯法精讲&#xff01;_哔哩哔哩_bilibili 第一想法&#xff1a; 组合总和与第22天组合总…

爬虫实战-掌上高考网实战

1.确定需求&#xff1a;爬取什么数据爬取大学名称 2.找到数据源地址数据在哪个链接中https://api.zjzw.cn/web/api/?keyword&page1&province_id&ranktype&request_type1&size20&top_school_id[3703,2461,659,3117,597,1724]&type&uriapidata/…