FastViT实战:使用FastViT实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
    • 安装 grad-cam
    • 安装mmcv
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集
  • 补充一个知识点:torch.jit
    • 两种保存方式

摘要

论文翻译:https://wanghao.blog.csdn.net/article/details/132407722?spm=1001.2014.3001.5502
或者
https://blog.csdn.net/m0_47867638/article/details/132441806?spm=1001.2014.3001.5502

官方源码:https://github.com/apple/ml-fastvit

FastViT是一种混合ViT架构,它通过引入一种新型的token混合运算符RepMixer来达到最先进的延迟-准确性权衡。RepMixer通过消除网络中的跳过连接来降低内存访问成本。FastViT进一步应用训练时间过度参数化和大核卷积来提高准确性,并根据经验表明这些选择对延迟的影响最小。实验结果表明,FastViT在移动设备上的速度比最近的混合Transformer架构CMT快3.5倍,比EfficientNet快4.9倍,比ConvNeXt快1.9倍。在相似的延迟下,FastViT在ImageNet上的Top-1精度比MobileOne高出4.2%。此外,FastViT模型能够较好的适应域外和破损数据,相较于其它SOTA架构具备很强的鲁棒性和泛化性能。

在这里插入图片描述

这篇文章使用FastViT完成植物分类任务,模型采用fastvit_t8向大家展示如何使用FastViT。fastvit_t8在这个数据集上实现了95+%的ACC,如下图:

在这里插入图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现FastViT模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装 grad-cam

pip install grad-cam

安装mmcv

pip install -U openmim
mim install mmcv

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

FastViT_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  ├─modules
│  │  ├─mobileone.py
│  │  └─replknet.py
│  └─fastvit.py
├─mean_std.py
├─export_model.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。
export_model.py:导出重参数模型
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练InceptionNext模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

补充一个知识点:torch.jit

FastViT用到了Torch.jit保存模型。所以,我把这个知识点做个说明,方便大家理解。模型训练好后自然想要将里面所有层涉及的权重保存下来,这样子我们的模型就能部署在任意有pytorch环境下了。但是,用Torch.save/load还会依赖模型文件。

torch.jit是PyTorch的模型压缩和序列化工具,它可以将训练好的神经网络模型转换成TorchScript格式的脚本,以便在不需要Python解释器的情况下进行部署和运行。不再依赖模型文件。

torch.jit可以将训练好的神经网络模型转换成TorchScript格式的脚本,这样可以大大减少模型的内存占用,提高模型的运行速度,同时也可以避免Python环境的不稳定性对模型运行的影响。

两种保存方式

torch.jit.trace:这种方式为追踪一个函数的执行流,使用时需要提供一个测试输入。详见:https://pytorch.org/docs/1.6.0/generated/torch.jit.trace.html?highlight=jit%20trace#torch.jit.trace

需要注意的是这个接口只追踪测试输入走过的函数执行流(如果模型中有多条分支的话只会保存测试输入走过的分支!!!!!),所以对于一些多分支的模型不要采用这种方式,采用下面的Torch.jit.script。比如model.eval()和model.train()可以控制模型内BN层和dropout的权重是否固定,如果采用这种方式只能保留其中之一状态(固定或不固定)。

torch.jit.script:使用这种方式可以将一个模型完整的保存下来,和上面的trace正好相对。如果模型中的分支很多,并且在运行时会改变的话一定要用这种形式保存。详见:https://pytorch.org/docs/1.6.0/generated/torch.jit.script.html?highlight=torch%20jit%20script#torch.jit.script

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/75874.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql leetcode打题记录

文章目录 完成度基本语法高级语法连接日期 函数编写函数聚合函数 因为上过的数据库课实在太水了,所以打算先在菜鸟教程/CSDN/leetcode先学一下基本语法,然后去做Stanford数据库原理的课程CS145。 小目标:把leetcode上不用钱的mysql的题先做一…

企业密码安全:ADSelfService Plus 提升密码管理的千里之行

在当今数字化时代,企业的密码安全变得至关重要。密码是保护企业敏感信息和数据的第一道防线,而有效的密码管理对于确保网络安全至关重要。ADSelfService Plus是一款强大的密码管理和自助服务解决方案,它在提供密码安全方面走在了前沿。 ADSel…

【大数据之Kafka】十、Kafka消费者工作流程

1 Kafka消费方式 (1)pull(拉)模式:消费者从broker中主动拉取数据。(Kafka中使用) 不足:如果Kafka中没有数据,消费者可能会陷入循环,一直返回空数据。 &#…

Python自动化测试(1)-自动化测试及基本技术手段概述

生产力概述 在如今以google为首的互联网时代,软件的开发和生产模式都已经发生了变化, 在《参与感》一书提到:某位从微软出来的工程师很困惑,微软在google还有facebook这些公司发展的时候,为何为感觉没法有效还击&…

嵌入式基础知识-信息安全与加密

本篇来介绍计算机领域的信息安全以及加密相关基础知识,这些在嵌入式软件开发中也同样会用到。 1 信息安全 1.1 信息安全的基本要素 保密性:确保信息不被泄露给未授权的实体。包括最小授权原则、防暴露、信息加密、物理加密。完整性:保证数…

电容笔值不值得买?开学季比较好用的电容笔

眼看着新学期即将到来,到底应该选择什么样的电容笔?一款原装的苹果Pencil,就卖到了将近一千块,这对于很多人来说,都是一个十分昂贵的价格。事实上,由于平替电容笔的价格非常便宜,只要一二百元就…

【Spring Boot 源码学习】OnClassCondition 详解

Spring Boot 源码学习系列 OnClassCondition 详解 引言往期内容主要内容1. getOutcomes 方法2. 多处理器拆分处理3. StandardOutcomesResolver 内部类4. getMatchOutcome 方法 总结 引言 上篇博文带大家从源码深入了自动配置过滤匹配父类 FilteringSpringBootCondition&#x…

尚硅谷大数据项目《在线教育之离线数仓》笔记007

视频地址:尚硅谷大数据项目《在线教育之离线数仓》_哔哩哔哩_bilibili 目录 第12章 报表数据导出 P112 01、创建数据表 02、修改datax的jar包 03、ads_traffic_stats_by_source.json文件 P113 P114 P115 P116 P117 P118 P119 P120 P121 P122【122_在…

小米13Pro/13Ultra刷面具ROOT后激活LSPosed框架微X模块详细教程

喜欢买小米手机,很多是因为小米手机的开放,支持root权限,而ROOT对普通用户来说更多的是刷入DIY模块功能,今天ROM乐园小编就教大家如何使用面具ROOT,实现大家日常情况下非常依赖的微X模块功能,体验微X模块的…

Redis原理:动态字符串SDS

(课程总结自b站黑马程序员课程) 一、引言 Redis中保存的Key是字符串,value往往是字符串或者字符串的集合。可见字符串是Redis中最常用的一种数据结构。 不过Redis没有直接使用C语言中的字符串,因为C语言字符串存在很多问题&…

DHTMLX Gantt 8.0.5 Crack -甘特图

8.0.5 2023 年 9 月 1 日。错误修复版本 修复 修复通过gantt.getGanttInstance配置启用扩展而触发的错误警告修复启用skip_off_time配置时gantt.exportToExcel()的不正确工作示例查看器的改进 8.0.4 2023 年 7 月 31 日。错误修复版本 修复 修复数据处理器不跟踪资源数据…

微信小程序slot插槽的介绍,以及如何通过uniapp使用动态插槽

微信小程序文档 - slots介绍 由上述文档看俩来&#xff0c;微信小程序官方并没有提及动态插槽内容。 uniapp文档 - slots介绍 uni官方也未提及关于动态插槽的内容 在实际使用中&#xff0c;直接通过 <<slot :name"item.xxx" /> 这种形式会报错&#xff…

23062C++QTday4

仿照string类&#xff0c;完成myString 类 代码&#xff1a; #include <iostream> #include <cstring> using namespace std; class myString {private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度public://无参构造my…

分布式AKF拆分原则

目录 1 前言2 什么是AKF3 如何基于 AKF X 轴扩展系统&#xff1f;4 如何基于 AKF Y 轴扩展系统&#xff1f;5 如何基于 AKF Z 轴扩展系统&#xff1f;6 小结 1 前言 当我们需要分布式系统提供更强的性能时&#xff0c;该怎样扩展系统呢&#xff1f;什么时候该加机器&#xff1…

项目打包docker镜像 | 上传nexus | jenkins一键构建

文章目录 前言准备实操1、打开docker的远程访问2、编写dockerfile文件3、指定nexus环境4、配置jenkins5、使用jenkins构建 总结 前言 Docker部署项目是指使用Docker容器化技术将应用程序及其依赖项打包成一个独立的、可移植的运行环境&#xff0c;并在各种操作系统和平台上进行…

Unreal Engine Loop 流程

引擎LOOP 虚幻引擎的启动是怎么一个过程。 之前在分析热更新和加载流程过程中&#xff0c;做了一个图。记录一下&#xff01;&#xff01; ![在这里插入图片描述](https://img-blog.csdnimg.cn/f11f7762f5dd42f9b4dd9b7455fa7a74.png#pic_center 只是记录&#xff0c;以备后用…

使用LightPicture开源搭建私人图床:详细教程及远程访问配置方法

文章目录 1.前言2. Lightpicture网站搭建2.1. Lightpicture下载和安装2.2. Lightpicture网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 现在的手机越来越先进&#xff0c;功能也越来越多&#xff0c;而手机…

大数据技术之Hadoop:Yarn集群部署(七)

目录 一、部署说明 二、集群规划 三、开始配置 3.1 MapReduce配置文件 3.2 YARN配置文件 3.3 分发配置文件 四、集群启停 4.1 命令介绍 4.2 演示 4.3 查看YARN的WEB UI页面 一、部署说明 Hadoop HDFS分布式文件系统&#xff0c;我们会启动&#xff1a; NameNode进…

LeGo-LOAM 源码解析

文章目录 0、整体框架1、imageProjection —— 点云分割0. main()1. cloudHandler()2. copyPointCloud()3. findStartEndAngle()4. projectPointCloud()5. groundRemoval()6. cloudSegmentation()7. labelComponents()8. publishCloud()9. resetParameters() 2、featureAssocia…

java多线程(超详细)

1 - 线程 1.1 - 进程 进程就是正在运行中的程序&#xff08;进程是驻留在内存中的&#xff09; 是系统执行资源分配和调度的独立单位 每一进程都有属于自己的存储空间和系统资源 注意&#xff1a;进程A和进程B的内存独立不共享。 1.2 - 线程 线程就是进程中的单个顺序控制…