PyTorch框架学习十三——优化器

PyTorch框架学习十三——优化器

  • 一、优化器
  • 二、Optimizer类
    • 1.基本属性
    • 2.基本方法
  • 三、学习率与动量
    • 1.学习率learning rate
    • 2.动量、冲量Momentum
  • 四、十种常见的优化器(简单罗列)

上次笔记简单介绍了一下损失函数的概念以及18种常用的损失函数,这次笔记介绍优化器的相关知识以及PyTorch中的使用。

一、优化器

PyTorch中的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签。

导数:函数在指定坐标轴上的变化率。
方向导数:指定方向上的变化率。
梯度:一个向量,方向为方向导数取得最大值的方向。

二、Optimizer类

1.基本属性

  1. defaults:优化器超参数,包含优化选项的默认值的dict(当参数组没有指定这些值时使用)。
  2. state:参数的缓存,如momentum的缓存。
  3. param_groups:管理的参数组,形式上是列表,每个元素都是一个字典。

2.基本方法

(1)zero_grad():清空所管理的参数的梯度。因为PyTorch中张量梯度不会自动清零。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
optimizer.step()        # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))

结果如下:

weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])
weight in optimizer:1314236528344
weight in weight:1314236528344weight.grad is tensor([[1., 1.],[1., 1.]])after optimizer.zero_grad(), weight.grad is
tensor([[0., 0.],[0., 0.]])

(2) step():执行一步优化更新。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
optimizer.step()        # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))

结果如下:

weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])

(3) add_param_group():添加参数组。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("optimizer.param_groups is\n{}".format(optimizer.param_groups))w2 = torch.randn((3, 3), requires_grad=True)optimizer.add_param_group({"params": w2, 'lr': 0.0001})print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

结果如下:

optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661, -1.5228],[ 0.3817, -1.0276, -0.5631],[-0.8923, -0.0583, -0.1955]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

(4)state_dict():获取优化器当前状态信息字典。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()print("state_dict before step:\n", opt_state_dict)for i in range(10):optimizer.step()print("state_dict after step:\n", optimizer.state_dict())torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))

结果如下:

state_dict before step:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]}
state_dict after step:{'state': {2872948098296: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]}

获取到了优化器当前状态的信息字典,其中那个2872948098296是存放权重的地址,并将这些参数信息保存为一个pkl文件:
在这里插入图片描述

(5)load_state_dict():加载状态信息字典。

optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())

从刚刚保存参数的pkl文件中读取参数赋给一个新的空的优化器,结果为:

state_dict before load state:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]}
state_dict after load state:{'state': {1838346925624: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]}

注:state_dict()与load_state_dict()一般经常用于模型训练中的保存和读取模型参数,防止断电等突发情况导致模型训练强行中断而前功尽弃。

三、学习率与动量

1.学习率learning rate

梯度下降:在这里插入图片描述
其中LR就是学习率,作用是控制更新的步伐,如果太大可能导致模型无法收敛或者是梯度爆炸,如果太小可能使得训练时间过长,需要调节。

2.动量、冲量Momentum

结合当前梯度与上一次更新信息,用于当前更新。
PyTorch中梯度下降的更新公式为:
在这里插入图片描述
其中:
在这里插入图片描述

  • Wi:第i次更新的参数。
  • lr:学习率。
  • Vi:更新量。
  • m:momentum系数。
  • g(Wi):Wi的梯度。

举个例子:

100这个时刻的更新量不仅与当前梯度有关,还与之前的梯度有关,只是越以前的对当前时刻的影响就越小。

momentum的作用主要是可以加速收敛。

四、十种常见的优化器(简单罗列)

目前对优化器的了解还不多,以后会继续跟进,这里就简单罗列一下:

  1. optim.SGD:随机梯度下降法
  2. optim.Adagrad:自适应学习率梯度下降法
  3. optim.RMSprop:Adagrad的改进
  4. optim.Adadelta:Adagrad的改进
  5. optim.Adam:RMSprop结合Momentum
  6. optim.Adamax:Adam增加学习率上限
  7. optim.SparseAdam:稀疏版的Adam
  8. optim.ASGD:随机平均梯度下降法
  9. optim.Rprop:弹性反向传播
  10. optim.LBFGS :BFGS的改进

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

最全芯片产业报告出炉,计算、存储、模拟IC一文扫尽

来源:智东西最近几年, 半导体产业风起云涌。 一方面, 中国半导体异军突起, 另一方面, 全球产业面临超级周期,加上人工智能等新兴应用的崛起,中美科技摩擦频发,全球半导体现状如何&am…

python向CSV文件写内容

f open(r"D:\test.csv", w) f.write(1,2,3\n) f.write(4,5,6\n) f.close() 注意:上面例子中的123456这6个数字会分别写入不同的单元格里,即以逗号作为分隔符将字符串内容分开放到不同单元格 上面例子的图: 如果要把变量的值放入…

PyTorch框架学习十四——学习率调整策略

PyTorch框架学习十四——学习率调整策略一、_LRScheduler类二、六种常见的学习率调整策略1.StepLR2.MultiStepLR3.ExponentialLR4.CosineAnnealingLR5.ReduceLRonPlateau6.LambdaLR在上次笔记优化器的内容中介绍了学习率的概念,但是在整个训练过程中学习率并不是一直…

JavaScript数组常用方法

转载于:https://www.cnblogs.com/kenan9527/p/4926145.html

蕨叶形生物刷新生命史,动物界至少起源于5.7亿年前

来源 :newsweek.com根据发表于《古生物学》期刊(Palaeontology)的一项研究,动物界可能比科学界所知更加古老。研究人员发现,一种名为“美妙春光虫”(Stromatoveris psygmoglena)的海洋生物在埃迪…

PyTorch框架学习十五——可视化工具TensorBoard

PyTorch框架学习十五——可视化工具TensorBoard一、TensorBoard简介二、TensorBoard安装及测试三、TensorBoard的使用1.add_scalar()2.add_scalars()3.add_histogram()4.add_image()5.add_graph()之前的笔记介绍了模型训练中的数据、模型、损失函数和优化器,下面将介…

CNN、RNN、DNN的内部网络结构有什么区别?

来源:AI量化百科神经网络技术起源于上世纪五、六十年代,当时叫感知机(perceptron),拥有输入层、输出层和一个隐含层。输入的特征向量通过隐含层变换达到输出层,在输出层得到分类结果。早期感知机的推动者是…

L2级自动驾驶量产趋势解读

来源:《国盛计算机组》L2 级自动驾驶离我们比想象的更近。18 年下半年部分 L2 车型已面世,凯迪拉克、吉利、长城、长安、上汽等均已推出了 L2 自动驾驶车辆。国内目前在售2872个车型,L2级功能渗透率平均超过25%,豪华车甚至超过了6…

PyTorch框架学习十六——正则化与Dropout

PyTorch框架学习十六——正则化与Dropout一、泛化误差二、L2正则化与权值衰减三、正则化之Dropout补充:这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout。 一、泛化误差 一般模型的泛化误差可以被分解为三部分:偏差、方差与噪声…

HDU 5510 Bazinga 暴力匹配加剪枝

Bazinga Time Limit: 20 Sec Memory Limit: 256 MB 题目连接 http://acm.hdu.edu.cn/showproblem.php?pid5510 Description Ladies and gentlemen, please sit up straight.Dont tilt your head. Im serious.For n given strings S1,S2,⋯,Sn, labelled from 1 to n, you shou…

PyTorch框架学习十七——Batch Normalization

PyTorch框架学习十七——Batch Normalization一、BN的概念二、Internal Covariate Shift(ICS)三、BN的一个应用案例四、PyTorch中BN的实现1._BatchNorm类2.nn.BatchNorm1d/2d/3d(1)nn.BatchNorm1d(2)nn.Bat…

人工智能影响未来娱乐的31种方式

来源:资本实验室 技术改变生活,而各种新技术每天都在重新定义我们的生活状态。技术改变娱乐,甚至有了互联网时代“娱乐至死”的警语。当人工智能介入我们的生活,特别是娱乐的时候,一切又将大为不同。尽管很多时候我们很…

素数与量子物理的结合能带来解决黎曼猜想的新可能吗?

来源:中国科学院数学与系统科学研究院翻译:墨竹校对:杨璐1972年,物理学家弗里曼戴森(Freeman Dyson)写了一篇名为《错失的机会》(Missed Opportunities)的文章。在该文中&#xff0c…

PyTorch框架学习十八——Layer Normalization、Instance Normalization、Group Normalization

PyTorch框架学习十八——Layer Normalization、Instance Normalization、Group Normalization一、为什么要标准化?二、BN、LN、IN、GN的异同三、Layer Normalization四、Instance Normalization五、Group Normalization上次笔记介绍了Batch Normalization以及它在Py…

重磅!苹果祭出大招:史上最强 Mac 发布,iPad OS 惊艳问世

来源:网络大数据伴随着时间的脚步进入到 6 月份,一年一度的苹果 WWDC 开发者大会又如期到来;这是苹果自举办开发者大会以来的第三十场 WWDC,因而有着不一样的意义。今年的 WWDC 在苹果库比蒂诺总部附近的 San Jose McEnery Convention Center…

PyTorch框架学习十九——模型加载与保存

PyTorch框架学习十九——模型加载与保存一、序列化与反序列化二、PyTorch中的序列化与反序列化1.torch.save2.torch.load三、模型的保存1.方法一:保存整个Module2.方法二:仅保存模型参数四、模型的加载1.加载整个模型2.仅加载模型参数五、断点续训练1.断…

新计算推动信息技术产业新发展?

来源:工信头条未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能,互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括:建立AI智能系统智商评测体系,开展世界人工智能智商评测;开展互联网&#…

保存tensorboard的损失曲线为图片

损失loss一般是标量,损失曲线一般显示在TensorBoard的SCALARS下,如图所示: 如果想将损失曲线保存下来,选中左边“Show data download links”按钮,曲线下面就会有一个下载按钮,但是只能保存为SVG文件&#…

美国服务机器人技术路线图

来源:美国国家科学基金会服务机器人正在以高速的增长速度加速步入我们的日常生活。正是基于广阔的市场前景,美国国家科学基金会颁布了《美国机器人技术路线图》,其中服务机器人是其中的重点一章。服务机器人的主要应用领域服务机器人是一类用…

OpenCV与图像处理学习一——图像基础知识、读入、显示、保存图像、灰度转化、通道分离与合并

OpenCV与图像处理学习一——图像基础知识、读入、显示、保存图像、灰度转化、通道分离与合并一、图像基础知识1.1 数字图像的概念1.2 数字图像的应用1.3 OpenCV介绍二、图像属性2.1 图像格式2.2 图像尺寸2.2.1 读入图像2.2.2 显示图像2.2.3 保存图像2.3 图像分辨率和通道数2.3.…