神经网络优化(二) - 滑动平均

1 滑动平均概述

滑动平均(也称为 影子值 ):记录了每一个参数一段时间内过往值的平均,增加了模型的泛化性。

滑动平均通常针对所有参数进行优化:W 和 b,

简单地理解,滑动平均像是给参数加了一个影子,参数变化,影子缓慢追随。

滑动平均的表示公式为

影子 = 衰减率 * 影子 + ( 1 - 衰减率 ) * 参数

滑动平均值 = 衰减率 * 滑动平均值 + ( 1 - 衰减率 )* 参数

备注

影子初值 = 参数初值

衰减率 = min{ MOVING_AVERAGE_DECAY, (1+轮数) / (10 + 轮数 ) }

示例:

MOVING_AVERAGE_DECAY 为 0.99, 参数 w1 为 0,轮数 global_step 为 0,w1的滑动平均值为 0 。

参数w1更新为 1 时,则

 w1的滑动平均值 = min( 0.99, 1/10 ) * 0 + ( 1 - min( 0.99, 1/10 ) * 1 = 0.9

 假设轮数 global_step 为 100 时,参数 w1 更新为 10 时,则

w1滑动平均值 = min(0.99, 101/110) * 0.9 + ( 1 - min( 0.99, 101/110) * 10 = 1.644

再次运行

w1滑动平均值 = min(0.99, 101/110) * 1.644 + ( 1 - min( 0.99, 101/110) * 10 = 2.328

再次运行

w1滑动平均值 = 2.956

 

2 滑动平均在Tensorflow中的表示方式

第一步 实例化滑动平均类ema

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY(滑动平均衰减率),global_step(轮数计数器,表示当前轮数)
)

备注:

MOVING_AVERAGE_DECAY 滑动平均衰减率是超参数,一般设定的值比较大;

global_step - 轮数计数器,表示当前轮数,这个参数与其他计数器公用。

第二步 求算滑动平均节点ema_op

ema_op = ema.apply([])

ema.apply([ ]) 函数表示对 [ ] 中的所有数值求滑动平均。

示例:

ema_op = ema.apply(tf.trainable_variables())

每当运行此代码时,会对所以待优化参数进行求滑动平均运算。

第三步 具体实现方式

在工程应用中,我们通常会将计算滑动平均 ema_op 和训练过程 train_step 绑定在一起运行,使其合成一个训练节点,实现的代码如下

with tf.control_dependencies([ train_step, ema_op ]):train_op = tf.no_op(name = 'train')

 

另外:

查看某参数的滑动平均值

函数ema.average(参数名) --->  返回 ’ 参数名 ’ 的滑动平均值,

3 示例代码

# 待优化参数w1,不断更新w1参数,求w1的滑动平均(影子)import tensorflow as tf# 1. 定义变量及滑动平均类# 定义一个32位浮点变量并赋初值为0.0,
w1 = tf.Variable(0, dtype=tf.float32)# 轮数计数器,表示NN的迭代轮数,赋初始值为0,同时不可被优化(不参数训练)
global_step = tf.Variable(0, trainable=False)# 设定衰减率为0.99
MOVING_AVERAGE_DECAY = 0.99# 实例化滑动平均类
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)# ema.apply()函数中的参数为待优化更新列表
# 每运行sess.run(ema_op)时,会对函数中的参数求算滑动平均值
# tf.trainable_variables()函数会自动将所有待训练的参数汇总为待列表
# 因该段代码中仅有w1一个参数,ema_op = ema.apply([w1])与下段代码等价
ema_op = ema.apply(tf.trainable_variables())# 2. 查看不同迭代中变量取值的变化。
with tf.Session() as sess:# 初始化init_op = tf.global_variables_initializer()sess.run(init_op)# 用ema.average(w1)获取w1滑动平均值 (要运行多个节点,作为列表中的元素列出,写在sess.run中)# 打印出当前参数w1和w1滑动平均值print("current global_step:", sess.run(global_step))print("current w1", sess.run([w1, ema.average(w1)]))# 参数w1的值赋为1sess.run(tf.assign(w1, 1))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1", sess.run([w1, ema.average(w1)]))# 更新global_step和w1的值,模拟出轮数为100时,参数w1变为10, 以下代码global_step保持为100,每次执行滑动平均操作,影子值会更新 sess.run(tf.assign(global_step, 100))sess.run(tf.assign(w1, 10))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))# 每次sess.run会更新一次w1的滑动平均值
    sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:" , sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:" , sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))

运行

current global_step: 0
current w1 [0.0, 0.0]
current global_step: 0
current w1 [1.0, 0.9]
current global_step: 100
current w1: [10.0, 1.6445453]
current global_step: 100
current w1: [10.0, 2.3281732]
current global_step: 100
current w1: [10.0, 2.955868]
current global_step: 100
current w1: [10.0, 3.532206]
current global_step: 100
current w1: [10.0, 4.061389]

 

w1 的滑动平均值都向参数 w1 靠近。可见,滑动平均追随参数的变化而变化。

转载于:https://www.cnblogs.com/gengyi/p/9901502.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/450826.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker完全自学手册

阿里云大学免费课程:Docker完全自学手册课程介绍:Docker 是 PaaS 提供商 dotCloud 开源的一个基于 LXC 的高级容器引擎,源代码托管在 Github 上, 基于go语言并遵从Apache2.0协议开源。Docker 是一个开源的应用容器引擎,让开发者可…

Spring 之注解事务 @Transactional

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到教程。 先让我们看代码吧! 以下代码为在“Spring3事务管理——基于tx/aop命名空间的配置”基础上修改。首先修改applicationContext…

超级程序员神话

摘要:大部分的程序员在思想里都会某种程度的承认,承认自己只是一个普通的程序员,但这世界上确实有一些超级程序员,在一个为企业开发应用的程序员和一个为谷歌写搜索算法的程序员之间,或和一个开发用来控制读写头从磁盘…

HashMap30连问,彻底搞懂HashMap

文章目录一、背景知识1、什么是Map?2、什么是Hash?3、什么是哈希表?4、什么是HashMap?5、如何使用HashMap?6、HashMap有哪些核心参数?7、HashMap与HashTable的对比?8、HashMap和HashSet的区别?…

博弈论的算法总结

开头先啰嗦一句:想学好博弈,必然要花费很多的时间,深入学习,不要存在一知半解,应该是一看到题目,就想到博弈的类型。 以及,想不断重复不断重复,做大量各大oj网站的题目,最…

Slog55_lua面向对象之lua类

Slog55_lua面向对象之lua类 ArthurSlog SLog-55 Year1 GuangzhouChina Aug 30th 2018 微信扫描二维码,关注我的公众号GitHub 掘金主页 简书主页 segmentfault 现实中的事情不是根据人的喜好而定的 比如长在你嘴里的智齿 大部分情况下 你会因为自己&#xff0…

Spring中的@scope注解

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到教程。 Scope 简单点说就是用来指定bean的作用域作用域 (官方解释:scope用来声明IOC容器中的对象应该处的限定场景或者…

编程语言大比拼——谁的效率高

摘要:C、C、Java这几个屹立不倒的开发语言,如果以功能点作为单位的话,谁的效率最高呢?如果在项目初期就能确定功能点数量,那么就可以很好的预测项目完成时间。这一点是不是对你很有帮助呢? 一份6000个项目的…

Hadoop之Flume详解

1、日志采集框架Flume   1.1 Flume介绍     Flume是一个分布式、可靠、和高可用的海量日志采集、聚合和传输的系统。     Flume可以采集文件,socket数据包等各种形式源数据,又可以将采集到的数据输出到HDFS、hbase、hive、     kafka等众多…

搞懂Java的反射机制

搞懂Java的反射机制 1.什么是反射? java的反射机制是指可以在运行状态下获取类和对象的所有属性和方法。 2.反射的作用? 1、在运行时获取一个类/对象的成员变量和方法 2、在运行时创建一个类的对象 3、在运行时判断一个对象是否属于一个类 3.反射有哪些…

表单oninput和onchange事件区别

oninput事件是元素value发生变化是立刻触发,而onchange是元素发生变化并且失去焦点时才会触发。 转载于:https://www.cnblogs.com/ykli/p/9565601.html

Struts2中<s:iterator>基本用法及示例

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到教程。 Struts2中<s:iterator>基本用法及示例 Iterator用于遍历集合&#xff08;java.util.Collection&#xff09;或枚举值&#xff08;j…

如何使用postman做接口测试

1、get请求传参 只要是get请求都可以在浏览器中直接发&#xff1a; 在访问地址后面拼 ?keyvalue&keyvalue 例如&#xff1a;在浏览器中直接输入访问地址&#xff0c;后面直接拼需要传给服务器的参数http://api.nnzhp.cn/api/user/stu_info?stu_name小黑2、post请求&…

【狂神说】分析前后端分离开源项目?

文章目录1.如何分析开源项目项目简介项目源码2.观察开源项目3.开源项目下载4.跑起来是第一步5.前后端分离项目固定套路6.如何找到一个开源项目1.如何分析开源项目 学习的方式&#xff1a; 不知道这个代码怎么来的这个代码跑不起来这个项目对我们有什么帮助&#xff0c;不会模…

设计公共API的六个注意事项

摘要&#xff1a;俗话说&#xff1a;“好东西就要贡献出来和大家一起分享”&#xff0c;尤其是在互联网业务高度发达的今天&#xff0c;如果你的创业公司提供了一项很酷的技术或者服务&#xff0c;并且其他用户也非常喜欢该产品&#xff0c;在这种情况下&#xff0c;最好的解决…

go 交叉编译

golang中windows交叉编译 env GOOSlinux GOARCHamd64 go build .打包镜像 FROM alpineMAINTAINER "congge"ADD ./casino_niuniu /usr/local/casino_niuniu/bin/casino_niuniu ADD ./templates /usr/loca/lcasino_niuniu/bin/templates ADD ./public /usr/local/casin…

IntelliJ Idea 2017 免费激活方法

见&#xff1a;https://www.cnblogs.com/suiyueqiannian/p/6754091.html 1. 到网站 http://idea.lanyus.com/ 获取注册码。 2.填入下面的license server: http://intellij.mandroid.cn/   http://idea.imsxm.com/   http://idea.iteblog.com/key.php 以上方法验证均可以

P3193 [HNOI2008]GT考试

传送门 容易看出是道DP 考虑一位一位填数字 设 f [ i ] [ j ] 表示填到第 i 位&#xff0c;在不吉利串上匹配到第 j 位时不出现不吉利数字的方案数 设 g [ i ] [ j ] 表示不吉利串匹配到第 i 位&#xff0c;再添加一个数字&#xff0c;使串匹配到第 j 位的方案数 那么方程显然为…

LeetCode刷题攻略

目录 一、LeetCode简介 二、刷leetcode的主要目的 三、常用的数据结构 四、常用的算法思想 五、选择算法题 1、刷题选择 2、刷题方法 方法一&#xff1a;顺序法 方法二&#xff1a;标签法 方法三&#xff1a;随机法 方法四&#xff1a;必杀法 六、刷题攻略 TIP 1&…

SQLserver数据库反编译生成Hibernate实体类和映射文件

一、建立项目和sqlserver数据库 eclipse&#xff0c;我使用的版本是neon3 二、Data Source Explorer 选择OK 在data source Explorer的Database Connections 选择New 填写好General的连接信息 新建New Driver Definition 填写完选择OK 选择刚才的Drivers Test Connetion测试 N…