08-KNN手写数字识别

标签下载地址

文件内容备注
train-images-idx3-ubyte.gz训练集图片:55000张训练图片,5000张验证图片
train-labels-idx1-ubyte.gz训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz测试集图片:10000张图片t表示test,测试图片,10k表示10*1000一共一万张图片
t10k-labels-idx1-ubyte.gz测试集图片对应的数字标签

对于每一个样本都有一个对应的标签进行唯一的标识,故为一个监督学习
操作的每个图片必须是灰度图(单通道0是白色,1是黑色)
对于标签5401
在这里插入图片描述
标签中的4,并不是存储4这个数字,而是存储十位(0-9),第五行为黑色,则为1,即0000100000,因为1所处于第5个,即描述为:4
在这里插入图片描述

KNN最近邻域法

KNN的根本原理:一张待检测的图片,与相应的样本进行比较,如果在样本图片中存在K个与待检测图片相类似的图片,那么就会把当前这K个图片记录下来。再在这K个图片中找到相似性最大的(例如10个图片中有8个描述的当前数字都是1,那么这个图片检测出来的就是1)

装载图片:
input_data.read_data_sets('MNIST_data',one_hot=True)
参数一:当前文件夹的名称
参数二:one_hot是个布尔类型,one_hot中有一个为1,其余都为0

随机获取训练数组的下标:
np.random.choice(trainNum,trainSize,replace=False)
参数一:随机值的范围
参数二:生成trainSize这么多个随机数
参数三:是否可以重复
在0-trainNum之间随机选取trainSize这么多个随机数,且不可重复

import tensorflow as tf
import numpy as np
import random 
from tensorflow.examples.tutorials.mnist import input_data
# load data 2 one_hot : 1 0000 1 fileName 
mnist = input_data.read_data_sets('E:\\Jupyter_workspace\\study\\DL\\MNIST_data',one_hot=True)#完成数据的装载,将装载的图片放入mnist中
# 属性设置
trainNum = 55000#总共需要训练多少张图片
testNum = 10000#测试图片
trainSize = 500#训练是需要多少张图片
testSize = 5#测试多少张图片
k = 4#从训练样本中找到K个与测试图片相近的图片,并且统计这K个图片中类别最多的几,并且把这个数作为最终的结果
# data 分解 1 trainSize   2范围0-trainNum 3 replace=False #数据的分解
#这里使用的是随机获取测试图片和训练图片的下标,故每次运行的结果都会不一样
trainIndex = np.random.choice(trainNum,trainSize,replace=False)#随机获取训练数组的下标
testIndex = np.random.choice(testNum,testSize,replace=False)#随机获取测试图片的标签下标
trainData = mnist.train.images[trainIndex]# 获取训练图片
trainLabel = mnist.train.labels[trainIndex]# 获取训练标签
testData = mnist.test.images[testIndex]# 获取测试的数据
testLabel = mnist.test.labels[testIndex]
print('trainData.shape=',trainData.shape)#训练数据的维度 500*784  500表示图片个数  图片的宽高为28*28 = 784,即图片上有784个像素点
print('trainLabel.shape=',trainLabel.shape)#训练标签的维度 500*10
print('testData.shape=',testData.shape)#测试数据的维度 5*784
print('testLabel.shape=',testLabel.shape)#测试标签的维度 5*10
print('testLabel=',testLabel)
#testLabel是个五行十列的数据,在标签中,所有的数据都放在数组中进行表示
'''
testLabel= [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]    3--->testData [0][0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]               1--->testData [1][0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]               9--->testData [2][0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]               6--->testData [3][0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]              4--->testData [4]
'''# tf input  784->image
trainDataInput = tf.placeholder(shape=[None,784],dtype=tf.float32)#定义训练的数组,784列的数据表示一张完整的图片,前面的行表示图片的个数这里用None表示
trainLabelInput = tf.placeholder(shape=[None,10],dtype=tf.float32)#列为10,因为每个数字都是10维的
testDataInput = tf.placeholder(shape=[None,784],dtype=tf.float32)#定义测试数据
testLabelInput = tf.placeholder(shape=[None,10],dtype=tf.float32)#定义测试标签#KNN的距离公式:
#knn distance 5*785.  5*1*784
# 5 500 784 (3D) 2500*784#计算trainData测试图片和trainData训练图片的距离之差,测试图片有5张,训练图片有500张,每个维度都是784维,故最后计算的结果为一个三维数据,(测试数据,训练数据,二者之差),会产生5*500*784个数据,故需要扩展testDataInput的维度f1 = tf.expand_dims(testDataInput,1) # 完成当前的维度转换,原本的testDataInput是一个5*785,经过维度转换则成为5*1*784  维度扩展
f2 = tf.subtract(trainDataInput,f1)# 完成测试图片与训练图片二者之差,得到的结果放入784维中,可以通过sum将这784维的差异累加到一块,即sum(784)
f3 = tf.reduce_sum(tf.abs(f2),reduction_indices=2)# 所有的数据都装载到f2中,因为有的距离是负数,需要取绝对值;设置在第二个维度上进行累加 即:完成数据累加取绝对值之后的784个像素点之间的差异 
#所有的差异距离都放入在放f3中,是个5*500数组f4 = tf.negative(f3)# 取反
f5,f6 = tf.nn.top_k(f4,k=4) # 选取f4中所有元素最大的四个值,因为f4是f3的取反,故选取f3中最小的四个数值
#f5为f3中最小的数,f6为这个最下的数所对应的下标# f6 index->trainLabelInput
#f6存储的是最近的图片的下标,通过这些下标作为索引去获取图片的标签
f7 = tf.gather(trainLabelInput,f6)#根据f6的下标来# f8 f9都是表示数字的获取# f8 num reduce_sum  reduction_indices=1 '竖直'
f8 = tf.reduce_sum(f7,reduction_indices=1)#完成数字的累加,将f7这个三维通过竖直的方向进行累加# tf.argmax 选取f8中,某一个最大的值,并记录其所处的下标index
f9 = tf.argmax(f8,dimension=1)#
# f9为5张测试图片中最大的下标 test5 image -> 5 num
with tf.Session() as sess:# f1 <- testData 5张图片p1 = sess.run(f1,feed_dict={testDataInput:testData[0:testSize]})#运行f1并给其一个参数,这个参数是testData测试图片,testData中总共有5张图片,这5张图片维待检测的手写数字print('p1=',p1.shape)# p1= (5, 1, 784) 每个图片必须用784维来表示p2 = sess.run(f2,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize]})#运行f2 表示训练数据和测试二者对应数据做差print('p2=',p2.shape)#p2= (5, 500, 784) 例如:(1,100)表示第2张测试图片和第101张训练图片所有的像素对应做差都放入784中,784都为具体的值,故需要对784进行累加  p3 = sess.run(f3,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize]})#print('p3=',p3.shape)#p3= (5, 500)表示(测试图片是哪一张,训练图片是哪一张)print('p3[0,0]=',p3[0,0]) #130.451表示第1张测试图片和第1张训练图片的距离差   knn distance p3[0,0]= 155.812p4 = sess.run(f4,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize]})print('p4=',p4.shape)print('p4[0,0]',p4[0,0])p5,p6 = sess.run((f5,f6),feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize]})#p5= (5, 4) 每一张测试图片(5张)分别对应4张最近训练图片#p6= (5, 4)print('p5=',p5.shape)print('p6=',p6.shape)print('p5[0,0]',p5[0])# 第1张测试图片分别对应4张最近训练图片的值print('p6[0,0]',p6[0])# 第1张测试图片分别对应4张最近训练图片的下标p7 = sess.run(f7,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize],trainLabelInput:trainLabel})print('p7=',p7.shape)#p7= (5, 4, 10)表示5组4行10列print('p7[]',p7)#5组表示5个测试图片,4行每行表示一个最近的测试图片,每一行中又有10个元素,这10个元素分别对应10个lable标签p8 = sess.run(f8,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize],trainLabelInput:trainLabel})print('p8=',p8.shape)#p8=(5,10)print('p8[]=',p8)#5行10列,每一行为f7每一组所对应的竖直方向上的累加p9 = sess.run(f9,feed_dict={trainDataInput:trainData,testDataInput:testData[0:testSize],trainLabelInput:trainLabel})print('p9=',p9.shape)#p9=(5,)是一个一维数组,5列print('p9[]=',p9)#每一个元素表示p8中最大值所对应的下标p10 = np.argmax(testLabel[0:testSize],axis=1)#最终标签中的内容,统计一下第2个维度上的标签print('p10[]=',p10)#若p9和p10的内容相同,则检测概率为100%j = 0
for i in range(0,5):if p10[i] == p9[i]:j = j+1
print('ac=',j*100/testSize)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378521.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MFC odbc访问远程数据库

首先&#xff0c;MFC通过ODBC访问数据库&#xff0c;主要使用两个类&#xff0c;一个是CDataBase&#xff0c;一个是CRecordset。第一个是用于建立数据库连接的&#xff0c;第二个是数据集&#xff0c;用来查询的。步骤如下&#xff1a;1.实例化一个CDataBase对象&#xff0c;并…

微机原理——扩展存储器设计

目录【1】存储器的层次结构【2】存储器的分类【3】SRAM1、基本原理&#xff1a;2、结构&#xff1a;3、芯片参数与引脚解读&#xff1a;4、CPU与SRAM的连接方式【4】DRAM1、基本原理&#xff1a;2、结构3、芯片引脚解读&#xff1a;【5】存储器系统设计【6】存储器扩展设计&…

floatvalue 重写_Java Number floatValue()方法与示例

floatvalue 重写Number类floatValue()方法 (Number Class floatValue() method) floatValue() method is available in java.lang package. floatValue()方法在java.lang包中可用。 floatValue() method is used to return the value denoted by this Number object converted …

array_column php什么版本可以用,array_column兼容php5.5以下版本

gistfile1.txt// ----------------------------------------------------------------------// |获取二维数组中指定的一列&#xff0c;PHP5.5以后有专用函数array_column()// ----------------------------------------------------------------------// |param array $arr// …

。net学习之控件的使用注意点

jQuery使用 1、自定义属性的使用<script>$(#xwjj_i_main br[brinfoPd_KangQiao_Subject_Xwjj_br_1]).hide();</script> 2、ready代码块$(document).ready(function(){ //你的代码}); 3、简单的特效hide&#xff08;&#xff09;$("a").click(function()…

09-CNN手写数字识别

CNN卷积神经网络的本质就是卷积运算 维度的调整&#xff1a; tf.reshape(imageInput,[-1,28,28,1]) imageInput为[None,784]&#xff0c;N行* 784维 调整为 M28行28列*1通道 即&#xff1a;二维转化为四维数据 参数一&#xff1a;等价于运算结果M 参数二&#xff1a;28 28 表示…

【转】左值与右值

出处&#xff1a;http://www.embedded.com/electronics-blogs/programming-pointers/4023341/Lvalues-and-Rvalues C and C enforce subtle differences on the expressions to the left and right of the assignment operator If youve been programming in either C or C for…

Opencv将处理后的视频保存出现的问题

问题描述&#xff1a; 代码运行过程中&#xff0c;imshow出来的每帧的效果图是正确的&#xff0c;但是按照网上的方法保存下来却是0kb&#xff0c;打开不了。 参考的网上的一些方法&#xff0c;均是失败的&#xff0c;具体原因我也不清楚&#xff1a; 1、例如我这样设置&#x…

Java Number shortValue()方法与示例

Number类shortValue()方法 (Number Class shortValue() method) shortValue() method is available in java.lang package. shortValue()方法在java.lang包中可用。 shortValue() method is used to return the value denoted by this Number object converted to type short (…

MATLAB可以打开gms文件吗,gms文件扩展名,gms文件怎么打开?

.gms文件类型&#xff1a;Gesture and Motion Signal File扩展名为.gms的文件是一个数据文件。文件说明&#xff1a;Low-level, binary, minimal but generic format used to organize and store Gesture and Motion Signals in a flexible and optimized way; gesture-related…

黑白图片颜色反转并保存

将图像的黑白颜色反转并保存 import cv2 # opencv读取图像 img cv2.imread(rE:\Python-workspace\OpenCV\OpenCV/YY.png, 1) cv2.imshow(img, img) img_shape img.shape # 图像大小(565, 650, 3) print(img_shape) h img_shape[0] w img_shape[1] # 彩色图像转换为灰度图…

家猫WEB系统

现在只放源码在些.为它写应用很简单有空整理文档演示地址:jiamaocode.com/os/ 源码&#xff1a;http://jiamaocode.com/ProCts/2011/04/14/1918/1918.html转载于:https://www.cnblogs.com/jiamao/archive/2011/04/16/2018339.html

C# DataRow数组转换为DataTable

public DataTable ToDataTable(DataRow[] rows) { if (rows null || rows.Length 0) return null; DataTable tmp rows[0].Table.Clone(); // 复制DataRow的表结构 foreach (DataRow row in rows) tmp.Rows.Add(row); // 将DataRow添加…

plesk 运行不了php,如何在Plesk中使用composer(使用其他版本的PHP运行Composer)

对于基于Plesk的服务器, composer的默认安装将使用系统安装的PHP版本, 而不使用Plesk所安装的任何版本。尽管Composer至少需要PHP 5.3.2, 但是当你尝试在需要特定版本PHP的项目中安装依赖项时, 就会出现问题。例如, 如果你有一个至少需要PHP 7.2的项目, 并且系统的默认PHP安装是…

Java Calendar hashCode()方法与示例

日历类hashCode()方法 (Calendar Class hashCode() method) hashCode() method is available in java.util package. hashCode()方法在java.util包中可用。 hashCode() method is used to retrieve the hash code value of this Calendar. hashCode()方法用于检索此Calendar的哈…

Error: Flash Download failed - Target DLL has been cancelled

博主联系方式: QQ:1540984562 QQ交流群:892023501 群里会有往届的smarters和电赛选手,群里也会不时分享一些有用的资料,有问题可以在群里多问问。 由于换了新电脑,keil重装了下,然而之前的MCU的支持包没有安装,以及一些其他的问题,导致可以编译但是不能将程序烧录到单片…

设计一个较为合理的实验方案来研究芳纶纤维的染色热力学性能

请你设计一个较为合理的实验方案来研究芳纶纤维的染色热力学性能?包括吸附等温线、亲和力、染色热和染色熵的测定,并指出实验中应注意哪些事项来减少实验误差? 标准答案: 染色热力学理论研究染色平衡问题。研究染色热力学性能:首先研究选择适宜的染料 吸附等温线类型测定…

我也谈委托与事件

虽然在博客园里面已经有很多关于C#委托和事件的文章&#xff0c;但是为了自己在学习的过程中&#xff0c;加深对委托的理解&#xff0c;我还是决定写一下自己的心得体会。以备他日在回来复习。委托&#xff08;delegate&#xff09;是一个类&#xff0c;但是这个类在声明的时候…

php错误拦截机制,php拦截异常怎么写-PHP问题

php拦截异常可以通过PHP的错误、异常机制及其内建数set_exception_handler、set_error_handler、register_shutdown_function 来写。首先我们定义错误拦截类&#xff0c;该类用于将错误、异常拦截下来&#xff0c;用我们自己定义的处理方式进行处理&#xff0c;该类放在文件名为…

智能车复工日记【4】:关于图像的上下位机的调整问题总结

系列文章 【智能车Code review】—曲率计算、最小二乘法拟合 【智能车Code review】——坡道图像与控制处理 【智能车Code review】——拐点的寻找 【智能车Code review】——小S与中S道路判断 【智能车Code review】——环岛的判定与补线操作 智能车复工日记【1】——菜单索引…