pytorch迁移学习训练图像分类

pytorch迁移学习训练图像分类

  • 一、环境配置
  • 二、迁移学习关键代码
  • 三、完整代码
  • 四、结果对比

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型

一、环境配置

1,安装所需的包

pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装Pytorch

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,创建目录

import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')

4,下载数据集压缩包(下载之后需要解压数据集)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip

二、迁移学习关键代码

以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:

  • 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
  • 选择二:微调训练所有层。

适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关

model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
  • 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())

三、完整代码

import time
import osimport numpy as np
from tqdm import tqdmimport torch
import torchvision
import torch.nn as nn# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train')	# 测试集路径
test_path = os.path.join(dataset_dir, 'val')	# 测试集路径from torchvision import datasets# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)# 定义数据加载器DataLoader
from torch.utils.data import DataLoaderBATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from torchvision import models
import torch.optim as optim# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc	# 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())    # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())# 训练配置
model = model.to(device)# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 训练轮次 Epoch
EPOCHS = 30# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):model.train()for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)           # 前向预测,获得当前 batch 的预测结果loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数optimizer.zero_grad()loss.backward()                   # 损失函数对神经网络权重反向传播求梯度optimizer.step()                  # 优化更新神经网络权重# 测试集上初步测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重

四、结果对比

调用不同迁移学习得到的模型对比测试集准确率

# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')# 测试集上进行测试
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度_, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果total += labels.size(0)correct += (preds == labels).sum()   # 预测正确样本个数print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%
在这里插入图片描述
而所有权重随机的选择三测试集准确率为 43.228%
43.228

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83770.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ArcGIS】基本概念-矢量空间分析

栅格数据与矢量数据 1.1 栅格数据 栅格图是一个规则的阵列,包含着一定数量的像元或者栅格 常用的栅格图格式有:tif,png,jpeg/jpg等 1.2 矢量数据 矢量图是由一组描述点、线、面,以及它们的色彩、位置的数据&#x…

堆内存与栈内存

文章目录 1. 栈内存2. 堆内存3. 区别和联系参考资料 1. 栈内存 栈内存是为线程留出的临时空间 每个线程都有一个固定大小的栈空间,而且栈空间存储的数据只能由当前线程访问,所以它是线程安全的。栈空间的分配和回收是由系统来做的,我们不需…

如何玩转CSDN AI工具集

前言 人工智能生成内容(AIGC)是当下最具有前景的技术领域之一。AI能够以惊人的速度和准确度生成各种类型的内容,完成文章翻译、代码生成、AI对话、插图创作等工作,带来了许多令人兴奋的机遇。 本文将介绍CSDN AI工具集的基本使用…

「大数据-0」虚拟机VMware安装、配置、使用、创建虚拟机集群教程

目录 一、下载VMware Wworkstation Pro 16 二、安装VMware Wworkstation Pro 16 三、检查与设置VMware的网卡 1. 检查 2. 设置VMware网段 四、在VMware上安装Linux虚拟机 五、对安装好的虚拟机进行设置 1. 打开设置 2. 设置中文 3. 修改字体大小 4. 修改终端字体大小 5. 关闭虚…

pip pip3安装库时都指向python2的库

当在python3的环境下使用pip3安装库时&#xff0c;发现居然都指向了python2的库 pip -V pip3 -V安装命令更改为&#xff1a; python3 -m pip install <package>

【操作系统笔记】内存寻址

物理寻址 主存&#xff08;内存&#xff09; 计算机主存也可以称为物理内存&#xff0c;内存可以看成由若干个连续字节大小的单元组成的数组每个字节都有一个唯一的物理地址&#xff08;Physical Address&#xff09;CPU访问内存前&#xff0c;先拿到内存地址&#xff0c;然后…

【数据结构】二叉树的链式实现及遍历

文章目录 一、二叉树的遍历1、前序遍历2、中序遍历3、后序遍历4、层序遍历 二、二叉树结点个数及高度1、二叉树节点个数2、二叉树叶子节点个数3、二叉树第k层节点个数4、二叉树查找值为x的节点 三、二叉树创建及销毁1、通过前序遍历数组创建二叉树2、二叉树的销毁3、判断是否为…

Pytest单元测试框架 —— Pytest+Allure+Jenkins的应用

一、简介 pytestallurejenkins进行接口测试、生成测试报告、结合jenkins进行集成。 pytest是python的一种单元测试框架&#xff0c;与python自带的unittest测试框架类似&#xff0c;但是比unittest框架使用起来更简洁&#xff0c;效率更高 allure-pytest是python的一个第三方…

IMX6ULL移植篇-Linux内核源码文件表

一. Linux内核源码目录 我们在分析 Linux 之前&#xff0c;一定要先在 Ubuntu 中编译一下 Linux &#xff0c;因为编译过程会生成一些文件&#xff0c;而生成的这些恰恰是分析 Linux 不可或缺的文件。 二. Linux内核源码重要文件含义 编译后的 Linux内核源码重要的文件…

修改接口,字段的内容允许清空,避免歧义,参数校验:@NotNull

1. 问题描述 修改接口&#xff0c;字段的内容允许清空&#xff0c;是否应该做参数校验&#xff1f;如何做参数校验&#xff1f; 2. 说明 2.1. 需要对字段进行校验。 因为不校验&#xff0c;字段可能不传&#xff0c;或者字段的值为null&#xff1b;这样无法判断出&#xff…

【Linux基础】第27讲 Linux 查找和过滤命令(二)——grep命令

Grep命令 grep是根据文件的内容进行查找&#xff0c;会对文件的每一行按照给定的模式&#xff08;patter&#xff09;进行匹配查找 基本格式&#xff1a; grep [options]范围 [options] 主要参数 -c: 只输出匹配行的计数 -i : 不区分大小写 -n: 显示匹配行及行号 -w: 显示整个…

[Linux入门]---文本编辑器vim使用

文章目录 1.Linux编辑器-vim使用2.vim的基本概念4.vim正常模式命令集从正常模式进入插入模式从插入模式转换为命令模式移动光标删除文字复制替换撤销更改跳至指定行 5.vim末行模式命令集5.总结 1.Linux编辑器-vim使用 vi/vim作为Linux开发工具之一&#xff0c;从它的键盘操作图…

驱动开发练习,platform实现如下功能

实验要求 驱动代码 #include <linux/init.h> #include <linux/module.h> #include <linux/platform_device.h> #include <linux/mod_devicetable.h> #include <linux/of_gpio.h> #include <linux/unistd.h> #include <linux/interrupt…

PDCA循环

目录 1.认识PDCA&#xff1a; 2.PDCA循环的经典案例 3.PDCA的四个阶段和八个步骤 4.PDCA循环的优缺点&#xff1a; 5.案例 6.其他作用 1.认识PDCA&#xff1a; PDCA循环最早由美国质量统计控制之父Shewhat&#xff08;休哈特&#xff09;提出的PDS&#xff08;Plan Do Se…

hadoop3.x搭建到集群调优

一、基础环境安装 https://blog.csdn.net/fen_dou_shao_nian/article/details/120945221 二、hadoop运行环境搭建 2.1 模板虚拟机环境准备 0&#xff09;安装模板虚拟机&#xff0c;IP 地址 192.168.10.100、主机名称 hadoop100、内存 4G、硬盘 50G 1&#xff09;hadoop100…

【Html】用CSS定义咖啡 - 咖啡配料展示

显示效果 代码 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>CodePen - For The Love Of Coffee</title><link rel"stylesheet" href"./style.css">&l…

阿里云服务器部署安装hadoop与elasticsearch踩坑笔记

2023-09-12 14:00——2023.09.13 20:06 目录 00、软件版本 01、阿里云服务器部署hadoop 1.1、修改四个配置文件 1.1.1、core-site.xml 1.1.2、hdfs-site.xml 1.1.3、mapred-site.xml 1.1.4、yarn-site.xml 1.2、修改系统/etc/hosts文件与系统变量 1.2.1、修改主机名解…

基于ENC28J60+uIP1.0+STM32的UDP Server实现,以及主动发送数据,几个关键的问题可算整明白了!

ENC28J60&#xff0c;是一款SPI接口的以太网PHYMAC芯片&#xff0c;实现以太网物理层和MAC层硬件通信。uIP是一个TCP/IP软件协议栈&#xff0c;实现TCP、UDP、ARP、ICMP等网络协议。STM32F103RCT6通过SPI接口与ENC28J60通讯&#xff0c;并移植uIP协议&#xff0c;实现一个小型的…

利用Linux虚拟化技术实现资源隔离和管理

在现代计算机系统中&#xff0c;资源隔离和管理是非常重要的&#xff0c;特别是在多租户环境下。通过利用Linux虚拟化技术&#xff0c;我们可以实现对计算资源&#xff08;如CPU、内存和存储&#xff09;的隔离和管理&#xff0c;以提供安全、高效、稳定的计算环境。下面将详细…

如何将内网ip映射到外网?快解析内网穿透

关于内网ip映射到外网的问题&#xff0c;就是网络地址转换&#xff0c;私网借公网。要实现这个&#xff0c;看起来说得不错&#xff0c;实际上是有前提条件的。要实现内网ip映射到外网&#xff0c;首先要有一个固定的公网IP&#xff0c;可以从运营商那里得到。当你得到公网IP后…