用 Kaggle 经典案例教你用 CNN 做图像分类!

我们来看一个 Kaggle 上比较经典的一个图像分类的比赛 CIFAR( CIFAR-10 - Object Recognition in Images ),这个比赛现在已经关闭了,但不妨碍我们来去通过它学习一下卷积神经网络做图像识别的代码结构。相信很多学过深度学习的同学都尝试过这个比赛,如果对此比较熟悉的可以跳过本篇,如果没有尝试过的同学可以来学习一下哈。

整个代码已经放在了我的 GitHub 上,建议可以把代码 pull 下来,边看文章边看代码。

GitHub 地址:NELSONZHAO/zhihu

如果觉得有帮助,麻烦点个 star 啦~

介绍

文章主要分为两个部分,第一部分我们将通过一个简单的 KNN 来实现图像的分类,第二部分我们通过卷积神经网络提升整个图像分类的性能。

第一部分

提到图像分类,我们可能会想到传统机器学习中 KNN 算法,通过找到当前待分类图像的 K 个近邻,以近邻的类别判断当前图像的类别。

由于我们的图像实际上是由一个一个像素组成的,因此每一个图像可以看做是一个向量,那么我们此时就可以来计算向量(图片)之间的距离。比如,我们的图片如果是 32x32 像素的,那么可以展开成一个 1x1024 的向量,就可以计算这些向量间的 L1 或者 L2 距离,找到它们的近邻,从而根据近邻的类别来判断图像的类别。

以下例子中 K=5。

用 Kaggle 经典案例教你用 CNN 做图像分类!

下面我们就来用 scikit-learn 实现以下 KNN 对图像的分类。

首先我们需要下载数据文件,网址为 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 。我们数据包含了 60000 万图片,每张图片的维度为 32 x 32 x 3,这些图片都有各自的标注,一共分为了以下十类:

  • airplane

  • automobile

  • bird

  • cat

  • deer

  • dog

  • frog

  • horse

  • ship

  • truck

数据是被序列化以后存储的,因此我们需要使用 Python 中的 pickle 包将它们读进来。整个压缩包解压以后,会有 5 个 data_batch 和 1 个 test_batch。我们首先把数据加载进来:

用 Kaggle 经典案例教你用 CNN 做图像分类!

我们定义了一个函数来获取 batch 中的 features 和 labels,通过上面的步骤,我们就可以获得 train 数据与 test 数据。

我们的每个图片的维度是 32 x 32 x 3,其中 3 代表 RGB。我们先来看一些这些图片长什么样子。

用 Kaggle 经典案例教你用 CNN 做图像分类!

每张图片的像素其实很低,缩小以后我们可以看到图片中有汽车,马,飞机等。

构造好了我们的 x_train, y_train, x_test 以及 y_test 以后,我们就可以开始建模过程。在将图片扔进模型之前,我们首先要对数据进行预处理,包括重塑和归一化两步,首先将 32 x 32 x 3 转化为一个 3072 维的向量,再对数据进行归一化,归一化的目的在于计算距离时保证各个维度的量纲一致。

用 Kaggle 经典案例教你用 CNN 做图像分类!

到此为止,我们已经对数据进行了预处理,下面就可以调用 KNN 来进行训练,我分别采用了 K=1,3,5 来看模型的效果。

用 Kaggle 经典案例教你用 CNN 做图像分类!

从 KNN 的分类准确率来看,是要比我们随机猜测类别提高了不少。我们随机猜测图片类别时,准确率大概是 10%,KNN 方式的图片分类可以将准确率提高到 35% 左右。当然有兴趣的小伙伴还可以去测试一下其他的 K 值,同时在上面的算法中,默认距离衡量方式是欧式距离,还可以尝试其他度量距离来进行建模。

虽然 KNN 在 test 数据集上表现有所提升,但是这个准确率还是太低了。除此之外,KNN 有一个缺点,就是所有的计算时间都在 predict 阶段,当一个新的图来的时候,涉及到大量的距离计算,这就意味着一旦我们要拿它来进行图像识别,那可能要等非常久才能拿到结果,而且还不是那么的准。


第二部分

在上一部分,我们用了非常简单的 KNN 思想实现了图像分类。在这个部分,我们将通过卷积神经网络来实现一个更加准确、高效的模型。

加载数据的过程与上一部分相同,不再赘述。当我们将数据加载完毕后,首先要做以下三件事:

  • 对输入数据归一化

  • 对标签进行 one-hot 编码

  • 构造训练集,验证集和测试集

对输入数据归一化

在这里我们使用 sklearn 中的 minmax 归一化。

用 Kaggle 经典案例教你用 CNN 做图像分类!

首先将训练数据集重塑为 [50000, 3072] 的形状,利用 minmax 来进行归一化。最后再将图像重塑回原来的形状。

对标签进行 one-hot 编码

同样我们在这里使用 sklearn 中的 LabelBinarizer 来进行 one-hot 编码。

用 Kaggle 经典案例教你用 CNN 做图像分类!

构造 train 和 val

目前我们已经有了 train 和 test 数据集,接下来我们要将加载进来的 train 分成训练集和验证集。从而在训练过程中观察验证集的结果。

用 Kaggle 经典案例教你用 CNN 做图像分类!

我们将训练数据集按照 8:2 分为 train 和 validation。

卷积网络

完成了数据的预处理,我们接下来就要开始进行建模。

首先我们把一些重要的参数设置好,并且将输入和标签 tensor 构造好。

用 Kaggle 经典案例教你用 CNN 做图像分类!

img_shape 是整个训练集的形状,为 [40000, 32, 32, 3],同时我们的输入形状是 [batch_size, 32, 32, 3],由于前面我们已经对标签进行了 one-hot 编码,因此标签是一个 [batch_size, 10] 的 tensor。

接下来我们先来看一下整个卷积网络的结构:

用 Kaggle 经典案例教你用 CNN 做图像分类!

在这里我设置了两层卷积 + 两层全连接层的结构,大家也可以尝试其他不同的结构和参数。

用 Kaggle 经典案例教你用 CNN 做图像分类!

conv2d 中我自己定义了初始化权重为 truncated_normal,事实证明权重初始化对于卷积结果有一定的影响。

在这里,我们来说一下 conv2d 的参数:

  • 输入 tensor:inputs_

  • 滤波器的数量:64

  • 滤波器的 size:height=2, width=2, depth 默认与 inputs_的 depth 相同

  • strides:strides 默认为 1x1,因此在这里我没有重新设置 strides

  • padding:padding 我选了 same,在 strides 是 1 的情况下,经过卷积以后 height 和 width 与原图保持一致

  • kernel_initializer:滤波器的初始化权重

在这里讲一下卷积函数中的两种常见 padding 方式,分别是 valid,same。假设我们输入图片长和宽均为 h,filter 的 size 为 k x k,strides 为 s x s,padding 大小 = p。当 padding=valid 时,经过卷积以后的图片新的长(或宽)为 用 Kaggle 经典案例教你用 CNN 做图像分类! ;当 padding=same 时,经过卷积以后 用 Kaggle 经典案例教你用 CNN 做图像分类! 。但在 TensorFlow 中的实现与这里有所区别,在 TensorFlow 中,当 padding=valid 时, 用 Kaggle 经典案例教你用 CNN 做图像分类! ;当 padding=same 时, 用 Kaggle 经典案例教你用 CNN 做图像分类! 。

其余参数类似,这里不再赘述,如果还不是很清楚的小伙伴可以去查看官方文档。

在第一个全连接层中我加入了 dropout 正则化防止过拟合,同时加快训练速度。

训练模型

完成了模型的构建,下面我们就来开始训练整个模型。

用 Kaggle 经典案例教你用 CNN 做图像分类!

在训练过程中,每 100 轮打印一次日志,显示出当前 train loss 和 validation 上的准确率。

我们来看一下最终的训练结果:

用 Kaggle 经典案例教你用 CNN 做图像分类!

上图是我之前跑的一次结果,这次跑出来可能有所出入,但准确率大概会在 65%-70% 之间。

最后在 validation 上的准确率大约稳定在了 70% 左右,我们接下来看一下在 test 数据上的准确率。下面的代码是在 test 测试准确率的代码。

用 Kaggle 经典案例教你用 CNN 做图像分类!

我们把训练结果加载进来,设置 test 的 batchs_size 为 100,来测试我们的训练结果。最终我们的测试准确率也基本在 70% 左右。

总结

至此,我们实现了两种图像分类的算法。第一种是 KNN,它的思想非常好理解,但缺点在于计算量都集中在测试阶段,训练阶段的计算量几乎为 0,另外,它的准确性也非常差。第二种我们利用 CNN 实现了分类,最终的测试结果大约在 70% 左右,相比 KNN 的 30% 准确率,它的分类效果表现的相当好。当然,如果想要继续提升模型的准确率,就需要采用其他的一些手段,如果感兴趣的小伙伴可以去看一下相关链接(http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#43494641522d3130) 里的技巧,Kaggle 上的第一名准确率已经超过了 95%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473479.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flask知识点回顾以及重点内容

1. HTTP通信与Web框架 1.1 流程 客户端将请求打包成HTTP的请求报文(HTTP协议格式的请求数据) 采用TCP传输发送给服务器端 服务器接收到请求报文后按照HTTP协议进行解析 服务器根据解析后获知的客户端请求进行逻辑执行 服务器将执行后的结果封装成HTTP的响…

机器学习回归算法—线性回归及案例分析

一、回归算法回归是统计学中最有力的工具之一。机器学习监督学习算法分为分类算法和回归算法两种,其实就是根据类别标签分布类型为离散型、连续性而定义的。回归算法用于连续型分布预测,针对的是数值型的样本,使用回归,可以在给定…

LeetCode 1669. 合并两个链表

文章目录1. 题目2. 解题1. 题目 给你两个链表 list1 和 list2 ,它们包含的元素分别为 n 个和 m 个。 请你将 list1 中第 a 个节点到第 b 个节点删除,并将list2 接在被删除节点的位置。 下图中蓝色边和节点展示了操作后的结果: 请你返回结果…

机器学习回归算法—性能评估欠拟合与过拟合

机器学习中的泛化,泛化即是,模型学习到的概念在它处于学习的过程中时模型没有遇见过的样本时候的表现。在机器学习领域中,当我们讨论一个机器学习模型学习和泛化的好坏时,我们通常使用术语:过拟合和欠拟合。我们知道模…

Nginx安全配置

nginx本身不能处理PHP,它只是个web服务器,当接收到请求后,如果是php请求,则发给php解释器处理,并把结果返回给客户端。nginx一般是把请求发fastcgi管理进程处理,fastcgi管理进程选择cgi子进程处理结果并返回…

LeetCode 1670. 设计前中后队列(deque)

文章目录1. 题目2. 解题1. 题目 请你设计一个队列,支持在前,中,后三个位置的 push 和 pop 操作。 请你完成 FrontMiddleBack 类: FrontMiddleBack() 初始化队列。 void pushFront(int val) 将 val 添加到队列的 最前面 。 void…

java 1.7 新特性

1.对Java集合(Collections)的增强支持 在JDK1.7之前的版本中,Java集合容器中存取元素的形式如下: 以List、Set、Map集合容器为例: 在JDK1.7中,摒弃了Java集合接口的实现类,如:ArrayL…

LeetCode 1671. 得到山形数组的最少删除次数(最长上升子序DP nlogn)

文章目录1. 题目2. 解题2.1 n^2 解法2.2 nlogn 解法197 / 1891,前10.4%435 / 6154,前7.07%前三题如下: LeetCode 5557. 最大重复子字符串 LeetCode 5558. 合并两个链表 LeetCode 5560. 设计前中后队列(deque) 1. 题目…

机器学习Tensorflow基础知识、张量与变量

TensorFlow是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即…

LeetCode 1672. 最富有客户的资产总量

文章目录1. 题目2. 解题1. 题目 给你一个 m x n 的整数网格 accounts ,其中 accounts[i][j] 是第 i​​​​​​​​​​​​ 位客户在第 j 家银行托管的资产数量。 返回最富有客户所拥有的 资产总量 。 客户的 资产总量 就是他们在各家银行托管的资产数量之和。最…

机器学习Tensorflow基本操作:线程队列图像

一、线程和队列在使用TensorFlow进行异步计算时,队列是一种强大的机制。为了感受一下队列,让我们来看一个简单的例子。我们先创建一个“先入先出”的队列(FIFOQueue),并将其内部所有元素初始化为零。然后,我…

关于使用ModelSim中编写testbench模板问题

对于初学者来说写Testbench测试文件还是比较困难的,但Modelsim和quartus ii都提供了模板,下面就如何使用Modelsim提供的模板进行操作。 Modelsim提供了很多Testbench模板,我们直接拿过来用可以减少工作量。对源文件编译完后,鼠标光…

LeetCode 1673. 找出最具竞争力的子序列(单调栈)

文章目录1. 题目2. 解题1. 题目 给你一个整数数组 nums 和一个正整数 k ,返回长度为 k 且最具 竞争力 的 nums 子序列。 数组的子序列是从数组中删除一些元素(可能不删除元素)得到的序列。 在子序列 a 和子序列 b 第一个不相同的位置上&am…

android获取string.xml的值

为什么需要把应用中出现的文字单独存放在string.xml文件中呢? 一:是为了国际化,当需要国际化时,只需要再提供一个string.xml文件,把里面的汉子信息都修改为对应的语言(如,English),再…

牛客 怕npy的牛牛(双指针)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/9556/B 来源:牛客网 题目描述 牛牛非常怕他的女朋友,怕到了走火入魔的程度,以至于每当他看到一个字符串同时含有n,p,y三个字母他都害怕的不行。 现在…

Flask入门之上传文件到本地服务器

Flask入门之上传文件到服务器今天要做一个简单的页面,可以实现将文件 上传到服务器(保存在指定文件夹)#Sample.py1 # coding:utf-82 3 from flask import Flask,render_template,request,redirect,url_for4 from werkzeug.utils import secur…

对象的三种状态

来自为知笔记(Wiz)转载于:https://www.cnblogs.com/zmpandzmp/p/3649196.html

Cygwin中如何像在Ubuntu中一样安装软件

cygwin作为windows下模拟Linux环境的的工具,使得我们能在windows下非常方便的使用Linux的命令和工具,下面讲讲怎样在cygwin添加不支持的命令。 1.首先安装cygwin:我们可以到Cygwin的官方网站下载Cygwin的安装程序,地址是&#xff…

LeetCode 321. 拼接最大数(单调栈)*

文章目录1. 题目2. 解题1. 题目 给定长度分别为 m 和 n 的两个数组&#xff0c;其元素由 0-9 构成&#xff0c;表示两个自然数各位上的数字。 现在从这两个数组中选出 k (k < m n) 个数字拼接成一个新的数&#xff0c;要求从同一个数组中取出的数字保持其在原数组中的相对…

pandas数据分析选则接近数值的最接优方案

import numpy as np import pandas as pd# pandas数据分析选则接近数值的最接优方案# 1.准备数据 CHILD_TABLE (720, 750) CHIDL_STOOL (300, 350) CHILD_PLAY_LEN (300, 400) CHILD_TENT (1100, 1300) # 2.遍历循环&#xff0c;添加到列表中 sum_length_lst [] play_lst …