Hiera实战:使用Hiera实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集
  • _pickle.PicklingError: Can't pickle <function Head.<lambda> at 0x000001DE8DD7F240>: attribute lookup Head.<lambda> on hiera.hiera failed

摘要

现代层次视觉变换器在追求监督分类表现时增加了几个特定于视觉的组件。 这些组件虽然带来了有效的准确性和吸引人的FLOP计数,但增加的复杂性实际上使这些变换器比普通ViT更快。作者认为这种额外的体积是不必要的。 通过使用强大的视觉预训练任务(MAE)进行预训练,可以从最先进的多阶段视觉变换器中去除所有花里胡哨的东西,同时不会丢失准确性。 在此过程中,作者创建了Hiera,这是一种极其简单的层次视觉变换器,它比以前的模型更准确,同时在推理和训练过程中都明显更快。 在各种任务上评估了Hiera对于图像和视频识别的表现。 代码和模型可以在https://github.com/facebookresearch/hiera上获得。

这篇文章使用Hiera完成植物分类任务,模型采用hiera_tiny_224向大家展示如何使用Hiera。hiera_tiny_224在这个数据集上实现了70+%的ACC,如下图:

在这里插入图片描述

在这里插入图片描述

准确率很低,下面的文章中,我会讲解如何使用MAE提高ACC。使用MAE可以让ACC提升了85+%。

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现Hiera模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。这个参数控制输入数据和标签的混合程度。如果mixup_alpha接近0,那么输出将更接近原始输入;如果mixup_alpha接近1,那么输出将更混合。例如,如果mixup_alpha=0.5,那么新的输入数据将会是原始输入数据和另一个随机选取的输入数据的线性组合,标签也会相应地进行混合。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):这个参数用于控制CutMix数据增强技术的混合程度。如果cutmix_alpha为0,那么不会进行任何混合;如果cutmix_alpha为1,那么替换部分的大小将在0到1之间随机选择。例如,如果cutmix_alpha=0.5,那么在随机选取的两个输入数据中,会有50%的概率将一个输入数据的一部分替换为另一个输入数据。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。如果prob=0.1,那么有10%的概率会对输入数据和标签进行Mixup处理。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。如果switch_prob=0.5,那么在应用Mixup处理时,有50%的概率会将标签从原始标签切换到另一个标签。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素))。 'batch’模式会对整个批次进行Mixup操作,而’instance’模式会对每个实例进行Mixup操作。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。 标签平滑是一种技术,通过将概率分布向所有类别均匀分布的方式,使得模型的预测更平滑。这有助于防止模型过度优化并增强模型的鲁棒性。例如,如果label_smoothing=0.1,那么在计算损失函数时,会将每个类别的概率分布向所有类别均匀分布,使得模型的预测更加不确定。

num_classes (int): 目标的类数。分类任务中的类别数量。对于回归任务,这个参数应该是None。例如,如果这是一个二分类问题,那么num_classes=2。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在训练过程中,可以使用EMA来对模型参数进行平均,以提高测试指标并增加模型的鲁棒性。

具体来说,EMA通过以下方式实现:假设我们有n个数据,那么EMA的计算公式为:EMA = (1 - λ) * EMA + λ * 数据值,其中λ是权重因子,通常设置为0.9-0.999。

在深度学习中,使用EMA对模型参数进行平均可以帮助提高模型的泛化能力和稳定性。

在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

Hiera_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─hiera
│  ├─__init__.py
│  ├─benchmarking.py
│  ├─hiera.py
│  ├─hiera_mae.py
│  └─hiera_utils.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练RevCol模型
hiera:来源官方代码,对面的代码做了一些适应性修改。这里使用hiera_tiny_224模型举例。

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

_pickle.PicklingError: Can’t pickle <function Head. at 0x000001DE8DD7F240>: attribute lookup Head. on hiera.hiera failed

这个错误,是由于Hiera中使用了lambda表达式,如下图:
在这里插入图片描述
将Head中的代码修改为:

class Head(nn.Module):def __init__(self,dim: int,num_classes: int,dropout_rate: float = 0.0,):super().__init__()self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()self.projection = nn.Linear(dim, num_classes)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.dropout(x)x = self.projection(x)if not self.training:x = F.softmax(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/174284.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue组件的几种通信方式

这里写目录标题 Vue组件的几种通信&#xff08;数据传递&#xff09;方式非父子组件间通信&#xff08;Bus事件总线&#xff09;介绍实例 非父子通信-provide&inject1.作用2.场景3.语法4.注意 父子组件间的通信固定props属性名&#xff08;v-model&#xff09;介绍实例 不固…

vue页面表单提交时如何做校验

我们在做新增的时候&#xff0c;新增对话框是要加必填校验的&#xff0c;否则就可能会加空数据或者会产生sql的报错。那么这个校验是如何加的呢&#xff1f;下面我们来说一下。 文章目录 一、必填校验1.1 给form表单绑定一个:rules校验规则&#xff0c;给每个item加上一个prop…

api自动化测试

API测试已成为日常的测试任务之一&#xff0c;为了提高测试效率&#xff0c;减少重复的手工操作&#xff0c;API自动化测试也逐渐变得愈加重要&#xff0c;本文是自己在API自动化测试方面的一些经验积累和心得、汇总成文&#xff0c;以飨读者 我相信自动化技能已经成为高级测试…

5. 链表

内存空间是所有程序的公共资源&#xff0c;在一个复杂的系统运行环境下&#xff0c;空闲的内存空间可能散落在内存各处。我们知道&#xff0c;存储数组的内存空间必须是连续的&#xff0c;而当数组非常大时&#xff0c;内存可能无法提供如此大的连续空间。此时链表的灵活性优势…

【回眸】Tessy 单元测试软件使用指南(二)常见导入问题答疑篇

目录 TESSY如何导入工程的头文件&#xff08;单个&递归导入&#xff09; 问题1&#xff1a;XXXXXXXX.h: No such file or directory 问题2&#xff1a;warning&#xff1a;null character(s) ignored 问题3&#xff1a;Tessy软件在analyze过程中遇到大量乱码也找不到原因…

接口测试快速入门 以飞致云平台为例

飞致云电商API地址系统来自飞致云项目。接口API地址&#xff1a;https://gz.fit2cloud.com/swagger-ui.html 飞致云电商系统接口文档 V1.0&#xff1a;见 有道云笔记 该网站可以做接口测试练习。快速了解如何测试接口&#xff0c;如何做关联 系统基地址&#xff1a;https://g…

编程在现代社会中的重要性

文章目录 编程的重要性引言编程在现代社会中的重要性常见的编程应用场景结语 编程的重要性 引言 编程在现代社会中的重要性是不言而喻的&#xff0c;它可以让我们创造出各种有用的软件&#xff0c;解决各种复杂的问题&#xff0c;甚至改变世界。 编程在现代社会中的重要性 编…

如何选择到适合自己的IP代理服务商?IPIDEA为您分享

随着互联网的发展&#xff0c;越来越多的企业开始依赖网络进行各种业务&#xff0c;对于代理IP这个工具来说&#xff0c;应该都不陌生&#xff0c;尤其是大数据、跨境行业的企业&#xff0c;为了让出海业务更顺利&#xff0c;也为了保护企业的数据安全和隐私&#xff0c;许多企…

深入探索Maven:优雅构建Java项目的新方式(一)

Maven高级 1&#xff0c;分模块开发1.1 分模块开发设计1.2 分模块开发实现 2&#xff0c;依赖管理2.1 依赖传递与冲突问题2.2 可选依赖和排除依赖方案一:可选依赖方案二:排除依赖 3&#xff0c;聚合和继承3.1 聚合步骤1:创建一个空的maven项目步骤2:将项目的打包方式改为pom步骤…

XmlRPC协议详解(一款不支持原生异步请求的协议)

XmlRPC协议详解 文章目录 XmlRPC协议详解什么是RPC&#xff1f;什么是XmlRPC&#xff1f;XmlRPC详解请求示例响应示例错误响应示例参数的数据类型 结束语 什么是RPC&#xff1f; RPC&#xff08;远程过程调用&#xff09;是一种用于实现分布式系统中不同进程或不同计算机之间通…

freertos任务切换的现场保存、恢复(任务栈空间)深度分析(以RISC-V架构为例)

1、任务控制块在内存中的布局 RISC-V架构采用的减栈&#xff0c;即栈向低地址空间生长&#xff1b;在freertos中采用任务控制块&#xff08;TCB&#xff09;结构来表示一个任务每个任务有自己的任务栈&#xff0c;任务栈是紧挨着TCB的&#xff0c;且TCB在地址高位&#xff0c;任…

关于el-table的二次封装及使用,支持自定义列内容

关于el-table的二次封装及使用 table组件 <template><el-table ref"tableComRef" :data"tableData" border style"width: 100%"><el-table-column v-if"tableHeaderList[0]?.type selection" type"selection&…

人力资源管理后台 === 左树右表

1.角色管理-编辑角色-进入行内编辑 获取数据之后针对每个数据定义标识-使用$set-代码位置(src/views/role/index.vue) // 针对每一行数据添加一个编辑标记this.list.forEach(item > {// item.isEdit false // 添加一个属性 初始值为false// 数据响应式的问题 数据变化 视图…

〔005〕虚幻 UE5 像素流多用户部署

✨ 目录 ▷ 为什么要部署多用户▷ 开启分发服务器▷ 配置启动多个信令服务器▷配置启动客户端▷多用户启动整体流程和预览▷注意事项▷ 为什么要部署多用户 之前的像素流部署,属于单用户,是有很大的弊端的打开多个窗口访问,可以看到当一个用户操作界面的时候,另一个界面也会…

为社会做贡献的EasyDarwin 4.0.1发布了,支持视频点播、文件直播、摄像机直播、直播录像、直播回放、录像MP4合成下载

经过几个月的不懈努力和测试&#xff0c;最新的EasyDarwin 4.0版本总算是发布出来了&#xff0c;功能还是老几样&#xff1a;文件点播、视频直播&#xff08;支持各种视频源&#xff09;、直播录像与回放、录像合成MP4下载&#xff0c;稍稍看一下细节&#xff1a; 文件上传与点…

cmdline

cmdline是一个kv结构,就是uboot参数传给kernel使用的 举例: Kernel command line: user_debug=31 storagemedia=mtd androidboot.storagemedia=mtd androidboot.mode=normal mac=00FA89112233 serial=LONBON12345 earlycon=uart8250,mmio32,0xff570000 console=ttyFIQ0…

PostgreSQL | EXTRACT | 获取时间的年月日字串

EXTRACT EXTRACT 函数是 PostgreSQL 中用于从日期和时间类型中提取特定部分&#xff08;如年、月、日、小时等&#xff09;的函数。 格式 EXTRACT(field FROM source) -- field 参数是要提取的部分&#xff0c;例如 YEAR、MONTH、DAY、HOUR 等。 -- source 参数是包含日期或…

Redis:持久化RDB和AOF

目录 概述RDB持久化流程指定备份文件的名称指定备份文件存放的目录触发RDB备份redis.conf 其他一些配置rdb的备份和恢复优缺点停止RDB AOF持久化流程AOF启动/修复/恢复AOF同步频率设置rewrite压缩原理触发机制重写流程no-appendfsync-on-rewrite 优缺点 如何选择 概述 Redis是…

带submodule的git仓库自动化一键git push、git pull脚本

前言 很久没写博客了&#xff0c;今天难得闲下来写一次。 不知道大家在使用git的时候有没有遇到过这样的问题&#xff1a;发现git submodule特别好用&#xff0c;适合用于满足同时开发和部署的需求&#xff0c;并且结构清晰&#xff0c;方便我们对整个代码层次有一个大概的了…

数据库第十第十一章 恢复和并发简答题

数据库第一章 概论简答题 数据库第二章 关系数据库简答题 数据库第三章 SQL简答题 数据库第四第五章 安全性和完整性简答题 数据库第七章 数据库设计简答题 数据库第九章 查询处理和优化简答题 1.什么是数据库中的事务&#xff1f;它有哪些特性&#xff1f;这些特性的含义是什么…