TenorFlow多层感知机识别手写体

文章目录

  • 数据准备
  • 建立模型
          • 建立输入层 x
          • 建立隐藏层h1
          • 建立隐藏层h2
          • 建立输出层
  • 定义训练方式
          • 建立训练数据label真实值 placeholder
          • 定义loss function
          • 选择optimizer
  • 定义评估模型的准确率
          • 计算每一项数据是否正确预测
          • 将计算预测正确结果,加总平均
  • 开始训练
          • 画出误差执行结果
          • 画出准确率执行结果
  • 评估模型的准确率
  • 进行预测
  • 找出预测错误

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.from ._conv import register_converters as _register_convertersWARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为NoneW = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重Wb = tf.Variable(tf.random_normal([1, output_dim]))XWb = tf.matmul(inputs, W) + bif activation is None:outputs = XWbelse:outputs = activation(XWb)return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,inputs=x ,activation=tf.nn.relu)  
建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict , labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \.minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):#执行15个训练周期#每个训练周期执行550批次训练for i in range(totalBatchs):batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})#使用验证数据计算准确率loss,acc = sess.run([loss_function,accuracy],feed_dict={x: mnist.validation.images, #验证数据的featuresy_label: mnist.validation.labels})#验证数据的labelepoch_list.append(epoch)loss_list.append(loss);accuracy_list.append(acc)    print("Train Epoch:", '%02d' % (epoch+1), \"Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):fig = plt.gcf()fig.set_size_inches(12, 14)if num>25: num=25 for i in range(0, num):ax=plt.subplot(5,5, 1+i)ax.imshow(np.reshape(images[idx],(28, 28)), cmap='binary')title= "label=" +str(np.argmax(labels[idx]))if len(prediction)>0:title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10) ax.set_xticks([]);ax.set_yticks([])        idx+=1 plt.show()
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

y_predict_Onehot=sess.run(y_predict,feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],dtype=float32)

找出预测错误

for i in range(400):if prediction_result[i]!=np.argmax(mnist.test.labels[i]):print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),"predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/685555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

倒计时52天(待续,,,

寒假学习内容总复习上&#xff1a; 倒计时67天-CSDN博客 1. #include<bits/stdc.h> using namespace std; #define int long long const int N2e56; const int inf0x3f3f3f3f; int month[13]{0,31,28,31,30,31,30,31,31,30,31,30,31}; int a[110]; void solve() {for(…

【超级干货】ArcGIS_空间连接_工具详解

帮助里对空间连接的解释&#xff1a; 根据空间关系将一个要素的属性连接到另一个要素。 目标要素和来自连接要素的被连接属性写入到输出要素类。 如上图所示&#xff0c;关键在于空间关系&#xff0c;只有当两个要素存在空间关系的时候&#xff0c;空间连接才有用武之地。 一…

JavaScript_00001_00000

contents 简介变量与数据类型自动类型转换强制类型转换 简介 变量与数据类型 根据变量定义的范围不同&#xff0c;变量有全局变量和局部变量之分。直接定义的变量是全局变量&#xff0c;全局变量可以被所有的脚本访问&#xff1b;在函数里定义的变量称为局部变量&#xff0c;…

【leetcode热题】对称二叉树

难度&#xff1a; 简单通过率&#xff1a; 42.2%题目链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目描述 给定一个二叉树&#xff0c;检查它是否是镜像对称的。 例如&#xff0c;二叉树 [1,2,2,3,4,4,3] 是对称的。 1/ \2 …

JavaWeb之Servlet接口

Servlet接口 什么是Servlet&#xff1f; Servlet是一种基于Java技术的Web组件&#xff0c;用于生成动态内容&#xff0c;由容器管理&#xff0c;是平台无关的Java类组成&#xff0c;并且由Java Web服务器加载执行&#xff0c;是Web容器的最基本组成单元 什么是Servlet容器&…

杂文随笔_

己写于亥年正月廿六&#xff0c;校图书馆&#xff0c;天气阴 “你的工作将会是你生活中很大一部分&#xff0c;唯一能使自己得到真正满足的是&#xff0c;做你伟大的工作,做一份伟大的工作的唯一方法是&#xff0c;热爱你所做的工作。“这是乔布斯在斯坦福大学的一次演讲所说的…

【c++】list详细讲解

> 作者简介&#xff1a;დ旧言~&#xff0c;目前大二&#xff0c;现在学习Java&#xff0c;c&#xff0c;c&#xff0c;Python等 > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟悉list库 > 毒鸡汤&#xff1a;你的脸上云淡…

外部中断0实验

实现现象&#xff1a;下载程序后&#xff0c;操作K3按键使LED1&#xff08;D11&#xff09;状态取反 注意事项&#xff1a;无。 #include "reg52.h" //此文件中定义了单片机的一些特殊功能寄存器typedef unsigned int u16; //对数据类型进行声明定义 typed…

定时器1中断实验

实现现象&#xff1a;下载程序后&#xff0c;静态数码管间隔一秒循环显示0-F。使用单片机内部定时器可以实现准确延时。 注意事项&#xff1a; 程序代码&#xff1a; #include "reg52.h" //此文件中定义了单片机的一些特殊功能寄存器typedef unsigned int u16; …

Java环境变量

&#xff08;1&#xff09;classpath:让jvm找到将要执行的java文件的目录 它通常包括程序当前目录和Java标准库的路径. 编译后&#xff0c;classpath的当前目录就是target下的classes目录&#xff0c;它是resource和main目录的合并&#xff0c;如果两个类目录相同&#xff0c;那…

关于Build Your Own Botnet的尝试

这是一次失败的尝试、 原文地址&#xff1a;关于Build Your Own Botnet的尝试 - Pleasure的博客 下面是正文内容&#xff1a; 前言 我在上一篇关于DDOS的文章种提到过这个项目&#xff0c;而且说明了由于这个项目是在2020年发布并开源的&#xff0c;并且已经有两年没有进行跟…

身份治理存在权限问题

身份治理正迅速成为 CISO 的首要考虑因素。二十年前&#xff0c;当萨班斯-奥克斯利法案(SoX) 和其他监管指令在互联网泡沫破灭后诞生时&#xff0c;身份治理要求就出现了。合规性控制&#xff0c;例如用户访问审查和有效管理员工访问生命周期的需要&#xff0c;是当时身份治理的…

1.2.1 相机模型—内参、外参

相机模型-内参、外参 更多内容&#xff0c;请关注&#xff1a; github&#xff1a;https://github.com/gotonote/Autopilot-Notes.git&#xff09; 针孔相机模型&#xff0c;包含四个坐标系&#xff1a;物理成像坐标系、像素坐标系、相机坐标系、世界坐标系。 相机参数包含&…

typescript类型详解

因为介绍了ts的全部类型,所以比较长,各位可以通过目录选择性观看 typescript类型概述typescript 类型注解概念-->监测类型变化 ts类型注解语法ts常用类型原始类型对象类型对象类型_数组类型 ts新增,联合类型ts函数类型ts 函数类型 voidts 函数类型可选参数 ts 对象类型ts 可…

The method toList() is undefined for the type Stream

The method toList() is undefined for the type Stream &#xff08;JDK16&#xff09; default List<T> toList() { return (List<T>) Collections.unmodifiableList(new ArrayList<>(Arrays.asList(this.toArray()))); }

Leetcode 503. 下一个更大元素 II

题意理解&#xff1a; 给定一个循环数组 nums &#xff08; nums[nums.length - 1] 的下一个元素是 nums[0] &#xff09;&#xff0c;返回 nums 中每个元素的 下一个更大元素 。 数字 x 的 下一个更大的元素 是按数组遍历顺序&#xff0c;这个数字之后的第一个比它更大的数&am…

C#系列-EF扩展框架Serilog.EntityFrameworkCore应用实例(39)

Serilog.EntityFrameworkCore 并不是一个官方或广泛认可的 NuGet 包。Serilog 是一个流行的日志记录库&#xff0c;它支持多种日志接收器&#xff08;sinks&#xff09;来将日志输出到不同的目的地&#xff0c;如文件、控制台、数据库等。但是&#xff0c;Serilog.EntityFramew…

作物模型狂奔:WOFOST(PCSE) 数据同化思路

去B吧&#xff0c;这里没图 整体思路&#xff1a;PCSE -》 敏感性分析 -》调参 -》同化 0、准备工作 0.0 电脑环境 我用的Win10啦&#xff0c;Linux、Mac可能得自己再去微调一下。 0.1 Python IDE 我用的Pycharm&#xff0c;个人感觉最好使的IDE&#xff0c;没有之一。 …

C#系列-EF框架的创新应用+利用EF框架技术的知名开源应用项目(42)

EF框架的创新应用 EF框架&#xff0c;即Entity Framework&#xff0c;是微软开发的一个开源的对象关系映射&#xff08;ORM&#xff09;框架&#xff0c;用于.NET应用程序中。它允许开发者以面向对象的方式处理数据库&#xff0c;而无需关心底层的SQL语句和数据库结构。 EF框架…

OpenAI Sora 初体验

OpenAI Sora 初体验 就在刚刚&#xff0c;OpenAI 再次投下一枚重磅炸弹——Sora&#xff0c;一个文本到视频生成模型。 我第一时间体验了 Sora。看过 Sora 的能力后&#xff0c;我真的印象深刻。对细节的关注、无缝的角色刻画以及生成视频的绝对质量真正将可能性提升到了一个新…