(动手学习深度学习)第13章 计算机视觉---微调

文章目录

    • 微调
      • 总结
    • 微调代码实现

微调

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

在这里插入图片描述
查看数据集中图像的形状

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)

在这里插入图片描述

  1. 数据增强
# 图像增广
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强[torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强[torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize]
)
  1. 定义和初始化模型
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)

在这里插入图片描述

  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512—>1000,改成512—>2)
  • (2)在增加一层分类层(如:512—>1000, 改成512—>1000, 1000—>2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

在这里插入图片描述

print(finetune_net)

在这里插入图片描述

  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)device = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='none')if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(finetune_net, 5e-5, 128, 10)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

直接训练:整个模型都使用相同的学习率,重新训练

scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(scracth_net, 5e-4, param_group=False)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/149433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney绘画提示词Prompt参考学习教程

一、工具 SparkAi: SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软…

SDL2 播放音频数据(PCM)

1.简介 这里以常用的视频原始数据PCM数据为例,展示音频的播放。 SDL播放音频的流程如下: 初始化音频子系统:SDL_Init()。设置音频参数:SDL_AudioSpec。设置回调函数:SDL_AudioCallback。打开音频设备:SD…

ML-Net:通过深度学习彻底改变多标签分类

一、说明 多标签分类是一项具有挑战性的机器学习任务,其中输入可以同时属于多个类。传统的多标签分类方法通常依赖于将问题转化为一系列二元分类任务或使用集成方法。然而,深度学习的出现开创了多标签分类的新时代,ML-Net 等模型突破了该领域…

02-1解析xpath

我是在edge浏览器中安装的xpath,需要安装的朋友可以参考下面这篇博客最新版edge浏览器中安装xpath插件 一、xpathd的使用 安装lxml pip install lxml ‐i https://pypi.douban.com/simple导入lxml.etree from lxml import etreeetree.parse() 解析本地文件 htm…

vscode的git 工具使用

vscode的git 工具使用 目录概述需求: 设计思路实现思路分析1.git 工具的使用2.提交代码3.查看历史提交代码 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a be…

Servlet API 详解

文章目录 前言什么是ServletHttpServletHttpServletRequest1. form表单构造POST请求2. JSON形式表示body部分 HttpServletResponse 前言 前面为大家介绍了如何使用 servlet 写一个简单的网站,前面只是大概了解了如何使用简单的 servlet,而平时网站的逻辑…

Hive Lateral View explode列为空时导致数据异常丢失

一、问题描述 日常工作中我们经常会遇到一些非结构化数据,因此常常会将Lateral View 结合explode使用,达到将非结构化数据转化成结构化数据的目的,但是该方法对应explode的内容是有非null限制的,否则就有可能造成数据缺失。 SE…

typora整理markdown笔记

效果 符号 快捷键 斜体 * * ctrlB(代表同时按) 加粗 ** ** ctrlI 竖线 > 超链接 清除样式 ctrl\ 图片 ![图片描述][图片绝对路径/相对路径] 如何在Typora中插入图像? ➊ 使用Markdown语法 (不推荐,太慢) ➋ 直接拷贝图…

Theory behind GAN

假如要生成一些人脸图,实际上就是想要找到一个分布,从这个分布内sample出来的图片像是人脸,分布之外生成的就不像人脸。而GAN要做的就是找到这个distribution。 在GAN之前用的是Maximum Likelihood Estimation。 Maximum Likelihood Estimat…

同创永益联合红帽打造一站式数字韧性解决方案

随着AI技术的快速兴起,IT技术已成为推动业务持续增长的重要驱动力,这要求企业不断尝试新事物,改变固有流程,加强IT技术与业务的合作,同时提升数字韧性能力,以实现业务目标。10月26日,红帽2023中…

Bert学习笔记(简单入门版)

目 录 一、基础架构 二、输入部分 三、预训练:MLMNSP 3.1 MLM:掩码语言模型 3.1.1 mask模型缺点 3.1.2 mask的概率问题 3.1.3 mask代码实践 3.2 NSP 四、如何微调Bert 五、如何提升BERT下游任务表现 5.1 一般做法 5.2 如何在相同领域数据中进…

学习UI第一天

在工作闲暇之余,自己画的原型图,再次做一次记录,哈哈哈 萌宠领养UI设计原型图 https://modao.cc/proto/lq2KqIVBs48xwylNZlA7OP/sharing?view_moderead_only #萌宠领养-分享 可以点击此链接,进行查看O(∩_∩)O哈哈~

二进制部署k8s集群-过程中的问题总结(接上篇的部署)

1、kube-apiserver部署过程中的问题 kube-apiserver.conf配置文件更改 2、calico的下载地址 curl https://docs.projectcalico.org/v3.20/manifests/calico.yaml -O 这里如果kubernetes的节点服务器为多网卡配置会产生报错 修改calino.yaml配置文件 解决方法: 调…

论文阅读——DiffusionDet

在目标检测上使用扩散模型 前向过程:真实框-->随机框 后向过程:随机框-->真实框 前向过程: 一般一张图片真实框的数目不同,填补到同一的N个框,填补方法可以是重复真实框,填补和图片大小一样的框&a…

Redis 学习

Redis 集群共3种集群模式: 1. 主从模式 (主写从读),写请求分配到主库,完后数据同步到从库,从库主要负责读请求 同步过程: 从库启动向主库发送同步请求,数据传输的形式是RDB文件&am…

C++算法入门练习——树的带权路径长度

现有一棵n个结点的树(结点编号为从0到n-1,根结点为0号结点),每个结点有各自的权值w。 结点的路径长度是指,从根结点到该结点的边数;结点的带权路径长度是指,结点权值乘以结点的路径长度&#x…

C语言-求一个整数储存在内存中的二进制中1的个数

#define _CRT_SECURE_NO_WARNINGS #include<stdio.h>int main() {/*求一个整数储存在内存中的二进制中1的个数*/int number;scanf("%d", &number);int i 0;int count 0;for (i 0; i < 32; i){if (1 ((number >> i) & 1)){count;}}printf(…

CentOS 7搭建Gitlab流程

目录 1、查询docker镜像gitlab-ce 2、拉取镜像 3、查询已下载的镜像 4、新建gitlab文件夹 5、在gitlab文件夹下新建相关文件夹 6、创建运行gitlab的容器 7、查看docker容器 8、根据Linux地址访问gitlab 9、进入docker容器&#xff0c;设置用户名的和密码 10、登录git…

卫生纸标准及鉴别

一、标准分类及含义 &#xff08;1&#xff09;标准分类 ①GB——国家强制标准&#xff08;即最低标准&#xff09; ②GB/T——国家推荐标准 ③QB——轻工行业标准 ④QB/T——轻工行业推荐标准 &#xff08;2&#xff09;含义 ①国家标准是指国家标准化主管机构批准发布的。…

适用于4×4MiMo 4G/5G,支持GNSS和WiFi 6E的车载天线解决方案

德思特Panorama智能天线致力于为用户提供在各类复杂场景中稳定供给5G、WIFI和GNSS信号的卓越性能和支持。随着5G新频段逐渐应用、WIFI 6E频率升级以及多频定位应用的普及&#xff0c;传统的BAT[G]M-7-60[-24-58]系列天线已不再适用于当前多变的环境。 然而&#xff0c;BAT天线的…