计算机视觉教程的量化转移学习

(实验性)计算机视觉教程的量化转移学习

本教程以 Sasank Chilamkurthy 编写的原始 PyTorch 转移学习教程为基础。

转移学习是指利用预训练的模型应用于不同数据集的技术。 使用转移学习的主要方式有两种:

  1. 作为固定特征提取器的 ConvNet :在这里,您“冻结” 网络中所有参数的权重,但最后几层(又称“头部”)的权重通常 连接的图层)。 将这些最后一层替换为使用随机权重初始化的新层,并且仅训练这些层。
  2. 对 ConvNet 进行微调:使用随机训练的网络初始化模型,而不是随机初始化,然后像往常一样进行训练,但使用另一个数据集。 通常,如果输出数量不同,则在网络中也会更换磁头(或磁头的一部分)。 这种方法通常将学习率设置为较小的值。 这样做是因为已经对网络进行了训练,并且只需进行较小的更改即可将其“微调”到新的数据集。

您还可以结合以上两种方法:首先,可以冻结特征提取器,并训练头部。 之后,您可以解冻特征提取器(或其一部分),将学习率设置为较小的值,然后继续进行训练。

在本部分中,您将使用第一种方法-使用量化模型提取特征。

第 0 部分。先决条件

在深入学习迁移学习之前,让我们回顾一下“先决条件”,例如安装和数据加载/可视化。

# Imports
import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import timeplt.ion()

安装每夜构建

因为您将使用 PyTorch 的实验部分,所以建议安装最新版本的torchtorchvision。 您可以在中找到有关本地安装的最新说明。 例如,要在没有 GPU 支持的情况下进行安装:

pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# For CUDA support use https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

载入资料

注意

本部分与原始的迁移学习教程相同。

我们将使用torchvisiontorch.utils.data包加载数据。

您今天要解决的问题是从图像中对蚂蚁蜜蜂进行分类。 该数据集包含约 120 张针对蚂蚁和蜜蜂的训练图像。 每个类别有 75 个验证图像。 可以认为这是一个很小的数据集。 但是,由于我们正在使用迁移学习,因此我们应该能够很好地概括。

此数据集是 imagenet 的很小子集。

Note

从此处下载数据,并将其提取到data目录。

import torch
from torchvision import transforms, datasets# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,shuffle=True, num_workers=8)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化一些图像

让我们可视化一些训练图像,以了解数据扩充。

import torchvisiondef imshow(inp, title=None, ax=None, figsize=(5, 5)):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)if ax is None:fig, ax = plt.subplots(1, figsize=figsize)ax.imshow(inp)ax.set_xticks([])ax.set_yticks([])if title is not None:ax.set_title(title)# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs, nrow=4)fig, ax = plt.subplots(1, figsize=(10, 10))
imshow(out, title=[class_names[x] for x in classes], ax=ax)

模型训练的支持功能

以下是模型训练的通用功能。 此功能也

  • 安排学习率
  • 保存最佳模型
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, device='cpu'):"""Support function for model training.Args:model: Model to be trainedcriterion: Optimization criterion (loss)optimizer: Optimizer to use for trainingscheduler: Instance of ``torch.optim.lr_scheduler``num_epochs: Number of epochsdevice: Device to run the training on. Must be 'cpu' or 'cuda'"""since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model

可视化模型预测的支持功能

通用功能可显示一些图像的预测

def visualize_model(model, rows=3, cols=3):was_training = model.trainingmodel.eval()current_row = current_col = 0fig, ax = plt.subplots(rows, cols, figsize=(cols*2, rows*2))with torch.no_grad():for idx, (imgs, lbls) in enumerate(dataloaders['val']):imgs = imgs.cpu()lbls = lbls.cpu()outputs = model(imgs)_, preds = torch.max(outputs, 1)for jdx in range(imgs.size()[0]):imshow(imgs.data[jdx], ax=ax[current_row, current_col])ax[current_row, current_col].axis('off')ax[current_row, current_col].set_title('predicted: {}'.format(class_names[preds[jdx]]))current_col += 1if current_col >= cols:current_row += 1current_col = 0if current_row >= rows:model.train(mode=was_training)returnmodel.train(mode=was_training)

第 1 部分。训练基于量化特征提取器的自定义分类器

在本部分中,您将使用“冻结”量化特征提取器,并在其顶部训练自定义分类器头。 与浮点模型不同,您不需要为量化模型设置 require_grad = False,因为它没有可训练的参数。 请参阅文档了解更多详细信息。

加载预训练的模型:在本练习中,您将使用 ResNet-18 。

import torchvision.models.quantization as models# You will need the number of filters in the `fc` for future use.
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_fe = models.resnet18(pretrained=True, progress=True, quantize=True)
num_ftrs = model_fe.fc.in_features

此时,您需要修改预训练模型。 该模型在开始和结束时都有量化/去量化块。 但是,由于只使用要素提取器,因此反量化层必须在线性层(头部)之前移动。 最简单的方法是将模型包装在nn.Sequential模块中。

第一步是在 ResNet 模型中隔离特征提取器。 尽管在本示例中,您被责成使用fc以外的所有图层作为特征提取器,但实际上,您可以根据需要选择任意数量的零件。 如果您也想替换一些卷积层,这将很有用。

Note

将特征提取器与量化模型的其余部分分开时,必须手动将量化器/去量化器放置在要保持量化的部分的开头和结尾。

下面的函数创建一个带有自定义头部的模型。

from torch import nndef create_combined_model(model_fe):# Step 1\. Isolate the feature extractor.model_fe_features = nn.Sequential(model_fe.quant,  # Quantize the inputmodel_fe.conv1,model_fe.bn1,model_fe.relu,model_fe.maxpool,model_fe.layer1,model_fe.layer2,model_fe.layer3,model_fe.layer4,model_fe.avgpool,model_fe.dequant,  # Dequantize the output)# Step 2\. Create a new "head"new_head = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(num_ftrs, 2),)# Step 3\. Combine, and don't forget the quant stubs.new_model = nn.Sequential(model_fe_features,nn.Flatten(1),new_head,)return new_model

警告

当前,量化模型只能在 CPU 上运行。 但是,可以将模型的未量化部分发送到 GPU。

import torch.optim as optim
new_model = create_combined_model(model_fe)
new_model = new_model.to('cpu')criterion = nn.CrossEntropyLoss()# Note that we are only training the head.
optimizer_ft = optim.SGD(new_model.parameters(), lr=0.01, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练和评估

此步骤在 CPU 上大约需要 15-25 分钟。 由于量化模型只能在 CPU 上运行,因此您不能在 GPU 上运行训练。

new_model = train_model(new_model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25, device='cpu')visualize_model(new_model)
plt.tight_layout()

第 2 部分。微调可量化模型

在这一部分中,我们将微调用于迁移学习的特征提取器,并对特征提取器进行量化。 请注意,在第 1 部分和第 2 部分中,特征提取器都是量化的。 不同之处在于,在第 1 部分中,我们使用了预训练的量化模型。 在这一部分中,我们将在对感兴趣的数据集进行微调之后创建一个量化的特征提取器,因此这是一种在具有量化优势的同时通过转移学习获得更好的准确性的方法。 请注意,在我们的特定示例中,训练集非常小(120 张图像),因此微调整个模型的好处并不明显。 但是,此处显示的过程将提高使用较大数据集进行传递学习的准确性。

预训练特征提取器必须是可量化的。 为确保其可量化,请执行以下步骤:

  1. 使用torch.quantization.fuse_modules熔断(Conv, BN, ReLU)(Conv, BN)(Conv, ReLU)
  2. 将特征提取器与自定义顶端连接。这需要对特征提取器的输出进行反量化。
  3. 在特征提取器的适当位置插入伪量化模块,以模拟训练期间的量化。

对于步骤(1),我们使用torchvision/models/quantization中的模型,这些模型具有成员方法fuse_model。 此功能将所有convbnrelu模块融合在一起。 对于定制模型,这将需要使用模块列表调用torch.quantization.fuse_modules API 进行手动融合。

步骤(2)由上一节中使用的create_combined_model功能执行。

步骤(3)通过使用torch.quantization.prepare_qat来实现,它会插入伪量化模块。

在步骤(4)中,您可以开始“微调”模型,然后将其转换为完全量化的版本(步骤 5)。

要将微调模型转换为量化模型,可以调用torch.quantization.convert函数(在本例中,仅对特征提取器进行量化)。

Note

由于随机初始化,您的结果可能与本教程中显示的结果不同。

#注意 <cite>quantize = False</cite> model = models.resnet18(pretrained = True,progress = True,quantize = False)num_ftrs = model.fc.in_features

#步骤 1 model.train()model.fuse_model()#步骤 2 model_ft = create_combined_model(model)model_ft [0] .qconfig = torch.quantization.default_qat_qconfig#使用默认 QAT 配置#步骤 3 model_ft = torch.quantization.prepare_qat (model_ft,inplace = True)

优化模型

在当前教程中,整个模型都经过了微调。 通常,这将导致更高的精度。 但是,由于此处使用的训练集很小,最终导致我们过度适应了训练集。

步骤 4.微调模型

for param in model_ft.parameters():param.requires_grad = Truemodel_ft.to(device)  # We can fine-tune on GPU if availablecriterion = nn.CrossEntropyLoss()# Note that we are training everything, so the learning rate is lower
# Notice the smaller learning rate
optimizer_ft = optim.SGD(model_ft.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.1)# Decay LR by a factor of 0.3 every several epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.3)model_ft_tuned = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25, device=device)

步骤 5.转换为量化模型

from torch.quantization import convert
model_ft_tuned.cpu()model_quantized_and_trained = convert(model_ft_tuned, inplace=False)

让我们看看量化模型在几张图像上的表现

visualize_model(model_quantized_and_trained)plt.ioff()
plt.tight_layout()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/55934.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Tkinter系列02/5】界面初步和布局

本文是系列文章第二部分。前文见&#xff1a;【Tkinter系列01/5】界面初步和布局_无水先生的博客-CSDN博客 说明 一般来说&#xff0c;界面开发中&#xff0c;如果不是大型的软件&#xff0c;就不必用QT之类的实现&#xff0c;用Tkinter已经足够&#xff0c;然而即便是Tkinter规…

解决vue中改变浏览器大小时其他页面的echarts不渲染了

1、在页面中加入ref <div class"echart_income" ref"echartsWaring"></div> 2、去掉echarts渲染出来的属性_echarts_instance_ initChart() {//移除图表中_echarts_instance_属性this.$refs.echartsWaring.removeAttribute(_echarts_insta…

spring boot 项目整合 websocket

1.业务背景 负责的项目有一个搜索功能&#xff0c;搜索的范围几乎是全表扫&#xff0c;且数据源类型贼多。目前对搜索的数据量量级未知&#xff0c;但肯定不会太少&#xff0c;不仅需要搜索还得点击下载文件。 关于搜索这块类型 众多&#xff0c;未了避免有个别极大数据源影响整…

linux操作系统的权限的深入学习(未完)

1.Linux权限的概念 Linux下有两种用户&#xff1a;超级用户&#xff08;root&#xff09;、普通用户。 超级用户&#xff1a;可以再linux系统下做任何事情&#xff0c;不受限制 普通用户&#xff1a;在linux下做有限的事情。 超级用户的命令提示符是“#”&#xff0c;普通用户…

Spring Authorization Server入门 (十六) Spring Cloud Gateway对接认证服务

前言 之前虽然单独讲过Security Client和Resource Server的对接&#xff0c;但是都是基于Spring webmvc的&#xff0c;Gateway这种非阻塞式的网关是基于webflux的&#xff0c;对于集成Security相关内容略有不同&#xff0c;且涉及到代理其它微服务&#xff0c;所以会稍微比较麻…

Prometheus 监控系统

常用的监控系统有哪些&#xff1f; 老牌传统 Zabbix Nagios Cacti 新一代的 Prometheus 夜莺 Zabbix 和 Prometheus 的区别&#xff1f;如何选择&#xff1f;【重中之重】 Zabbix 更适用于传统业务架构的物理机、虚拟机环境的监控&#xff0c;对容器环境的支持较差&#xf…

战略形成是权力妥协的过程,江湖,政治是常态

战略权力派&#xff1a;战略形成是各种权力妥协的过程【安志强趣讲270期】 趣讲大白话&#xff1a;有人的地方就有政治 **************************** 有人的地方就有江湖 有组织的地方就有政治 公司的战略是各种人的权力博弈的产物 围观权力&#xff1a;就是组织内部 宏观权力…

MyCAT命令行监控

9066端口 &#xff0c;用mysql命令行连接 Mysql –utest –ptest –P9066 show help 可显示所有相关管理命令 显示后端物理库连接信息&#xff0c;包括当前连接数&#xff0c;端口 Show backend Show connection 显示当前前端客户端连接情况&#xff0c;已经网络流量信息、…

Tomcat 部署时 war 和 war exploded区别

在 Tomcat 调试部署的时候&#xff0c;我们通常会看到有下面 2 个选项。 是选择war还是war exploded 这里首先看一下他们两个的区别&#xff1a; war 模式&#xff1a;将WEB工程以包的形式上传到服务器 &#xff1b;war exploded 模式&#xff1a;将WEB工程以当前文件夹的位置…

【Go 基础篇】Go语言数组遍历:探索多种遍历数组的方式

数组作为一种基本的数据结构&#xff0c;在Go语言中扮演着重要角色。而数组的遍历是使用数组的基础&#xff0c;它涉及到如何按顺序访问数组中的每个元素。在本文中&#xff0c;我们将深入探讨Go语言中多种数组遍历的方式&#xff0c;为你展示如何高效地处理数组数据。 前言 …

Excel筛选后复制粘贴不连续问题的解决

一直以来都没好好正视这个问题认真寻求解决办法 终于还是被需求逼出来了&#xff0c;懒人拯救世界[doge] 一共找到两个方法&#xff0c;个人比较喜欢第二种&#xff0c;用起来很方便 Way1&#xff1a;CtrlG定位可见单元格后使用vlookup解决&#xff08;感觉不定位直接公式向下…

C语言日常刷题 4

文章目录 题目答案与解析123456 题目 1、设变量已正确定义&#xff0c;以下不能统计出一行中输入字符个数&#xff08;不包含回车符&#xff09;的程序段是&#xff08; &#xff09; A: n0;while(chgetchar()!‘\n’)n; B: n0;while(getchar()!‘\n’)n; C: for(n0;getchar()…

golang http transport源码分析

golang http transport源码分析 前言 Golang http库在日常开发中使用会很多。这里通过一个demo例子出发&#xff0c;从源码角度梳理golang http库底层的数据结构以及大致的调用流程 例子 package mainimport ("fmt""net/http""net/url""…

2023年高教社杯 国赛数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米&#xff0c;宽为12米&…

YOLO目标检测——肺炎分类数据集下载分享

肺炎分类数据集总共21000图片&#xff0c;可应用于&#xff1a;肺炎检测、疾病诊断、疾病预测和预警等等。 数据集点击下载&#xff1a;YOLO肺炎分类数据集21000图片.rar

如何深入理解 Node.js 中的流(Streams)

Node.js是一个强大的允许开发人员构建可扩展和高效的应用程序。Node.js的一个关键特性是其内置对流的支持。流是Node.js中的一个基本概念&#xff0c;它能够实现高效的数据处理&#xff0c;特别是在处理大量信息或实时处理数据时。 在本文中&#xff0c;我们将探讨Node.js中的流…

腾讯云服务器地域和可用区详细介绍_选择攻略

腾讯云服务器地域有什么区别&#xff1f;怎么选择比较好&#xff1f;地域选择就近原则&#xff0c;距离地域越近网络延迟越低&#xff0c;速度越快。关于地域的选择还有很多因素&#xff0c;地域节点选择还要考虑到网络延迟速度方面、内网连接、是否需要备案、不同地域价格因素…

微服务dubbo

微服务是一种软件开发架构风格&#xff0c;它将一个应用程序拆分成一组小型、独立的服务&#xff0c;每个服务都可以独立部署、管理和扩展。每个服务都可以通过轻量级的通信机制&#xff08;通常是 HTTP/REST 或消息队列&#xff09;相互通信。微服务架构追求高内聚、低耦合&am…

什么是Git?解释Git的分布式版本控制系统的优势?

1、什么是Git&#xff1f;解释Git的分布式版本控制系统的优势&#xff1f; Git是一个开源的分布式版本控制系统&#xff0c;用于跟踪和管理代码库的版本历史。它允许用户在本地计算机上跟踪和管理代码库的更改&#xff0c;并与其他人协作开发项目。Git的分布式特性意味着它不需…

Cookie for Mac:隐私保护工具保护您的在线隐私

随着互联网的发展&#xff0c;我们每天都会浏览各种网站&#xff0c;享受在线购物、社交娱乐和学习资料等各种便利。然而&#xff0c;您是否曾经遇到过需要频繁输入用户名和密码的情况&#xff1f;或者不方便访问您常用的网站&#xff1f;如果是这样&#xff0c;那么Cookie for…