13 Tensorflow机制(翻译)

    代码: tensorflow/examples/tutorials/mnist/

    本文的目的是来展示如何使用Tensorflow训练和评估手写数字识别问题。本文的观众是那些对使用Tensorflow进行机器学习感兴趣的人。

    本文的目的并不是讲解机器学习。

    请确认您已经安装了Tensorflow。

 

    教程文件

文件作用
mnist.py用来创建一个完全连接的MNIST模型。
fully_connected_feed.py使用下载的数据集训练模型。

    运行fully_connected_feed.py文件开始训练。

python fully_connected_feed.py

 

    准备数据

    MNIST是机器学习的一个经典问题。这个问题是识别28*28像素图片上的数字,从0到9。

    更多信息,请参考Yann LeCun's MNIST page 或者 Chris Olah's visualizations of MNIST。

 

    数据下载

    在run_training()方法之前,input_data.read_data_sets()方法可以让数据下载到本机训练文件夹,解压数据并返回一个DataSet实例。

data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

    注意:fake_data是用来进行单元测试的,读者可以忽略。

数据集作用
data_sets.train55000图片和标签,用来训练。
data_sets.validation5000图片和标签,用来在迭代中校验模型准确度。
data_sets.test10000图片和标签,用来测试训练模型准确度。

   

    输入和占位符

    placeholder_inputs()函数创建两个tf.placeholder,用来定义输入的形状,包括fetch_size。

images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))

    在训练循环中,图片和标签数据集会被切分成batch_size大小,跟占位符匹配,然后通过feed_dict参数传递到sess.run()方法中。

 

    创建图

    创建占位符后,mnist.py文件中会通过三个步骤来创建图:inference(), loss(), 和training()。

  1. inference() - 运行网络来进行预测。
  2. loss() - 用来计算损失值。
  3. training() - 计算梯度。

    inference层

    inference()函数创建图,返回预测结果。

    它把图片占位符当作输入,并在上面构建一对完全连接的层,使用ReLU激活后,连接一个10个节点的线性层。

    每一层都位于tf.name_scope声明的命名空间中。

with tf.name_scope('hidden1'):

    在该命名空间中,权重和偏置会产生tf.Variable实例,并具有所需的形状。

weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units], stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]), name='biases')

    例如,这些会在hidden1命名空间中创建,那么权重的唯一名称为“hidden1/weights”。

    每个变量使用初始化器作为构造函数。

    通常,权重会使用tf.truncated_normal(截尾正态分布)作为初始化器,它是一个2D张量,第一个参数表示该层中的神经元数,第二个表示它连接的层中的神经元数。再第一层hidden1中,权限矩阵的大小是[图片像素, hidden1神经元数],因为该权重连接图片输入。tf.truncated_normal初始化器会根据平均值和标准差产生一些随机数。

  然后,偏置会使用tf.zeros作为初始化器,保证开始时所有数都是0。它们的形状跟它们连接的层的神经元一样。

  该图的三个主要运算:两个tf.nn.relu操作(包括隐层中的一个tf.matmul操作)和一个额外的tf.matmul操作。然后依次创建,连接到输入占位符或上一层的输出张量上。

 

hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)

 

hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
logits = tf.matmul(hidden2, weights) + biases

    最后,logits张量包含输出结果。

 

    损失

    loss()函数通过添加所需的损失操作来进一步构建图形。

    首先,将labels_placeholder的值转换为64位整数。 然后,添加tf.nn.sparse_softmax_cross_entropy_with_logits操作,以自动从labels_placeholder产生标签,并将inference()函数的输出与这些标签进行比较。

   

labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='xentropy')    

    然后使用tf.reduce_mean将batch维度(第一维)的交叉熵的平均数作为总损耗。

loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

    然后返回包含损失值的张量。

    注意:交叉熵是信息论中的一个想法,它使我们能够描述神经网络的预测有多糟糕。有关更多信息,请阅读博客文章Visual Information Theory(http://colah.github.io/posts/2015-09-Visual-Information/)
    训练
    training()函数通过梯度下降法计算最小损失。
    首先,它从loss()函数中获取损失张量,并将其传递给tf.summary.scalar,该函数用于在与tf.summary.FileWriter一起使用时将事件生成摘要。
   
tf.summary.scalar('loss', loss)

    接下来,我们实例化一个tf.train.GradientDescentOptimizer,进行梯度下降算法。

optimizer = tf.train.GradientDescentOptimizer(learning_rate)

    然后,我们定义一个变量,用来作为全局训练步骤的计数器,并且tf.train.Optimizer.minimize op用于更新系统中的可训练权重,并增加全局步长。 通常,这个操作被称为train_op. 它是由TensorFlow会话运行的,以便引导一个完整的训练步骤。

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

   

    训练模型

    构建图形后,可以在full_connected_feed.py中由用户代码控制的循环中进行迭代训练和评估。

    图

    在run_training()函数的顶部,其中的命令指示所有构建的操作都与默认的全局tf.Graph实例相关联。

with tf.Graph().as_default():

    tf.Graph是可以作为一组一起执行的操作的集合。 大多数TensorFlow用户只需要依赖于单个默认图形。
    更复杂的使用多个图形是可能的,但超出了这个简单教程的范围。

    会话

    一旦所有的构建准备工作已经完成并且生成了所有必要的操作,就会创建一个tf.Session来运行图形。

sess = tf.Session()

    或者,可以将会话生成到某个作用域中:

with tf.Session() as sess:

    会话的空参数表示此代码将附加到默认本地会话(或创建尚未创建)。
    在创建会话之后,所有的tf.Variable实例都通过在初始化操作中调用tf.Session.run来初始化。

init = tf.global_variables_initializer()
sess.run(init)

    tf.Session.run方法将进行参数传递操作。在这个调用中,只进行变量的初始值。 图的其余部分都不在这里运行; 这在下面的训练循环中运行。

 

    训练循环

    在会话初始化变量后,可以开始训练。
    用户代码控制每一步的训练,最简单的循环可以是:

for step in xrange(FLAGS.max_steps):sess.run(train_op)

    但是,本教程稍微复杂一些,因为它还必须分割每个步骤的输入数据,以匹配先前生成的占位符。

   

    数据输入到图

    对于每个步骤,代码将生成一个Feed字典,其中包含一组数据,用于训练,由其所对应的占位符操作输入。
    在fill_feed_dict()函数中,查询给定的DataSet用于其下一个batch_size图像和标签集,填充与占位符匹配的张量,其中包含下一个图像和标签。

images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, FLAGS.fake_data)

    然后生成一个python字典对象,其中占位符作为键,代表性的Feed张量作为值。  

feed_dict = {images_placeholder: images_feed,labels_placeholder: labels_feed,
}

    这将被传递给sess.run()函数的feed_dict参数,以供该训练循环使用。

   

    检查状态

    该代码指定在运行调用中获取的两个值:[train_op,loss]。

for step in xrange(FLAGS.max_steps):feed_dict = fill_feed_dict(data_sets.train,images_placeholder,labels_placeholder)_, loss_value = sess.run([train_op, loss],feed_dict=feed_dict)

    因为要获取两个值,所以sess.run()返回一个包含两个项的元组。 要提取的值列表中的每个Tensor对应于返回的元组中的numpy数组,在该训练步骤中填充该张量的值。 由于train_op是没有输出值的操作,返回的元组中的相应元素为None,因此被丢弃。 然而,如果模型在训练过程中发生分歧,则损失张量的值可能变为NaN,因此我们捕获该值用于记录。
    假设没有NaN,训练运行良好,训练循环还会每100个步骤打印一个简单的状态文本,让用户知道训练状态。

if step % 100 == 0:print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))

   

    状态可视化

    为了输出TensorBoard使用的事件文件,在图形构建阶段,所有的摘要(在这种情况下只有一个)被收集到一个Tensor中。

summary = tf.summary.merge_all()

    然后在创建会话之后,可以将tf.summary.FileWriter实例化为写入事件文件,其中包含图形本身和摘要的值。

summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    最后,每次评估摘要并将输出传递给add_summary()函数时,事件文件将被更新为新的摘要值。

summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)

    当写入事件文件时,可以针对训练文件夹运行TensorBoard,以显示摘要中的值。

    注意:有关如何构建和运行Tensorboard的更多信息,请参阅随附的教程Tensorboard:可视化学习。

   

    保存检查点

    为了输出一个检查点文件,可以用于稍后恢复模型进行进一步的训练或评估,我们实例化一个tf.train.Saver。

saver = tf.train.Saver()

    在训练循环中,将定期调用tf.train.Saver.save方法,将训练中各变量的值写入检查点文件。

   

saver.save(sess, FLAGS.train_dir, global_step=step)

    在稍后的某些时候,可以使用tf.train.Saver.restore方法来重新加载模型参数来恢复训练。

saver.restore(sess, FLAGS.train_dir)

   

    评估模型

    每一步,代码将尝试针对训练和测试数据集来评估模型。 do_eval()函数被执行三次,用于训练,验证和测试数据集。

print('Training Data Eval:')
do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.train)
print('Validation Data Eval:')
do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.validation)
print('Test Data Eval:')
do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.test)

    请注意,更复杂的使用通常会将data_sets.test隔离,以便在大量超参数调整后才能进行检查。 然而,为了简单的小MNIST问题,我们对所有数据进行评估。

   

    构建评估图

    在进入训练循环之前,评估操作应该是通过调用mnist.py中的evaluate()函数,使用与loss()函数相同的参数构建的。

eval_correct = mnist.evaluation(logits, labels_placeholder)

    评估函数简单地生成一个tf.nn.in_top_k操作,如果真正的标签可以在K个最可能的预测中找到,那么可以自动对每个模型输出进行评分。 在这种情况下,我们将K的值设置为1,以便仅对真实标签考虑预测是否正确。 

eval_correct = tf.nn.in_top_k(logits, labels, 1)

   

    评估输出

    然后可以创建一个填充feed_dict的循环,并针对eval_correct op调用sess.run()来评估给定数据集上的模型。

for step in xrange(steps_per_epoch):feed_dict = fill_feed_dict(data_set,images_placeholder,labels_placeholder)true_count += sess.run(eval_correct, feed_dict=feed_dict)

    true_count变量简单地累加了in_top_k op已经确定为正确的所有预测。 从那里可以从简单地除以实例的总数来计算精度。

precision = true_count / num_examples
print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %(num_examples, true_count, precision))

 

 

   原文:《TensorFlow Mechanics 101》:https://www.tensorflow.org/get_started/mnist/mechanics

 

   

 

转载于:https://www.cnblogs.com/tengge/p/6920670.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/269282.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

有趣的Web版Ubuntu Linux

其实这不是真的 Ubuntu 啦。不过,在看到 Wubuntu 时,其逼真的模仿效果真是令人惊叹不已。不管怎么样,让我们来体验一把 Web 版的 Ubuntu 吧。首先,我们会经历一个 Ubuntu 启动过程。其启动画面与真实的 Ubuntu 一模一样。接着&…

重新定义旅游网站,米胖新版发布

还记得一年多之前,我在web 2.0 网站推荐这篇博客中提到了米胖。没多久,我认识了米胖的两位帅气又有才气的当家人。在多次聊天之后,我被他们的激情与专注深深地折服了,在那时我就坚信米胖一定能够发展得很好,走出一条属…

wait和notify使用例子

public class Test2 {public static void main(String[] args) {String lock "lock";Thread thread1 new Thread(new Runnable() {Overridepublic void run() {synchronized (lock){System.out.println("线程1开始等待" System.currentTimeMillis());tr…

Linux基础系列:常用命令(5)_samba服务与nginx服务

作业一:部署samba 每个用户有自己的目录,可以浏览内容,也可以删除 所有的用户共享一个目录,只能浏览内容,不能删 安装samba服务 1、准备环境 setenforce 0 2、安装软件包 yum -y install samba 3、修改配置文件 /etc/s…

python练习,随机数字 函数,循环,if,格式化输出

# double ball game import random count 10000000000 # 设置多少注 blue_start 1 blue_end 5 a [] def make_surprise():i 0while i < 6:i 1number random.randrange(1, 32, 1)a.append(format({:02d}.format(number)))a.append(format({:02d}.format(rando…

notify()唤醒线程,不会立即释放锁对象,需要等到当前同步代码块都执行完后才能释放锁对象

notify()唤醒线程&#xff0c;不会立即释放锁对象&#xff0c;需要等到当前同步代码块都执行完后才能释放锁对象 public class Test3 {public static void main(String[] args) {List<String> list new ArrayList<>();Thread thread1 new Thread(new Runnable(…

LINUX下的APACHE的配置

今天写一下LINUX下的APACHE的配置方法。APACHE是作为WEB服务器的。它的优点在于用缓存方式来加快网页的搜索速度。APACHE缺省只支持静态网页LINUX下有APACHE的RPM包安装上第一张盘里的httpd-2.0.40-21.i386.rpm 包1 /etc/httpd/conf.d 放在这里的都是动态网页的配置文件2 /etc/…

程序实践:命令行之连连看

命令行之连连看 程序实践周课题&#xff0c;VC6.0上可编译执行 游戏截图&#xff1a; #include <cstdio>#include <cstring> #include <iostream> #include <windows.h> #include <time.h> #include <algorithm> using namespace std; in…

interrupt()会中断线程的wait等待

public class Thread5 {public static void main(String[] args) {SubThread subThread new SubThread();subThread.start();try {//主线程睡眠2秒&#xff0c;确保子线程处于wait状态Thread.sleep(2000);} catch (InterruptedException e) {e.printStackTrace();}subThread.i…

在ASP.Net 2.0中实现多语言界面的方法

1&#xff0e; 跟以前一样做界面&#xff0c;只是注意&#xff0c;把所有需要有多语言界面的文字都用label来做 2&#xff0e; 做完以后&#xff0c;在Solution Explorer里选中这个文件&#xff0c;选Tools-&#xff1e;Generate Local Resource3&#xff0e; 你会发现生成了一…

Qt 使用代码编写的自定义控件类

Qt 使用代码编写的自定义控件类 首先需要完成继承QWidget 或者Qt 原生控件类的类编写实现在需要使用自定义控件类的 UI 文件中添加一个 自定义类的控件&#xff08;也就是自定义类继承的控件&#xff09;将这个控件进行提升&#xff08;promote) 为自定义类&#xff0c;记得设…

mac使用word怎么显示左侧目录树

1&#xff0c;点击”视图” 2&#xff0c;点击“导航窗口” 3&#xff0c;点击如图所示图标

java BigDecimal去掉小数点后的零

new BigDecimal(spstFil.getCnt().stripTrailingZeros().toPlainString())

Qt 多重继承时 moc 编译出错

class SZNR103Client : public QObject , public CommBase {在这里插入代码片 bash 在这里插入代码片 注意一点&#xff1a; QOBject 必须写在自己的类前面&#xff0c;否则编译会有问题

中毒,重装,杀毒……最近一段时间,很烦的一件事,不断重复……

之前写的&#xff0c;因为最近太多人中毒了&#xff0c;太多人问了&#xff0c;太多人找我了…… 所以&#xff0c;很烦很烦…… 自己简直成了专业杀毒软件&#xff08;麻烦还没有杀毒软件的朋友&#xff0c;用金钱或者其他办法找个杀毒软件&#xff0c;一定…

使用postman发送HttpServletRequest请求

使用postman发送HttpServletRequest请求 使用postman发送HttpServletRequest请求 Headers部分是key: Content-Type value: application/x-www-form-urlencoded 后台使用这个接收String name request.getParameter("name");

第一次写,python爬虫图片,操作excel。

第一次写博客&#xff0c;其实老早就注册博客园了&#xff0c;有写博客的想法&#xff0c;就是没有行动&#xff0c;总是学了忘&#xff0c;忘了丢&#xff0c;最后啥都没有&#xff0c;电脑里零零散散&#xff0c;东找找&#xff0c;西看看&#xff0c;今天认识到写博客的重要…

JavaScript 异常处理

异常处理概述在代码的运行过程中&#xff0c;错误是不可避免的&#xff0c;总的来说&#xff0c;错误发生于两种情况&#xff1a;一是程序内部的逻辑或者语法错误&#xff0c;二是运行环境或者用户输入中不可预知的数据造成的错误。对于前者&#xff0c;就称之为错误&#xff0…