【TensorFlow-windows】学习笔记八——简化网络书写

前言

之前写代码的时候都要预先初始化权重,还得担心变量是否会出现被重复定义的错误,但是看网上有直接用tf.layers构建网络,很简洁的方法。

这里主要尝试了不预定义权重,是否能够实现正常训练、模型保存和调用,事实证明阔以。

验证

训练与模型保存

很简洁的代码直接五十行实现了手写数字的网络训练

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./TensorFlow-Examples-master/examples/3_NeuralNetworks/tmp",one_hot=True)steps=5000
batch_size=100
def conv_network(x):x=tf.reshape(x,[-1,28,28,1])#第一层卷积conv1=tf.layers.conv2d(inputs=x,filters=32,kernel_size=[5,5],activation=tf.nn.relu)conv1=tf.layers.max_pooling2d(conv1,pool_size=[2,2],strides=[2,2])#第二层卷积conv2=tf.layers.conv2d(inputs=conv1,filters=64,kernel_size=[3,3],activation=tf.nn.relu)conv2=tf.layers.max_pooling2d(inputs=conv2,pool_size=[2,2],strides=[2,2])#第三层卷积conv3=tf.layers.conv2d(inputs=conv2,filters=32,kernel_size=[3,3],activation=tf.nn.relu)conv3=tf.layers.max_pooling2d(inputs=conv3,pool_size=[2,2],strides=[2,2])#全连接fc1=tf.layers.flatten(conv3)fc1=tf.layers.dense(fc1,500,activation=tf.nn.relu)#输出fc2=tf.layers.dense(fc1,10)fc2=tf.nn.softmax(fc2) #因为loss里面用了softmax_cross_enrtopy,所以此行去掉return fc2input_img=tf.placeholder(dtype=tf.float32,shape=[None,28*28],name='X')
input_lab=tf.placeholder(dtype=tf.int32,shape=[None,10])#损失函数
output_lab=conv_network(input_img)
logit_loss=tf.nn.softmax_cross_entropy_with_logits_v2(labels=input_lab,logits=output_lab)
loss=tf.reduce_mean(tf.cast(logit_loss,tf.float32)) #可以去掉,因为softmax_cross_entroy自带求均值
optim=tf.train.AdamOptimizer(0.001).minimize(loss)
#评估函数
pred_equal=tf.equal(tf.arg_max(output_lab,1),tf.arg_max(input_lab,1))
accuracy=tf.reduce_mean(tf.cast(pred_equal,tf.float32))init=tf.global_variables_initializer()
saver=tf.train.Saver()
tf.add_to_collection('pred',output_lab)
with tf.Session() as sess:sess.run(init)for step in range(steps):data_x,data_y=mnist.train.next_batch(batch_size)sess.run(optim,feed_dict={input_img:data_x,input_lab:data_y})if step%100==0 or step==1:accuracy_val=sess.run(accuracy,feed_dict={input_img:data_x,input_lab:data_y})print('step'+str(step)+' ,loss '+'{:.4f}'.format(accuracy_val))print('training finished!!')saver.save(sess,'./layermodel/CNN_layer')

【更新日志】 2019-9-2
学艺不精,上面由于损失函数用的softmax_cross_entropy_with_logits_v2,所以输出会被归一化,得分也是一个batch的损失均值,因而构建网络的时候,没必要用最后下面两句话:

loss=tf.reduce_mean(tf.cast(logit_loss,tf.float32))
fc2=tf.nn.softmax(fc2)

调用模型

实现单张手写数字的识别

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
saver=tf.train.import_meta_graph('./layermodel/CNN_layer.meta')
sess=tf.Session()
saver.restore(sess,'./layermodel/CNN_layer')
graph=tf.get_default_graph()
print(graph.get_all_collection_keys())
#['pred', 'train_op', 'trainable_variables', 'variables']
print(graph.get_collection('trainable_variables'))
prediction=graph.get_collection('pred')
X=graph.get_tensor_by_name('X:0')
#读取图片
image=cv2.imread('./mnist/test/2/2_2.png')
image=cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
plt.imshow(image)
plt.show()
#显示图片
input_img=np.reshape(image,[1,28*28])
result=sess.run(prediction,feed_dict={X:input_img})
print(result)
#[array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)]

后记

其实主要是为了后续使用tf.layers里面的其它结构比如BN做准备,因为代码越复杂,写起来越恶心,不如现在看看如何简化代码,第一步就是去除了权重的预定义,后续慢慢研究其它的。

训练代码:链接:https://pan.baidu.com/s/1gmX-YBkz4nNG3RpJ_rEBKQ 密码:o8u2

测试代码:链接:https://pan.baidu.com/s/1ME9pgyM9TNQadmzMeURlNg 密码:5z7k

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246602.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化学习——Qlearning

前言 在控制决策领域里面强化学习还是占很重比例的,最近出了几篇角色控制的论文需要研究,其中部分涉及到强化学习,都有开源,有兴趣可以点开看看: A Deep Learning Framework For Character Motion Synthesis and Edit…

【TensorFlow-windows】keras接口学习——线性回归与简单的分类

前言 之前有写过几篇TensorFlow相关文章,但是用的比较底层的写法,比如tf.nn和tf.layers,也写了部分基本模型如自编码和对抗网络等,感觉写起来不太舒服,最近看官方文档发现它的教程基本都使用的keras API,这…

【TensorFlow-windows】keras接口——卷积手写数字识别,模型保存和调用

前言 上一节学习了以TensorFlow为底端的keras接口最简单的使用,这里就继续学习怎么写卷积分类模型和各种保存方法(仅保存权重、权重和网络结构同时保存) 国际惯例,参考博客: 官方教程 【注】其实不用看博客,直接翻到文末看我的c…

【TensorFlow-windows】keras接口——BatchNorm和ResNet

前言 之前学习利用Keras简单地堆叠卷积网络去构建分类模型的方法,但是对于很深的网络结构很难保证梯度在各层能够正常传播,经常发生梯度消失、梯度爆炸或者其它奇奇怪怪的问题。为了解决这类问题,大佬们想了各种办法,比如最原始的…

【TensorFlow-windows】keras接口——卷积核可视化

前言 在机器之心上看到了关于卷积核可视化相关理论,但是作者的源代码是基于fastai写的,而fastai的底层是pytorch,本来准备自己用Keras复现一遍的,但是尴尬地发现Keras还没玩熟练,随后发现了一个keras-vis包可以用于做…

【TensorFlow-windows】投影变换

前言 没什么重要的,就是想测试一下tensorflow的投影变换函数tf.contrib.image.transform中每个参数的含义 国际惯例,参考文档 官方文档 描述 调用方法与默认参数: tf.contrib.image.transform(images,transforms,interpolationNEAREST,…

【TensorFlow-windows】扩展层之STN

前言 读TensorFlow相关代码看到了STN的应用,搜索以后发现可替代池化,增强网络对图像变换(旋转、缩放、偏移等)的抗干扰能力,简单说就是提高卷积神经网络的空间不变性。 国际惯例,参考博客: 理解Spatial Transformer…

【TensorFlow-windows】部分损失函数测试

前言 在TensorFlow中提供了挺多损失函数的,这里主要测试一下均方差与交叉熵相关的几个函数的计算流程。主要是测试来自于tf.nn与tf.losses的mean_square_error、sigmoid_cross_entry、softmax_cross_entry、sparse_softmax_cross_entry 国际惯例,参考博…

RS编码-Python工具包使用

前言 最近学习二维码相关知识,遇到了ReedSolomon编码,简称RS编码,中文名里德所罗门编码。遇到的问题是使用的工具包返回的编码是bytearray类型,而二维码是二进制01编码,所以本博客主要验证,如何将bytearra…

【TensorFlow-windows】MobileNet理论概览与实现

前言 轻量级神经网络中,比较重要的有MobileNet和ShuffleNet,其实还有其它的,比如SqueezeNet、Xception等。 本博客为MobileNet的前两个版本的理论简介与Keras中封装好的模块的对应实现方案。 国际惯例,参考博客: 纵…

【TensorFlow-windows】keras接口——ImageDataGenerator裁剪

前言 Keras中有一个图像数据处理器ImageDataGenerator,能够很方便地进行数据增强,并且从文件中批量加载图片,避免数据集过大时,一下子加载进内存会崩掉。但是从官方文档发现,并没有一个比较重要的图像增强方式&#x…

【TensorFlow-windows】name_scope与variable_scope

前言 探索一下variable_scope和name_scope相关的作用域,为下一章节tensorboard的学习做准备 其实关于variable_scope与get_variable实现变量共享,在最开始的博客有介绍过: 【TensorFlow-windows】学习笔记二——低级API 当然还是国际惯例…

【TensorFlow-windows】TensorBoard可视化

前言 紧接上一篇博客,学习tensorboard可视化训练过程。 国际惯例,参考博客: MNIST机器学习入门 Tensorboard 详解(上篇) Tensorboard 可视化好帮手 2 tf-dev-summit-tensorboard-tutorial tensorflow官方mnist_…

深度学习特征归一化方法——BN、LN、IN、GN

前言 最近看到Group Normalization的论文,主要提到了四个特征归一化方法:Batch Norm、Layer Norm、Instance Norm、Group Norm。此外,论文还提到了Local Response Normalization(LRN)、Weight Normalization(WN)、Batch Renormalization(BR)…

【TensorFlow-windows】keras接口——利用tensorflow的方法加载数据

前言 之前使用tensorflow和keras的时候,都各自有一套数据读取方法,但是遇到一个问题就是,在训练的时候,GPU的利用率忽高忽低,极大可能是由于训练过程中读取每个batch数据造成的,所以又看了tensorflow官方的…

【TensorFlow-serving】初步学习模型部署

前言 初步学习tensorflow serving的手写数字识别模型部署。包括简单的模型训练、保存、部署上线。因为对docker和网络不太熟悉,可能会有部分错误,但是看完博客,能跑通整个流程。此博客将详细介绍流程,但是不详细介绍每个流程的每…

Tensorflow 指令加速

一直没注意过使用Tensorflow的时候有一条warning: Warning: your cpu supports instructions that this tensorflow binary was not compiled to use: avx2 fma这玩意是可以用来加速推断的,分CPU和GPU版,业务相关部署在CPU上,实测…

骨骼动画——论文与代码精读《Phase-Functioned Neural Networks for Character Control》

前言 最近一直玩CV,对之前学的动捕知识都忘得差不多了,最近要好好总结一下一直以来学习的内容,不能学了忘。对2017年的SIGGRAPH论文《Phase-Functioned Neural Networks for Character Control》进行一波深入剖析吧,结合源码。 额…

颜色协调模型Color Harmoniztion

前言 最近做换脸,在肤色调整的那一块,看到一个有意思的文章,复现一波玩玩。不过最后一步掉链子了,有兴趣的可以一起讨论把链子补上。 主要是github上大佬的那个复现代码和原文有点差异,而且代码复杂度过高&#xff0…

Openpose推断阶段原理

前言 之前出过一个关于openpose配置的博客,不过那个代码虽然写的很好,而且是官方的,但是分析起来很困难,然后再opencv相关博客中找到了比较清晰的实现,这里分析一波openpose的推断过程。 国际惯例,参考博…