【深度学习-图像识别】使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)


文章目录

  • 前言
    • fastai介绍
      • 数据集介绍
  • 一、环境准备
  • 二、数据集处理
    • 1.数据目录结构
    • 2.导入依赖项
    • 2.读入数据
    • 3.模型构建
      • 3.1 寻找合适的学习率
      • 3.2 模型调优
    • 4.模型保存与应用
  • 总结
      • 人工智能-图像识别 系列文章目录


前言

fastai介绍

fastai 是一个深度学习库,它为从业人员提供了高级组件,可以快速、轻松地在标准深度学习领域提供最先进的结果,并为研究人员提供了低级组件,可以混合和匹配以构建新的方法。以解耦抽象的方式表达了许多深度学习和数据处理技术的通用底层模式。
fastai 有两个主要的设计目标:易于使用、快速高效,同时具有很强的可破解性和可配置性。它建立在提供可组合构件的低级应用程序接口的层次结构之上。这样,如果用户想重写部分高级应用程序接口或添加特定行为以满足自己的需求,就不必学习如何使用最底层的应用程序接口。
在这里插入图片描述

数据集介绍

下载链接
Caltech101国内下载地址
Caltech101

Caltech101数据集内部有 101 个类别的物体图片。每个类别约有 40 至 800 张图片。大多数类别约有 50 张图片。每张图片的大小大约为 300 x 200 像素。并且作者还标注了这些图片中每个物体的轮廓,这些都包含在 "Annotations.tar "中。还有一个 MATLAB 脚本 "show_annotations.m "可以查看注释。

Collected in September 2003 by Fei-Fei Li, Marco Andreetto, and
Marc’Aurelio Ranzato。

一、环境准备

这里展示使用GPU进行训练的环境搭建,只用CPU也可以进行训练,只是训练时间比较慢。
首先安装Anaconda,通过conda安装我们需要的包

 conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidiaconda install -c nvidia fastai anaconda

详情可见第一篇文章。

二、数据集处理

1.数据目录结构

├───data_iamge
│   ├───101_ObjectCategories
│   │   ├───accordion
│   │   ├───airplanes
│   │   ├───anchor
│   │   ├───ant
│   │   ├───BACKGROUND_Google
│   │   ├───barrel
│   │   ├───bass
│   │   ├───beaver
│   │   ├───binocular
│   │   ├───bonsai
│   │   ├───brain
│   │   ├───brontosaurus
...

2.导入依赖项

from fastai import *
from fastai.vision.all import *
from fastai.metrics import error_rateimport os
#from keras.utils import plot_model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

查看环境以及版本信息,cuda.is_available()判断是否可以用GPU。

print(torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

True
2.0.1
11.8
8700

'''SEED Everything'''
def seed_everything(SEED=42):random.seed(SEED)np.random.seed(SEED)torch.manual_seed(SEED)torch.cuda.manual_seed(SEED)torch.cuda.manual_seed_all(SEED)torch.backends.cudnn.benchmark = True # keep True if all the input have same size.
SEED=42
seed_everything(SEED=SEED)
'''SEED Everything'''

2.读入数据

代码如下(示例):

path='./data_image/101_ObjectCategories/'
image_rsize=224
item_tfms = [Resize((image_rsize,image_rsize))]
data = ImageDataLoaders.from_folder(path, train = '.', valid_pct=0.2,size=image_rsize,item_tfms=item_tfms)
data.show_batch(figsize=(7,6))

在这里插入图片描述

3.模型构建

这里使用预训练模型resnet101,这是一个非常优秀的残差网络模型。
这些残差网络更容易优化,并且可以从显着增加的深度中获得准确性。
这些残差网络的集合在 ImageNet 测试集上实现了 3.57% 的误差。该结果在ILSVRC 1分类任务中获得第一名。

learn = cnn_learner(data, models.resnet101, model_dir='./model', path = Path("."))

3.1 寻找合适的学习率

learn.lr_find()

在这里插入图片描述

接下来使用fit_one_cycle方法用更小的学习率进一步训练。fit_one_cycle使用的是一种周期性学习率,从较小的学习率开始学习,缓慢提高至较高的学习率,然后再慢慢下降,周而复始,每个周期的长度略微缩短,在训练的最后部分,允许学习率比之前的最小值降得更低。这不仅可以加速训练,还有助于防止模型落入损失平面的陡峭区域,使模型更倾向于寻找更平坦的极小值,从而缓解过拟合现象。

lr1 = 1e-3
lr2 = 1e-1
epoch	train_loss	valid_loss	time
0	1.417713	1.648756	00:45
1	3.097069	9.964518	00:43
2	5.385355	5.347832	00:44
3	4.194504	12.162844	00:44
4	2.985504	3.486863	00:43
5	2.152388	22.297184	00:43
6	1.295905	3.554162	00:43
7	0.630879	9.193820	00:43
8	0.361619	49.334236	00:43
9	0.255115	9.832499	00:43

3.2 模型调优

unfreeze
在fastai课程中使用的是预训练模型,模型卷积层的权重已经提前在ImageNet
上训练好了,在使用的时候一般只需要在预训练模型最后一层卷积层后添加自定义的全连接层即可。卷积层默认是freeze的,即在训练阶段进行反向传播时不会更新卷积层的权重,只会更新全连接层的权重。在训练几个epoch之后,全连接层的权重已经训练的差不多了,但accuracy还没有达到你的要求,这时你可以调用unfreeze然后再进行训练,这样在进行反向传播时便会更新卷积层的权重(一般不会对卷积层权重进行较大的更新,只会进行一点点的微调,越靠前的卷积层调整的幅度越小,所以有了differential
learning rate 这一想法)

precompute
当precompute=True时,会提前计算出每一个训练样本(不包括增强样本)在预训练模型最后一层卷积层的activation,
并将其缓存下来,之后在训练阶段进行前向传播的时候,直接将precompute 的activation 作为后面全连接层(FC
Layer)的输入,这样便省去前面卷积层进行前向传播的计算量,减少训练所需时间(这种优势在epoch比较大的时候能够显著0提高训练速度)。当precompute=False时,则不会提前计算训练样本的activation,每一个epoch都需要重新将训练样本+增强样本(前提是进行了增强操作)进行卷积层的前向传播,然后进行反向传播更新对应的权重。

learn.unfreeze()
learn.show_results()

在这里插入图片描述
从展示的部分训练结果可以看出,只有一张图被预测错误了,其他的都是正确的。

4.模型保存与应用

最后我们可以将模型保存下来,并且对验证集的图片的类别进行预测。

learn.export(Path("./model/export.pkl"))
from PIL import Image
img = Image.open(path+'ant/image_0001.jpg')
image_rsize=224
# Resize the image to 224x224
img_resized = img.resize((image_rsize,image_rsize))
pred, pred_idx, probs = learn.predict(img_resized)
im_t = cast(array(img_resized), TensorImage)
# Print the predicted label and probability
print(f"Predicted label: {pred}, probability: {probs[pred_idx]:.4f}")
img

在这里插入图片描述

总结

epoch	train_loss	valid_loss	time
0	1.030772	979.477417	00:52
1	1.074642	86.289436	00:52
2	0.553576	0.457210	00:52
3	0.302997	0.546438	00:52
4	0.176070	0.596845	00:52

我们借助fastai训练了resnet101模型,对 101 个类别的图像数据集进行了分类。
使用基于pytorch的fastai库,使用resnet模型和有101个类别的Caltech101图像数据集,训练了一个高准确率的多分类的深度学习模型,能够对101个类别的图像大数据集进行准确的图像类别识别。
使用简洁高效的代码,借助GPU提升训练速度(也可以使用CPU训练,本项目会自动识别硬件),首先数据集进行预处理,然后对模型进行训练,并将模型保存为pkl格式,最后对测试集的图像的类别进行预测。

可见,使用fastai进行图像多分类是非常简便的,所使用的代码行数非常少却能达到很高的准确率,而且借助GPU训练速度非常快。

这里将全部的代码和图片数据集打包起来了,方便大家复现。
开箱即用,欢迎下载:
使用fastai对Caltech101数据集进行图像多分类

单独下载数据集
Caltech101数据集 2023完整版 增加了更多图片


人工智能-图像识别 系列文章目录

  1. 环境搭建: pytorch以及fastai安装,配置GPU训练环境 待更新。。。
  2. 使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/48444.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot实践八--用户管理系统

一,技术介绍 技术选型功能说明springboot是一种基于 Spring 框架的快速开发应用程序的框架,它的主要作用是简化 Spring 应用程序的配置和开发,同时提供一系列开箱即用的功能和组件,如内置服务器、数据访问、安全、监控等&#xf…

[oneAPI] 基于BERT预训练模型的SWAG问答任务

[oneAPI] 基于BERT预训练模型的SWAG问答任务 基于Intel DevCloud for oneAPI下的Intel Optimization for PyTorch基于BERT预训练模型的SWAG问答任务数据集下载和描述数据集构建问答选择模型训练 结果参考资料 比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d…

方案:AI边缘计算智慧工地解决方案

一、方案背景 在工程项目管理中,工程施工现场涉及面广,多种元素交叉,状况较为复杂,如人员出入、机械运行、物料运输等。特别是传统的现场管理模式依赖于管理人员的现场巡查。当发现安全风险时,需要提前报告&#xff0…

合宙Air724UG LuatOS-Air LVGL API--对象

对象 概念 在 LVGL 中,用户界面的基本构建块是对象。例如,按钮,标签,图像,列表,图表或文本区域。 属性 基本属性 所有对象类型都共享一些基本属性: Position (位置) Size (尺寸) Parent (父母…

linux 免交互

Linux 免交互 1、免交互概念2、基本免交互的例子2.1命令行免交互统计2.2使用脚本免交互统计2.3使用免交互命令打印2.4免交互修改密码2.5重定向查看2.6重定向到指定文件2.7重定向直接指定文件2.8使用脚本完成重定向输入2.9免交互脚本完成赋值变量2.10关闭变量替换功能&#xff0…

云计算在IT领域的发展和应用

文章目录 云计算的发展历程云计算的核心概念云计算在IT领域的应用1. 基础设施即服务(IaaS):2. 平台即服务(PaaS):3. 软件即服务(SaaS): 云计算的拓展应用结论 &#x1f3…

如何进行在线pdf转ppt?在线pdf转ppt的方法

在当今数字化时代,PDF文件的广泛应用为我们的工作和学习带来了巨大的便利。然而,有时候我们可能需要将PDF转换为PPT文件,以便更好地展示和分享内容。在线PDF转PPT工具因其操作简便、高效而备受欢迎。如何进行在线pdf转ppt呢?接下来&#xff…

fatal: not a git repository (or any of the parent directories): .git

提示说没有.git这样一个目录 在命令行 输入 git init 然后回车就好了 git remote add origin https:/.git git push -u origin "master"

《Java极简设计模式》第04章:建造者模式(Builder)

作者:冰河 星球:http://m6z.cn/6aeFbs 博客:https://binghe.gitcode.host 文章汇总:https://binghe.gitcode.host/md/all/all.html 源码地址:https://github.com/binghe001/java-simple-design-patterns/tree/master/j…

Node.js下载安装及环境配置教程

一、进入官网地址下载安装包 https://nodejs.org/zh-cn/download/ 选择对应你系统的Node.js版本,这里我选择的是Windows系统、64位 Tips:如果想下载指定版本,点击【以往的版本】,即可选择自己想要的版本下载 二、安装程序 &a…

【Apollo学习笔记】——规划模块TASK之PATH_REUSE_DECIDER

文章目录 前言PATH_REUSE_DECIDER功能简介PATH_REUSE_DECIDER相关配置PATH_REUSE_DECIDER总体流程PATH_REUSE_DECIDER相关子函数IsCollisionFreeTrimHistoryPathIsIgnoredBlockingObstacle和GetBlockingObstacleS Else参考 前言 在Apollo星火计划学习笔记——Apollo路径规划算…

Hive Cli / HiveServer2 中使用 dayofweek 函数引发的BUG!

文章目录 前言dayofweek 函数官方说明BUG 重现Spark SQL 中的使用总结 前言 使用的集群环境为: hive 3.1.2spark 3.0.2 dayofweek 函数官方说明 dayofweek(date) - Returns the day of the week for date/timestamp (1 Sunday, 2 Monday, …, 7 Saturday). …

数据封装与解封装过程

2.2数据封装与解封装过程(二) 如果网络世界只有终端设备,那么将不能称之为网络。正因为有很多中转设备才形成了今天如此复杂的Internet网络,只不过一贯作为网络用户的我们没有机会感知它们的存在,这都是传输层的“功劳”,由于传输…

在外SSH远程连接macOS服务器

文章目录 前言1. macOS打开远程登录2. 局域网内测试ssh远程3. 公网ssh远程连接macOS3.1 macOS安装配置cpolar3.2 获取ssh隧道公网地址3.3 测试公网ssh远程连接macOS 4. 配置公网固定TCP地址4.1 保留一个固定TCP端口地址4.2 配置固定TCP端口地址 5. 使用固定TCP端口地址ssh远程 …

科技云报道:云计算下半场,公有云市场生变,私有云风景独好

科技云报道原创。 大数据、云计算、人工智能,组成了恢弘的万亿级科技市场。这三个领域,无论远观近观,都如此性感和魅力,让一代又一代创业者为之杀伐攻略。 然而高手过招往往一瞬之间便已胜负知晓,云计算市场的巨幕甫…

测试框架pytest教程(11)-pytestAPI

常量 pytest.__version__ #输出pytest版本 pytest.version_tuple #输出版本的元组形式 功能 pytest.approx pytest.approx 是一个用于进行数值近似比较的 pytest 断言工具。 在测试中,有时候需要对浮点数或其他具有小数部分的数值进行比较。然而,由于…

Node.JS教程

文章目录 Node.JSNode.js学习指南一、Node.js基础1.认识Node.js2.开发环境搭建3. 模块、包、commonJS3.1、为什么要有模块化开发?3.2、CommonJS规范3.3、 modules模块化规范写法 总结 Node.JS Node.js学习指南 服务端开发底层平台周边生态 学习前提 JavaScript、E…

Rspack 创建 vue2/3 项目接入 antdv(rspack.config.js 配置 less 主题)

一、简介 Rspack CLI 官方文档。 rspack.config.js 官方文档。 二、创建 vue 项目 创建项目(文档中还提供了 Rspack 内置 monorepo 框架 Nx 的创建方式,根据需求进行选择) # npm 方式 $ npm create rspacklatest# yarn 方式 $ yarn create…

html动态爱心代码【二】(附源码)

目录 前言 效果演示 内容修改 完整代码 总结 前言 七夕马上就要到了,为了帮助大家高效表白,下面再给大家带来了实用的HTML浪漫表白代码(附源码)背景音乐,可用于520,情人节,生日,表白等场景&#xff0c…

Android 面试之Glide做了哪些优化?

前言 Glide可以说是最常用的图片加载框架了,Glide链式调用使用方便,性能上也可以满足大多数场景的使用,Glide源码与原理也是面试中的常客。 但是Glide的源码内容比较多,想要学习它的源码往往千头万绪,一时抓不住重点.…