kaggle计算机视觉比赛技巧,9. 计算机视觉 - 9.12. 实战Kaggle比赛:图像分类(CIFAR-10) - 《动手学深度学习》 - 书栈网 · BookStack...

9.12. 实战Kaggle比赛:图像分类(CIFAR-10)

到目前为止,我们一直在用Gluon的data包直接获取NDArray格式的图像数据集。然而,实际中的图像数据集往往是以图像文件的形式存在的。在本节中,我们将从原始的图像文件开始,一步步整理、读取并将其变换为NDArray格式。

我们曾在“图像增广”一节中实验过CIFAR-10数据集。它是计算机视觉领域的一个重要数据集。现在我们将应用前面所学的知识,动手实战CIFAR-10图像分类问题的Kaggle比赛。该比赛的网页地址是https://www.kaggle.com/c/cifar-10 。

图9.16展示了该比赛的网页信息。为了便于提交结果,请先在Kaggle网站上注册账号。

e21389f3c099b6fe538226d20f97726a.gif

图 9.16 CIFAR-10图像分类比赛的网页信息。比赛数据集可通过点击“Data”标签获取

首先,导入比赛所需的包或模块。In[1]:importd2lzhasd2l

frommxnetimportautograd,gluon,init

frommxnet.gluonimportdataasgdata,lossasgloss,nn

importos

importpandasaspd

importshutil

importtime

9.12.1. 获取和整理数据集

比赛数据分为训练集和测试集。训练集包含5万张图像。测试集包含30万张图像,其中有1万张图像用来计分,其他29万张不计分的图像是为了防止人工标注测试集并提交标注结果。两个数据集中的图像格式都是png,高和宽均为32像素,并含有RGB三个通道(彩色)。图像一共涵盖10个类别,分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。图9.16的左上角展示了数据集中部分飞机、汽车和鸟的图像。

9.12.1.1. 下载数据集

登录Kaggle后,可以点击图9.16所示的CIFAR-10图像分类比赛网页上的“Data”标签,并分别下载训练数据集train.7z、测试数据集test.7z和训练数据集标签trainLabels.csv。

9.12.1.2. 解压数据集

下载完训练数据集train.7z和测试数据集test.7z后需要解压缩。解压缩后,将训练数据集、测试数据集以及训练数据集标签分别存放在以下3个路径:../data/kaggle_cifar10/train/[1-50000].png;

../data/kaggle_cifar10/test/[1-300000].png;

../data/kaggle_cifar10/trainLabels.csv。

为方便快速上手,我们提供了上述数据集的小规模采样,其中train_tiny.zip包含100个训练样本,而test_tiny.zip仅包含1个测试样本。它们解压后的文件夹名称分别为train_tiny和test_tiny。此外,将训练数据集标签的压缩文件解压,并得到trainLabels.csv。如果使用上述Kaggle比赛的完整数据集,还需要把下面demo变量改为False。In[2]:# 如果使用下载的Kaggle比赛的完整数据集,把demo变量改为False

demo=True

ifdemo:

importzipfile

forfin['train_tiny.zip','test_tiny.zip','trainLabels.csv.zip']:

withzipfile.ZipFile('../data/kaggle_cifar10/'+f,'r')asz:

z.extractall('../data/kaggle_cifar10/')

9.12.1.3. 整理数据集

我们需要整理数据集,以方便训练和测试模型。以下的read_label_file函数将用来读取训练数据集的标签文件。该函数中的参数valid_ratio是验证集样本数与原始训练集样本数之比。In[3]:defread_label_file(data_dir,label_file,train_dir,valid_ratio):

withopen(os.path.join(data_dir,label_file),'r')asf:

# 跳过文件头行(栏名称)

lines=f.readlines()[1:]

tokens=[l.rstrip().split(',')forlinlines]

idx_label=dict(((int(idx),label)foridx,labelintokens))

labels=set(idx_label.values())

n_train_valid=len(os.listdir(os.path.join(data_dir,train_dir)))

n_train=int(n_train_valid*(1-valid_ratio))

assert0

returnn_train// len(labels), idx_label

下面定义一个辅助函数,从而仅在路径不存在的情况下创建路径。In[4]:defmkdir_if_not_exist(path):# 本函数已保存在d2lzh包中方便以后使用

ifnotos.path.exists(os.path.join(*path)):

os.makedirs(os.path.join(*path))

我们接下来定义reorg_train_valid函数来从原始训练集中切分出验证集。以valid_ratio=0.1为例,由于原始训练集有50,000张图像,调参时将有45,000张图像用于训练并存放在路径input_dir/train下,而另外5,000张图像将作为验证集并存放在路径input_dir/valid下。经过整理后,同一类图像将被放在同一个文件夹下,便于稍后读取。In[5]:defreorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,

idx_label):

label_count={}

fortrain_fileinos.listdir(os.path.join(data_dir,train_dir)):

idx=int(train_file.split('.')[0])

label=idx_label[idx]

mkdir_if_not_exist([data_dir,input_dir,'train_valid',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'train_valid',label))

iflabelnotinlabel_countorlabel_count[label]

mkdir_if_not_exist([data_dir,input_dir,'train',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'train',label))

label_count[label]=label_count.get(label,0)+1

else:

mkdir_if_not_exist([data_dir,input_dir,'valid',label])

shutil.copy(os.path.join(data_dir,train_dir,train_file),

os.path.join(data_dir,input_dir,'valid',label))

下面的reorg_test函数用来整理测试集,从而方便预测时的读取。In[6]:defreorg_test(data_dir,test_dir,input_dir):

mkdir_if_not_exist([data_dir,input_dir,'test','unknown'])

fortest_fileinos.listdir(os.path.join(data_dir,test_dir)):

shutil.copy(os.path.join(data_dir,test_dir,test_file),

os.path.join(data_dir,input_dir,'test','unknown'))

最后,我们用一个函数分别调用前面定义的read_label_file函数、reorg_train_valid函数以及reorg_test函数。In[7]:defreorg_cifar10_data(data_dir,label_file,train_dir,test_dir,input_dir,

valid_ratio):

n_train_per_label,idx_label=read_label_file(data_dir,label_file,

train_dir,valid_ratio)

reorg_train_valid(data_dir,train_dir,input_dir,n_train_per_label,

idx_label)

reorg_test(data_dir,test_dir,input_dir)

我们在这里只使用100个训练样本和1个测试样本。训练数据集和测试数据集的文件夹名称分别为train_tiny和test_tiny。相应地,我们仅将批量大小设为1。实际训练和测试时应使用Kaggle比赛的完整数据集,并将批量大小batch_size设为一个较大的整数,如128。我们将10%的训练样本作为调参使用的验证集。In[8]:ifdemo:

# 注意,此处使用小训练集和小测试集并将批量大小相应设小。使用Kaggle比赛的完整数据集时可

# 设批量大小为较大整数

train_dir,test_dir,batch_size='train_tiny','test_tiny',1

else:

train_dir,test_dir,batch_size='train','test',128

data_dir,label_file='../data/kaggle_cifar10','trainLabels.csv'

input_dir,valid_ratio='train_valid_test',0.1

reorg_cifar10_data(data_dir,label_file,train_dir,test_dir,input_dir,

valid_ratio)

9.12.2. 图像增广

为应对过拟合,我们使用图像增广。例如,加入transforms.RandomFlipLeftRight()即可随机对图像做镜面翻转,也可以通过transforms.Normalize()对彩色图像RGB三个通道分别做标准化。下面列举了其中的部分操作,你可以根据需求来决定是否使用或修改这些操作。In[9]:transform_train=gdata.vision.transforms.Compose([

# 将图像放大成高和宽各为40像素的正方形

gdata.vision.transforms.Resize(40),

# 随机对高和宽各为40像素的正方形图像裁剪出面积为原图像面积0.64~1倍的小正方形,再放缩为

# 高和宽各为32像素的正方形

gdata.vision.transforms.RandomResizedCrop(32,scale=(0.64,1.0),

ratio=(1.0,1.0)),

gdata.vision.transforms.RandomFlipLeftRight(),

gdata.vision.transforms.ToTensor(),

# 对图像的每个通道做标准化

gdata.vision.transforms.Normalize([0.4914,0.4822,0.4465],

[0.2023,0.1994,0.2010])])

测试时,为保证输出的确定性,我们仅对图像做标准化。In[10]:transform_test=gdata.vision.transforms.Compose([

gdata.vision.transforms.ToTensor(),

gdata.vision.transforms.Normalize([0.4914,0.4822,0.4465],

[0.2023,0.1994,0.2010])])

9.12.3. 读取数据集

接下来,可以通过创建ImageFolderDataset实例来读取整理后的含原始图像文件的数据集,其中每个数据样本包括图像和标签。In[11]:# 读取原始图像文件。flag=1说明输入图像有3个通道(彩色)

train_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'train'),flag=1)

valid_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'valid'),flag=1)

train_valid_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'train_valid'),flag=1)

test_ds=gdata.vision.ImageFolderDataset(

os.path.join(data_dir,input_dir,'test'),flag=1)

我们在DataLoader中指明定义好的图像增广操作。在训练时,我们仅用验证集评价模型,因此需要保证输出的确定性。在预测时,我们将在训练集和验证集的并集上训练模型,以充分利用所有标注的数据。In[12]:train_iter=gdata.DataLoader(train_ds.transform_first(transform_train),

batch_size,shuffle=True,last_batch='keep')

valid_iter=gdata.DataLoader(valid_ds.transform_first(transform_test),

batch_size,shuffle=True,last_batch='keep')

train_valid_iter=gdata.DataLoader(train_valid_ds.transform_first(

transform_train),batch_size,shuffle=True,last_batch='keep')

test_iter=gdata.DataLoader(test_ds.transform_first(transform_test),

batch_size,shuffle=False,last_batch='keep')

9.12.4. 定义模型

与“残差网络(ResNet)”一节中的实现稍有不同,这里基于HybridBlock类构建残差块。这是为了提升执行效率。In[13]:classResidual(nn.HybridBlock):

def__init__(self,num_channels,use_1x1conv=False,strides=1,**kwargs):

super(Residual,self).__init__(**kwargs)

self.conv1=nn.Conv2D(num_channels,kernel_size=3,padding=1,

strides=strides)

self.conv2=nn.Conv2D(num_channels,kernel_size=3,padding=1)

ifuse_1x1conv:

self.conv3=nn.Conv2D(num_channels,kernel_size=1,

strides=strides)

else:

self.conv3=None

self.bn1=nn.BatchNorm()

self.bn2=nn.BatchNorm()

defhybrid_forward(self,F,X):

Y=F.relu(self.bn1(self.conv1(X)))

Y=self.bn2(self.conv2(Y))

ifself.conv3:

X=self.conv3(X)

returnF.relu(Y+X)

下面定义ResNet-18模型。In[14]:defresnet18(num_classes):

net=nn.HybridSequential()

net.add(nn.Conv2D(64,kernel_size=3,strides=1,padding=1),

nn.BatchNorm(),nn.Activation('relu'))

defresnet_block(num_channels,num_residuals,first_block=False):

blk=nn.HybridSequential()

foriinrange(num_residuals):

ifi==0andnotfirst_block:

blk.add(Residual(num_channels,use_1x1conv=True,strides=2))

else:

blk.add(Residual(num_channels))

returnblk

net.add(resnet_block(64,2,first_block=True),

resnet_block(128,2),

resnet_block(256,2),

resnet_block(512,2))

net.add(nn.GlobalAvgPool2D(),nn.Dense(num_classes))

returnnet

CIFAR-10图像分类问题的类别个数为10。我们将在训练开始前对模型进行Xavier随机初始化。In[15]:defget_net(ctx):

num_classes=10

net=resnet18(num_classes)

net.initialize(ctx=ctx,init=init.Xavier())

returnnet

loss=gloss.SoftmaxCrossEntropyLoss()

9.12.5. 定义训练函数

我们将根据模型在验证集上的表现来选择模型并调节超参数。下面定义了模型的训练函数train。我们记录了每个迭代周期的训练时间,这有助于比较不同模型的时间开销。In[16]:deftrain(net,train_iter,valid_iter,num_epochs,lr,wd,ctx,lr_period,

lr_decay):

trainer=gluon.Trainer(net.collect_params(),'sgd',

{'learning_rate':lr,'momentum':0.9,'wd':wd})

forepochinrange(num_epochs):

train_l_sum,train_acc_sum,n,start=0.0,0.0,0,time.time()

ifepoch>0andepoch%lr_period==0:

trainer.set_learning_rate(trainer.learning_rate*lr_decay)

forX,yintrain_iter:

y=y.astype('float32').as_in_context(ctx)

withautograd.record():

y_hat=net(X.as_in_context(ctx))

l=loss(y_hat,y).sum()

l.backward()

trainer.step(batch_size)

train_l_sum+=l.asscalar()

train_acc_sum+=(y_hat.argmax(axis=1)==y).sum().asscalar()

n+=y.size

time_s="time %.2f sec"%(time.time()-start)

ifvalid_iterisnotNone:

valid_acc=d2l.evaluate_accuracy(valid_iter,net,ctx)

epoch_s=("epoch %d, loss %f, train acc %f, valid acc %f, "

%(epoch+1,train_l_sum/n,train_acc_sum/n,

valid_acc))

else:

epoch_s=("epoch %d, loss %f, train acc %f, "%

(epoch+1,train_l_sum/n,train_acc_sum/n))

print(epoch_s+time_s+', lr '+str(trainer.learning_rate))

9.12.6. 训练并验证模型

现在,我们可以训练并验证模型了。下面的超参数都是可以调节的,如增加迭代周期等。由于lr_period和lr_decay分别设为80和0.1,优化算法的学习率将在每80个迭代周期后自乘0.1。简单起见,这里仅训练1个迭代周期。In[17]:ctx,num_epochs,lr,wd=d2l.try_gpu(),1,0.1,5e-4

lr_period,lr_decay,net=80,0.1,get_net(ctx)

net.hybridize()

train(net,train_iter,valid_iter,num_epochs,lr,wd,ctx,lr_period,

lr_decay)epoch1,loss5.998157,train acc0.055556,valid acc0.100000,time1.34sec,lr0.1

9.12.7. 对测试集分类并在Kaggle提交结果

得到一组满意的模型设计和超参数后,我们使用所有训练数据集(含验证集)重新训练模型,并对测试集进行分类。In[18]:net,preds=get_net(ctx),[]

net.hybridize()

train(net,train_valid_iter,None,num_epochs,lr,wd,ctx,lr_period,

lr_decay)

forX,_intest_iter:

y_hat=net(X.as_in_context(ctx))

preds.extend(y_hat.argmax(axis=1).astype(int).asnumpy())

sorted_ids=list(range(1,len(test_ds)+1))

sorted_ids.sort(key=lambdax:str(x))

df=pd.DataFrame({'id':sorted_ids,'label':preds})

df['label']=df['label'].apply(lambdax:train_valid_ds.synsets[x])

df.to_csv('submission.csv',index=False)epoch1,loss6.620115,train acc0.090000,time1.24sec,lr0.1

执行完上述代码后,我们会得到一个submission.csv文件。这个文件符合Kaggle比赛要求的提交格式。提交结果的方法与“实战Kaggle比赛:房价预测”一节中的类似。

9.12.8. 小结可以通过创建ImageFolderDataset实例来读取含原始图像文件的数据集。

可以应用卷积神经网络、图像增广和混合式编程来实战图像分类比赛。

9.12.9. 练习使用Kaggle比赛的完整CIFAR-10数据集。把批量大小batch_size和迭代周期数num_epochs分别改为128和300。可以在这个比赛中得到什么样的准确率和名次?

如果不使用图像增广的方法能得到什么样的准确率?

参与讨论,在社区交流方法和结果。你能发掘出其他更好的技巧吗?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/542271.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

qthread中获取当前优先级_Linux中强大的top命令

top命令算是最直观、好用的查看服务器负载的命令了。它实时动态刷新显示服务器状态信息,且可以通过交互式命令自定义显示内容,非常强大。在终端中输入top,回车后会显示如下内容:top - 21:48:39 up 8:57, 2 users, load average: 0…

JavaScript中带示例的String repeat()方法

JavaScript | 字符串repeat()方法 (JavaScript | String repeat() Method) The String.repeat() method in JavaScript is used to generate a string by repeating the calling string n number of times. n can be any integer from o to any possible number in JavaScript.…

Python生成验证码

#!/usr/bin/env python #coding:utf8 import random #方法1: str_codezxcvbnmasdfghjklqwertyuiopZXCVBNMASDFGHJKLQWERTYUIOP0123456789new_codefor i in range(4):   new_coderandom.choice(str_code)print new_code #方法2: new_code[]def str_code…

snmp 获得硬件信息_计算机网络基础课程—简单网络管理协议(SNMP)

简单网络管理协议(Simple Network Management Protocol)•除了提供网络层服务的协议和使用那些服务的应用程序,因特网还需要运行一些让管理员进行设备管理、调试问题、控制路由、监测机器状态的软件。这种行为称为网络管理。••随着网络技术的飞速发展,…

僵尸毁灭工程 服务器已停止运行,《僵尸毁灭工程》steam is not enabled错误解决方法...

Steam 上面的 Project Zomboid 因为带有 VAC 所以建服开服需要 Steam服务器认证,这也是出现 steam is not enabled 错误主要原因,也是无法和普通零售正版所建的服务器联机的罪魁祸首。分两种情况(下面 Project Zomboid 均简称PZ):1、steam版P…

spring boot 1.4默认使用 hibernate validator

spring boot 1.4默认使用 hibernate validator 5.2.4 Final实现校验功能。hibernate validator 5.2.4 Final是JSR 349 Bean Validation 1.1的具体实现。 How to disable Hibernate validation in a Spring Boot project As [M. Deinum] mentioned in a comment on my original …

python mpi开销_GitHub - hustpython/MPIK-Means

并行计算的K-Means聚类算法实现一,实验介绍聚类是拥有相同属性的对象或记录的集合,属于无监督学习,K-Means聚类算法是其中较为简单的聚类算法之一,具有易理解,运算深度块的特点.1.1 实验内容通过本次课程我们将使用C语…

服务器修改开机启动项,启动项设置_服务器开机启动项

最近很多观众老爷在苦觅关于启动项设置的解答,今天钦编为大家综合5条解答来给大家解开疑惑! 有98%玩家认为启动项设置_服务器开机启动项值得一读!启动项设置1.如何在bios设置硬盘为第一启动项详细步骤根据BIOS分类的不同操作不同:…

字符串查找字符出现次数_查找字符串作为子序列出现的次数

字符串查找字符出现次数Description: 描述: Its a popular interview question based of dynamic programming which has been already featured in Accolite, Amazon. 这是一个流行的基于动态编程的面试问题,已经在亚马逊的Accolite中得到了体现。 Pr…

Ubuntu 忘记密码的处理方法

Ubuntu系统启动时选择recovery mode,也就是恢复模式。接着选择Drop to root shell prompt ,也就是获取root权限。输入命令查看用户名 cat /etc/shadow ,$号前面的是用户名输入命令:passwd "用户名" 回车就可以输入新密码了转载于:…

服务器mdl文件转换,Simulink Project 中 MDL 到 SLX 模型文件格式的转换

打开弹体示例项目并将 MDL 文件另存为 SLX运行以下命令以创建并打开“sldemo_slproject_airframe”示例的工作副本。Simulink.ModelManagement.Project.projectDemo(airframe, svn);rebuild_s_functions(no_progress_dialog);Creating sandbox for project.Created example fil…

vue 修改div宽度_Vue 组件通信方式及其应用场景总结(1.5W字)

前言相信实际项目中用过vue的同学,一定对vue中父子组件之间的通信并不陌生,vue中采用良好的数据通讯方式,避免组件通信带来的困扰。今天笔者和大家一起分享vue父子组件之间的通信方式,优缺点,及其实际工作中的应用场景…

Java System类identityHashCode()方法及示例

系统类identityHashCode()方法 (System class identityHashCode() method) identityHashCode() method is available in java.lang package. identityHashCode()方法在java.lang包中可用。 identityHashCode() method is used to return the hashcode of the given object – B…

Linux中SysRq的使用(魔术键)

转:http://www.chinaunix.net/old_jh/4/902287.html 魔术键:Linux Magic System Request Key Hacks 当Linux 系统不能正常响应用户请求时, 可以使用SysRq小工具控制Linux. 一 SysRq的启用与关闭 要想启用SysRq, 需要在配置内核时设置Magic SysRq key (CO…

链接服务器访问接口返回了消息没有活动事务,因为链接服务器 SQLEHR 的 OLE DB 访问接口 SQLNCLI10 无法启动分布式事务。...

查看一下MSDTC啟動是否正確1、运行 regedt32,浏览至 HKEY_LOCAL_MACHINE\Software\Microsoft\MSDTC。添加一个 DWORD 值 TurnOffRpcSecurity,值数据为 1。2、重启MS DTC服务。3、打开“管理工具”的“组件服务”。a. 浏览至"启动管理工具"。b.…

micropython 蜂鸣器_基于MicroPython的TPYBoard微信远程可燃气体报警器的设计与实现...

前言在我们平时的生活中,经常看到因气体泄漏发生爆炸事故的新闻。房屋起火、人体中毒等此类的新闻报道层出不穷。这种情况下,人民就发明了可燃气体报警器。当工业环境、日常生活环境(如使用天然气的厨房)中可燃性气体发生泄露,可燃气体报警器…

Java PropertyPermission getActions()方法与示例

PropertyPermission类的getActions()方法 (PropertyPermission Class getActions() method) getActions() method is available in java.util package. getActions()方法在java.util包中可用。 getActions() method is used to get the list of current actions in the form of…

源码安装nginx以及平滑升级

源码安装nginx以及平滑升级作者:尹正杰版权声明:原创作品,谢绝转载!否则将追究法律责任。欢迎加入:高级运维工程师之路 598432640这个博客不方便上传软件包,我给大家把软件包放到百度云链接:htt…

ajax 跨站返回值,jquery ajax 跨域问题

补充回答:你的动态页只是一个请求页。例如你新建一个 get.asp 页面,用以下代码,在服务端实现像URL异步(ajax)请求,将请求结果输出。客户端页面再次用ajax(JS或者jquery的)向get.asp请求数据。两次ajax完成异域数据请求。get.asp代…

Bootstrap学习笔记系列1-------Bootstrap网格系统

目录 Bootstrap网格系统 学习笔记简单网格偏移列嵌套列列排序Bootstrap网格系统 学习笔记 简单网格 先上代码再解释 <!DOCTYPE html> <html><head><title>Bootstrap 模板</title><meta charset"utf-8"><!-- 引入 Bootstrap -…