【深度学习】如何找到最优学习率

经过了大量炼丹的同学都知道,超参数是一个非常玄乎的东西,比如batch size,学习率等,这些东西的设定并没有什么规律和原因,论文中设定的超参数一般都是靠经验决定的。但是超参数往往又特别重要,比如学习率,如果设置了一个太大的学习率,那么loss就爆了,设置的学习率太小,需要等待的时间就特别长,那么我们是否有一个科学的办法来决定我们的初始学习率呢?

在这篇文章中,我会讲一种非常简单却有效的方法来确定合理的初始学习率。

学习率的重要性

目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

深度学习:如何找到最优学习率

这里 α 就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。

深度学习:如何找到最优学习率

学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。

深度学习:如何找到最优学习率
深度学习:如何找到最优学习率

从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。

实现

上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重,model zoo参考于Cadene的repo。目前这个repo刚刚开始,欢迎有兴趣的小伙伴加入我。

下面就是部分代码,近期会把找学习率的代码合并到mxtorch中。这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码点这里。

   
  1. criterion = torch.nn.CrossEntropyLoss()
  2. net = model_zoo.resnet50(pretrained=True)
  3. net.fc = nn.Linear(2048, 120)
  4.  
  5. with torch.cuda.device(0):
  6. net = net.cuda()
  7.  
  8. basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)
  9. optimizer = ScheduledOptim(basic_optim)
  10.  
  11.  
  12. lr_mult = (1 / 1e-5) ** (1 / 100)
  13. lr = []
  14. losses = []
  15. best_loss = 1e9
  16. for data, label in train_data:
  17. with torch.cuda.device(0):
  18. data = Variable(data.cuda())
  19. label = Variable(label.cuda())
  20. # forward
  21. out = net(data)
  22. loss = criterion(out, label)
  23. # backward
  24. optimizer.zero_grad()
  25. loss.backward()
  26. optimizer.step()
  27. lr.append(optimizer.learning_rate)
  28. losses.append(loss.data[0])
  29. optimizer.set_learning_rate(optimizer.learning_rate lr_mult)
  30. if loss.data[0] < best_loss:
  31. best_loss = loss.data[0]
  32. if loss.data[0] > 4 best_loss or optimizer.learning_rate > 1.:
  33. break
  34.  
  35. plt.figure()
  36. plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))
  37. plt.xlabel(‘learning rate’)
  38. plt.ylabel(‘loss’)
  39. plt.plot(np.log(lr), losses)
  40. plt.show()
  41. plt.figure()
  42. plt.xlabel(‘num iterations’)
  43. plt.ylabel(‘learning rate’)
  44. plt.plot(lr)

one more thing

通过上面的例子我们能够有一个非常有效的方法寻找初始学习率,同时在我们的认知中,学习率的策略都是不断地做decay,而上面的论文别出心裁,提出了一种循环变化学习率的思想,能够更快的达到最优解,非常具有启发性,推荐大家去阅读阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/171616.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

建造者模式-C语言实现

UML类图&#xff1a; 代码实现&#xff1a; #include <stdio.h> #include <stdlib.h>// 产品类 typedef struct {char* part1;char* part2;char* part3; } Product;// 抽象建造者类 typedef struct {void (*buildPart1)(void*, const char*);void (*buildPart2)(v…

论文种使用的数据集怎么获取

1.根据论文中的描述&#xff0c;可能会提及数据集已上传至某个网站&#xff01; 最常见的 1.GitHub 2.paper with code 3.期刊官网找到这篇论文&#xff0c;看是否存在补充材料&#xff01; 4.论文中提到&#xff0c;若读者需要&#xff0c;可邮件联系XXX(某位作者或任意作者)…

RabbitMQ之延迟消息实战

RabbitMQ之延迟消息实战 使用死信交换机实现延迟消息 使用死信交换机的过期时间以及没有消费者进行消费&#xff0c;时间到了就会到死信队列中&#xff0c;由此可以实现延迟消息使用延迟消息插件 前提&#xff1a;需要mq配置插件 延时信息案例实战 把一个30分钟的延迟消息可以…

前端review

关于实时预览vs code中的颜色代码需要安装的插件&#xff0c;包括html文件格式中的颜色代码安装Flutter Color插件 VSCode 前端常用插件集合 1.Auto Close Tag自动闭合HTML/XML标签 2.Auto Rename Tag自动完成另一侧标签的同步修改 3.Beautify格式化代码&#xff0c;值得注…

【高可用架构】Haproxy 和 Keepalived 的区别

Haproxy 和 Keepalived 的区别 1.负载均衡器介绍2.Haproxy 和 Keepalived 的基本概念和特点2.1 Haproxy2.2 Keepalived 3.Haproxy 和 Keepalived 的区别3.1 功能上的区别3.2 架构上的区别3.3 配置上的区别 4.总结 1.负载均衡器介绍 负载均衡器是一种解决高并发和高可用的常用的…

每日OJ题_算法_双指针⑥剑指 Offer 57. 和为s的两个数字

目录 剑指 Offer 57. 和为s的两个数字 解析代码&#xff1a; 剑指 Offer 57. 和为s的两个数字 LCR 179. 查找总价格为目标值的两个商品 - 力扣&#xff08;LeetCode&#xff09; 难度 简单 购物车内的商品价格按照升序记录于数组 price。请在购物车中找到两个商品的价格总…

蓝桥杯官网练习题(平均)

问题描述 有一个长度为 n 的数组&#xff08; n 是 10 的倍数&#xff09;&#xff0c;每个数 ai 都是区间 [0,9] 中的整数。小明发现数组里每种数出现的次数不太平均&#xff0c;而更改第 i 个数的代价为 bi&#xff0c;他想更改若干个数的值使得这 10 种数出现的次数相等…

【开源】基于Vue和SpringBoot的农家乐订餐系统

项目编号&#xff1a; S 043 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S043&#xff0c;文末获取源码。} 项目编号&#xff1a;S043&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户2.2 管理员 三、系统展示四、核…

C/C++内存管理,malloc,realloc,calloc,new,delete详解!!!

1.初步了解内存中各个区间存储的数据特征 1.栈区&#xff1a;存储一些局部变量、函数参数、返回值等&#xff0c;跟函数栈振有关&#xff0c;出了作用域&#xff0c;生命周期结束。 2.堆区&#xff1a;用于动态开辟空间&#xff0c;如果不主动销毁空间&#xff0c;则程序运行结…

302. 任务安排3,斜率优化dp,一般情况

有 N 个任务排成一个序列在一台机器上等待执行&#xff0c;它们的顺序不得改变。 机器会把这 N 个任务分成若干批&#xff0c;每一批包含连续的若干个任务。 从时刻 0 开始&#xff0c;任务被分批加工&#xff0c;执行第 i 个任务所需的时间是 Ti。 另外&#xff0c;在每批任…

给WordPress 自带的搜索功能添加过滤只搜索文章的标题

如果想让 WordPress 自带的搜索功能只搜索文章标题&#xff0c;让搜索结果更加精确&#xff08;其实WordPress 自带的搜索功能本来模糊查找就很弱&#xff09;&#xff0c;可以将下面的代码添加到当前主题functions.php中&#xff1a; 用过滤器&#xff1a;posts_search 就可以…

因式分解的几何意义

本来准备和女儿一起玩一道几何题&#xff0c;想想还是算了&#xff0c;不如讲点更有趣的。 任何因式分解都是在堆积木&#xff0c;不信你看&#xff1a; 二项式定理&#xff0c;洋灰三角&#xff0c;都是面积&#xff0c;体积&#xff0c;超维体积的拼接&#xff0c;一个大超…

Python 安装django-cors-headers解决跨域问题

一、PythonCorsHeaders概念 PythonCorsHeaders是一个轻量级的Python工具&#xff0c;用于解决跨域HTTP请求的问题。它允许你指定哪些网站或IP地址可以访问你的站点&#xff0c;并控制这些站点可以访问哪些内容。 现代网站越来越多地使用Ajax技术&#xff0c;使得浏览器能够从不…

kafka基本操作以及kafka-topics.sh 使用方式

文章目录 1 kafka的基本操作1.1 创建topic1.2 查看topic1.3 查看topic属性1.4 发送消息1.5 消费消息 2 kafka-topics.sh 使用方式2.1 查看帮助2.2 副本数量规则2.3 创建主题2.4 查看broker上所有的主题2.5 查看指定主题 topic 的详细信息2.6 修改主题信息之增加主题分区数量2.7…

docker操作手册

写在前面的几个重要命令 docker与本地件的文件拷贝 # 查看容器ID docker ps -a# 本地文件拷本到容器 docker cp {local_path} {CONTAINER ID}:{path}# 容器拷本到本地 docker cp {CONTAINER ID}:{path} {local_path} # eg docker cp /Users/helloworld/Downloads/R-3.5.0 0a1…

【人工智能】Chatgpt的训练原理

前言 前不久&#xff0c;在学习C语言的我写了一段三子棋的代码&#xff0c;但是与我对抗的电脑是没有任何思考的&#xff0c;你看了这段代码就理解为什么了&#xff1a; void computerMove(char Board[ROW][COL], int row, int col) {while (1){unsigned int i rand() % ROW, …

设计模式之十二:复合模式

模式通常被一起使用&#xff0c;并被组合在同一个解决方案中。 复合模式在一个解决方案中结合两个或多个模式&#xff0c;以解决一般或重复发生的问题。 首先重新构建鸭子模拟器&#xff1a; package headfirst.designpatterns.combining.ducks;public interface Quackable …

【HarmonyOS】ArkUI状态管理:组件内状态、装饰器、高级用法与最佳实战

文章目录 ArkUI状态管理机制详解1. 概述2. 基本概念2.1 状态变量2.2 数据传递和同步2.3 初始化方法3. 装饰器总览3.1 管理组件拥有的状态3.2 管理应用拥有的状态3.3 其他状态管理功能4. @State装饰器详解4.1 使用规则说明4.2 传递/访问规则说明4.3 观察变化和行为表现5. 使用场…

2.一维数组——输入10个成绩,求平均成绩,将低于平均成绩的分数输出

文章目录 前言一、题目描述 二、题目分析 三、解题 程序运行代码 前言 本系列为一维数组编程题&#xff0c;点滴成长&#xff0c;一起逆袭。 一、题目描述 输入10个成绩&#xff0c;求平均成绩&#xff0c;将低于平均成绩的分数输出 二、题目分析 averagesum/输入个数; 三、…

计算机网络入门

计算机网络 一、计算机网络基础 定义计算机网络计算机网络的发展历程计算机网络的分类&#xff08;局域网、广域网、互联网等&#xff09; 1. 计算机网络的定义&#xff1a; 计算机网络是指通过通信链路将多台计算机连接在一起&#xff0c;以便它们之间能够相互通信和共享资…