CNN训练模型 花卉

一、CNN训练模型

模型尺寸分析:卷积层全都采用了补0,所以经过卷积层长和宽不变,只有深度加深。池化层全都没有补0,所以经过池化层长和宽均减小,深度不变。http://download.tensorflow.org/example_images/flower_photos.tgz

模型尺寸变化:100×100×3->100×100×32->50×50×32->50×50×64->25×25×64->25×25×128->12×12×128->12×12×128->6×6×128

CNN训练代码如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time
#数据集地址
path='E:/data/datasets/flower_photos/'
#模型保存地址
model_path='E:/data/model/flower/model.ckpt'
#将所有的图片resize成100*100
w=100
h=100
c=3
#读取图片
def read_img(path):
    cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
    imgs=[]
    labels=[]
    for idx,folder in enumerate(cate):
        for im in glob.glob(folder+'/*.jpg'):
            print('reading the images:%s'%(im))
            img=io.imread(im)
            img=transform.resize(img,(w,h))
            imgs.append(img)
            labels.append(idx)
    return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)
#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]
#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]
#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')
def inference(input_tensor, train, regularizer):
    with tf.variable_scope('layer1-conv1'):
        conv1_weights = tf.get_variable("weight",[5,5,3,32],initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv1_biases = tf.get_variable("bias", [32], initializer=tf.constant_initializer(0.0))
        conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
    with tf.name_scope("layer2-pool1"):
        pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="VALID")
    with tf.variable_scope("layer3-conv2"):
        conv2_weights = tf.get_variable("weight",[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv2_biases = tf.get_variable("bias", [64], initializer=tf.constant_initializer(0.0))
        conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
    with tf.name_scope("layer4-pool2"):
        pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
    with tf.variable_scope("layer5-conv3"):
        conv3_weights = tf.get_variable("weight",[3,3,64,128],initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv3_biases = tf.get_variable("bias", [128], initializer=tf.constant_initializer(0.0))
        conv3 = tf.nn.conv2d(pool2, conv3_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_biases))
    with tf.name_scope("layer6-pool3"):
        pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
    with tf.variable_scope("layer7-conv4"):
        conv4_weights = tf.get_variable("weight",[3,3,128,128],initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv4_biases = tf.get_variable("bias", [128], initializer=tf.constant_initializer(0.0))
        conv4 = tf.nn.conv2d(pool3, conv4_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu4 = tf.nn.relu(tf.nn.bias_add(conv4, conv4_biases))
    with tf.name_scope("layer8-pool4"):
        pool4 = tf.nn.max_pool(relu4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        nodes = 6*6*128
        reshaped = tf.reshape(pool4,[-1,nodes])
    with tf.variable_scope('layer9-fc1'):
        fc1_weights = tf.get_variable("weight", [nodes, 1024],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(fc1_weights))
        fc1_biases = tf.get_variable("bias", [1024], initializer=tf.constant_initializer(0.1))
        fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)
        if train: fc1 = tf.nn.dropout(fc1, 0.5)
    with tf.variable_scope('layer10-fc2'):
        fc2_weights = tf.get_variable("weight", [1024, 512],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(fc2_weights))
        fc2_biases = tf.get_variable("bias", [512], initializer=tf.constant_initializer(0.1))
        fc2 = tf.nn.relu(tf.matmul(fc1, fc2_weights) + fc2_biases)
        if train: fc2 = tf.nn.dropout(fc2, 0.5)
    with tf.variable_scope('layer11-fc3'):
        fc3_weights = tf.get_variable("weight", [512, 5],
                                      initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(fc3_weights))
        fc3_biases = tf.get_variable("bias", [5], initializer=tf.constant_initializer(0.1))
        logit = tf.matmul(fc2, fc3_weights) + fc3_biases
    return logit
#---------------------------网络结束---------------------------
regularizer = tf.contrib.layers.l2_regularizer(0.0001)
logits = inference(x,False,regularizer)
#(小处理)将logits乘以1赋值给logits_eval,定义name,方便在后续调用模型时通过tensor名字调用输出tensor
b = tf.constant(value=1,dtype=tf.float32)
logits_eval = tf.multiply(logits,b,name='logits_eval')
loss=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)   
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batch_size]
        else:
            excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt], targets[excerpt]
#训练和测试数据,可将n_epoch设置更大一些
n_epoch=10                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
batch_size=64
saver=tf.train.Saver()
sess=tf.Session() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
    start_time = time.time()
    #training
    train_loss, train_acc, n_batch = 0, 0, 0
    for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
        _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
        train_loss += err; train_acc += ac; n_batch += 1
    print("   train loss: %f" % (np.sum(train_loss)/ n_batch))
    print("   train acc: %f" % (np.sum(train_acc)/ n_batch))
    #validation
    val_loss, val_acc, n_batch = 0, 0, 0
    for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
        err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
        val_loss += err; val_acc += ac; n_batch += 1
    print("   validation loss: %f" % (np.sum(val_loss)/ n_batch))
    print("   validation acc: %f" % (np.sum(val_acc)/ n_batch))
saver.save(sess,model_path)
sess.close()

二、调用模型进行预测

调用模型进行花卉的预测,代码如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from skimage import io,transform
import tensorflow as tf
import numpy as np
path1 = "E:/data/datasets/flower_photos/daisy/5547758_eea9edfd54_n.jpg"
path2 = "E:/data/datasets/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"
path3 = "E:/data/datasets/flower_photos/roses/394990940_7af082cf8d_n.jpg"
path4 = "E:/data/datasets/flower_photos/sunflowers/6953297_8576bf4ea3.jpg"
path5 = "E:/data/datasets/flower_photos/tulips/10791227_7168491604.jpg"
flower_dict = {0:'dasiy',1:'dandelion',2:'roses',3:'sunflowers',4:'tulips'}
w=100
h=100
c=3
def read_one_image(path):
    img = io.imread(path)
    img = transform.resize(img,(w,h))
    return np.asarray(img)
with tf.Session() as sess:
    data = []
    data1 = read_one_image(path1)
    data2 = read_one_image(path2)
    data3 = read_one_image(path3)
    data4 = read_one_image(path4)
    data5 = read_one_image(path5)
    data.append(data1)
    data.append(data2)
    data.append(data3)
    data.append(data4)
    data.append(data5)
    saver = tf.train.import_meta_graph('E:/data/model/flower/model.ckpt.meta')
    saver.restore(sess,tf.train.latest_checkpoint('E:/data/model/flower/'))
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    feed_dict = {x:data}
    logits = graph.get_tensor_by_name("logits_eval:0")
    classification_result = sess.run(logits,feed_dict)
    #打印出预测矩阵
    print(classification_result)
    #打印出预测矩阵每一行最大值的索引
    print(tf.argmax(classification_result,1).eval())
    #根据索引通过字典对应花的分类
    output = []
    output = tf.argmax(classification_result,1).eval()
    for i in range(len(output)):
        print("第",i+1,"朵花预测:"+flower_dict[output[i]])

运行结果:

?
1
2
3
4
5
6
7
8
9
10
11
[[  5.76620245   3.18228579  -3.89464641  -2.81310582   1.40294015]
 [ -1.01490593   3.55570269  -2.76053429   2.93104005  -3.47138596]
 [ -8.05292606  -7.26499033  11.70479774   0.59627819   2.15948296]
 [ -5.12940931   2.18423128  -3.33257103   9.0591135    5.03963232]
 [ -4.25288343  -0.95963973  -2.33347392   1.54485476   5.76069307]]
[0 1 2 3 4]
1 朵花预测:dasiy
2 朵花预测:dandelion
3 朵花预测:roses
4 朵花预测:sunflowers
5 朵花预测:tulips

预测结果和调用模型代码中的五个路径相比较是完全准确的。

本文的模型对于花卉的分类准确率大概在70%左右,采用迁移学习调用Inception-v3模型对本文中的花卉数据集分类准确率在95%左右。主要的原因在于本文的CNN模型较于简单,而且花卉数据集本身就比mnist手写数字数据集分类难度就要大一点,同样的模型在mnist手写数字的识别上准确率要比花卉数据集准确率高不少。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/387233.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux re

正则表达式并不是一个工具程序,而是一个字符串处理的标准依据,如果想要以正则表达式的方式处理字符串,就得使用支持正则表达式的工具,例如grep、vi、sed、asw等。 注意:ls不支持正则表达式。 grep 正则表达式: 注意gr…

操作系统02进程管理Process_Description_and_Control

作业的基本概念:用户再一次计算过程中或一次事务处理过程中,要求计算机系统所做的工作的集合。 包含多个程序、多个数据、作业控制说明书 系统调用时操作系统提供给编程人员的唯一接口。 1、文件操作类; 2、进程控制类; 3、资…

蓝桥杯 方格填数(全排列+图形补齐)

方格填数 如下的10个格子 填入0~9的数字,同一数字不能重复填。要求:连续的两个数字不能相邻。(左右、上下、对角都算相邻) 一共有多少种可能的填数方案? 请填写表示方案数目的整数。注意:你提交的应该是一个…

操作系统03进程管理Process_Scheduling

2 Process Scheduling >Type of scheduling >Scheduling Criteria (准则) >Scheduling Algorithm >Real-Time Scheduling (嵌入式系统) 2.1 Learning Objectives By the end of this lecture you should be able to Explain what is Response Time 响应时间-…

花卉分类CNN

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。 任务:花卉分类 版本:tensorflow 1.3 数据:http://download.tensorflow.org/example_images/f…

【模板】可持久化线段树

可持久化线段树/主席树: 顾名思义,该数据结构是可以访问历史版本的线段树。用于解决需要查询历史信息的区间问题。 在功能与时间复杂度上与开n棵线段树无异,然而空间复杂度从$O(n\times nlogn)$降到了$O(nlogn)$。 实现方法: 每次…

skimage库需要依赖 numpy+mkl 和scipy

skimage库需要依赖 numpymkl 和scipy1、打开运行,输入cmd回车,输入python回车,查看python版本2、在https://www.lfd.uci.edu/~gohlke/pythonlibs/#numpy 中,根据自己python版本下载需要的包 (因为我的是python 2.7.13 …

操作系统04进程同步与通信

4.1 进程间的相互作用 4.1.1 进程间的联系资源共享关系相互合作关系临界资源应互斥访问。临界区:不论是硬件临界资源,还是软件临界资源,多个进程必须互斥地对它们进行访问。把在每个进程中访问临界资源的那段代码称为临界资源区。显然&#x…

oracle迁移到greenplum的方案

oracle数据库是一种关系型数据库管理系统,在数据库领域一直处于领先的地位,适合于大型项目的开发;银行、电信、电商、金融等各领域都大量使用Oracle数据库。 greenplum是一款开源的分布式数据库存储解决方案,主要关注数据仓库和BI…

CNN框架的搭建及各个参数的调节

本文代码下载地址:我的github本文主要讲解将CNN应用于人脸识别的流程,程序基于PythonnumpytheanoPIL开发,采用类似LeNet5的CNN模型,应用于olivettifaces人脸数据库,实现人脸识别的功能,模型的误差降到了5%以…

操作系统05死锁

进程管理4--Deadlock and Starvation Concurrency: Deadlock and Starvation 内容提要 >产生死锁与饥饿的原因 >解决死锁的方法 >死锁/同步的经典问题:哲学家进餐问题 Deadlock 系统的一种随机性错误 Permanent blocking of a set of processes that eith…

CNN tensorflow 人脸识别

数据材料这是一个小型的人脸数据库,一共有40个人,每个人有10张照片作为样本数据。这些图片都是黑白照片,意味着这些图片都只有灰度0-255,没有rgb三通道。于是我们需要对这张大图片切分成一个个的小脸。整张图片大小是1190 942&am…

数据结构01绪论

第一章绪论 1.1 什么是数据结构 数据结构是一门研究非数值计算的程序设计问题中,计算机的操作对象以及他们之间的关系和操作的学科。 面向过程程序数据结构算法 数据结构是介于数学、计算机硬件、计算机软件三者之间的一门核心课程。 数据结构是程序设计、编译…

css3动画、2D与3D效果

1.兼容性 css3针对同一样式在不同浏览器的兼容 需要在样式属性前加上内核前缀; 谷歌(chrome) -webkit-transition: Opera(欧鹏) -o-transition: Firefox(火狐) -moz-transition Ie -ms-tr…

ES6学习笔记(六)数组的扩展

1.扩展运算符 1.1含义 扩展运算符(spread)是三个点(...)。它好比 rest 参数的逆运算,将一个数组转为用逗号分隔的参数序列。 console.log(...[1, 2, 3]) // 1 2 3console.log(1, ...[2, 3, 4], 5) // 1 2 3 4 5[...doc…

数据结构02线性表

第二章 线性表 C中STL顺序表:vector http://blog.csdn.net/weixin_37289816/article/details/54710677链表:list http://blog.csdn.net/weixin_37289816/article/details/54773406在数据元素的非空有限集中: (1)存在唯一一个被称作“第…

训练一个神经网络 能让她认得我

写个神经网络,让她认得我(๑•ᴗ•๑)(Tensorflow,opencv,dlib,cnn,人脸识别) 这段时间正在学习tensorflow的卷积神经网络部分,为了对卷积神经网络能够有一个更深的了解,自己动手实现一个例程是比较好的方式,所以就选了一个这样比…

数据结构03栈和队列

第三章栈和队列 STL栈:stack http://blog.csdn.net/weixin_37289816/article/details/54773495队列:queue http://blog.csdn.net/weixin_37289816/article/details/54773581priority_queue http://blog.csdn.net/weixin_37289816/article/details/5477…

Java动态编译执行

在某些情况下,我们需要动态生成java代码,通过动态编译,然后执行代码。JAVA API提供了相应的工具(JavaCompiler)来实现动态编译。下面我们通过一个简单的例子介绍,如何通过JavaCompiler实现java代码动态编译…

树莓派pwm驱动好盈电调及伺服电机

本文讲述如何通过树莓派的硬件PWM控制好盈电调来驱动RC车子的前进后退,以及如何驱动伺服电机来控制车子转向。 1. 好盈电调简介 车子上的电调型号为:WP-10BLS-A-RTR,在好盈官网并没有搜到对应手册,但找到一份通用RC竞速车的电调使…