CNN tensorflow 人脸识别

数据材料

这是一个小型的人脸数据库,一共有40个人,每个人有10张照片作为样本数据。这些图片都是黑白照片,意味着这些图片都只有灰度0-255,没有rgb三通道。于是我们需要对这张大图片切分成一个个的小脸。整张图片大小是1190 × 942,一共有20 × 20张照片。那么每张照片的大小就是(1190 / 20)× (942 / 20)= 57 × 47 (大约,以为每张图片之间存在间距)。

问题解决:

10类样本,利用CNN训练可以分类10类数据的神经网络,与手写字符识别类似

olivettifaces.gif

 

复制代码
#coding=utf-8
#http://www.jianshu.com/p/3e5ddc44aa56
#tensorflow 1.3.1
#python 3.6
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import numpy
from PIL import Image#获取dataset
def load_data(dataset_path):img = Image.open(dataset_path)# 定义一个20 × 20的训练样本,一共有40个人,每个人都10张样本照片img_ndarray = np.asarray(img, dtype='float64') / 256#img_ndarray = np.asarray(img, dtype='float32') / 32# 记录脸数据矩阵,57 * 47为每张脸的像素矩阵faces = np.empty((400, 57 * 47))for row in range(20):for column in range(20):faces[20 * row + column] = np.ndarray.flatten(img_ndarray[row * 57: (row + 1) * 57, column * 47 : (column + 1) * 47])label = np.zeros((400, 40))for i in range(40):label[i * 10: (i + 1) * 10, i] = 1# 将数据分成训练集,验证集,测试集train_data = np.empty((320, 57 * 47))train_label = np.zeros((320, 40))vaild_data = np.empty((40, 57 * 47))vaild_label = np.zeros((40, 40))test_data = np.empty((40, 57 * 47))test_label = np.zeros((40, 40))for i in range(40):train_data[i * 8: i * 8 + 8] = faces[i * 10: i * 10 + 8]train_label[i * 8: i * 8 + 8] = label[i * 10: i * 10 + 8]vaild_data[i] = faces[i * 10 + 8]vaild_label[i] = label[i * 10 + 8]test_data[i] = faces[i * 10 + 9]test_label[i] = label[i * 10 + 9]train_data = train_data.astype('float32')vaild_data = vaild_data.astype('float32')test_data = test_data.astype('float32')return [(train_data, train_label),(vaild_data, vaild_label),(test_data, test_label)]def convolutional_layer(data, kernel_size, bias_size, pooling_size):kernel = tf.get_variable("conv", kernel_size, initializer=tf.random_normal_initializer())bias = tf.get_variable('bias', bias_size, initializer=tf.random_normal_initializer())conv = tf.nn.conv2d(data, kernel, strides=[1, 1, 1, 1], padding='SAME')linear_output = tf.nn.relu(tf.add(conv, bias))pooling = tf.nn.max_pool(linear_output, ksize=pooling_size, strides=pooling_size, padding="SAME")return poolingdef linear_layer(data, weights_size, biases_size):weights = tf.get_variable("weigths", weights_size, initializer=tf.random_normal_initializer())biases = tf.get_variable("biases", biases_size, initializer=tf.random_normal_initializer())return tf.add(tf.matmul(data, weights), biases)def convolutional_neural_network(data):# 根据类别个数定义最后输出层的神经元n_ouput_layer = 40kernel_shape1=[5, 5, 1, 32]kernel_shape2=[5, 5, 32, 64]full_conn_w_shape = [15 * 12 * 64, 1024]out_w_shape = [1024, n_ouput_layer]bias_shape1=[32]bias_shape2=[64]full_conn_b_shape = [1024]out_b_shape = [n_ouput_layer]data = tf.reshape(data, [-1, 57, 47, 1])# 经过第一层卷积神经网络后,得到的张量shape为:[batch, 29, 24, 32]with tf.variable_scope("conv_layer1") as layer1:layer1_output = convolutional_layer(data=data,kernel_size=kernel_shape1,bias_size=bias_shape1,pooling_size=[1, 2, 2, 1])# 经过第二层卷积神经网络后,得到的张量shape为:[batch, 15, 12, 64]with tf.variable_scope("conv_layer2") as layer2:layer2_output = convolutional_layer(data=layer1_output,kernel_size=kernel_shape2,bias_size=bias_shape2,pooling_size=[1, 2, 2, 1])with tf.variable_scope("full_connection") as full_layer3:# 讲卷积层张量数据拉成2-D张量只有有一列的列向量layer2_output_flatten = tf.contrib.layers.flatten(layer2_output)layer3_output = tf.nn.relu(linear_layer(data=layer2_output_flatten,weights_size=full_conn_w_shape,biases_size=full_conn_b_shape))# layer3_output = tf.nn.dropout(layer3_output, 0.8)with tf.variable_scope("output") as output_layer4:output = linear_layer(data=layer3_output,weights_size=out_w_shape,biases_size=out_b_shape)return output;def train_facedata(dataset, model_dir,model_path):# train_set_x = data[0][0]# train_set_y = data[0][1]# valid_set_x = data[1][0]# valid_set_y = data[1][1]# test_set_x = data[2][0]# test_set_y = data[2][1]# X = tf.placeholder(tf.float32, shape=(None, None), name="x-input")  # 输入数据# Y = tf.placeholder(tf.float32, shape=(None, None), name='y-input')  # 输入标签
batch_size = 40# train_set_x, train_set_y = dataset[0]# valid_set_x, valid_set_y = dataset[1]# test_set_x, test_set_y = dataset[2]train_set_x = dataset[0][0]train_set_y = dataset[0][1]valid_set_x = dataset[1][0]valid_set_y = dataset[1][1]test_set_x = dataset[2][0]test_set_y = dataset[2][1]X = tf.placeholder(tf.float32, [batch_size, 57 * 47])Y = tf.placeholder(tf.float32, [batch_size, 40])predict = convolutional_neural_network(X)cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=Y))optimizer = tf.train.AdamOptimizer(1e-2).minimize(cost_func)# 用于保存训练的最佳模型saver = tf.train.Saver()#model_dir = './model'#model_path = model_dir + '/best.ckpt'
    with tf.Session() as session:# 若不存在模型数据,需要训练模型参数if not os.path.exists(model_path + ".index"):session.run(tf.global_variables_initializer())best_loss = float('Inf')for epoch in range(20):epoch_loss = 0for i in range((int)(np.shape(train_set_x)[0] / batch_size)):x = train_set_x[i * batch_size: (i + 1) * batch_size]y = train_set_y[i * batch_size: (i + 1) * batch_size]_, cost = session.run([optimizer, cost_func], feed_dict={X: x, Y: y})epoch_loss += costprint(epoch, ' : ', epoch_loss)if best_loss > epoch_loss:best_loss = epoch_lossif not os.path.exists(model_dir):os.mkdir(model_dir)print("create the directory: %s" % model_dir)save_path = saver.save(session, model_path)print("Model saved in file: %s" % save_path)# 恢复数据并校验和测试
        saver.restore(session, model_path)correct = tf.equal(tf.argmax(predict,1), tf.argmax(Y,1))valid_accuracy = tf.reduce_mean(tf.cast(correct,'float'))print('valid set accuracy: ', valid_accuracy.eval({X: valid_set_x, Y: valid_set_y}))test_pred = tf.argmax(predict, 1).eval({X: test_set_x})test_true = np.argmax(test_set_y, 1)test_correct = correct.eval({X: test_set_x, Y: test_set_y})incorrect_index = [i for i in range(np.shape(test_correct)[0]) if not test_correct[i]]for i in incorrect_index:print('picture person is %i, but mis-predicted as person %i'%(test_true[i], test_pred[i]))plot_errordata(incorrect_index, "olivettifaces.gif")#画出在测试集中错误的数据
def plot_errordata(error_index, dataset_path):img = mpimg.imread(dataset_path)plt.imshow(img)currentAxis = plt.gca()for index in error_index:row = index // 2column = index % 2currentAxis.add_patch(patches.Rectangle(xy=(47 * 9 if column == 0 else 47 * 19,row * 57),width=47,height=57,linewidth=1,edgecolor='r',facecolor='none'))plt.savefig("result.png")plt.show()def main():dataset_path = "olivettifaces.gif"data = load_data(dataset_path)model_dir = './model'model_path = model_dir + '/best.ckpt'train_facedata(data, model_dir, model_path)if __name__ == "__main__" :main()
复制代码

 C:\python36\python.exe X:/DeepLearning/code/face/TensorFlow_CNN_face/facerecognition_main.py
valid set accuracy:  0.825
picture person is 0, but mis-predicted as person 23
picture person is 6, but mis-predicted as person 38
picture person is 8, but mis-predicted as person 34
picture person is 15, but mis-predicted as person 11
picture person is 24, but mis-predicted as person 7
picture person is 29, but mis-predicted as person 7
picture person is 33, but mis-predicted as person 39

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/387221.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构01绪论

第一章绪论 1.1 什么是数据结构 数据结构是一门研究非数值计算的程序设计问题中,计算机的操作对象以及他们之间的关系和操作的学科。 面向过程程序数据结构算法 数据结构是介于数学、计算机硬件、计算机软件三者之间的一门核心课程。 数据结构是程序设计、编译…

css3动画、2D与3D效果

1.兼容性 css3针对同一样式在不同浏览器的兼容 需要在样式属性前加上内核前缀; 谷歌(chrome) -webkit-transition: Opera(欧鹏) -o-transition: Firefox(火狐) -moz-transition Ie -ms-tr…

数据结构02线性表

第二章 线性表 C中STL顺序表:vector http://blog.csdn.net/weixin_37289816/article/details/54710677链表:list http://blog.csdn.net/weixin_37289816/article/details/54773406在数据元素的非空有限集中: (1)存在唯一一个被称作“第…

训练一个神经网络 能让她认得我

写个神经网络,让她认得我(๑•ᴗ•๑)(Tensorflow,opencv,dlib,cnn,人脸识别) 这段时间正在学习tensorflow的卷积神经网络部分,为了对卷积神经网络能够有一个更深的了解,自己动手实现一个例程是比较好的方式,所以就选了一个这样比…

数据结构03栈和队列

第三章栈和队列 STL栈:stack http://blog.csdn.net/weixin_37289816/article/details/54773495队列:queue http://blog.csdn.net/weixin_37289816/article/details/54773581priority_queue http://blog.csdn.net/weixin_37289816/article/details/5477…

树莓派pwm驱动好盈电调及伺服电机

本文讲述如何通过树莓派的硬件PWM控制好盈电调来驱动RC车子的前进后退,以及如何驱动伺服电机来控制车子转向。 1. 好盈电调简介 车子上的电调型号为:WP-10BLS-A-RTR,在好盈官网并没有搜到对应手册,但找到一份通用RC竞速车的电调使…

数据结构04串

第四章 串 STL:string http://blog.csdn.net/weixin_37289816/article/details/54716009计算机上非数值处理的对象基本上是字符串数据。 在不同类型的应用中,字符串具有不同的特点,要有效的实现字符串的处理,必须选用合适的存储…

CAS单点登录原理解析

CAS单点登录原理解析 SSO英文全称Single Sign On,单点登录。SSO是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。CAS是一种基于http协议的B/S应用系统单点登录实现方案,认识CAS之前首先要熟悉http协议、Session与Co…

数据结构05数组和广义表

第五章 数组 和 广义表 数组和广义表可以看成是线性表在下述含义上的扩展:表中的数据元素本身也是一个数据结构。 5.1 数组的定义 n维数组中每个元素都受着n个关系的约束,每个元素都有一个直接后继元素。 可以把二维数组看成是这样一个定长线性表&…

数据结构06树和二叉树

第六章 树和二叉树 6.1 树的定义和基本术语 树 Tree 是n个结点的有限集。 任意一棵非空树中: (1)有且仅有一个特定的称为根(root)的结点; (2)当n>1时,其余结点可…

CountDownLatch,CyclicBarrier和Semaphore

在java 1.5中,提供了一些非常有用的辅助类来帮助我们进行并发编程,比如CountDownLatch,CyclicBarrier和Semaphore,今天我们就来学习一下这三个辅助类的用法。以下是本文目录大纲:一.CountDownLatch用法二.CyclicBarrie…

数据结构07排序

第十章内部排序 10.1 概述 排序就是把一组数据按关键字的大小有规律地排列。经过排序的数据更易于查找。 排序前KiKj,且Ki在前: 排序方法是稳定的,若排序后Ki在前; 排序方法是不稳定的,如排序后Kj在前。 分类: 内…

数据结构08查找

第九章 查找 另一种在实际应用中大量使用的数据结构--查找表。 所谓查找,即为在一个含有众多的数据元素的查找表中找出某个“特定的”数据元素。 查找表 search table 是由同一类型的数据元素构成的集合。集合中的数据元素之间存在着完全松散的关系,故…

下载Centos7 64位镜像

下载Centos7 64位镜像 1.打开Centos官网 打开Centos官方网站地址:https://www.centos.org/,点击Get CentOS Now 2.点击Minimal ISO镜像 Minimal ISO镜像,与DVD ISO镜像的差别有很多,这里只说两点 1.Minimal ISO类似于Windows的纯净…

Scala01入门

第1章 可伸展的语言 Scala应用范围广,从编写脚本,到建立大型系统。 运行在标准Java平台上,与Java库无缝交互。 更能发挥力量的地方:建立大型系统或可重用控件的架构。 将面向对象和函数式编程加入到静态类型语言。 在Scala中&a…

Java网络01基本网络概念

协议 Protocol:明确规则 (1)地址格式; (2)数据如何分包; ... TCP/IP四层模型: 应用层 HTTP SMTP POP IMAP 传输层 TCP UDP 网际层 IP 主机网络层 host to host layer 数模、…

Java网络02基本Web概念

URI Uniform Resource Identifier 同一资源标识符 以特定语法标识一个资源的字符串 绝对URI:URI模式模式特有部分 scheme:scheme-specific-part scheme分为: data file本地文件系统 ftp http telnet urn 统一资源名 scheme-specific-part为&am…

解决自建ca认证后浏览器警告

前一篇讲解了基本的建立证书的过程,但是建立后总是会在浏览器那里警告: 此链接不是私密链接 --谷歌浏览器 此证书颁发机构不可信 此证书不是这个网站的 --ie浏览器 总之证书是生成成功了,但是其中的内容填写错误了&a…

Java网络03流

网络程序所做的很大一部分工作只是输入和输出:从一个系统向另一个系统移动数据。 输出流 Java的基本输出流类是java.io.OutputStream: public abstract class OutputStream 这个类提供了写入数据所需的基本方法,包括: public abstract vo…

基于微信小程序开发的仿微信demo

(本文参考自github/liujians,地址:https://github.com/liujians/weApp) 作者声明: 基于微信小程序开发的仿微信demo 整合了ionic的样式库和weui的样式库 使用请查看使用必读! 更新日志请点击这里 目前功能 查看消息 网络请求获取数据(download示例server…