resnet50结构_无需额外数据、Tricks、架构调整,CMU开源首个将ResNet50精度提升至80%+新方法

本文是CMU的Zhiqiang Shen提出的一种提升标准ResNet50精度的方法,它应该是首个将ResNet50的Top1精度刷到80%+的(无需额外数据,无需其他tricks,无需网络架构调整)。该文对于研究知识蒸馏的同学应该是有不少可参考的价值,尤其是里面提到的几点讨论与结论,值得深思。

本文首发自极市平台,作者 @Happy,转载需获授权。

6766771378d8332af638c9342f7d90fb.png

paper: https://arxiv.org/abs/2009.08453

code: https://github.com/szq0214/MEAL-V2

Abstract

该文提出一种简单而有效的方法,无需任何tricks,它可以将标准ResNet50的Top1精度提升到80%+。该方法是基于作者之前MEAL(通过判别方式进行知识蒸馏集成)改进而来,作者对MEAL进行了以下两点改进:(1) 仅在最后的输出部分使用相似性损失与判别损失;(2) 采用所有老师模型的平均概率作为更强的监督信息进行蒸馏。该文提到一个非常重要的发现:在蒸馏阶段不应当使用one-hot方式的标签编码。这样一种简单的方案可以取得SOTA性能,且并未用到以下几种常见涨点tricks:(1)类似ResNet50-D的架构改进;(2)额外训练数据;(3) AutoAug、RandAug等;(4)cosine学习率机制;(5)mixup/cutmix数据增广策略;(6) 标签平滑。

在ImageNet数据集上,本文所提方法取得了80.67%的Top1精度(single crop@224),以极大的优势超越其他同架构方案。该方法可以视作采用知识蒸馏对ResNet50涨点的一个新的基准,该文可谓首个在不改变网络架构、无需额外训练数据的前提下将ResNet提升到超过80%Top1精度的方法。

Method

提升模型精度的trick一般包含这样几点:(1) 更好的数据增广方法,比如Mixup、Cutmix、AutoAug、RandAug、Fix resolution discrepancy等;(2) 网络架构的调整,比如SENet、ResNeSt之于ResNet;(3)更好的学习率调整机制,比如cosine;(4)额外的训练数据;(5) 知识蒸馏。而本文则聚焦于采用知识蒸馏(teacher-student)的方法提升标准ResNet50的精度。该文所用方法具有这样几点优势(与已有方法的对比见下表):

  • No Architecture Modification;
  • No outsize training data beyond ImageNet;
  • No cosine learning rate
  • No extra data augmentation, like mixup, autoaug;
  • No label Smoothing.

4e6e41f1b7c309b0bce72efc353a4310.png

与此同时,该文还得到这样一个发现:The one-hot/hard label is not neccssary and could not be used in the distillation process,该发现对于知识蒸馏尤为重要。

接下来,我们将从Teachers Ensemble, KL-divergence, Discriminator三个方面进行该文方法的介绍。

091bcb1581b7b02b28ef41c2a8e9c89c.png

Teachers Ensemble

在该文的知识蒸馏框架中,采用老师模型集成的方式提升更精度的预测并用于指导学生模型训练。上图给出了MEALV1与MEALV2的两者的区别与联系,在训练阶段,在每次迭代开始前MEALV1通过老师选择模块选择用于蒸馏的老师模型;而该文则是采用多个老师模型的平均预测概率作为监督信息。那么,这里所提到的Teachers Ensemble可以描述如下:

其中,

分别表示输入、老师模型个数,以及老师模型的预测概率。

KL-divergence

KL散度是知识蒸馏领域最常用的一种损失,它用度量两个概率分布之间的相似性。在该文中,KL散度用于度量学生模型的预测概率与前述老师模型的平均预测概率之间的相似性。KL散度损失函数可以描述如下:

当然,各位同学不用花费精力去研究上述公式,目前各大深度学习框架中均有该损失函数的实现,直接调用就好。除了KL散度损失外,另一个常用的损失函数就是交叉熵损失,定义如下:

各位有没有发现,截止到目前上述所提到的信息基本上就是知识蒸馏最基本的一些信息了。除了Teachers Ensemble外,该文的创新点在哪里呢?

Discriminator

判别器是一个二分类器,它用于判别输入特征来自老师模型还是来自学生模型。它由sigmoid与二值交叉熵损失构成,定义如下:

作者定义了一个sigmoid函数用于模拟老师-学生的概率,定义如下:

其中

表示一个三层感知器,即三个全连接,
表示logistic函数。该文采用最后未经softmax处理的输出作为该判别器的输入。

考虑到该文采用的是Teachers Ensemble方式,不方便得到中间特征输出;同时为了使整个框架更简洁,作者仅仅采用了相似损失与判别损失用于蒸馏。作者通过实验表明:老师集成模型的的最后一层输出足以蒸馏一个强学生模型。

Experiments

训练数据:ImageNet,即ILSVRC2012训练集,包含1000个类别,120W数据;测试集:ImageNet,包含5W数据。

在训练过程中,作者采用了最基本的数据增广:RandomResizedCrop、RandomHorizontalFlip,在测试阶段采用了CenterCrop。8GPU用训练,batch=512,优化器为SGD,未采用weight decay,StepLR,初始学习率为0.01,合计训练180epoch,在100epoch时学习率x0.1。

当学生模型的输入为

时,老师模型为senet154,resnet152_vl;当学生模型的输入为
时,老师模型为efficientnet_b4, efficientnet_b4_ns。注:预训练模型源自
rwightman大神(https://github.com/rwightman/pytorch-image-models)。

在实验方面,作者分别以ResNet50、MobileNetV3为基准进行了实验对比,那么接下来就分别进行相关结果的介绍咯。

ResNet50

b9e8be1959258eff4e4e6233e764d96a.png

上表给出了所提方法在ResNet50上的性能对比。当输入为

时,该方法取得了80.67%的Top1精度,以2.46%的指标优于MEAL;甚至,所提方法还超越了ResNeSt50-fast的80.64%(需要修改网路结构,同时用到了诸多tricks);当输入增大到
后,所提方法去了81.72%的Top1精度,以2.62%优于FixRes的79.1%(训练224,测试384)。

作者同时还探索了所提方法与其他数据增广的互补性,当引入CutMix后,模型的性能还可以进一步提升达到80.98%@224。尽管该提升并不大,但这意味着ResNet50还有继续提升的空间。

更有意思的是,所提学生模型的精度非常接近两个老师模型的精度(81.22%/95.36%, 81.01%/95.42%)了。

MobileNetV3

50b016993df7cb29b50937e96c27ad3d.png

上表给出了所提方法在MobileNetV3与EfficientNetB0上的性能对比。可以看到:MobileNetV3-Samall-0.75的性能提升了2.20%,MobileNetV3-Small-1.0的性能提升了2.25%, MobileNetV3-Large-1.0的性能提升了1.72%, EfficietnNet-B0的性能提升了1.49%(76.8/93.2源自EfficientNet原文,而77.3、93.5源自rwightman大神)。在轻量型模型上取得这样的性能提升着实令人惊讶,要知道,该文方法不会导致推理的任何调整。

Discussion

接下来就是“填坑”时间了,对前文的几个“坑”来进行简单的分析与讨论。

  • Why is the hard/one-hot label not necessary in knowledge distillation?

One-hot标签是人工标注的,存在不正确或标注信息不全。ImageNet数据中有不少图像包含不止一个目标,但仅赋予了one-hot标签,难以很好的表示图像的内容信息。而更精度的老师模型足以提供高质量的内容信息并更好的引导老师模型的优化方向。

  • How does the discriminator help the optimization?

判别器用于防止学生模型在训练数据上过拟合,同时可以起到正则作用。

  • How about the generalization ability of our method on large students?

作者同时还尝试了一些大模型(比如ResNeXt-101 32x8d)同时作为老师和学生模型,这意味着老师模型与学生模型具有相近的容量,正如所期望的,提升不如小模型,但仍可以看到一些提升。一般而言,源自老师模型的软监督信息要比人工标准信息更优化。总而言之一句话:更强的老师模型可以蒸馏出更强的学生模型。

  • Is there still room to improve the performance of vanilla ResNet50?

答案是肯定的。替换更多、更强的老师模型还可以进一步提升学生模型的精度,同时引入其他tricks可能同样有益(作者没有去尝试哦,资源约束,深表同感,哈哈)。作者提到:当前的老师-学生模型选择是从训练效率、计算资源等方面均衡的选择,该文的目的是验证方法的有效性,而非更高精度(看到这里,无言以对)。

全文到此结束,对该文感兴趣的同学建议去查看一下原文的分析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/571647.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linq to sql 行转列_SQL 难题:行转列

问题:有一张学生成绩表sc(sid 学号,cid 课程,score 成绩),请查询出每个学生的英语、数学的成绩(行转列,一个学生只有一行记录)。建表语句:create 实现方式1—…

c++测试cpu_测评丨NXP系列 LS1028 LS1046等产品网络性能测试

号外号外!继OK1012A-C面市以来,飞凌嵌入式公司相继推出了OK1043A-C、OK1046A-C,以及最新上市的OK1028A-C,OK10XX系列产品也是一个大家族了。正所谓春兰秋菊,各擅胜场。下面小编就各产品的网络性能为您简单介绍一下。先…

四.Windows I/O模型之重叠IO(overlapped)模型

1.适用于除Windows CE之外的各种Windows平台.在使用这个模型之前应该确保该系统安装了Winsock2.重叠模型的基本设计原理是使用一个重叠的数据结构,一次投递一个或多个Winsock I/O请求。在重叠模型中,收发数据使用WSA开头的函数。2.WSA_FLAG_OVERLAPPED标…

vscode怎样导入数据_【Python开发】用VSCode+Jupyter notebook 编写 Python

版权声明:小博主水平有限,希望大家多多指导。本文仅代表作者本人观点。1、过去,想要在 VSCode 中运行 Jupyter notebook 需要安装一个 Neuron 扩展,我也装过,感觉很强大、很方便。不过现在,VSCode 中 Pytho…

springboot怎么杀进程_全新Steam在线游戏 Among us太空狼人杀攻略

众多游戏爱好者已加入我们!带你发现好游戏!休闲娱乐小游戏!点击下方↓↓↓↓"开始游戏",赶紧进入吧!!戳“开始游戏”玩百款火爆小游戏!《Among us》游戏好玩吗?《Among us…

kafka 怎么样连接图形化界面_从零开始搭建Kafka+SpringBoot分布式消息系统

前言由于kafka强依赖于zookeeper,所以需先搭建好zookeeper集群。由于zookeeper是由java编写的,需运行在jvm上,所以首先应具备java环境。(ps:默认您的centos系统可联网,本教程就不教配置ip什么的了)(ps2:没有…

《Iterative-GAN》的算法伪代码整理

花了一下午时间整理本人的论文Iterative-GAN的算法伪代码,由于篇幅较长,投会议方面的文章就不加入了,以后如果投期刊再说。留此存档。 转载于:https://www.cnblogs.com/punkcure/p/7821031.html

h5能调取摄像头吗_高质感的国产中型车,实力能比肩本田雅阁吗?带你看红旗H5...

中国品牌的豪华中型车,带你看红旗H5伴随着经济的快速发展,大家的钱包现在也是越来越鼓,也开始向往更加美好的生活。曾经很多人买车都是为了满足基本的代步需求,如今也开始在车辆的品质与行驶质感上有了更高要求。而为了迎合市场变…

lstm网络_LSTM(长短期记忆网络)

在上篇文章一文看尽RNN(循环神经网络)中,我们对RNN模型做了总结。由于RNN也有梯度消失的问题,因此很难处理长序列的数据,大牛们对RNN做了改进,得到了RNN的特例LSTM(Long Short-Term Memory),它可以避免常规RNN的梯度消…

ant接口用什么天线_手机听收音机时,为什么必须用耳机作为天线?

名侦探柯基-十万个为什么 第七十六期起因,观看活着韩国丧尸电影时的一幕,刘亚仁想听电台广播,却无奈于所有设备都是无线的,由此疑惑到,只有插入有线的耳机,才能收听广播吗?耳机线就是天线&#…

qt c++ 图片预览_Qt多语言国际化

Qt附加工具介绍Qt Assistant(Qt助手)Qt Linguist(Qt语言家)Qt Designer(Qt设计师)Qt AssistantQt Assistant是可配置且可重新发布的文档阅读器,可以方便地进行定制并与Qt应用程序一起重新发布。Qt Assistan…

Icon+启动图尺寸

1、LaunchImage 启动图 命名格式: 1x -> xxx.png 2x -> xxx2x.png Retina 4 -> xxx2x.png     转载于:https://www.cnblogs.com/z-z-z/p/7828082.html

智商情商哪个重要_《所谓逆商高,就是心态好》:逆商,比情商和智商更重要...

所谓“逆商”,是指人们遇到逆境时的应对能力,即战胜挫折、摆脱困境和超越困难的能力。我们一生会面临各种各样的难题,也许是考试失利,也许是和心爱的人分离,也许是工作上竞争失败……在失意的时候你会做何选择&#xf…

mysql 排名_学会在MySQL中实现Rank高级排名函数,所有取前几名问题全部解决.

MySQL中没有Rank排名函数,当我们需要查询排名时,只能使用MySQL数据库中的基本查询语句来查询普通排名。尽管如此,可不要小瞧基础而简单的查询语句,我们可以利用其来达到Rank函数一样的高级排名效果。在这里我用一个简单例子来实现…

意大利_【解读】去意大利留学,一定要学意大利语吗?意大利语难吗?

喜欢意大利,想去意大利留学,但不想学意大利语可以吗?意大利语太难了,听说有英授专业(本来就要学英语、考雅思所以不担心英语)……问题来了去意大利留学,选择英授专业的话还需要学意大利语吗?我们一点点剖析…

MD5、SHA1、SHA256的简单讲解

简述: 最近在研究系统以及驱动,当下载比较大的文件时总会提供SHA1或者SHA256,下载结束后使用校验工具得到的值与它进行比对来判断下载是否成功。 使用工具校验 certutil -hashfile 文件名 sha1/sha256/md5正文: MD5、SHA1、SHA256这些都被称为 哈希…

java swarm集群_52个Java程序员不可或缺的 Docker 工具

Docker工具分类列表编排和调度持续集成/持续部署(CI / CD)监控记录安全存储/卷管理联网服务发现构建管理编排和调度 1. KubernetesKubernetes是市场上最实用的最受欢迎的容器编排引擎。最初作为一个Google项目开始,成千上万的团队使用它来部署生产中的容器。谷歌声称…

comsol显示电场计算结果_在 COMSOL 中构建磁流体动力学多物理场模型

COMSOL Multiphysics 软件中的模型都是从零开始构建的,软件支持多物理场,因此用户可以按照自己的意愿轻松地组合代表不同物理场现象的模型。有时这可以通过使用软件的内置功能来实现,但有些情况下,用户需要做一些额外的工作。我们…

RGB转LAB色彩空间

https://www.cnblogs.com/hrlnw/p/4126017.html 1.原理 RGB无法直接转换成LAB,需要先转换成XYZ再转换成LAB,即:RGB——XYZ——LAB 因此转换公式分两部分: (1)RGB转XYZ 假设r,g,b为像素三个通道,…

React- jsx的使用可以渲染html标签 或React组件

React 的 JSX 使用大、小写的约定来区分本地组件的类和 HTML 标签。既渲染html标签需要使用小写字母开头的标签名而渲染本地React组件需要使用大写字母开头的标签名 注意: 由于 JSX 就是 JavaScript,一些标识符像 class 和 for 不建议作为 XML 属性名。作为替代&…