超干货|使用Keras和CNN构建分类器(内含代码和讲解)

摘要: 为了让文章不那么枯燥,我构建了一个精灵图鉴数据集(Pokedex)这都是一些受欢迎的精灵图。我们在已经准备好的图像数据集上,使用Keras库训练一个卷积神经网络(CNN)。


为了让文章不那么枯燥,我构建了一个精灵图鉴数据集(Pokedex)这都是一些受欢迎的精灵图。我们在已经准备好的图像数据集上,使用Keras库训练一个卷积神经网络(CNN)。


深度学习数据集



上图是来自我们的精灵图鉴深度学习数据集中的合成图样本。我的目标是使用Keras库和深度学习训练一个CNN,对Pokedex数据集中的图像进行识别和分类。Pokedex数据集包括:Bulbasaur (234 images)Charmander (238 images)Squirtle (223 images)Pikachu (234 images)Mewtwo (239 images)


训练图像包括以下组合:电视或电影的静态帧;交易卡;行动人物;玩具和小玩意儿;图纸和粉丝的艺术效果图。

在这种多样化的训练图像的情况下,实验结果证明,CNN模型的分类准确度高达97


CNNKeras库的项目结构

该项目分为几个部分,目录结构如下:



如上图所示,共分为3个目录:

1.数据集:包含五个类,每个类都是一个子目录。

2.示例:包含用于测试卷积神经网络的图像。

3.pyimagesearch模块:包含我们的SmallerVGGNet模型类。


另外,根目录下有5个文件:

1.plot.png:训练脚本运行后,生成的训练/测试准确性和损耗图。

2.lb.pickleLabelBinarizer序列化文件,在类名称查找机制中包含类索引。

3.pokedex.model:序列化Keras CNN模型文件(即权重文件)。

4.train.py:训练Keras CNN,绘制准确性/损耗函数,然后将卷积神经网络和类标签二进制文件序列化到磁盘。

5.classify.py:测试脚本。


KerasCNN架构


我们今天使用的CNN架构,是由SimonyanZisserman2014年的论文用于大规模图像识别的强深度卷积网络中介绍的VGGNet网络的简单版本,结构图如上图所示。该网络架构的特点是:

1.只使用3*3的卷积层堆叠在一起来增加深度。

2.使用最大池化来减小数组大小。

3.网络末端全连接层在softmax分类器之前。


假设你已经在系统上安装并配置了Keras。如果没有,请参照以下连接了解开发环境的配置教程:

1.配置Ubuntu,使用Python进行深度学习。

2.设置Ubuntu 16.04 + CUDA + GPU,使用Python进行深度学习。

3.配置macOS,使用Python进行深度学习。


继续使用SmallerVGGNet——VGGNet的更小版本。在pyimagesearch模块中创建一个名为smallervggnet.py的新文件,并插入以下代码:



注意:在pyimagesearch中创建一个_init_.py文件,以便Python知道该目录是一个模块。如果你对_init_.py文件不熟悉或者不知道如何使用它来创建模块,你只需在原文的下载部分下载目录结构、源代码、数据集和示例图像。


现在定义SmallerVGGNet类:



该构建方法需要四个参数:

1.width:图像宽度。

2.height :图像高度。

3.depth :图像深度。

4.classes :数据集中类的数量(这将影响模型的最后一层),我们使用了5Pokemon 类。


注意:我们使用的是深度为3、大小为96 * 96的输入图像。后边解释输入数组通过网络的空间维度时,请记住这一点。


由于我们使用的是TensorFlow后台,因此用“channels last”对输入数据进行排序;如果想用“channels last”,则可以用代码中的23-25行进行处理。

为模型添加层,下图为第一个CONV => RELU => POOL代码块:


卷积层有32个内核大小为3*3的滤波器,使用RELU激活函数,然后进行批量标准化。

池化层使用3 *3的池化,将空间维度从96 *96快速降低到32 * 32(输入图像的大小为96 * 96 * 3的来训练网络)。

如代码所示,在网络架构中使用DropoutDropout随机将节点从当前层断开,并连接到下一层。这个随机断开的过程有助于降低模型中的冗余——网络层中没有任何单个节点负责预测某个类、对象、边或角。

在使用另外一个池化层前,添加(CONV => RELU* 2层:


在降低输入数组的空间维度前,将多个卷积层RELU层堆叠在一起可以学习更丰富的特征集。

请注意:将滤波器大小从32增加到64。随着网络的深入,输入数组的空间维度越小,滤波器学习到的内容更多;将最大池化层从3*3降低到2*2,以确保不会过快地降低空间维度。在这个过程中再次执行Dropout

再添加一个(CONV => RELU)* 2 => POOL代码块:


我们已经将滤波器的大小增加到128。对25%的节点执行Droupout以减少过拟合。

最后,还有一组FC => RELU层和一个softmax分类器:


Dense1024使用具有校正的线性单位激活和批量归一化指定全连接层。

最后再执行一次Droupout——在训练期间我们Droupout50%的节点。通常情况下,你会在全连接层在较低速率下使用40-50%的Droupout,其他网络层为10-25%的Droupout

softmax分类器对模型进行四舍五入,该分类器将返回每个类别标签的预测概率值。


CNN + Keras训练脚本的实现

既然VGGNet小版本已经实现,现在我们使用Keras来训练卷积神经网络。

创建一个名为train.py的新文件,并插入以下代码,导入需要的软件包和库:


使用”Agg” matplotlib后台,以便可以将数字保存在背景中(第3行)。

ImageDataGenerator类用于数据增强,这是一种对数据集中的图像进行随机变换(旋转、剪切等)以生成其他训练数据的技术。数据增强有助于防止过拟合。

7行导入了Adam优化器,用于训练网络。

9行的LabelBinarizer是一个重要的类,其作用如下:


1.输入一组类标签的集合(即表示数据集中人类可读的类标签字符串)。

2.将类标签转换为独热编码矢量。

3.允许从Keras CNN中进行整型类别标签预测,并转换为人类可读标签。


经常会有读者问:如何将类标签字符串转换为整型?或者如何将整型转换为类标签字符串。答案就是使用LabelBinarizer类。

10行的train_test_split函数用来创建训练和测试分叉。

读者对我自己的imutils较为了解。如果你没有安装或更新,可以通过以下方式进行安装:


如果你使用的是Python虚拟环境,确保在安装或升级imutils之前,用workon命令访问特定的虚拟环境。

我们来解析一下命令行参数:



对于我们的训练脚本,有三个必须的参数:

1.--dataset:输入数据集的路径。数据集放在一个目录中,其子目录代表每个类,每个子目录约有250个精灵图片。

2.--model:输出模型的路径,将训练模型输出到磁盘。

3.--labelbin:输出标签二进制器的路径。


还有一个可选参数--plot。如果不指定路径或文件名,那么plot.png文件则在当前工作目录中。

不需要修改第22-31行来提供新的文件路径,代码在运行时会自行处理。

现在,初始化一些重要的变量:



35-38行对训练Keras CNN时使用的重要变量进行初始化:

1.-EPOCHS训练网络的次数。

2.-INIT-LR初始学习速率值,1e-3Adam优化器的默认值,用来优化网络。

3.-BS将成批的图像传送到网络中进行训练,同一时期会有多个批次,BS值控制批次的大小。

4.-IMAGE-DIMS提供输入图像的空间维度数。输入的图像为96*96*3(即RGB)。


然后初始化两个列表——datalabels,分别保存预处理后的图像和标签。第46-48行抓取所有的图像路径并随机扰乱。

现在,对所有的图像路径ImagePaths进行循环:


首先对imagePaths进行循环(第51行),再对图像进行加载(第53行),然后调整其大小以适应模型(第54行)。

现在,更新datalabels列表。

调用Keras库的img_to_arry函数,将图像转换为与Keras库兼容的数组(第55行),然后将图像添加到名为data的列表中(56)

对于labels列表,我们在第60行文件路径中提取出label,并将其添加在第61行。

那么,为什么需要类标签分解过程呢?

考虑到这样一个事实,我们有目的地创建dataset目录结构,格式如下:


60行的路径分隔符可以将路径分割成一个数组,然后获取列表中的倒数第二项——类标签。

然后进行额外的预处理、二值化标签和数据分区,代码如下:         


首先将data数组转换为NumPy数组,然后将像素强度缩放到[0,1]范围内(第64行),也要将列表中的labels转换为NumPy数组(第65行)。打印data矩阵的大小(以MB为单位)。

然后使用scikit-learn库的LabelBinarzer对标签进行二进制化(7071)

对于深度学习(或者任何机器学习),通常的做法是将训练和测试分开。第7576行将训练集和测试集按照80/20的比例进行分割。

接下来创建图像数据增强对象:


因为训练数据有限(每个类别的图像数量小于250),因此可以利用数据增强为模型提供更多的图像(基于现有图像),数据增强是一种很重要的工具。

7981行使用ImageDataGenerator对变量aug进行初始化,即ImageDataGenerator      

            现在,我们开始编译模型和训练:


85行和第86行使用96963的输入图像初始化Keras CNN模型。注意,我将SmallerVGGNet设计为接受96963输入图像。

87行使用具有学习速率衰减的Adam优化器,然后在88行和89行使用分类交叉熵编译模型。

若只有2个类别,则使用二元交叉熵作为损失函数。

93-97行调用Kerasfit_generator方法训练网络。这一过程需要花费点时间,这取决于你是用CPU还是GPU进行训练。

一旦Keras CNN训练完成,我们需要保存模型(1)和标签二进制化器(2),因为在训练或测试集以外的图像上进行测试时,需要从磁盘中加载出来:


对模型(101行)和标签二进制器(105-107行)进行序列化,以便稍后在classify.py脚本中使用。


最后,绘制训练和损失的准确性图,并保存到磁盘(第121行),而不是显示出来,原因有二:(1)我的服务器在云端;2)确保不会忘记保存图。


使用Keras训练CNN

执行以下代码训练模型:


训练脚本的输出结果如上图所示,Keras CNN模型在训练集上的分类准确率为96.84%;在测试集上的准确率为97.07

训练损失函数和准确性图如下:


如上图所示,对模型训练100次,并在有限的过拟合下实现了低损耗。在新的数据上也能获得更高的准确性。


创建CNNKeras的脚本

现在,CNN已经训练过了,我们需要编写一个脚本,对新图像进行分类。新建一个文件,并命名为classify.py,插入以下代码:


上图中第2-9行导入必要的库。


我们来解析下代码中的参数(12-19行),需要的三个参数如下:

1.--model:已训练模型的路径。

2.--labelbin:标签二进制器的路径。

3.--image:输入图像的路径。

接下来,加载图像并对其进行预处理:


22行加载输入图像image,并复制一个副本,赋值给out(第23行)。

和训练过程使用的预处理方式一样,我们对图像进行预处理(26-29行)。加载模型和标签二值化器(3435行),对图像进行分类:


随后,对图像进行分类并创建标签(39-41行)。

剩余的代码用于显示:


46-47行从filename中提取精灵图鉴的名字,并与label进行比较。Correct变量是正确(correct不正确(incorrect。然后执行以下操作:

1.50行将概率值和正确/不正确文本添加到类别标签label上。

2.51行调整输出图像大小,使其适合屏幕输出。

3.5253行在输出图像上绘制标签。

4.5758行显示输出图像并等待按键退出。


KNNKeras对图像分类

运行classify.py脚本(确保已经从原文下载部分获取代码和图片)!下载并解压缩文件到这个项目的根目录下,然后从Charmander图像开始。代码及试验结果如下:



Bulbasaur图像分类的代码及结果如下所示:



其他图像的分类代码和以上两个图像的代码一样,可自行验证其结果。


模型的局限性

该模型的主要局限是训练数据少。我在各种不同的图像进行测试,发现有时分类不正确。我仔细地检查了输入图像和神经网络,发现图像中的主要颜色会影响分类结果。

例如,如果图像中有许多红色和橙色,则可能会返回“Charmander”标签;图像中的黄色通常会返回“Pikachu”标签。这归因于输入数据,精灵图鉴是虚构的,它没有真实世界中的真实图像。并且,我们只为每个类别提供了比较有限的数据(约225-250张图片)。

理想情况下,训练卷积神经网络时,每个类别至少应有500-1,000幅图像。

可以将Keras深度学习模型作为REST API吗?

如果想将此模型(或任何其他深度学习模型)用作REST API运行,可以参照下面的博文内容:

1.构建一个简单的Keras + 深度学习REST API

2.可扩展的Keras + 深度学习REST API

3.使用KerasRedisFlaskApache进行深度学习


总结

这篇文章主要介绍了如何使用Keras库来训练卷积神经网络(CNN)。使用的是自己创建的数据集(精灵图鉴)作为训练集和测试集,其分类的准确度达到97.07%。


原文链接

干货好文,请关注扫描以下二维码:



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/521907.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网易云音乐热评的规律,44万条数据告诉你

戳蓝字“CSDN云计算”关注我们哦!网易云的每日推荐里藏着你听过的歌,你听过的歌里藏着你的故事。网易云音乐的评论里,藏着许多人的故事。我们爬取了网易云音乐歌单中48400首歌的444054条热评,来看看网易云的热门评论里&#xff0c…

java servlet spring_spring与tomcat 对应关系,servlet各版本写法

构建项目时,需要springjdktomcat各版本对应关系,找了半天,一点都不高效,特此总结下,方便查阅。tomcatjdkservlet对应关系官网文档地址:http://tomcat.apache.org/whichversion.html当前时间版本关系。sprin…

数据科学指南

摘要: 本文为学习数据科学的指南,从编程语言、数据库管理、数据收集和整理方法、特征工程、机器学习模型、模型评估方面进行讲述。数据科学是一个相当庞大且具有多样化的领域,侧重于数学、计算机科学和领域专业知识。但是在本文中大部分内容将…

乐高ev3搭建图_乐高EV3机械爪合集

点击上方蓝字关注我!乐高EV3机械爪合集哈喽小伙伴们!新的一周我们又见面啦。这周给大家带来的是EV3的机械爪合集,5种不同结构类型的机械爪来自五十川老师的作品,可以应用于各种比赛或者任务场景中,下面我们先来看一下这…

马云卸任,张勇宣布未来五年目标:消费规模10万亿;华为发布基于5G和AI解决方案;新iPhone不支持5G 库克:市场不成熟……...

戳蓝字“CSDN云计算”关注我们哦!嗨,大家好,重磅君带来的【云重磅】特别栏目,如期而至,每周五第一时间为大家带来重磅新闻。把握技术风向标,了解行业应用与实践,就交给我重磅君吧!重…

望京“黑客”图鉴

摘要: 不是所有黑客可以登顶望京阿里中心 A 座 34 楼。本文转载自雷锋网宅客频道。最近《北京女子图鉴》很火。不过看这篇文章之前,你要做好几个心理准备:1.这篇文章不是讲黑客男主如何在北京遇上12个女主“打怪升级”的故事。2.因为我们的采…

解决:'webpack-dev-server' 不是内部或外部命令,也不是可运行的程序 或批处理文件。

webpack-dev-server错误法则: 前往项目根目录删除node_modules文件夹,然后在项目根目录路径下的终端运行"npm install"等待安装完之后,再次运行“npm run dev”,有些人的是马上就可以了,然而往往还会有人&am…

GPU云服务器深度学习性能模型初探

摘要: 本文根据实测数据,初步探讨了在弹性GPU云服务器上深度学习的性能模型,可帮助科学选择GPU实例的规格。1 背景得益于GPU强大的计算能力,深度学习近年来在图像处理、语音识别、自然语言处理等领域取得了重大突破,GP…

HDC.2019后再发力,AppGallery Connect服务新升级

不久前,华为2019开发者大会在东莞松山湖圆满落,来自全球的1500多个合作伙伴、5000多名开发者汇聚一堂,共同探讨未来科技发展,其中华为应用市场AppGallery Connect服务也在大会上重磅亮相,引起了广大开发者的关注。如今…

idea 利用vue.js插件创建vue初始化项目

IDEA 构建出的 Vue 项目是不含 node_modules 的,所以要先调出终端,执行 npm install 运行完成后,输入 npm run dev 即可。 vue初始化项目完成!!! 另外 IDE 嘛,总是在 UI 上下了很多功夫&am…

Tensorflow快餐教程(4) - 矩阵

摘要: Tensorflow矩阵基础运算矩阵矩阵的初始化矩阵因为元素更多,所以初始化函数更多了。光靠tf.linspace,tf.range之类的线性生成函数已经不够用了。可以通过先生成一个线性序列,然后再reshape成一个矩阵的方式来初始化。例&…

为什么阿里程序猿纷纷在内网晒代码?

摘要: 大家知道,阿里有两万多名可爱的程序员。 他们也没什么别的爱好,就是多才多艺了一点:这帮阿里程序猿在改变世界前 要先撼动歌坛 就是热心肠了一点:阿里有个程序员,因为闯红灯上新闻了 虽然怕老婆但也能…

从七个方面,面试大厂高级工程师

戳蓝字“CSDN云计算”关注我们哦!在上周,我密集面试了若干位Java后端的候选人,工作经验在3到5年间。我的标准其实不复杂:第一能干活;第二Java基础要好;第三最好熟悉些分布式框架。我相信其它公司招开发时&a…

Java需要掌握的底层知识_Java程序员应该掌握的底层知识

缓存缓存行:缓存行越大,局部性空间效率越高,但读取时间慢缓存行越小,局部性空间效率越低,但读取时间快取一个折中值,目前多用:64字节public class CacheLinePadding { //执行时间在4s左右public…

(vue基础试炼_01)使用vue.js 快速入门hello world

文章目录一、需求案例二、案例实现2.1. 原始js写法2.2. 怎样使用vue.js &#xff1f;2.3. 使用vue.js 写法三、案例vue简述&#xff1f;四、案例趣味延伸五、表达值作用及嘱咐语一、需求案例 在页面显示hello world 二、案例实现 2.1. 原始js写法 <!DOCTYPE html> &l…

如何让机器理解汉字一笔一画的奥秘?

摘要&#xff1a;从智能客服到机器翻译&#xff0c;从文本摘要生成到用户评论分析&#xff0c;从文本安全风控到商品描述建模&#xff0c;无不用到自然语言技术&#xff0c;作为人工智能领域的一个重要分支&#xff0c;如何让机器更懂得人类的语言&#xff0c;尤其是汉字这种强…

Logtail从入门到精通(三):机器分组配置

摘要&#xff1a; 基于集团内数年来的Agent运维经验总结&#xff0c;我们设计了一种灵活性更高、使用更加便捷、耦合度更低的配置&机器管理方式&#xff1a;自定义标识机器分组。此种方式对于动态环境非常适用&#xff0c;尤其适用于弹性伸缩服务和swarm、pouch(阿里docker…

(vue基础试炼_02)使用vue.js实现隔2秒显示不同内容

接上一篇&#xff1a;&#xff08;vue基础试炼_01&#xff09;使用vue.js 快速入门hello worldhttps://gblfy.blog.csdn.net/article/details/103841156 文章目录一、原始js写法① 效果图② 2秒之后二、使用vue实现① 思考② vue写法③ 效果图三、vue总结一、原始js写法 <!…

适合小团队作战,奖金+招聘绿色通道,这一届算法大赛关注下?

大赛背景伴随着5G、物联网与大数据形成的后互联网格局的逐步形成&#xff0c;日益多样化的用户触点、庞杂的行为数据和沉重的业务体量也给我们的数据资产管理带来了不容忽视的挑战。为了建立更加精准的数据挖掘形式和更加智能的机器学习算法&#xff0c;对不断生成的用户行为事…