实践心得:从读论文到复现到为开源贡献代码

摘要: 本文讲述了从在fast.ai库中读论文,到根据论文复制实验并做出改进,并将改进后的开源代码放入fast.ai库中。


介绍


去年我发现MOOC网上有大量的Keras和TensorKow教学视频,之后我从零开始学习及参加一些Kaggle比赛,并在二月底获得了fast.ai国际奖学金。去年秋天,当我在全力学习PyTorch时,我在feed中发现了一条关于新论文的推文:“平均权重会产生更广泛的局部优化和更好的泛化。”具体来说,就是我看到一条如何将其添加到fast.ai库的推文。现在我也参与到这项研究。

在一名软件工程师的职业生涯中,我发现学习一门新技术最好的方法是将它应用到具体的项目。所以我认为这不仅可以练习提高我的PyTorch能力,还能更好的熟悉fast.ai库,也能提高我阅读和理解深度学习论文的能力。

作者发表了使用随机加权平均(SWA)训练VGG16和预激活的Resnet-110模型时获得的改进。对于VGG网络结构,SWA将错误率从6.58%降低到6.28%,相对提高了4.5%,而Resnet模型则更明显,将误差从4.47%减少到3.85%,相对提高了13.9%。

论文

背景

随机加权平均(SWA)方法来自于集成。集成是用于提高机器学习模型性能的流行的技术。例如,ensemble算法获得了Nekix奖,因为Netkix过于复杂不适用于实际生产,而在像Kaggle这样的竞争平台上,集成最终性能表现结果可以远超单个模型。

最简单的方式为,集成可以对不同初始化的模型的若干副本进行训练,并将对副本的预测平均以得到整体的预测。但是这种方法的缺点是必须承担n个不同副本的成本。研究人员提出快照集成(Snapshot Ensembles)方法。改方法是对一个模型进行训练,并将模型收敛到几个局部最优点,保存每个最优点的权重。这样一个单一的训练就可以产生n个不同的模型,将这些预测平均就能预测出整体。

在发表SWA论文之前,作者曾发表过快速几何集成(FGE)方法的论文,改方法改进了快照集成的结果,FGE方法为“局部最优能通过近乎恒定损耗的简单曲线连接起来”也就是说,通过FGE作者能够发现损耗曲面中的曲线具有理想的特性,以及通过这些曲线集成模型。

在SWA论文中,作者提供了SWA接近FGE的证据。然而,SWA比FGE的好处是推理成本较低 。FGE需要产生n个模型的预测结果,而对于SWA而言,最终只需要一个模型,因此推断可以更快。

算法

SWA算法的工作原理相对简单。首先制作你正在训练的模型的副本,以便用于跟踪平均权重。在完成epoch训练后,通过以下公式更新副本的权重:

 

其中n_models是已经包含在平均值中的模型数量,w_swa表示副本的权重,w表示正在训练的模型的权重。这相当于在每个epoch训练时期结束时存储模型的运行平均值。这就是该算法的精髓,但论文还介绍了一些细节,首页作者制定了具体的学习率计划,以确保SGD在开始平均模型时就能够找到出最优点。其次,对网络进行预训练以达到开始时就有一定数量的epochs,而不是一开始就追踪平均值。另外,如果使用周期性学习率,那么需要在每个周期结束时存储平均值,而不是在每个epoch后。

寻找更广泛的最优点

SWA的算法的工作方式,作者提供了证据,证明与SGD相比,它能使模型达到更广泛的局部最优,从而能够提高模型的泛化能力,因为训练损失和测试数据可能不完全一致。因此,对训练数据进行更广泛的优化使得模型对测试数据进行优化。


图三的一部分

由图可得,训练损失(左)和测试错误(右)相似但不完全相同。例如,最右边的X处于训练损失表面的最佳点,但距离最优测试误差有一定距离。正是这些差异能更容易的寻找更广泛的最优点,这更可能成为训练和测试损失的最佳点。

作者提出观点:SWA可以找到更广泛的最优点。并在论文Optima Width章节中通过实验给出了证据,将损失作为给定方向上的Optima距离的函数,来比较SGD和SWA能够发现的最优点宽度。作者对10个不同的方向进行了采样,并测量了用SGD和SWA对CIFAR-10进行训练的Preactivation Resnet的损失,结果如下:

 

图4:“测试误差...作为随机射线上的点函数,起始于CIFAR-100上预激活ResNet-110的SWA(蓝色)和SGD(绿色)解决方案。”

图中数据提供了证据,表明SWA发现的optima比SGD所发现的更广泛,因为它与SWA最优的距离比增加同样数量的测试错误的距离更大。例如,要达到50%的测试误差,你必须从距离SGD的最佳距离为30,而SGD为50。

实验

作者进行了大量的实验来验证SWA方法在不同的数据集和模型架构上的有效性。首先,我将详细描述为了实现该算法做的实验设置,然后讲解一些关键结果。

使用VGG16和预激活的Resnet-110体系结构在CIFAR-10上进行了复制实验。每个体系结构都有一定的预算,以表示仅使用SGD +动量来训练模型收敛所需的时间数。VGG预算为200,而Resnet则为150。然后,为了测试SWA,模型用SGD +动力培训约75%的预算,然后用SWA进行额外的epochs训练,达到原始预算的1、1.25和1.5倍。对每个测试训练了三个模型,并报告平均值和标准偏差。

除了对CIFAR-10的实验外,作者还对CIFAR-100进行了类似的实验。他们还在ImageNet上测试了预训练模型,使用SWA运行了10个epochs,并发现在预训练的ResNet-50、ResNet152和DenseNet-161的精度提高了。最后,作者通过使用固定学习速率的SWA,成功地从scratch中训练了一个宽的ResNet-28-10。

实现

阅读并理解该论文后,我尝试在fast.ai库中找出哪个位置添加代码能够使SWA正常工作。该位置已经找到了,因为fast.ai库提供了添加自定义回调的功能。如果我用每个epoch结束时调用的hook来写回调,那么就能在适当的时间更新权重的运行平均值。这是结束的代码:



回调采用三个参数:model、swa_model和swa_start。前两个是我们正在训练的模型,以及我们将用来存储加权平均的模型副本。swa_start参数是平均开始的时间,因为在论文中,模型总是在开始跟踪平均权重之前,用SGD+动量对一定数量的epochs进行训练。

从这里你可以看到SWA回调如何将算法从文件转换成PyTorch代码。在SWA开始的epoch中,我们将更新参数的运行平均值,并增加平均值中包含的模型数量。

在SWA模型进行推断前,我们还需要用包含代码修复batchnorm的运算平均值。batchnorm层通常在训练期间计算这些运行统计数据,但由于模型的权重是作为其他模型的平均值计算的,所以这些运行统计数据对于SWA模型是错误的,因此需要再次单次传递数据让batchnorm层计算正确的运行统计数据。修复代码如下:

 

测试

测试非常重要,但是在机器学习代码中应用单元测试是很困难的,因为有一些不确定的因素或者测试的状态需要较长时间。为了确保所做工作实际上是有效的,我做了两个测试,一个是“功能”测试,它们是较小的代码块,通常运行在比较简单的模型上,旨在回答:“这个功能是否按照我的想法实现了?”例如,一项功能测试表明,在经过几个阶段的训练后,SWA模型实际上等于所有SGD模型参数的平均值:

 

这些测试通常在30秒内就能运行完成,所以在编写实现代码遇到问题时能快速提醒我。由于fast.ai库的开发速度非常快,这些测试还能在试图解决master分支合并问题时快速识别问题。

第二个测试为“实验”测试。它的目的是回答:“如果我用自己的实现和fast.ai库重新进行论文中的实验,我是否能观察到与论文相同的结果?”每次我实现一个功能就会运行这个测试,以确定SWA是否对库做出有用的贡献。实验测试要比功能测试花费的时间长,但能确保一切都按预期运行。

最后我可以复述论文的结果-随机权重平均确实在CIFAR-10上产生了比一般SGD更高的准确性,并且随着训练时期的增加,这种改善通常会增加。正如下表所示,我所有的结果都比原始论文结果更准确。其中一个因素可能是数据增强的方式——对于CIFAR-10,通过将每个图像填充4个像素并随机裁剪进行增强,并且我发现fast.ai默认使用不同类型的填充(rekection填充)。然而,可以清楚地看到SWA改善超过SGD +momentum的模式。

 

原始论文的结果

 

我的结果

获取测试代码请点击代码

结论

我对这个项目的最终结果非常满意,因为我从最前沿的研究论文中复制了一个实验,并为机器学习开源代码做出了自己的第一个贡献。我想鼓励大家下载fast.ai库,并尝试一下SWA吧!

原文链接

本文为云栖社区原创内容,未经允许不得转载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/521760.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python xml etree word_使用python格式化插入的元素xml.etree模块,包括新行

我正在将一个元素插入到一个大的xml文件中。我希望插入的元素位于顶部(所以我需要使用根.插入方法,并且不能仅附加到文件中)。我也希望元素的格式与文件的其余部分相匹配。在原始XML文件的格式为....然后运行以下代码:^{pr2}$它以以下形式创建输出&#…

(vue基础试炼_08)Vue模板语法

文章目录一、插值表达式二、v-text 中不是字符串而是js表达式三、v-html四、js表达式&#xff0c;可以和字符串拼接五、源码链接一、插值表达式 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Vue模…

FPGA设计中遇到的奇葩问题之“芯片也要看出身”

阿里云资深专家隐达分享了他十余年工作经历中的一段奇葩历程。文章诙谐幽默&#xff0c;用玄幻小说的写法分享技术问题&#xff0c;非常值得大家一读。 &#xff08;一&#xff09;昨夜西风凋碧树。独上高楼&#xff0c;望尽天涯路2000年的时候&#xff0c;做设计基本都是使用X…

php msgid排重,如何应用php数组对百万数据停止排重

如何应用php数组对百万数据停止排重如何应用php数组对百万数据停止排重在往常的工作中&#xff0c;常常接到要对网站的会员停止站内信、手机短信、email停止群发信息的告诉&#xff0c;用户列表普通由别的同事提供&#xff0c;当中难免会有反复&#xff0c;为了避免反复发送&am…

FPGA资源平民化的新晋- F3 技术解析

摘要&#xff1a; FPGA (现场可编程门阵列)由于其硬件并行加速能力和可编程特性&#xff0c;在传统通信领域和IC设计领域大放异彩。一路走来&#xff0c;FPGA并非一个新兴的硬件器件&#xff0c;由于其开发门槛过高&#xff0c;硬件加速算法的发布和部署保护要求非常高&#xf…

Vue计算属性、方法、侦听器

文章目录一、基础计算模板二、计算属性computed三、方法methods四、侦听器watch五、总结六、源码地址一、基础计算模板 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Vue计算属性、方法、侦听器<…

python数据分析简答题_Python数据分析与数据可视化-中国大学mooc-试题题目及答案...

Python数据分析与数据可视化-中国大学mooc-试题题目及答案更多相关问题【简答题】城轨供电系统按功能划分为几部分&#xff1f;各有什么作用&#xff1f;【多选题】影响债券价格的因素有【单选题】关于注射剂的质量要求&#xff0c;叙述错误的是 prefix\"o\" ns\&quo…

漫画:五分钟看懂车联网

戳蓝字“CSDN云计算”关注我们哦&#xff01;福利扫描添加小编微信&#xff0c;备注“姓名公司职位”&#xff0c;加入【云计算学习交流群】&#xff0c;和志同道合的朋友们共同打卡学习&#xff01;推荐阅读&#xff1a;华为 | 泰山之巅 鲲鹏展翅 扶摇直上九万里聊聊我是如何在…

对数据科学家来说最重要的算法和统计模型

摘要&#xff1a; 本文提供了工业中常用的关键算法和统计技术的概要&#xff0c;以及与这些技术相关的短缺资源。作为一个在这个行业已经好几年的数据科学家&#xff0c;在LinkedIn和QuoLa上&#xff0c;我经常接触一些学生或者想转行的人&#xff0c;帮助他们进行机器学习的职…

JAVA ulimit,java-从linux中的jvm中查找硬打开和软打开文件限制(ulimit -n和ulimit -Hn)

我有一个问题,我需要从Java / groovy程序中找出Linux中进程的硬打开和软打开文件限制.当我从终端执行ulimit时,它将为硬打开文件限制和软打开文件限制提供单独的值.$ulimit -n1024$ulimit -Hn4096但是,如果我以常规方式执行它,它将忽略软限制并始终返回硬限制值.groovy> [ba…

计算属性的setter和getter

computed的属性不仅可以写一个get方法&#xff0c;通过其他的值算出一个新值&#xff1b;同时&#xff0c;也可以设置set方法&#xff0c;通过设置一个值&#xff0c;来改变他相关联的值&#xff01;而改变了相关联的值之后&#xff0c;又会引起fullName的重新计算&#xff0c;…

python制作远程桌面控制_Python 远程桌面协议RDPY简介

RDPY 是基于 Twisted Python 实现的微软 RDP 远程桌面协议。RDPY 提供了如下 RDP 和 VNC 支持&#xff1a;RDP Man In The Middle proxy which record sessionRDP HoneypotRDP screenshoterRDP clientVNC clientVNC screenshoterRSS Player目前能够找到的关于RDPY的中文介绍确实…

华为愿出售5G技术渴望对手;苹果将向印度投资10亿美元;华为全联接大会首发计算战略;腾讯自研轻量级物联网操作系统正式开源……...

戳蓝字“CSDN云计算”关注我们哦&#xff01;嗨&#xff0c;大家好&#xff0c;重磅君带来的【云重磅】特别栏目&#xff0c;如期而至&#xff0c;每周五第一时间为大家带来重磅新闻。把握技术风向标&#xff0c;了解行业应用与实践&#xff0c;就交给我重磅君吧&#xff01;重…

数组元素反序

和前面的字符串逆向输出有异曲同工之妙 第一位和最后一位交换位置&#xff0c;然后用比大小循环 那么接下来修改一下这个程序&#xff0c;我们接下来解释一下p的概念 画图解释&#xff1a; 在最前面的 定义的时候&#xff0c;我们将p&#xff08;0&#xff09;定义在了1上&…

如何计算Java对象所占内存的大小

摘要&#xff1a; 本文以如何计算Java对象占用内存大小为切入点&#xff0c;在讨论计算Java对象占用堆内存大小的方法的基础上&#xff0c;详细讨论了Java对象头格式并结合JDK源码对对象头中的协议字段做了介绍&#xff0c;涉及内存模型、锁原理、分代GC、OOP-Klass模型等内容。…

hilbert谱 matlab,怎么在matlab中做信号hilbert边际谱分析

摘要&#xff1a;传统的数字滤波器的设计过程复杂&#xff0c;计算工作量大&#xff0c;滤波特性调整困难&#xff0c;影响了它的应用。本文介绍了一种利用MATLAB信号处理工具箱(Signal Processing Toolbox)快速有效的设计由软件组成的常规数字滤波器的设计方法。给出了使用MAT…

时间序列数据的处理

摘要&#xff1a; 随着云计算和IoT的发展&#xff0c;时间序列数据的数据量急剧膨胀&#xff0c;高效的分析时间序列数据&#xff0c;使之产生业务价值成为一个热门话题。阿里巴巴数据库事业部的HiTSDB团队为您分享时间序列数据的计算分析的一般方法以及优化手段。演讲嘉宾简介…

saas java框架_XMReport-提供web项目Java套打解决方案

简介XMReport是国内首款支持在线编辑&#xff0c;维护的控件式报表产品。XMReport报表产品分为设计器与引擎两个部分&#xff0c;其中报表设计器是完全基于HTML5技术&#xff0c;提供优秀跨平台的支持&#xff0c;用户无需安装客户端或者插件&#xff0c;仅使用浏览器即可进行报…

只有程序员才能读懂的西游记

戳蓝字“CSDN云计算”关注我们哦&#xff01;这其实一个有关计算机网络协议的故事一、我佛造经传极乐话说我佛如来为度化天下苍生&#xff0c;有三藏真经&#xff0c;可劝人为善。就如图中所示&#xff0c;真经所藏之处&#xff0c;在于云端。佛祖所管辖之下&#xff0c;有四个…

Logtail从入门到精通(四):正则表达式Java日志采集实战

摘要&#xff1a; 为简化日志接入门槛&#xff0c;我们提供了极简模式的日志解析方式&#xff08;如[开启日志采集之旅]()中的介绍&#xff09;。为了更好的对日志进行分析&#xff0c;我们还提供了其他解析方式&#xff0c;例如&#xff1a;分隔符模式、完整正则模式、JSON模式…