【学习】使用PyTorch训练与评估自己的ResNet网络教程

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客

项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

视频手把手教程:我将维护一个集成各主干网络的图像分类项目_哔哩哔哩_bilibili

主要是复现和训练测试自己的数据集

复现部分

0.环境问题

pytorch官网里面找个合适的CUDA11.0安装一下,然后把requirements.txt安装一下

pip install -r requirements.txt

 参考版本:

pip list
Package                Version
---------------------- ---------------
certifi                2021.5.30
cycler                 0.11.0
dataclasses            0.8
importlib-resources    5.4.0
joblib                 1.1.1
kiwisolver             1.3.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
numpy                  1.19.2
olefile                0.46
opencv-contrib-python  4.0.1.24
opencv-python          4.0.1.24
opencv-python-headless 4.0.1.24
packaging              21.3
Pillow                 8.4.0
pip                    21.3.1
pyparsing              3.0.7
python-dateutil        2.9.0.post0
scikit-learn           0.24.2
scipy                  1.5.4
setuptools             36.4.0
six                    1.16.0
terminaltables         3.1.10
threadpoolctl          3.1.0
torch                  1.7.1
torchaudio             0.7.0a0+a853dff
torchvision            0.8.2
tqdm                   4.64.1
typing_extensions      4.1.1
wheel                  0.37.1
zipp                   3.6.0
  • 下载MobileNetV3-Small权重至datas
  • 利用项目里的猫狗图片检验一下安装情况
    python tools/single_test.py datas/cat-dog.png models/mobilenet/mobilenet_v3_small.py --classes-map datas/imageNet1kAnnotation.txt
    

    成功的话大概这样:

 1.数据集问题

 先下载花卉数据集(0zat):flower_photos.zip_免费高速下载|百度网盘-分享无限制 (baidu.com)

 原始地址在项目的资料部分:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

 目录结构,按照花卉类型存放

├─flower_photos
│  ├─daisy
│  │      100080576_f52e8ee070_n.jpg
│  │      10140303196_b88d3d6cec.jpg
│  │      ...
│  ├─dandelion
│  │      10043234166_e6dd915111_n.jpg
│  │      10200780773_c6051a7d71_n.jpg
│  │      ...
│  ├─roses
│  │      10090824183_d02c613f10_m.jpg
│  │      102501987_3cdb8e5394_n.jpg
│  │      ...
│  ├─sunflowers
│  │      1008566138_6927679c8a.jpg
│  │      1022552002_2b93faf9e7_n.jpg
│  │      ...
│  └─tulips
│  │      100930342_92e8746431_n.jpg
│  │      10094729603_eeca3f2cb6.jpg
│  │      ...
  • datas/中创建标签文件annotations.txt,按行将类别名的索引写入文件(应该已经写好了);即
    daisy 0
    dandelion 1
    roses 2
    sunflowers 3
    tulips 4
    

    之后进行数据集划分,随机分为训练和测试集。

  • 在tools/split_data.py中修改原始数据集地址和划分后的数据集地址。(new_datasets最好别更改)

    init_dataset = './flower_photos'
    new_dataset = './Awesome-Backbones/datasets'
    

    终端使用命令:

    python tools/split_data.py
    

    划分后的数据集格式大概为:

    ├─...
    ├─datasets
    │  ├─test
    │  │  ├─daisy
    │  │  ├─dandelion
    │  │  ├─roses
    │  │  ├─sunflowers
    │  │  └─tulips
    │  └─train
    │      ├─daisy
    │      ├─dandelion
    │      ├─roses
    │      ├─sunflowers
    │      └─tulips
    ├─...
    

    查看tools/get_annotation.py,看看路径要不要更改:

  • datasets_path   = '你的数据集路径'
    

 终端使用命令:

python tools/get_annotation.py

 该命令应该会在datas/下形成train.txt和test.txt,里面是具体照片的位置

2.修改配置文件

/models下有许多的模型配置文件

 以resnet为例

 挑一个顺眼的改改

以resnet101为例

# model settingsmodel_cfg = dict(backbone=dict(type='ResNet',depth=101,num_stages=4,out_indices=(3, ),style='pytorch'),neck=dict(type='GlobalAveragePooling'),head=dict(type='LinearClsHead',num_classes=5,in_channels=2048,loss=dict(type='CrossEntropyLoss', loss_weight=1.0),topk=(1, 5),))# dataloader pipeline
img_lighting_cfg = dict(eigval=[55.4625, 4.7940, 1.1475],eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],[-0.5836, -0.6948, 0.4203]],alphastd=0.1,to_rgb=True)
policies = [dict(type='AutoContrast', prob=0.5),dict(type='Equalize', prob=0.5),dict(type='Invert', prob=0.5),dict(type='Rotate',magnitude_key='angle',magnitude_range=(0, 30),pad_val=0,prob=0.5,random_negative_prob=0.5),dict(type='Posterize',magnitude_key='bits',magnitude_range=(0, 4),prob=0.5),dict(type='Solarize',magnitude_key='thr',magnitude_range=(0, 256),prob=0.5),dict(type='SolarizeAdd',magnitude_key='magnitude',magnitude_range=(0, 110),thr=128,prob=0.5),dict(type='ColorTransform',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Contrast',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Brightness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Sharpness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5),dict(type='Cutout',magnitude_key='shape',magnitude_range=(1, 41),pad_val=0,prob=0.5),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5,interpolation='bicubic'),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5,interpolation='bicubic')
]
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandAugment',policies=policies,num_policies=2,magnitude_level=12),dict(type='RandomResizedCrop',size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),dict(type='Lighting', **img_lighting_cfg),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=False),dict(type='ImageToTensor', keys=['img']),dict(type='ToTensor', keys=['gt_label']),dict(type='Collect', keys=['img', 'gt_label'])
]
val_pipeline = [dict(type='LoadImageFromFile'),dict(type='CenterCrop',crop_size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=True),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img'])
]# train
data_cfg = dict(batch_size = 32,num_workers = 0,train = dict(pretrained_flag = False,pretrained_weights = '',freeze_flag = False,freeze_layers = ('backbone',),epoches = 150,),test=dict(ckpt = './logs/ResNet/2024-06-26-10-37-00/Last_Epoch150.pth',metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'confusion'],metric_options = dict(topk = (1,5),thrs = None,average_mode='none'))
)# optimizer
optimizer_cfg = dict(type='SGD',lr=0.001,momentum=0.9,weight_decay=1e-4)# learning 
lr_config = dict(type='StepLrUpdater', step=[30, 60, 90])

主要改model_cfg里面的num_classes,data_cfg里的batch_size与num_workers

若有预训练权重则可以将pretrained_weights设置为True并将预训练的路径赋值给pretrained_weights

optimizer_cfg中修改初始学习率,根据batch_size调试

3.训练

终端运行

python tools/train.py models/resnet/resnet101.py

 运行结果

4.评估

在实际使用的配置文件中将ckpt修改

ckpt = '你的训练权重路径'

终端运行

python tools/evaluation.py models/resnet/resnet101.py

 运行结果

 我跑出来的准确率不高哈

5.测试

单张测试

python tools/single_test.py datasets/test/dandelion/14283011_3e7452c5b2_n.jpg models/resnet/resnet101.py

多张测试

使用batch_test.py,路径使用文件夹路径。

----------------------------------------------------------------------------------------------

使用自己的数据集

1.数据集准备

2.配置文件

3.训练

4.评估

5.测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34590.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【M1/M2】详细说明Parallels Desktop虚拟机的安装使用

希望文章能给到你启发和灵感~ 如果觉得有帮助的话,点赞+关注+收藏支持一下博主哦~ 阅读指南 开篇说明什么是Paralles Desktop一、基础环境说明1.1 硬件环境1.2 软件环境二、安装Parallels DeskTop2.1 下载软件安装包三、Parallels 下配置Windows 11 系统3.1 Windows 11安装3…

Docker系列之安全

Docker的安全前言一、Docker 容器与虚拟机的区别 1. 隔离与共享 2. 性能与损耗二、Docker 存在的安全问题 1.Docker 自身漏洞 2.Docker 源码问题三、 Docker 架构缺陷与安全机制 1. 容器之间的局域网攻击 2. DDoS 攻击耗尽资源 3. 有漏…

Vue_cli搭建过程项目创建

概述 vue-cli 官方提供的一个脚手架,用于快速生成一个 vue 的项目模板;预先定义 好的目录结构及基础代码,就好比咱们在创建 Maven 项目时可以选择创建一个 骨架项目,这个骨架项目就是脚手架,我们的开发更加的快速&am…

uni-app的showModal提示框,进行删除的二次确认,可自定义确定或取消操作

实现效果: 此处为删除的二次确认示例,点击删除按钮时出现该提示,该提示写在js script中。 实现方式: 通过uni.showModal进行提示,success为确认状态下的操作自定义,此处调用后端接口进行了删除操作&#…

如何成为专业的 .NET 开发人员

如今,网上有大量信息,找到正确的信息并非易事。当你开始编程之旅并希望获得全面的指南时,最好寻找一个可以指导你完成整个过程的指南。 本文将帮助您制定一份路线图,告诉您什么是重要的以及什么是需要学习的. 一.一切从软件基础…

【JavaScript】BOM编程

目录 一、BOM编程是什么 二、window对象的常用方法 1、弹窗API方法 2、计时器任务方法 三、window对象的属性对象常用方法 1、history网页浏览历史 2、location地址栏 3、数据存储属性对象 4、console控制台 一、BOM编程是什么 当我们使用浏览器打开一个网页窗口时,…

Volatility 内存取证【信安比赛快速入门】

一、练习基本命令使用 1、获取镜像信息 ./volatility -f Challenge.raw imageinfo 一般取第一个就可以了 2、查看用户 ./volatility -f Challenge.raw --profileWin7SP1x64 printkey -K "SAM\Domains\Account\Users\Names" 3、获取主机名 ./volatility -f Challenge…

Kafka入门-分区及压缩

一、生产者消息分区 Kafka的消息组织方式实际上是三级结构:主题-分区-消息。主题下的每条消息只会保存在某一个分区中,而不会在多个分区中被保存多份。 分区的作用就是提供负载均衡的能力,或者说对数据进行分区的主要原因,就是为…

门店客流统计)

门店客流统计 代码部分效果 代码部分 import cv2 import numpy as np from tracker import * import cvzone import timebg_subtractor cv2.createBackgroundSubtractorMOG2(history200, varThreshold140)# Open a video capture video_capture cv2.VideoCapture(r"sto…

昇思25天学习打卡营第3天|数据集与数据变换

数据集 数据集(Dataset)操作shufflemapbatch 数据变换(Transforms)Vision TransformsText TransformsLambda Transforms 总结 数据集(Dataset) 数据是深度学习的基础,深度神经网络的效果对数据…

algorithm中常见算法

1、前言 C的<algorithm>库是C标准库中的一个重要组成部分&#xff0c;它提供了一系列的函数&#xff0c;用于执行各种常见的算法操作&#xff0c;比如排序、查找、替换、合并等。这些算法函数通常以模板函数的形式提供&#xff0c;可以用于任何符合特定条件的容器类型。 …

玩个游戏 找以下2个wordpress外贸主题的不同 你几找到几处

Aitken艾特肯wordpress外贸主题 适合中国产品出海的蓝色风格wordpress外贸主题&#xff0c;产品多图展示、可自定义显示产品详细参数。 https://www.jianzhanpress.com/?p7060 Ultra奥创工业装备公司wordpress主题 蓝色风格wordpress主题&#xff0c;适合装备制造、工业设备…

用友U8 Cloud smartweb2.showRPCLoadingTip.d XXE漏洞复现

0x01 产品简介 用友U8 Cloud 提供企业级云ERP整体解决方案,全面支持多组织业务协同,实现企业互联网资源连接。 U8 Cloud 亦是亚太地区成长型企业最广泛采用的云解决方案。 0x02 漏洞概述 用友U8 Cloud smartweb2.showRPCLoadingTip.d 接口处存在XML实体,攻击者可通过该漏…

Origin做聚类分析并利用聚类插件绘制热力图

1.聚类分析 1.1 K均值聚类 step1、首先进行归一化&#xff0c;具体步骤如图1-1所示&#xff1a; 图1-1 操作后得到归一化值如图1-2所示&#xff1a; 图1-2 step2、执行K均值聚类分析&#xff0c;如图1-3所示&#xff0c;选中聚类列&#xff0c;接着点击“统计”—“多变量分析…

手把手从零开始搭建远程访问服务

远程访问服务工具——FRP frp 是一个能够实现内网穿透的高性能的反向代理应用&#xff0c;支持 TCP、UDP、HTTP、HTTPS 等多种协议。可以将内网服务以安全、便捷的方式通过具有公网的服务器来转发。 资源链接 根据自己服务型号和操作系统来选取对应的文件&#xff0c;不知道的…

VS2019中解决方案里的所有项目都是 <不同选项> 的解决方案

以上等等&#xff0c;全部是 <不同选项>。。。 这样的话&#xff0c;如何还原和查看原有的值呢&#xff0c;就这么丢失掉了吗&#xff1f; 不会&#xff0c;需要解决方案里配置一下。 解决&#xff1a; 解决方案右键属性 -> 配置属性 -> 配置 -> 将所有配置改…

三大办公软件实用小技巧 沈阳办公软件白领必修班

Word 学好办公软件能大大的提升我们的工作效率。下面让我们一起学习一下Word办公软件时几个实用小技巧&#xff01; 01.快速插入当前日期或时间 在使用Word办公软件进行文档的编辑处理时&#xff0c;如果需要在文章的末尾插入系统的当前日期或时间。通常情况下&#xff0c;我…

如何编写时区源文件

0、背景 ① 修改TZ环境变量改变时区不能立即生效。要求设置时区后立即生效&#xff0c;只能用修改/etc/localtime方式。 ② 原文作者 Bill Seymour&#xff0c;想要查看原文&#xff0c;点击官网地址https://www.iana.org/time-zones下载 zic 源码&#xff0c;源码目录中的 tz…

【TB作品】MSP430,G2533单片机,红外发射,红外接收,红外通信,IR发射

文章目录 题目红外NEC协议介绍基本概述数据帧结构位表示数据传输示例重复码&#xff08;Repeat Code&#xff09;实现细节发送端接收端 典型应用结论 最终效果代码 题目 遥控器 硬件&#xff1a;msp430g2553、oled显示器、ds18b20温度传感器、红外发射器、按键 软件功能&#…

MD5加密接口

签名算法 app_key和app_secret由对方系统提供 MD5_CALCULATE_HASH_FOR_CHAR&#xff08;中文加密与JAVA不一致&#xff09; 代码&#xff1a; *获取传输字段名的ASCII码&#xff0c;根据ASCII码对字段名进行排序SELECT * FROM zthr0051WHERE functionid iv_functionidINTO …