tensorflow学习笔记七----------卷积神经网络

卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer)。

使用mnist数据集,n个数据,每个数据的像素为28*28*1=784。先让这些数据通过第一个卷积层,在这个卷积上指定一个3*3*1的feature,这个feature的个数设为64。接着经过一个池化层,让这个池化层的窗口为2*2。然后在经过一个卷积层,在这个卷积上指定一个3*3*64的feature,这个featurn的个数设置为128,。接着经过一个池化层,让这个池化层的窗口为2*2。让结果经过一个全连接层,这个全连接层大小设置为1024,在经过第二个全连接层,大小设置为10,进行分类。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg   = mnist.train.images
trainlabel = mnist.train.labels
testimg    = mnist.test.images
testlabel  = mnist.test.labels
print ("MNIST ready")
#像素点为784
n_input  = 784
#十分类
n_output = 10
#wc1,第一个卷积层参数,3*3*1,共有64个
#wc2,第二个卷积层参数,3*3*64,共有128个
#wd1,第一个全连接层参数,经过两个池化层被压缩到7*7
#wd2,第二个全连接层参数
weights  = {'wc1': tf.Variable(tf.random_normal([3, 3, 1, 64], stddev=0.1)),'wc2': tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.1)),'wd1': tf.Variable(tf.random_normal([7*7*128, 1024], stddev=0.1)),'wd2': tf.Variable(tf.random_normal([1024, n_output], stddev=0.1))}
biases   = {'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),'bd2': tf.Variable(tf.random_normal([n_output], stddev=0.1))}

定义前向传播函数。先将输入数据预处理,变成tensorflow支持的四维图像;进行第一层的卷积层处理,调用conv2d函数;将卷积结果用激活函数进行处理(relu函数);将结果进行池化层处理,ksize代表窗口大小;将池化层的结果进行随机删除节点;进行第二层卷积和池化...;进行全连接层,先将数据进行reshape(此处为7*7*128);进行激活函数处理;得出结果。前向传播结束。

def conv_basic(_input, _w, _b, _keepratio):# INPUT_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])# CONV LAYER 1_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)# CONV LAYER 2_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)# VECTORIZE_dense1 = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])# FULLY CONNECTED LAYER 1_fc1 = tf.nn.relu(tf.add(tf.matmul(_dense1, _w['wd1']), _b['bd1']))_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)# FULLY CONNECTED LAYER 2_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])# RETURNout = { 'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool1_dr1': _pool_dr1,'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'dense1': _dense1,'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out}return out
print ("CNN READY")

定义损失函数,定义优化器

x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32)# FUNCTIONS

_pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
_corr = tf.equal(tf.argmax(_pred,1), tf.argmax(y,1)) 
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) 
init = tf.global_variables_initializer()# SAVER
save_step = 1
saver = tf.train.Saver(max_to_keep=3) print ("GRAPH READY")

进行迭代

do_train = 1
sess = tf.Session()
sess.run(init)training_epochs = 15
batch_size      = 16
display_step    = 1
if do_train == 1:for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)# Loop over all batchesfor i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)# Fit training using batch datasess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio:0.7})# Compute average lossavg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio:1.})/total_batch# Display logs per epoch stepif epoch % display_step == 0: print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio:1.})print (" Training accuracy: %.3f" % (train_acc))#test_acc = sess.run(accr, feed_dict={x: testimg, y: testlabel, keepratio:1.})#print (" Test accuracy: %.3f" % (test_acc))print ("OPTIMIZATION FINISHED")

 

转载于:https://www.cnblogs.com/xxp17457741/p/9480521.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/266325.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

matlab strfind用法,findstr和strfind区别

matlab中这两个字符串查找的函数findstr(), strfind()表明上看起来用法相似,效果也相似。一. findstr(s1,s2)--在较长的字符串中查找较短的字符串出现的次数,并返回其位置,因此无论s1,s2哪个为长字符串,位置在前在后都没有关系。例…

python发邮件给女朋友代码_python实现邮件发送完整代码(带附件发送方式)

实例一:利用SMTP与EMAIL实现邮件发送,带附件(完整代码) __author__ Administrator #codinggb2312 from email.Header import Header from email.MIMEText import MIMEText from email.MIMEMultipart import MIMEMultipart import…

Ubuntu识别USB设备

参考:如何解决Ubuntu无法识别USB设备 作者:一只青木呀 发布时间:2020-08-28 21:02:00 网址:https://blog.csdn.net/weixin_45309916/article/details/108286829 目录1、U盘识别2、识别转换解决Ubuntu无法识别USB3.0方法一&#xf…

用ElasticSearch存储日志

介绍 如果你使用elasticsearch来存储你的日志,本文给你提供一些做法和建议。 如果你想从多台主机向elasticsearch汇集日志,你有以下多种选择: Graylog2 安装在一台中心机上,然后它负责往elasticsearch插入日志,而且你可…

解除单个文件的与svn服务器的关联

有些文件和个人开发环境有关不需要和svn服务器做同步,可以取消其和svn服务的关联。 右键选中要取消关联的文件,右键菜单 Tortoise SVN ---> unversion and add to ignore list 确定后,文件图标会变成一把小剪刀,说明已经…

java xml出错,Java xml出现错误 javax.xml.transform.TransformerException: java.lang.NullPointerException...

Java xml出现错误 javax.xml.transform.TransformerException: java.lang.NullPointerException解决办法:利用Java操作XML,在操作XML过程中,执行到最后一步,在利用Transformer进行XML转换时出现NullPointerException错误&#xff…

Ubuntu磁盘扩容及启动问题整理

参考:Ubuntu磁盘扩容及启动问题整理 作者:一只青木呀 发布时间: 2020-12-08 10:42:19 网址:https://blog.csdn.net/weixin_45309916/article/details/110850358 也可参照正点原子的:Ubuntu磁盘空间不足?一招…

函数求值需要运行所有线程_精读《深度学习 - 函数式之美》

1 引言函数式语言在深度学习领域应用很广泛,因为函数式与深度学习模型的契合度很高,The Beauty of Functional Languages in Deep Learning — Clojure and Haskell 就很好的诠释了这个道理。通过这篇文章可以加深我们对深度学习与函数式编程的理解。2…

IOS(常用移动终端设备) push实现通知中心

参考文章: http://blog.csdn.net/zhuqilin0/article/details/6527113 http://www.dozer.cc/2013/03/push-notifications-server-side-implement/ http://blog.sina.com.cn/s/blog_71ce775e0101b43e.html PushSharp:https://github.com/Redth/PushShar…

Ubuntu下无法看到共享文件夹的解决办法

参考:Ubuntu下无法看到共享文件夹的解决办法 作者:一只青木呀 发布时间:2020-08-07 10:09:04 网址:https://blog.csdn.net/weixin_45309916/article/details/107856157 今天早上起来突然发现共享文件夹的查看不到了,找…

MySQL提供了以下三种方法用于获取数据库对象的元数据

MySQL提供了以下三种方法用于获取数据库对象的元数据: 1)show语句 2)从INFORMATION_SCHEMA数据库里查询相关表 3)命令行程序,如mysqlshow, mysqldump 用SHOW语句获取元数据 MySQL用show语句获取元数据是最常用的方法&a…

laravel 5.1 php版本号,发行版本说明 | 序言 | Laravel 5.1 中文文档

发行版本说明由 学院君 创建于5年前, 最后更新于 11个月前版本号 #219641 views23 likes0 collects支持政策LTS版本,比如Laravel 5.1,将会提供两年的bug修复和三年的安全修复支持。这些版本将会提供最长时间的支持和维护。对于其他通用版本,只…

NYOJ-522 Interval

Interval 时间限制&#xff1a;2000 ms | 内存限制&#xff1a;65535 KB难度&#xff1a;4描述There are n(1 < n < 100000) intervals [ai, bi] and m(1 < m < 100000) queries, -100000 < ai < bi < 100000 are integers.Each query contains an inte…

实现图书增删的代码_不仅仅是图书信息管理系统

点击蓝字 关注我们不仅仅是图书信息管理系统基于双链表&#xff0c;采用面向对象编程方法制作的图书管理系统❞效果演示root用户&#xff1a;账号&#xff1a;0&#xff0c;密码&#xff1a;0普通用户&#xff1a;账号&#xff1a;1001&#xff0c;密码&#xff1a;666666图书信…

HDU1846 - Brave Game【巴什博弈】

十年前读大学的时候&#xff0c;中国每年都要从国外引进一些电影大片&#xff0c;其中有一部电影就叫《勇敢者的游戏》&#xff08;英文名称&#xff1a;Zathura&#xff09;&#xff0c;一直到现在&#xff0c;我依然对于电影中的部分电脑特技印象深刻。 今天&#xff0c;大家…

Ubuntu18.04换源更新国内源

参考&#xff1a;Ubuntu18.04更新国内源 作者&#xff1a;一只青木呀 发布时间&#xff1a;2020-08-05 10:24:11 网址&#xff1a;https://blog.csdn.net/weixin_45309916/article/details/107808268 树莓派换源博文&#xff1a;https://blog.csdn.net/zhuguanlin121/article/d…

php中解析数组,在PHP中解析多维数组

您应该在将数据解析为Smarty之前准备好数据。你可以这样做&#xff1a;$result array(array(name > Hockey Team 1, category_id > 1),array(name > Hockey Team 2, category_id > 2),array(name > Hockey Team 3, category_id > 3),array(name > Footba…

了解jQuery并掌握jQuery对象和DOM对象的区别

jQuery的优势&#xff1a; 开源--开放源代码 轻量级 强大的选择器 出色的DOM操作(对DOM元素的一个增删改查) 完善的Ajax&#xff0c;出色的浏览器兼容性&#xff0c;丰富的插件支持&#xff0c;完善的文档&#xff08;说明书&#xff09; 链式操作方式&#xff0c; 写得少&…

linux下复制

复制文件 cp - i file tofile 复制目录 cp - r dic todic转载于:https://www.cnblogs.com/Hero-Qiang/archive/2013/03/20/2971579.html

rh php56 php,在全球范围内提供RHSCL PHP的最佳方法

我使用以下网址安装了RHSCL 2&#xff1a;使用RedHat订阅管理器.然后我运行yum删除php *,然后是yum install rh-php56一切顺利,除非现在找不到PHP.然后我运行find / -name php并在以下目录中找到rh-php56&#xff1a;/var/opt/rh/rh-php56/lib/php/opt/rh/rh-php56/register.co…