【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

一、前述

本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类。

同时对模型的保存和恢复做下示例。

二、具体原理

代码一:实现代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 文件名: 12_Softmax_regression.pyfrom tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)#从本地路径加载进来# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)#训练集图片
# 10,000 points of test data (mnist.test), and#测试集图片
# 5,000 points of validation data (mnist.validation).#验证集图片# Each image is 28 pixels by 28 pixels# 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) #x矩阵是m行*784列
# 初始化都是0,二维矩阵784乘以10个W值 #初始值最好不为0
W = tf.Variable(tf.zeros([784, 10]))#W矩阵是784行*10列
b = tf.Variable(tf.zeros([10]))#bias也必须有10个

y = tf.nn.softmax(tf.matmul(x, W) + b)# x*w 即为m行10列的矩阵就是y #预测值# 训练
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))#真实值 m行10列
# 定义损失函数,交叉熵损失函数
# 对于多分类问题,通常使用交叉熵损失函数
# reduction_indices等价于axis,指明按照每行加,还是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))#指明按照列加和 一列是一个类别
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#将损失函数梯度下降 #0.5是学习率# 初始化变量
sess = tf.InteractiveSession()#初始化Session
tf.global_variables_initializer().run()#初始化所有变量
for _ in range(1000):batch_xs, batch_ys = my_mnist.train.next_batch(100)#每次迭代取100行数据sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#每次迭代内部就是求梯度,然后更新参数
# 评估# tf.argmax()是一个从tensor中寻找最大值的序号 就是分类号,tf.argmax就是求各个预测的数字中概率最大的那一个
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 测试
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))# 总结
# 1,定义算法公式,也就是神经网络forward时的计算
# 2,定义loss,选定优化器,并指定优化器优化loss
# 3,迭代地对数据进行训练
# 4,在测试集或验证集上对准确率进行评测

代码二:保存模型

# 有时候需要把模型保持起来,有时候需要做一些checkpoint在训练中
# 以致于如果计算机宕机,我们还可以从之前checkpoint的位置去继续
# TensorFlow使得我们去保存和加载模型非常方便,仅需要去创建Saver节点在构建阶段最后
# 然后在计算阶段去调用save()方法from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).# Each image is 28 pixels by 28 pixels# 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二维矩阵784乘以10个W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)# 训练
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
# 定义损失函数,交叉熵损失函数
# 对于多分类问题,通常使用交叉熵损失函数
# reduction_indices等价于axis,指明按照每行加,还是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)# 初始化变量
init = tf.global_variables_initializer()
# 创建Saver()节点
saver = tf.train.Saver()#在运算之前,初始化之后

n_epoch = 1000with tf.Session() as sess:sess.run(init)for epoch in range(n_epoch):if epoch % 100 == 0:save_path = saver.save(sess, "./my_model.ckpt")#每跑100次save一次模型,可以保证容错性#直接保存session即可。
batch_xs, batch_ys = my_mnist.train.next_batch(100)#每一批次跑的数据 用m行数据/迭代次数来计算出来。sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})best_theta = W.eval()save_path = saver.save(sess, "./my_model_final.ckpt")#保存最后的模型,session实际上保存的上面所有的数据

代码三:恢复模型

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).# Each image is 28 pixels by 28 pixels# 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
# 所以输入的矩阵是None乘以784二维矩阵
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二维矩阵784乘以10个W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)
# labels是每张图片都对应一个one-hot的10个值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))saver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess, "./my_model_final.ckpt")#把路径下面所有的session的数据加载进来 y y_head还有模型都保存下来了。# 评估# tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 测试print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))

转载于:https://www.cnblogs.com/LHWorldBlog/p/8661434.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/278228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web页面锁屏初级尝试

因为工作需要&#xff0c;所以在网上找了一些素材来弄这个功能。在我找到的素材中&#xff0c;大多都是不完善的。虽然我的也不是很完善&#xff0c;但是怎么说呢。要求不是很高的话。可以直接拿来用的【需要引用jQuery】。废话不多说直接上代码 这部分是js代码 1 <script&g…

Java 并发工具箱之concurrent包

概述 java.util.concurrent 包是专为 Java并发编程而设计的包。包下的所有类可以分为如下几大类&#xff1a; locks部分&#xff1a;显式锁(互斥锁和速写锁)相关&#xff1b;atomic部分&#xff1a;原子变量类相关&#xff0c;是构建非阻塞算法的基础&#xff1b;executor部分&…

如何提高gps精度_如何在锻炼应用程序中提高GPS跟踪精度

如何提高gps精度l i g h t p o e t/Shutterstocklightpoet /快门Tracking your runs, bike rides, and other workouts is fun because you can see how much you’re improving (or, in my case, dismally failing to improve). For it to be effective, though, you have to …

centos proftp_在CentOS上禁用ProFTP

centos proftpI realize this is probably only relevant to about 3 of the readers, but I’m posting this so I don’t forget how to do it myself! In my efforts to ban the completely insecure FTP protocol from my life entirely, I’ve decided to disable the FTP…

Java通过Executors提供四种线程池

http://cuisuqiang.iteye.com/blog/2019372 Java通过Executors提供四种线程池&#xff0c;分别为&#xff1a;newCachedThreadPool创建一个可缓存线程池&#xff0c;如果线程池长度超过处理需要&#xff0c;可灵活回收空闲线程&#xff0c;若无可回收&#xff0c;则新建线程。n…

一个在线编写前端代码的好玩的工具

https://codesandbox.io/ 可以编写 Angular&#xff0c;React&#xff0c;Vue 等前端代码。 可以实时编辑和 preview。 live 功能&#xff0c;可以多人协作编辑&#xff0c;不过是收费的功能。 可以增加依赖的包&#xff0c;比如编写 React 时&#xff0c;可以安装任意的第三…

MySQL数据库基础(五)——SQL查询

MySQL数据库基础&#xff08;五&#xff09;——SQL查询 一、单表查询 1、查询所有字段 在SELECT语句中使用星号“”通配符查询所有字段在SELECT语句中指定所有字段select from TStudent; 2、查询指定字段 查询多个字段select Sname,sex,email from TStudent; 3、查询指定记录…

使用生成器创建新的迭代模式

一个函数中需要有一个 yield 语句即可将其转换为一个生成器。 def frange(start, stop, increment):x startwhile x < stop:yield xx incrementfor i in frange(0, 4, 2):print(i) # 0 2 一个生成器函数主要特征是它只会回应在迭代中使用到的 next 操作 def cutdata(n):p…

前端异常捕获与上报

在一般情况下我们代码报错啥的都会觉得 下图 然后现在来说下经常用的异常 1.try catch 这个是比较常见的异常捕获方式通常都是 使用try catch能够很好的捕获异常并对应进行相应处理&#xff0c;不至于让页面挂掉&#xff0c;但是其存在一些弊端&#xff0c;比如需要在捕获异常的…

Codeforces 924D Contact ATC (看题解)

Contact ATC 我跑去列方程&#xff0c; 然后就gg了。。。 我们计每个飞机最早到达时间为L[ i ], 最晚到达时间为R[ i ]&#xff0c; 对于面对面飞行的一对飞机&#xff0c; 只要他们的时间有交集则必定满足条件。 对于相同方向飞行的飞机&#xff0c; 只有其中一个的时间包含另…

基于ZXing Android实现生成二维码图片和相机扫描二维码图片即时解码的功能

NextQRCode ZXing开源库的精简版 **基于ZXing Android实现生成二维码图片和相机扫描二维码图片即时解码的功能原文博客 附源码下载地址** 与原ZXingMini项目对比 NextQRCode做了重大架构修改&#xff0c;原ZXingMini项目与当前NextQRCode不兼容 dependencies {compile com.gith…

flask sqlalchemy 单表查询

主要内容: 1 sqlalchemy: 一个python的ORM框架 2 使用sqlalchemy 的流程: 创建一个类 创建数据库引擎 将所有的类序列化成数据表 进行增删改查操作 # 1.创建一个 Class from sqlalchemy.ext.declarative import declarative_base Base declarative_base() # Base 是 ORM模型 基…

如何在Windows 7或Vista上安装IIS

If you are a developer using ASP.NET, one of the first things you’ll want to install on Windows 7 or Vista is IIS (internet information server). Keep in mind that your version of Windows may not come with IIS. I’m using Windows 7 Ultimate edition. 如果您…

Dubbo的使用及原理浅析

https://www.cnblogs.com/wang-meng/p/5791598.html转载于:https://www.cnblogs.com/h-wt/p/10490345.html

ThinkPHP3.2 实现阿里云OSS上传文件

为什么80%的码农都做不了架构师&#xff1f;>>> 0、配置文件Config&#xff0c;加入OSS配置选项&#xff0c;设置php.ini最大上传大小&#xff08;自行解决&#xff0c;这里不做演示&#xff09; OSS > array(ACCESS_KEY_ID > **************, //从OSS获得的…

ipad和iphone切图_如何在iPhone,iPad和Mac上签名PDF

ipad和iphone切图Khamosh PathakKhamosh PathakDo you have documents to sign? You don’t need to worry about printing, scanning, or even downloading a third-party app. You can sign PDFs right on your iPhone, iPad, and Mac. 你有文件要签名吗&#xff1f; 您无需…

一个页面上有大量的图片(大型电商网站),加载很慢,你有哪些方法优化这些图片的加载,给用户更好的体验。...

a. 图片懒加载&#xff0c;滚动到相应位置才加载图片。 b. 图片预加载&#xff0c;如果为幻灯片、相册等&#xff0c;将当前展示图片的前一张和后一张优先下载。 c. 使用CSSsprite&#xff0c;SVGsprite&#xff0c;Iconfont、Base64等技术&#xff0c;如果图片为css图片的话。…

[function.require]: Failed opening required 杰奇cms

在配置杰奇cms移动端的时候&#xff0c;出现了[function.require]: Failed opening required 不要慌&#xff0c;百度一下即可解决。这个就是权限问题。由于移动端要请求pc端的文件&#xff0c;没权限。加上一个iis_iusrs读写权限即可搞定&#xff01;转载于:https://www.cnblo…

在Ubuntu服务器上打开第二个控制台会话

Ubuntu Server has the native ability to run multiple console sessions from the server console prompt. If you are working on the actual console and are waiting for a long running command to finish, there’s no reason why you have to sit and wait… you can j…

Cloudstack系统配置(三)

系统配置 CloudStack提供一个基于web的UI&#xff0c;管理员和终端用户能够使用这个界面。用户界面版本依赖于登陆时使用的凭证不同而不同。用户界面是适用于大多数流行的浏览器包括IE7,IE8,IE9,Firefox Chrome等。URL是:(用你自己的管理控制服务器IP地址代替) 1http://<ma…