用PyTorch创建一个图像分类器?So easy!(Part 2)

在第一部分中,我们知道了为什么以及如何加载预先训练好的神经网络,我们可以用自己的分类器代替已有神经网络的分类器。那么,在这篇文章中,我们将学习如何训练分类器。

训练分类器

首先,我们需要为分类器提供待分类的图像。本文使用ImageFolder加载图像,预训练神经网络的输入有特定的格式,因此,我们需要用一些变换来调整图像的大小,即在将图像输入到神经网络之前,对其进行裁剪和标准化处理。

具体来说,将图像大小调整为224*224,并对图像进行标准化处理,即均值为 [0.485,0.456,0.406],标准差为[0.229,0.224,0.225],颜色管道的均值设为0,标准差缩放为1。

然后,使用DataLoader批量传递图像,由于有三个数据集:训练数据集、验证数据集和测试数据集,因此需要为每个数据集创建一个加载器。一切准备就绪后,就可以训练分类器了。

在这里,最重要的挑战就是——正确率(accuracy)。

让模型识别一个已经知道的图像,这不算啥事,但是我们现在的要求是:能够概括、确定以前从未见过的图像中花的类型。在实现这一目标过程中,我们一定要避免过拟合,即“分析的结果与特定数据集的联系过于紧密或完全对应,因此可能无法对其他数据集进行可靠的预测或分析”。

隐藏层

实现适当拟合的方法有很多种,其中一种很简单的方法就是:隐藏层

我们很容易陷入这样一种误区:拥有更多或更大的隐藏层,能够提高分类器的正确率,但事实并非如此。

增加隐藏层的数量或大小以后,我们的分类器就需要考虑更多不必要的参数。举个例子来说,将噪音看做是花朵的一部分,这会导致过拟合,也会降低精度,不仅如此,分类器还需要更长的时间来训练和预测。

因此,我建议你从数量较少的隐藏层开始,然后根据需要增加隐藏层的数量或大小,而不是一开始就使用特别多或特别大的隐藏层。

在第一部分介绍的《AI Programming with Python Nanodegree》课程中的花卉分类器项目中,我只需要一个小的隐藏层,在第一个完整训练周期内,就得到了70%以上的正确率。

数据增强

我们有很多图像可供模型训练,这非常不错。如果拥有更多的图像,数据增强就可以发挥作用了。每个图像在每个训练周期都会作为神经网络的输入,对神经网络训练一次。在这之前,我们可以对输入图像做一些随机变化,比如旋转、平移或缩放。这样,在每个训练周期内,输入图像都会有差异。

增加训练数据的种类有利于减少过拟合,同样也提高了分类器的概括能力,从而提高模型分类的整体准确度。

Shuffle

在训练分类器时,我们需要提供一系列随机的图像,以免引入任何误差。

举个例子来说,我们刚开始训练分类器时,我们使用“牵牛花”图像对模型进行训练,这样一来,分类器在后续训练过程中将会偏向“牵牛花”,因为它只知道“牵牛花”。因此,在我们使用其他类型的花进行训练时,分类器最初的偏好也将持续一段时间。

为了避免这一现象,我们就需要在数据加载器中使用不同的图像,这很简单,只需要在加载器中添加shuffle=true,代码如下:

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Dropout

有的时候,分类器中的节点可能会导致其他节点不能进行适当的训练,此外,节点可能会产生共同依赖,这就会导致过拟合。

Dropout技术通过在每个训练步骤中使一些节点处于不活跃状态,来避免这一问题。这样一来,在每个训练阶段都使用不同的节点子集,从而减少过拟合。

除了过拟合,我们一定要记住,学习率( learning rate )是最关键的超参数。如果学习率过大,模型的误差永远都不会降到最小;如果学习率过小,分类器将会训练的特别慢,因此,学习率不能过大也不能过小。一般来说,学习率可以是0.01,0.001,0.0001……,依此类推。

最后,在最后一层选择正确的激活函数会对模型的正确率会产生特别大的影响。举个例子来说,如果我们使用 negative log likelihood loss(NLLLoss),那么,在最后一层中,建议使用LogSoftmax激活函数。

结论

理解模型的训练过程,将有助于创建能够概括的模型,在预测新图像类型时的准确度更高。

在本文中,我们讨论了过拟合将会如何降低模型的概括能力,并学习了降低过拟合的方法。另外,我们也强调了学习率的重要性及其常用值。最后,我们知道,为最后一层选择正确的激活函数非常关键。

现在,我们已经知道应该如何训练分类器,那么,我们就可以用它来预测以前从未见过的花型了!

 

原文链接
本文为云栖社区原创内容,未经允许不得转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/520221.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

涨姿势,一个通信项目从开始到结束,原来还包括这些工作

戳蓝字“CSDN云计算”关注我们哦!作者 | 小枣君责编 | 阿秃本月12日,中国移动31个省的通信工程设计与可行性研究集采正式启动。这次集采规模庞大,涵盖了无线网(5G、FDD、NB等)、核心网、承载网、支撑网等专业方向,预估基本规模超4…

听说支付宝有一个“疯起来连自己都打”的项目

小蚂蚁说: 自古红蓝出CP,在蚂蚁金服就有这样两支“相爱相杀”的队伍——红军和蓝军。蓝军是进攻方,主要职责是挖掘系统的弱点并发起“真实”的攻击,俗称“找茬”;红军则是防守方,其防控体系建设中的实时核…

蚂蚁金服红蓝军技术攻防演练究竟有多“狠”

如果一个技术团队不干别的,专门“搞破坏”,这是一种怎样的存在?这真的不是“天方夜谭”,在支付宝确实有这么一支队伍——技术蓝军。蓝军的任务就是不断地攻击和进攻,而防守方则是技术红军。在支付宝,蓝军从…

阿里巴巴在内蒙古旱区试水物联网灌溉技术,一年省出1.5个西湖

阿里巴巴正用物联网技术解决干旱地区的灌溉问题,通过搭建农业物联网平台,全面监测农作物的生长状态,从而匹配最节约的灌溉方案。12月19日试验区研究人员得出预测结果:一年可以省出1.5个西湖的水。 一直以来干旱是困扰人类的重要环…

网易考拉在服务化改造方面的实践

导读: 网易考拉(以下简称考拉)是网易旗下以跨境业务为主的综合型电商,自2015年1月9日上线公测后,业务保持了高速增长,这背后离不开其技术团队的支撑。微服务化是电商IT架构演化的必然趋势,网易…

Oracle 11g Java驱动包ojdbc6.jar安装到maven库,并查看jar具体版本号

ojdbc6.jar下载 Oracle官方宣布的Oracle数据库11g的驱动jar包是ojdbc6.jar ojdbc6.jar下载地址:https://www.oracle.com/technetwork/database/enterprise-edition/jdbc-112010-090769.html (Oracle Database 11g Release 2 (11.2.0.4) JDBC Drivers & UCP Do…

阿里重磅开源Blink:为什么我们等了这么久?

12月20日,由阿里巴巴承办的 Flink Forward China 峰会在北京国家会议中心召开,来自阿里、华为、腾讯、美团点评、滴滴、字节跳动等公司的技术专家与参会者分享了各公司基于 Flink 的应用和实践经验。 感兴趣的开发者可以看云栖社区的对于大会的主会5场分…

GAN是一种特殊的损失函数?

数据科学家Jeremy Howard在fast.ai的《生成对抗网络(GAN)》课程中曾经讲过这样一句话: “从本质上来说,生成对抗网络(GAN)是一种特殊的损失函数。” 你是否能够理解这句话的意思?读完本文&…

matlab 三维 作图 坐标轴_这张图(不全),想利用matlab画一张三维图,X Y z 轴分别为经度 纬度 频率,这...

xrangeminx:dx:maxx; yrangeminy:dy:maxy;[X,Y] meshgrid(xrange,yrange);griddata(lon,lat,SST,X,Y);mesh(X,Y,Z), hold onplot3(lon,lat,SST,o),hold offmatlab 作图方法2113:plot3 三维曲线图;plot3(x1,y1,z1,x2,y2,z2,…,xn,yn,zn): surf(x,y,z)…

(Python)零起步数学+神经网络入门

在这篇文章中,我们将在Python中从头开始了解用于构建具有各种层神经网络(完全连接,卷积等)的小型库中的机器学习和代码。最终,我们将能够写出如下内容: 假设你对神经网络已经有一定的了解,这篇文…

短视频宝贝=慢?阿里巴巴工程师这样秒开短视频

前言 随着短视频兴起,各大APP中短视频随处可见,feeds流、详情页等等。怎样让用户有一个好的视频观看体验显得越来越重要了。大部分feeds里面滑动观看视频的时候,有明显的等待感,体验不是很好。针对这个问题我们展开了一波优化&am…

Haproxy 管控台介绍

Queue 队列 简称全称说明Curcurrent queued requests当前的队列请求数量Maxmax queued requests最大的队列请求数量Limit队列限制数量 Session rate (每秒的连接回话)列表 简称全称说明scurcurrent sessions每秒的当前回话的限制数量smaxmax sessions每秒的新的最大的回话量s…

阿里云时空数据库引擎HBase Ganos上线,场景、功能、优势全解析

随着全球卫星导航定位系统、传感网、移动互联网、IoT等技术的快速发展,越来越多的终端设备连接至网络,由此产生了大规模的时空位置信息,如车辆轨迹、个人轨迹、群体活动、可穿戴设备时空位置等。这些数据具有动态变化(数据写入频繁…

云栖专辑|阿里开发者们的第二个感悟:PG大V德哥的使命感与开放心态

2015年12月20日,云栖社区上线。2018年12月20日,云栖社区3岁。 阿里巴巴常说“晴天修屋顶”。 在我们看来,寒冬中,最值得投资的是学习,是增厚的知识储备。 所以社区特别制作了这个专辑——分享给开发者们20个弥足珍贵的…

VS Code 全局配置

文章目录1. settings.json2. 在项目根目录添加.eslintrc.js3. 在项目根目录添加.prettierrc.json1. settings.json ctrlshirtp 搜索settings.json替换为下面内容即可 {// 主题颜色 浅色主题"workbench.colorTheme": "Monokai","workbench.iconTheme…

云栖专辑 | 阿里开发者们的第3个感悟:从身边开源开始学习,用过才能更好理解代码

2015年12月20日,云栖社区上线。2018年12月20日,云栖社区3岁。 阿里巴巴常说“晴天修屋顶”。 在我们看来,寒冬中,最值得投资的是学习,是增厚的知识储备。 所以社区特别制作了这个专辑——分享给开发者们20个弥足珍贵的…

个人帐目管理系统java_Java 项目 个人帐目管理系统

目录第一部分项目描述 31.1项目目的 3第二部分需求和开发环境 32.1使用技术和开发环境 32.2项目需求 32.3详细功能 32.4 E-R图 32.5数据库的设计 32.5.1数据表的设计 32.5.2数据库约束的设计 42.5.3数据库序列的设计 42.5.4数据库索引的设计 42.5.5数据库视图的设计 52.5.6数据…

KubeCon 2018 参会记录 —— FluentBit Deep Dive

在最近的上海和北美KubeCon大会上,来自于Treasure Data的Eduardo Silva(Fluentd Maintainer)带来了最期待的关于容器日志采集工具FluentBit的最新进展以及深入解析的分享;我们知道Fluentd是在2016年底正式加入CNCF,成为…

全球首个!阿里云开源批流一体机器学习平台Alink……

11月28日,阿里云正式开源机器学习平台 Alink,这也是全球首个批流一体的算法平台,旨在降低算法开发门槛,帮助开发者掌握机器学习的生命全周期。 Flink Forward 2019在京举办,吸引众多开发者参与标题Alink基于实时计算引…