【学习】使用PyTorch训练与评估自己的ResNet网络教程

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客

项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

视频手把手教程:我将维护一个集成各主干网络的图像分类项目_哔哩哔哩_bilibili

主要是复现和训练测试自己的数据集

复现部分

0.环境问题

pytorch官网里面找个合适的CUDA11.0安装一下,然后把requirements.txt安装一下

pip install -r requirements.txt

 参考版本:

pip list
Package                Version
---------------------- ---------------
certifi                2021.5.30
cycler                 0.11.0
dataclasses            0.8
importlib-resources    5.4.0
joblib                 1.1.1
kiwisolver             1.3.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
numpy                  1.19.2
olefile                0.46
opencv-contrib-python  4.0.1.24
opencv-python          4.0.1.24
opencv-python-headless 4.0.1.24
packaging              21.3
Pillow                 8.4.0
pip                    21.3.1
pyparsing              3.0.7
python-dateutil        2.9.0.post0
scikit-learn           0.24.2
scipy                  1.5.4
setuptools             36.4.0
six                    1.16.0
terminaltables         3.1.10
threadpoolctl          3.1.0
torch                  1.7.1
torchaudio             0.7.0a0+a853dff
torchvision            0.8.2
tqdm                   4.64.1
typing_extensions      4.1.1
wheel                  0.37.1
zipp                   3.6.0
  • 下载MobileNetV3-Small权重至datas
  • 利用项目里的猫狗图片检验一下安装情况
    python tools/single_test.py datas/cat-dog.png models/mobilenet/mobilenet_v3_small.py --classes-map datas/imageNet1kAnnotation.txt
    

    成功的话大概这样:

 1.数据集问题

 先下载花卉数据集(0zat):flower_photos.zip_免费高速下载|百度网盘-分享无限制 (baidu.com)

 原始地址在项目的资料部分:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

 目录结构,按照花卉类型存放

├─flower_photos
│  ├─daisy
│  │      100080576_f52e8ee070_n.jpg
│  │      10140303196_b88d3d6cec.jpg
│  │      ...
│  ├─dandelion
│  │      10043234166_e6dd915111_n.jpg
│  │      10200780773_c6051a7d71_n.jpg
│  │      ...
│  ├─roses
│  │      10090824183_d02c613f10_m.jpg
│  │      102501987_3cdb8e5394_n.jpg
│  │      ...
│  ├─sunflowers
│  │      1008566138_6927679c8a.jpg
│  │      1022552002_2b93faf9e7_n.jpg
│  │      ...
│  └─tulips
│  │      100930342_92e8746431_n.jpg
│  │      10094729603_eeca3f2cb6.jpg
│  │      ...
  • datas/中创建标签文件annotations.txt,按行将类别名的索引写入文件(应该已经写好了);即
    daisy 0
    dandelion 1
    roses 2
    sunflowers 3
    tulips 4
    

    之后进行数据集划分,随机分为训练和测试集。

  • 在tools/split_data.py中修改原始数据集地址和划分后的数据集地址。(new_datasets最好别更改)

    init_dataset = './flower_photos'
    new_dataset = './Awesome-Backbones/datasets'
    

    终端使用命令:

    python tools/split_data.py
    

    划分后的数据集格式大概为:

    ├─...
    ├─datasets
    │  ├─test
    │  │  ├─daisy
    │  │  ├─dandelion
    │  │  ├─roses
    │  │  ├─sunflowers
    │  │  └─tulips
    │  └─train
    │      ├─daisy
    │      ├─dandelion
    │      ├─roses
    │      ├─sunflowers
    │      └─tulips
    ├─...
    

    查看tools/get_annotation.py,看看路径要不要更改:

  • datasets_path   = '你的数据集路径'
    

 终端使用命令:

python tools/get_annotation.py

 该命令应该会在datas/下形成train.txt和test.txt,里面是具体照片的位置

2.修改配置文件

/models下有许多的模型配置文件

 以resnet为例

 挑一个顺眼的改改

以resnet101为例

# model settingsmodel_cfg = dict(backbone=dict(type='ResNet',depth=101,num_stages=4,out_indices=(3, ),style='pytorch'),neck=dict(type='GlobalAveragePooling'),head=dict(type='LinearClsHead',num_classes=5,in_channels=2048,loss=dict(type='CrossEntropyLoss', loss_weight=1.0),topk=(1, 5),))# dataloader pipeline
img_lighting_cfg = dict(eigval=[55.4625, 4.7940, 1.1475],eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],[-0.5836, -0.6948, 0.4203]],alphastd=0.1,to_rgb=True)
policies = [dict(type='AutoContrast', prob=0.5),dict(type='Equalize', prob=0.5),dict(type='Invert', prob=0.5),dict(type='Rotate',magnitude_key='angle',magnitude_range=(0, 30),pad_val=0,prob=0.5,random_negative_prob=0.5),dict(type='Posterize',magnitude_key='bits',magnitude_range=(0, 4),prob=0.5),dict(type='Solarize',magnitude_key='thr',magnitude_range=(0, 256),prob=0.5),dict(type='SolarizeAdd',magnitude_key='magnitude',magnitude_range=(0, 110),thr=128,prob=0.5),dict(type='ColorTransform',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Contrast',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Brightness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Sharpness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5),dict(type='Cutout',magnitude_key='shape',magnitude_range=(1, 41),pad_val=0,prob=0.5),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5,interpolation='bicubic'),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5,interpolation='bicubic')
]
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandAugment',policies=policies,num_policies=2,magnitude_level=12),dict(type='RandomResizedCrop',size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),dict(type='Lighting', **img_lighting_cfg),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=False),dict(type='ImageToTensor', keys=['img']),dict(type='ToTensor', keys=['gt_label']),dict(type='Collect', keys=['img', 'gt_label'])
]
val_pipeline = [dict(type='LoadImageFromFile'),dict(type='CenterCrop',crop_size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=True),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img'])
]# train
data_cfg = dict(batch_size = 32,num_workers = 0,train = dict(pretrained_flag = False,pretrained_weights = '',freeze_flag = False,freeze_layers = ('backbone',),epoches = 150,),test=dict(ckpt = './logs/ResNet/2024-06-26-10-37-00/Last_Epoch150.pth',metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'confusion'],metric_options = dict(topk = (1,5),thrs = None,average_mode='none'))
)# optimizer
optimizer_cfg = dict(type='SGD',lr=0.001,momentum=0.9,weight_decay=1e-4)# learning 
lr_config = dict(type='StepLrUpdater', step=[30, 60, 90])

主要改model_cfg里面的num_classes,data_cfg里的batch_size与num_workers

若有预训练权重则可以将pretrained_weights设置为True并将预训练的路径赋值给pretrained_weights

optimizer_cfg中修改初始学习率,根据batch_size调试

3.训练

终端运行

python tools/train.py models/resnet/resnet101.py

 运行结果

4.评估

在实际使用的配置文件中将ckpt修改

ckpt = '你的训练权重路径'

终端运行

python tools/evaluation.py models/resnet/resnet101.py

 运行结果

 我跑出来的准确率不高哈

5.测试

单张测试

python tools/single_test.py datasets/test/dandelion/14283011_3e7452c5b2_n.jpg models/resnet/resnet101.py

多张测试

使用batch_test.py,路径使用文件夹路径。

----------------------------------------------------------------------------------------------

使用自己的数据集

1.数据集准备

2.配置文件

3.训练

4.评估

5.测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34590.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyScada(四)构建用户 HMI

构建用户 HMI(前端)的简短说明 在后端HMI部分: 图表,添加新图表页面,添加页面Widget,添加一个 Widget,在 Page 下选择您在 2 中添加的页面,并在Content下选择1 中的图表。小部件控…

【M1/M2】详细说明Parallels Desktop虚拟机的安装使用

希望文章能给到你启发和灵感~ 如果觉得有帮助的话,点赞+关注+收藏支持一下博主哦~ 阅读指南 开篇说明什么是Paralles Desktop一、基础环境说明1.1 硬件环境1.2 软件环境二、安装Parallels DeskTop2.1 下载软件安装包三、Parallels 下配置Windows 11 系统3.1 Windows 11安装3…

PostgreSQL 简介与基础(一)

1. 什么是 PostgreSQL 1.1 概述 PostgreSQL(常简称为Postgres)是一种功能强大的开源关系型数据库管理系统(RDBMS)。它以其可靠性、强大的功能和符合标准的特性著称。PostgreSQL 支持大部分 SQL 标准,并且具有许多现代…

护网面试内容

1.自我介绍 2.如何判断是不是误判 分析请求、响应内容,判断是否攻击成功首先看告警事件名称判断是网络攻击事件还是web攻击事件,网络攻击事件:定位五元组信息(源IP、目的IP、源端口、目的端口、协议),对整…

Docker系列之安全

Docker的安全前言一、Docker 容器与虚拟机的区别 1. 隔离与共享 2. 性能与损耗二、Docker 存在的安全问题 1.Docker 自身漏洞 2.Docker 源码问题三、 Docker 架构缺陷与安全机制 1. 容器之间的局域网攻击 2. DDoS 攻击耗尽资源 3. 有漏…

Vue_cli搭建过程项目创建

概述 vue-cli 官方提供的一个脚手架,用于快速生成一个 vue 的项目模板;预先定义 好的目录结构及基础代码,就好比咱们在创建 Maven 项目时可以选择创建一个 骨架项目,这个骨架项目就是脚手架,我们的开发更加的快速&am…

uni-app的showModal提示框,进行删除的二次确认,可自定义确定或取消操作

实现效果: 此处为删除的二次确认示例,点击删除按钮时出现该提示,该提示写在js script中。 实现方式: 通过uni.showModal进行提示,success为确认状态下的操作自定义,此处调用后端接口进行了删除操作&#…

【Android】App设置开机自启动之后但是无效的原因之一

问题描述 通过监听开机广播,然后启动App,但是在系统开机之后,App并没有启动。 问题原因 当一个App在安装之后,必须要先启动一次,然后这个开机启动的功能才可以正常使用。 但是由于我的这个App是一个服务类的App&am…

如何成为专业的 .NET 开发人员

如今,网上有大量信息,找到正确的信息并非易事。当你开始编程之旅并希望获得全面的指南时,最好寻找一个可以指导你完成整个过程的指南。 本文将帮助您制定一份路线图,告诉您什么是重要的以及什么是需要学习的. 一.一切从软件基础…

【JavaScript】BOM编程

目录 一、BOM编程是什么 二、window对象的常用方法 1、弹窗API方法 2、计时器任务方法 三、window对象的属性对象常用方法 1、history网页浏览历史 2、location地址栏 3、数据存储属性对象 4、console控制台 一、BOM编程是什么 当我们使用浏览器打开一个网页窗口时,…

Volatility 内存取证【信安比赛快速入门】

一、练习基本命令使用 1、获取镜像信息 ./volatility -f Challenge.raw imageinfo 一般取第一个就可以了 2、查看用户 ./volatility -f Challenge.raw --profileWin7SP1x64 printkey -K "SAM\Domains\Account\Users\Names" 3、获取主机名 ./volatility -f Challenge…

探索QCS6490目标检测AI应用开发(二):摄像头视频的拉取和解码

作为《探索QCS6490目标检测AI应用开发》文章,紧接上一期,我们介绍如何在应用程序中拉取视频流,并且用硬件解码,得到逐帧的图像画面。我们使用了高通的Intelligent Multimedia SDK(IM SDK)完成视频的拉流和硬…

Linux杀僵尸进程

ps -A -o stat,ppid,pid,cmd | grep -e ^[Zz] | awk {print $2}1、查看系统是否有僵尸进程 使用Top命令查找,当zombie前的数量不为0时,即系统内存在相应数量的僵尸进程。 2、定位僵尸进程 使用命令 ps -A -ostat,ppid,pid,cmd |grep -e [Zz]定位僵尸…

Kafka入门-分区及压缩

一、生产者消息分区 Kafka的消息组织方式实际上是三级结构:主题-分区-消息。主题下的每条消息只会保存在某一个分区中,而不会在多个分区中被保存多份。 分区的作用就是提供负载均衡的能力,或者说对数据进行分区的主要原因,就是为…

数据库与表的基本操作:构建数据世界的基石(三)

引言:从零构建数据结构的艺术 在上一章节《安装与配置》中,我们成功地在不同的操作系统上安装并配置了MySQL,为实战数据库管理奠定了坚实的基础。本章节,我们将深入探索数据库与表的基本操作,包括如何创建、删除数据库…

RandLA-Net语义分割

项⽬地址: GitHub - tsunghan-wu/RandLA-Net-pytorch: :four_leaf_clover: Pytorch Implementation of RandLA-Net (https://arxiv.org/abs/1911.11236) 搭建环境并配置RandLA-Net 根据Environment Setup 搭建环境(除了requirements.txt中的库&#xf…

畅谈GPT-5

前言 ChatGBT(Chat Generative Bidirectional Transformer)是一种基于自然语言处理技术的对话系统,它的出现是人工智能和自然语言处理技术发展的必然趋势。随着技术的更新和进步,GPT也迎来了一代代的更新迭代。 1.GPT的回顾 1.1 GPT-3的介绍 GPT-3(Gen…

门店客流统计)

门店客流统计 代码部分效果 代码部分 import cv2 import numpy as np from tracker import * import cvzone import timebg_subtractor cv2.createBackgroundSubtractorMOG2(history200, varThreshold140)# Open a video capture video_capture cv2.VideoCapture(r"sto…

昇思25天学习打卡营第3天|数据集与数据变换

数据集 数据集(Dataset)操作shufflemapbatch 数据变换(Transforms)Vision TransformsText TransformsLambda Transforms 总结 数据集(Dataset) 数据是深度学习的基础,深度神经网络的效果对数据…

力扣377 组合总和Ⅳ Java版本

文章目录 题目描述代码 题目描述 给你一个由 不同 整数组成的数组 nums ,和一个目标整数 target 。请你从 nums 中找出并返回总和为 target 的元素组合的个数。 题目数据保证答案符合 32 位整数范围。 示例 1: 输入:nums [1,2,3], targe…