TensorFlow实现简单的卷积网络

使用的数据集是MNIST,下载方法见之前的博客

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets(r"D:\PycharmProjects\tensorflow\MNIST_data", one_hot=True)
sess = tf.InteractiveSession()# 后面有很多权重和偏置需要创建,所以这里定义创建权重和偏置的函数以方便重复使用
# 我们需要给权重制造噪声以打破完全对称,因为我们使用ReLU,也给偏置加一些小的正值以避免死亡节点
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)# 卷积层和池化层也是接下来重复使用的,因此也为它们定义创建函数
# x是输入,W是卷积的参数,比如[5,5,1,32],前面两个数字是卷积核的尺寸,第三个数字代表有多少个channel
# 这里我们是灰度单色,所以是1,最后一个数字代表卷积核的数量,也就是这个卷积层会提取多少个特征
# 第三个参数是步长,虽然第三个参数提供的是一个长度为4的数组,但是第一维和最后一维的数字要求一定是 1
# 最后一个参数是填充的方法,SAME但表示添加全0填充,VALID表示不添加
def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')# 第二个参数为过滤器的尺寸。虽然是一个长度为4的一维数组,但是这个数组的第一个和最后一个数必须为1。
# 这意味着池化层的过滤器是不可以跨不同输入样例或者节点矩阵深度的。因为x的第一维对应一个batch,第四维是channel数
# 因为希望整体上缩小尺寸,所以strides步长设为2,如果设为1,我们会得到一个尺寸不变的图片
def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')  x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10]) # 真实标签
x_image = tf.reshape(x, [-1,28,28,1]) #将1×784转为28×28,颜色通道只有1,-1代表样本数量不确定#定义第一个卷积层,尺寸为5×5,1个颜色通道,32个卷积核
#tf.nn.bias_add提供了一个方便的函数给每一个节点加上偏置项,注意这里不能直接使用加法
#因为矩阵上不同位置上的节点都需要加上同样的偏置项
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(tf.nn.bias_add(conv2d(x_image, W_conv1), b_conv1))
h_pool1 = max_pool_2x2(h_conv1)#定义第二个卷积层
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(tf.nn.bias_add(conv2d(h_pool1, W_conv2), b_conv2))
h_pool2 = max_pool_2x2(h_conv2)#因为前面经历了两次2×2的池化层,所以边长只有1/4即图片变成7×7,因为第二个卷积层的卷积核数量为64
#所以输出tensor的尺寸为7×7×64,将其转成1D向量,再连接一个1024个隐含节点的全连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)#减轻过拟合
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)#将dropout输出层的输出连接一个softmax层,得到概率输出
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)#定义准确率
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))tf.global_variables_initializer().run()
for i in range(20000):batch = mnist.train.next_batch(50)if i%100 == 0: #每100次训练,对准确率进行一次评测train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})print("step %d, training accuracy %g"%(i, train_accuracy))train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BZOJ2819 Nim(DFS序)

题目:单点修改、树链查询。 可以直接用树链剖分做。。 修改是O(QlogN),查询是O(QlogNlogN),QN500000; 听说会超时。。 这题也可以用DFS序来做。 先不看修改,单单查询:可以求出每个点到根的xor值&#xff0c…

全球CMOS图像传感器厂商最新排名:黑马杀出

来源:半导体行业观察近期,台湾地区的Yuanta Research发布报告,介绍了其对CMOS图像传感器(CIS)市场的看法,以及到2022年的前景预期。从该研究报告可以看出,2018年全球CMOS图像传感器的市场规模为137亿美元,其…

下载CIFAR-10、CIFAR-100数据集的方法

该网站的数据集目录MNISTCIFAR-10CIFAR-100STL-10SVHNILSVRC2012 task 1 网址:http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html

吴恩达《机器学习》学习笔记十四——应用机器学习的建议实现一个机器学习模型的改进

吴恩达《机器学习》学习笔记十四——应用机器学习的建议实现一个机器学习模型的改进一、任务介绍二、代码实现1.准备数据2.代价函数3.梯度计算4.带有正则化的代价函数和梯度计算5.拟合数据6.创建多项式特征7.准备多项式回归数据8.绘制学习曲线𝜆0𝜆1&…

刘锋 吕乃基:互联网中心化与去中心化之争

前言:本文发表在2019年5月《中国社会科学报》上,主要从神经学角度分析互联网的发育过程,并对云计算和区块链为代表的中心化与去中心化技术趋势进行了探讨。当前,学术界和产业界对互联网的未来发展出现了分歧。随着谷歌、亚马逊、F…

胶囊网络不同实现代码

* Keras w/ TensorFlow backend: https://github.com/XifengGuo/CapsNet-keras * TensorFlow: https://github.com/naturomics/CapsNet-Tensorflow * PyTorch: https://github.com/gram-ai/capsule-networks

iOS-BMK标注覆盖物

在iOS开发中,地图算是一个比较重要的模块。我们常用的地图有高德地图,百度地图,谷歌地图,对于中国而言,苹果公司已经不再使用谷歌地图,官方使用的是高德地图。下面将讲述一下百度地图开发过程中的一些小的知…

PyTorch框架学习二——基本数据结构(张量)

PyTorch框架学习二——基本数据结构(张量)一、什么是张量?二、Tensor与Variable(PyTorch中)1.Variable2.Tensor三、Tensor的创建1.直接创建Tensor(1)torch.tensor()(2)to…

十年空缺一朝回归,百度正式任命王海峰出任CTO

来源:机器之心百度要回归技术初心了吗?自 2010 年李一男卸任百度 CTO 之后,百度对这一职位就再无公开任命,一空就是 10 年。而今天上午李彦宏突然发出的一纸职位调令,让这个空缺多年的百度 CTO 之位有了新的掌舵手。就…

Windows下卸载TensorFlow

激活tensorflow:activate tensorflow输入:pip uninstall tensorflowProceed(y/n):y如果是gpu版本: 激活tensorflow:activate tensorflow-gpu输入:pip uninstall tensorflow-gpuProceed&#xf…

大道至简第三章

大道至简读后感 这一章名为团队缺乏的不仅仅是管理,显而易见,这一章强调的就是作为一名软件工程实践者,团队协作的重要性。 这一章共分为八个小结,分别为三个人的团队,做项目 死亡游戏 ?,做 ISO 质量体系…

PyTorch框架学习三——张量操作

PyTorch框架学习三——张量操作一、拼接1.torch.cat()2.torch.stack()二、切分1.torch.chunk()2.torch.split()三、索引1.torch.index_select()2.torch.masked_select()四、变换1.torch.reshape()2.torch.transpace()3.torch.t()4.torch.squeeze()5.torch.unsqueeze()一、拼接 …

'chcp' 不是内部或外部命令,也不是可运行的程序

在cmd窗口中输入activate tensorflow时报错chcp 不是内部或外部命令,也不是可运行的程序 添加两个环境变量即可解决: 将Anaconda的安装地址添加到环境变量“PATH”,如果没有可以新建一个,我的安装地址是“D:\Anaconda”&#xf…

2019年全球企业人工智能发展现状分析报告

来源:199IT互联网数据中心《悬而未决的AI竞赛——全球企业人工智能发展现状》由德勤洞察发布,德勤中国科技、传媒和电信行业编译。为了解全球范围内的企业在应用人工智能技术方面的情况以及所取得的成效,德勤于2018年第三季度针对早期人工智能…

qt调动DLL

void func(void); // dll库中的函数 typedef void (*PFUNC)(void); 方法一&#xff1a; HMODULE g_hAPIDLL NULL; wchar_t tcDLLPath[100] L"D:\\name.dll"; g_hAPIDLL ::LoadLibrary(tcDLLPath); if (NULL g_hAPIDLL) { qDebug() << "load library f…

PyTorch框架学习四——计算图与动态图机制

PyTorch框架学习四——计算图与动态图机制一、计算图二、动态图与静态图三、torch.autograd1.torch.autograd.backward()2.torch.autograd.grad()3.autograd小贴士4.代码演示理解&#xff08;1&#xff09;构建计算图并反向求导&#xff1a;&#xff08;2&#xff09;grad_tens…

ipynb文件转为python(.py)文件

在Anaconda中的jupyter打开该ipynb文件&#xff0c;然后依次点击File—>Download as—>python(.py)

美国准备跳过5G直接到6G 用上万颗卫星包裹全球,靠谱吗?

来源&#xff1a;瞭望智库这项2015年提出的计划&#xff0c;规模极其巨大&#xff0c;总计要在2025年前发射近12000颗卫星。有自媒体认为&#xff0c;该计划表示美国将在太空中建立下一代宽带网络&#xff0c;绕过5G&#xff0c;直接升级到6G&#xff0c;并据此认为“6G并不遥远…

8月读书分享-《执行力是训练出来的》

写在最开头的是&#xff0c;没有拿到这本书之前其实我是很期待的&#xff0c;因为我觉得执行力是我所很需要的东西。但是拿到书之后就有一些失望了&#xff0c;因为我发现他的章节实在是太多了&#xff0c;我总觉得如果章节太多会不会其实是作者的归纳整理能力不太好呢&#xf…

PyTorch框架学习五——图像预处理transforms(一)

PyTorch框架学习五——图像预处理transforms&#xff08;一&#xff09;一、transforms运行机制二、transforms的具体方法1.裁剪&#xff08;1&#xff09;随机裁剪&#xff1a;transforms.RandomCrop()&#xff08;2&#xff09;中心裁剪&#xff1a;transforms.CenterCrop()&…