花卉分类CNN

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.3

数据:http://download.tensorflow.org/example_images/flower_photos.tgz

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

 

# -*- coding: utf-8 -*-
from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import timepath='e:/flower/'#将所有的图片resize成100*100
w=100
h=100
c=3#读取图片
def read_img(path):cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]imgs=[]labels=[]for idx,folder in enumerate(cate):for im in glob.glob(folder+'/*.jpg'):print('reading the images:%s'%(im))img=io.imread(im)img=transform.resize(img,(w,h))imgs.append(img)labels.append(idx)return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(inputs=x,  filters=32,  kernel_size=[5, 5], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)#第二个卷积层(50->25)
conv2=tf.layers.conv2d( inputs=pool1, filters=64,  kernel_size=[5, 5], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)#第三个卷积层(25->12)
conv3=tf.layers.conv2d(inputs=pool2, filters=128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)#第四个卷积层(12->6)
conv4=tf.layers.conv2d(  inputs=pool3, filters=128, kernel_size=[3, 3], padding="same",activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])#全连接层
dense1 = tf.layers.dense(inputs=re1,      units=1024,   activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1,  units=512,  activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2,  units=5,   activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)    
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):assert len(inputs) == len(targets)if shuffle:indices = np.arange(len(inputs))np.random.shuffle(indices)for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):if shuffle:excerpt = indices[start_idx:start_idx + batch_size]else:excerpt = slice(start_idx, start_idx + batch_size)yield inputs[excerpt], targets[excerpt]#训练和测试数据,可将n_epoch设置更大一些

n_epoch=1000
batch_size=64
sess=tf.InteractiveSession()  
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):start_time = time.time()#trainingtrain_loss, train_acc, n_batch = 0, 0, 0for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})train_loss += err; train_acc += ac; n_batch += 1print("   train loss: %f" % (train_loss/ n_batch))print("   train acc: %f" % (train_acc/ n_batch))#validationval_loss, val_acc, n_batch = 0, 0, 0for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})val_loss += err; val_acc += ac; n_batch += 1print("   validation loss: %f" % (val_loss/ n_batch))print("   validation acc: %f" % (val_acc/ n_batch))sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/387228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【模板】可持久化线段树

可持久化线段树/主席树: 顾名思义,该数据结构是可以访问历史版本的线段树。用于解决需要查询历史信息的区间问题。 在功能与时间复杂度上与开n棵线段树无异,然而空间复杂度从$O(n\times nlogn)$降到了$O(nlogn)$。 实现方法: 每次…

skimage库需要依赖 numpy+mkl 和scipy

skimage库需要依赖 numpymkl 和scipy1、打开运行,输入cmd回车,输入python回车,查看python版本2、在https://www.lfd.uci.edu/~gohlke/pythonlibs/#numpy 中,根据自己python版本下载需要的包 (因为我的是python 2.7.13 …

操作系统04进程同步与通信

4.1 进程间的相互作用 4.1.1 进程间的联系资源共享关系相互合作关系临界资源应互斥访问。临界区:不论是硬件临界资源,还是软件临界资源,多个进程必须互斥地对它们进行访问。把在每个进程中访问临界资源的那段代码称为临界资源区。显然&#x…

oracle迁移到greenplum的方案

oracle数据库是一种关系型数据库管理系统,在数据库领域一直处于领先的地位,适合于大型项目的开发;银行、电信、电商、金融等各领域都大量使用Oracle数据库。 greenplum是一款开源的分布式数据库存储解决方案,主要关注数据仓库和BI…

CNN框架的搭建及各个参数的调节

本文代码下载地址:我的github本文主要讲解将CNN应用于人脸识别的流程,程序基于PythonnumpytheanoPIL开发,采用类似LeNet5的CNN模型,应用于olivettifaces人脸数据库,实现人脸识别的功能,模型的误差降到了5%以…

操作系统05死锁

进程管理4--Deadlock and Starvation Concurrency: Deadlock and Starvation 内容提要 >产生死锁与饥饿的原因 >解决死锁的方法 >死锁/同步的经典问题:哲学家进餐问题 Deadlock 系统的一种随机性错误 Permanent blocking of a set of processes that eith…

CNN tensorflow 人脸识别

数据材料这是一个小型的人脸数据库,一共有40个人,每个人有10张照片作为样本数据。这些图片都是黑白照片,意味着这些图片都只有灰度0-255,没有rgb三通道。于是我们需要对这张大图片切分成一个个的小脸。整张图片大小是1190 942&am…

数据结构01绪论

第一章绪论 1.1 什么是数据结构 数据结构是一门研究非数值计算的程序设计问题中,计算机的操作对象以及他们之间的关系和操作的学科。 面向过程程序数据结构算法 数据结构是介于数学、计算机硬件、计算机软件三者之间的一门核心课程。 数据结构是程序设计、编译…

css3动画、2D与3D效果

1.兼容性 css3针对同一样式在不同浏览器的兼容 需要在样式属性前加上内核前缀; 谷歌(chrome) -webkit-transition: Opera(欧鹏) -o-transition: Firefox(火狐) -moz-transition Ie -ms-tr…

ES6学习笔记(六)数组的扩展

1.扩展运算符 1.1含义 扩展运算符(spread)是三个点(...)。它好比 rest 参数的逆运算,将一个数组转为用逗号分隔的参数序列。 console.log(...[1, 2, 3]) // 1 2 3console.log(1, ...[2, 3, 4], 5) // 1 2 3 4 5[...doc…

数据结构02线性表

第二章 线性表 C中STL顺序表:vector http://blog.csdn.net/weixin_37289816/article/details/54710677链表:list http://blog.csdn.net/weixin_37289816/article/details/54773406在数据元素的非空有限集中: (1)存在唯一一个被称作“第…

训练一个神经网络 能让她认得我

写个神经网络,让她认得我(๑•ᴗ•๑)(Tensorflow,opencv,dlib,cnn,人脸识别) 这段时间正在学习tensorflow的卷积神经网络部分,为了对卷积神经网络能够有一个更深的了解,自己动手实现一个例程是比较好的方式,所以就选了一个这样比…

数据结构03栈和队列

第三章栈和队列 STL栈:stack http://blog.csdn.net/weixin_37289816/article/details/54773495队列:queue http://blog.csdn.net/weixin_37289816/article/details/54773581priority_queue http://blog.csdn.net/weixin_37289816/article/details/5477…

Java动态编译执行

在某些情况下,我们需要动态生成java代码,通过动态编译,然后执行代码。JAVA API提供了相应的工具(JavaCompiler)来实现动态编译。下面我们通过一个简单的例子介绍,如何通过JavaCompiler实现java代码动态编译…

树莓派pwm驱动好盈电调及伺服电机

本文讲述如何通过树莓派的硬件PWM控制好盈电调来驱动RC车子的前进后退,以及如何驱动伺服电机来控制车子转向。 1. 好盈电调简介 车子上的电调型号为:WP-10BLS-A-RTR,在好盈官网并没有搜到对应手册,但找到一份通用RC竞速车的电调使…

数据结构04串

第四章 串 STL:string http://blog.csdn.net/weixin_37289816/article/details/54716009计算机上非数值处理的对象基本上是字符串数据。 在不同类型的应用中,字符串具有不同的特点,要有效的实现字符串的处理,必须选用合适的存储…

CAS单点登录原理解析

CAS单点登录原理解析 SSO英文全称Single Sign On,单点登录。SSO是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。CAS是一种基于http协议的B/S应用系统单点登录实现方案,认识CAS之前首先要熟悉http协议、Session与Co…

JDK1.6版添加了新的ScriptEngine类,允许用户直接执行js代码。

JDK1.6版添加了新的ScriptEngine类,允许用户直接执行js代码。在Java中直接调用js代码 不能调用浏览器中定义的js函数,会抛出异常提示ReferenceError: “alert” is not defined。[java] view plaincopypackage com.sinaapp.manjushri; import javax.sc…

数据结构05数组和广义表

第五章 数组 和 广义表 数组和广义表可以看成是线性表在下述含义上的扩展:表中的数据元素本身也是一个数据结构。 5.1 数组的定义 n维数组中每个元素都受着n个关系的约束,每个元素都有一个直接后继元素。 可以把二维数组看成是这样一个定长线性表&…

k8s的ingress使用

ingress 可以配置一个入口来提供k8s上service从外部来访问的url、负载平衡流量、终止SSL和提供基于名称的虚拟主机。 配置ingress的yaml: 要求域名解析无误 要求service对应的pod正常 一、test1.domain.com --> service1:8080 apiVersion: extensions/v1beta1…