tensorflow中batch normalization的用法

 转载网址:如果侵权,联系我删除

https://www.cnblogs.com/hrlnw/p/7227447.html

https://www.cnblogs.com/eilearn/p/9780696.html

https://www.cnblogs.com/stingsl/p/6428694.html

神经网络学习过程本质就是为了学习数据分布,一旦训练数据与测试数据的分布不同,那么网络的泛化能力也大大降低;另外一方面,一旦每批训练数据的分布各不相同(batch 梯度下降),那么网络就要在每次迭代都去学习适应不同的分布,这样将会大大降低网络的训练速度,这也正是为什么我们需要对数据都要做一个归一化预处理的原因。

对于深度网络的训练是一个复杂的过程,只要网络的前面几层发生微小的改变,那么后面几层就会被累积放大下去。一旦网络某一层的输入数据的分布发生改变,那么这一层网络就需要去适应学习这个新的数据分布,所以如果训练过程中,训练数据的分布一直在发生变化,那么将会影响网络的训练速度。

我们知道网络一旦train起来,那么参数就要发生更新,除了输入层的数据外(因为输入层数据,我们已经人为的为每个样本归一化),后面网络每一层的输入数据分布是一直在发生变化的,因为在训练的时候,前面层训练参数的更新将导致后面层输入数据分布的变化。以网络第二层为例:网络的第二层输入,是由第一层的参数和input计算得到的,而第一层的参数在整个训练过程中一直在变化,因此必然会引起后面每一层输入数据分布的改变。我们把网络中间层在训练过程中,数据分布的改变称之为:“Internal  Covariate Shift”。Paper所提出的算法,就是要解决在训练过程中,中间层数据分布发生改变的情况,于是就有了Batch  Normalization,这个牛逼算法的诞生。

1.原理

公式如下:

y=γ(x-μ)/σ+β

其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。

一般来讲,这些参数都是基于channel来做的,比如输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。

 

2.tensorflow中使用

tensorflow中batch normalization的实现主要有下面三个:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。

 

3.训练

训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):train_op = optimizer.minimize(loss)

 

4.测试

测试时需要注意一点,输入参数training=False,其他就没了

 

5.预测

预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:

var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也没法获取bn层中的μ和σ(也可能是我用法不对),不过好在所有的参数都在tf.global_variables()中,因此可以这么写:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述写法,即可把μ和σ保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。

注意上面有个不严谨的地方,因为我的网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。

 

2018.4.22更新

提供一个基于mnist的示例,供大家参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)if __name__ == '__main__':mnist = input_data.read_data_sets('mnist', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])image = tf.reshape(x, [-1, 28, 28, 1])conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv1')bn1 = tf.layers.batch_normalization(conv1, training=True, name='bn1')pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1')conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv2')bn2 = tf.layers.batch_normalization(conv2, training=True, name='bn2')pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2')flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer')weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights')biases = tf.get_variable(shape=[10], dtype=tf.float32,initializer=tf.constant_initializer(0.0), name='fc_biases')logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output')cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output))pred_label = tf.argmax(logit_output, 1)label = tf.argmax(y_, 1)accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32))update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)global_step = tf.get_variable('global_step', [], dtype=tf.int32,initializer=tf.constant_initializer(0), trainable=False)learning_rate = tf.train.exponential_decay(learning_rate=0.1, global_step=global_step, decay_steps=5000,decay_rate=0.1, staircase=True)opt = tf.train.AdadeltaOptimizer(learning_rate=learning_rate, name='optimizer')with tf.control_dependencies(update_ops):grads = opt.compute_gradients(cross_entropy)train_op = opt.apply_gradients(grads, global_step=global_step)tf_config = tf.ConfigProto()tf_config.gpu_options.allow_growth = Truetf_config.allow_soft_placement = Truesess = tf.InteractiveSession(config=tf_config)sess.run(tf.global_variables_initializer())# only save trainable and bn variablesvar_list = tf.trainable_variables()if global_step is not None:var_list.append(global_step)g_list = tf.global_variables()bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]var_list += bn_moving_varssaver = tf.train.Saver(var_list=var_list,max_to_keep=5)# save all variables# saver = tf.train.Saver(max_to_keep=5)if tf.train.latest_checkpoint('ckpts') is not None:saver.restore(sess, tf.train.latest_checkpoint('ckpts'))train_loops = 10000for i in range(train_loops):batch_xs, batch_ys = mnist.train.next_batch(32)_, step, loss, acc = sess.run([train_op, global_step, cross_entropy, accuracy],feed_dict={x: batch_xs, y_: batch_ys})if step % 100 == 0:  # print training infolog_str = 'step:%d \t loss:%.6f \t acc:%.6f' % (step, loss, acc)tf.logging.info(log_str)if step % 1000 == 0:  # save current modelsave_path = os.path.join('ckpts', 'mnist-model.ckpt')saver.save(sess, save_path, global_step=step)sess.close()
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)if __name__ == '__main__':mnist = input_data.read_data_sets('mnist', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])image = tf.reshape(x, [-1, 28, 28, 1])conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv1')bn1 = tf.layers.batch_normalization(conv1, training=False, name='bn1')pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1')conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv2')bn2 = tf.layers.batch_normalization(conv2, training=False, name='bn2')pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2')flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer')weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights')biases = tf.get_variable(shape=[10], dtype=tf.float32,initializer=tf.constant_initializer(0.0), name='fc_biases')logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output')cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output))pred_label = tf.argmax(logit_output, 1)label = tf.argmax(y_, 1)accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32))tf_config = tf.ConfigProto()tf_config.gpu_options.allow_growth = Truetf_config.allow_soft_placement = Truesess = tf.InteractiveSession(config=tf_config)saver = tf.train.Saver()if tf.train.latest_checkpoint('ckpts') is not None:saver.restore(sess, tf.train.latest_checkpoint('ckpts'))else:assert 'can not find checkpoint folder path!'loss, acc = sess.run([cross_entropy,accuracy],feed_dict={x: mnist.test.images,y_: mnist.test.labels})log_str = 'loss:%.6f \t acc:%.6f' % (loss, acc)tf.logging.info(log_str)sess.close()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2018全球最值得关注的60家半导体公司,7家中国公司新上榜 | 年度榜单

编译 | 张玺 四月来源:机器之能由《EE Times》每年评选全球值得关注的 60 家新创半导体公司排行榜——『Silicon 60』,今年已经迈向第 19 届,今年的关键词仍然是「机器学习」(machine learning),它正以硬件支持的运算形式强势崛起…

c语言,递归翻转一个单链表,c实现单链表

目的&#xff1a;主要是练习c里面单链表的实现&#xff0c;递归思想复习; #include <stdlib.h> #include <stdio.h>typedef struct _Node{//定义单链表的节点int value;struct _Node *next; }Node;Node* link(int len){//新建一个单链表int i0;Node* head (Node*)…

单片机实验报告-串口实验

一.实验目的 1. 掌握 51 单片机串口工作原理。 2. 掌握 51 单片机串口初始化编程。 3. 掌握 51 单片机串口的软硬件编程。 二.实验设备和器件 1.KEIL软件 2.PROTEUS仿真软件 3.伟福实验箱 三&#xff0e;实验内容 &#xff08;1&#xff09;编程实现&#xff1a…

学习率周期性变化

学习率周期性变化&#xff0c;能后解决陷入鞍点的问题&#xff0c;更多的方式请参考https://github.com/bckenstler/CLR base_lr:最低的学习率 max_lr:最高的学习率 step_size&#xff1a;&#xff08;2-8&#xff09;倍的每个epoch的训练次数。 scale_fn(x)&#xff1a;自…

清华发布《人工智能AI芯片研究报告》,一文读懂人才技术趋势

来源&#xff1a;Future智能摘要&#xff1a;大数据产业的爆炸性增长下&#xff0c;AI 芯片作为人工智能时代的技术核心之一&#xff0c;决定了平台的基础架构和发展生态。 近日&#xff0c;清华大学推出了《 人工智能芯片研究报告 》&#xff0c;全面讲解人工智能芯片&#xf…

开发者账号申请 真机调试 应用发布

技术博客http://www.cnblogs.com/ChenYilong/ 新浪微博http://weibo.com/luohanchenyilong 开发者账号申请 真机调试 应用发布 技术博客http://www.cnblogs.com/ChenYilong/新浪微博http://weibo.com/luohanchenyilong 要解决的问题 • 开发者账号申请 • 真机调试 • 真机调…

单片机实验-DA实验

一、实验目的 1、了解 D/A 转换的基本原理。 2、了解 D/A 转换芯片 0832 的性能及编程方法。 3、了解单片机系统中扩展 D/A 转换的基本方法。 二.实验设备和器件 1.KEIL软件 2.实验箱 三&#xff0e;实验内容 利用 DAC0832&#xff0c;编制程序产生锯齿波、三角波、正弦…

进化三部曲,从互联网大脑发育看产业互联网的未来

摘要&#xff1a;从互联网的左右大脑发育看&#xff0c;产业互联网可以看做互联网的下半场&#xff0c;但从互联网大脑的长远发育看&#xff0c;互联网依然处于大脑尚未发育成熟的婴儿时期&#xff0c;未来还需要漫长的时间发育。参考互联网右大脑的发育历程&#xff0c;我们判…

pycharm远程连接服务器(docker)调试+ssh连接多次报错

一&#xff0c;登入服务器建docker nvidia-docker run -it -v ~/workspace/:/workspace -w /workspace/ --namefzh_tf --shm-size 8G -p 1111:22 -p 1112:6006 -p 1113:8888 tensorflow/tensorflow:1.4.0-devel-gpu bash 二&#xff0c;开ssh服务 apt-get update apt-get i…

Verilog HDL语言设计4个独立的非门

代码&#xff1a; module yanxu11(in,out); input wire[3:0] in; output reg[3:0] out; always (in) begin out[0]~in[0]; out[1]~in[1]; out[2]~in[2]; out[3]~in[3]; end endmodule timescale 1ns/1ns module test(); reg[3:0] in; wire[3:0] out; yanxu11 U(…

深度长文:表面繁荣之下,人工智能的发展已陷入困境

来源&#xff1a;36氪编辑&#xff1a;郝鹏程摘要&#xff1a;《连线》杂志在其最近发布的12月刊上&#xff0c;以封面故事的形式报道了人工智能的发展状况。现在&#xff0c;深度学习面临着无法进行推理的困境&#xff0c;这也就意味着&#xff0c;它无法让机器具备像人一样的…

Facebook190亿美元收购WhatsApp

Facebook收购WhatsApp&#xff0c;前后只花费10天时间。这是Facebook迄今规模最大的一笔收购&#xff0c;可能也是史上最昂贵的一笔针对靠私人风投起家的企业的收购案。 2月9日&#xff0c;马克•扎克伯格(Mark Zuckerberg)与WhatsApp的创始人会面&#xff0c;到本周三&#xf…

Verilog HDL语言设计一个比较电路

设计一个比较电路&#xff0c;当输入的一位8421BCD码大于4时&#xff0c;输出为1&#xff0c;否则为0&#xff0c;进行功能仿真&#xff0c;查看仿真结果&#xff0c;将Verilog代码和仿真波形图整理入实验报告。 代码&#xff1a; module yanxu12(in,out); input wire[3:0] i…

交叉熵

1.公式 用sigmoid推导 上式做一下转换&#xff1a; y 视为类后验概率 p(y 1 | x)&#xff0c;则上式可以写为&#xff1a; 则有&#xff1a; 将上式进行简单综合&#xff0c;可写成如下形式&#xff1a; 写成对数形式就是我们熟知的交叉熵损失函数了&#xff0c;这也是交叉熵…

第5章 散列

我们在第4章讨论了查找树ADT&#xff0c;它允许对一组元素进行各种操作。本章讨论散列表(hash table)ADT&#xff0c;不过它只支持二叉查找树所允许的一部分操作。 散列表的实现常常叫作散列(hashing)。散列是一种以常数平均时间执行插入、删除和查找的技术。但是&#xff0c;那…

谷歌自动驾驶是个大坑,还好中国在构建自己的智能驾驶大系统

来源&#xff1a;张国斌中国有堪称全球最复杂的路况&#xff0c;例如上图是去年投入使用的重庆黄桷湾立交桥上下共5层&#xff0c;共20条匝道&#xff0c;堪称中国最复杂立交桥之最&#xff0c;据称走错一个路口要在这里一日游&#xff0c;这样的立交桥如果让谷歌无人驾驶车上去…

Qt 智能指针学习

原地址&#xff1a;http://blog.csdn.net/dbzhang800/article/details/6403285 从内存泄露开始&#xff1f; 很简单的入门程序&#xff0c;应该比较熟悉吧 ^_^ #include <QApplication> #include <QLabel>int main(int argc, char *argv[]) {QApplication app(argc…

Verilog HDL语言设计计数器+加法器

完成课本例题4.12&#xff0c;进行综合和仿真&#xff08;包括功能仿真和时序仿真&#xff09;&#xff0c;查看仿真结果&#xff0c;将Verilog代码和仿真波形图整理入实验报告。 功能文件&#xff1a; module shiyan1(out,reset,clk); input reset,clk; output reg[3:0] ou…

自动驾驶寒冬与否,关键看“芯”

来源&#xff1a;智车科技摘要&#xff1a;2018年&#xff0c;全世界瞩目的半导体行业大事件无疑是高通收购恩智浦了。虽然&#xff0c;最终这笔收购案以失败结尾&#xff0c;但高通的收购恩智浦的意图就是出自于拓展汽车芯片市场。智能汽车芯片的重要性也得以突显。前不久&…

如何知道自己的CPU支持SLAT

因为WP8 SDK发布&#xff0c;很多WP8的开发者们也开始陆续安装WP8的SDK的&#xff0c;然而安装WP8的SDK有很多软件和硬件的要求&#xff0c;其中有一个就是——要求CPU支持二级地址转换&#xff08;SLAT&#xff09;&#xff0c;如果CPU不支持二级地址转换的话&#xff0c;在电…