TensorFlow实现单隐层神经网络

这里使用MNIST数据集,MNIST数据集的下载地址http://yann.lecun.com/exdb/mnist/

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets(r"D:\PycharmProjects\tensorflow\MNIST_data", one_hot=True)# 创建默认的InteractiveSession,这样后面执行的各项操作就无需指定Session了
sess = tf.InteractiveSession()in_units = 784 # 输入节点数
h1_units = 300 # 隐含单元数
W1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
W2 = tf.Variable(tf.zeros([h1_units, 10])) # 因为是识别数字,输出单元数为10
b2 = tf.Variable(tf.zeros([10]))x = tf.placeholder(tf.float32, [None, in_units])
keep_prob = tf.placeholder(tf.float32)hidden1 = tf.nn.relu(tf.matmul(x, W1) + b1) # 使用relu激活函数可以解决梯度弥散
hidden1_drop = tf.nn.dropout(hidden1, keep_prob) # 使用Dropout方法,keep_prob是保留节点的概率
y = tf.nn.softmax(tf.matmul(hidden1_drop, W2) + b2) # 预测的标签y_ = tf.placeholder(tf.float32, [None, 10]) # 正确的标签
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) # 损失函数
train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy) # 学习率是0.3# 训练
tf.global_variables_initializer().run()
# 一共采用3000个batch,每个batch100个样本,一共300000个样本
# 一个数据集55000个样本,相当于对全数据集进行5轮(epoch)迭代
for i in range(3000):batch_xs, batch_ys = mnist.train.next_batch(100)train_step.run({x: batch_xs, y_: batch_ys, keep_prob: 0.75})# 测试
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 预测时keep_prob应该等于1,即使用全部特征来预测样本的类别
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491769.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美国一箭投放60颗卫星 马斯克组互联网“星链”

来源:新华网美国太空探索公司当地时间23日晚在美国佛罗里达州一处空军基地发射火箭,将60颗小卫星送入近地轨道。这标志着企业家埃隆马斯克组建互联网卫星群的“星链”项目迈出实质性一步,抢在电子商务巨头亚马逊公司创始人杰夫贝索斯的“柯伊…

Ubuntu 中Mysql 操作

一、mysql服务操作 0、查看数据库版本 sql-> status; 1、net start mysql //启动mysql服务 2、net stop mysql //停止mysql服务  3、mysql -h主机地址 -u用户名 -p用户密码 //进入mysql数据库 4、quit //退出mysql操作 5、mysqladmin -u用户名 -p旧密码 passwor…

吴恩达《机器学习》学习笔记五——逻辑回归

吴恩达《机器学习》学习笔记五——逻辑回归一、 分类(classification)1.定义2.阈值二、 逻辑(logistic)回归假设函数1.假设的表达式2.假设表达式的意义3.决策界限三、 代价函数1.平方误差函数的问题2.logistic回归的代价函数四、梯…

协方差与相关系数

定义: 协方差用于衡量两个变量的总体误差。而方差是协方差的一种特殊情况,即当两个变量是相同的情况。 期望值分别为E[X]与E[Y]的两个实随机变量X与Y之间的协方差Cov(X,Y)定义为: 如果两个变量的变化趋势一致,也就是说如果其中一…

吴恩达《机器学习》学习笔记六——过拟合与正则化

吴恩达《机器学习》学习笔记六——过拟合与正则化一、 过拟合问题1.线性回归过拟合问题2.逻辑回归过拟合问题3.过拟合的解决二、 正则化后的代价函数1.正则化思想2.实际使用的正则化三、 正则化的线性回归1.梯度下降的情况2.正规方程的情况四、 正则化的逻辑回归1.梯度下降的情…

Swift - 数组排序方法(附样例)

下面通过一个样例演示如何对数组元素进行排序。数组内为自定义用户对象,最终要实现按用户名排序,数据如下: 1234var userList [UserInfo]()userList.append(UserInfo(name: "张三", phone: "4234"))userList.append(Use…

5G时代,智能工厂迎来4大改变!

来源:亿欧网作为新一代移动通信技术,5G技术切合了传统制造企业智能制造转型对无线网络的应用需求,能满足工业环境下设备互联和远程交互应用需求。在物联网、工业自动化控制、物流追踪、工业AR、云化机器人等工业应用领域,5G技术起…

主成分分析PCA以及特征值和特征向量的意义

定义: 主成分分析(Principal Component Analysis,PCA), 是一种统计方法。通过正交变换将一组可能存在相关性的变量转换为一组线性不相关的变量,转换后的这组变量叫主成分。PCA的思想是将n维特征映射到k维上…

吴恩达《机器学习》学习笔记七——逻辑回归(二分类)代码

吴恩达《机器学习》学习笔记七——逻辑回归(二分类)代码一、无正则项的逻辑回归1.问题描述2.导入模块3.准备数据4.假设函数5.代价函数6.梯度下降7.拟合参数8.用训练集预测和验证9.寻找决策边界二、正则化逻辑回归1.准备数据2.特征映射3.正则化代价函数4.…

需要自己调研的框架,以及需要学习的内容

_webView 图文解析需要学习 Core Text 来解决图文混合的xml文件 项目中用到哪些数据持久化,什么场景下使用? http://www.cocoachina.com/ios/20150720/12610.html 解决问题的网站 CSDN.NETCocoaChina51CTO技术论坛特酷吧国外it技术论坛stack overflow 需…

Python获得一篇文档的不重复词列表并创建词向量

获得一篇文档的不重复词列表: def loadDataSet():postingList [[my, dog, has, flea, problems, help, please],[maybe, not, take, him, to, dog, park, stupid],[my, dalmation, is, so, cute, I, love, him],[stop, posting, stupid, worthless, garbage],[mr,…

从认知学到进化论,详述强化学习两大最新突破

来源:大数据文摘深层强化学习(deep RL)近年来在人工智能方面取得了令人瞩目的进步,在Atari游戏、围棋及无限制扑克等领域战胜了人类。通过将表征学习与奖励驱动行为相结合,深层强化学习又引发了心理学和神经科学领域的…

Python实现一个数组除以一个数

如果直接用python的一个list除以一个数,会报错: a [1.0, 1.0, 1.0] c a/3 print(c) TypeError: unsupported operand type(s) for /: list and int 使用Numpy可以轻松做到: import numpy as npa np.array([1,1,1]) c a/3 print(c)

吴恩达《机器学习》学习笔记九——神经网络相关(1)

吴恩达《机器学习》学习笔记九——神经网络相关(1)一、 非线性假设的问题二、 神经网络相关知识1.神经网络的大致历史2.神经网络的表示3.前向传播:向量化表示三、 例子与直觉理解1.问题描述:异或XOR、同或XNOR2.单个神经元如何计算…

刚刚,科学家发现了一大堆解释人类进化的基因...

图片来源:《Nature Genetics》来源:中国生物技术网 5月27日发表在《Nature Genetics》上的一项新研究发现, 以前被认为在不同生物体中具有相似作用的数十种基因,实际上是人类独有的, 这或许有助于解释我们这个物种是如…

Android数据存储——SQLite数据库(模板)

本篇整合Android使用数据库,要保存一个实体类的样本。 首先看一下数据库语句: ORM:关系对象映射 添加数据: ContentValues values new ContentValues();values.put("name", "小丽");values.put("phone", &qu…

Python切分文本(将文本文档切分为词列表)

对于一个句子,一种简单的方法是使用split() a This is an apple. Do you like apple? b a.split() print(b) # [This, is, an, apple., Do, you, like, apple?] 可以看到切分结果不错,但标点符号也当成了词的一部分,可以使用正则表达式…

吴恩达《机器学习》学习笔记八——逻辑回归(多分类)代码

吴恩达《机器学习》笔记八——逻辑回归(多分类)代码导入模块及加载数据sigmoid函数与假设函数代价函数梯度下降一对多分类预测验证课程链接:https://www.bilibili.com/video/BV164411b7dx?fromsearch&seid5329376196520099118 之前笔记…

DeepMind 综述深度强化学习:智能体和人类相似度竟然如此高!

来源:AI科技评论近年来,深度强化学习(Deep reinforcement learning)方法在人工智能方面取得了瞩目的成就,从 Atari 游戏、到围棋、再到无限制扑克等领域,AI 的表现都大大超越了专业选手,这一进展…

Python随机选择一部分训练样本作为测试样本

假设训练样本有30个,从训练样本中随机获得10个作为测试样本,剩下20个继续作为训练样本 import numpy as nptrainingSet list(range(30)) # 训练样本下标 testSet [] for i in range(10):randIndex int(np.random.uniform(0, len(training…