基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试

基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试


1. 引言

在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。
在这里插入图片描述


2. 环境准备

首先,确保已安装以下 Python 库:

pip install torch torchvision pandas d2l
  • torch:PyTorch 核心库。
  • torchvision:提供计算机视觉相关的工具。
  • pandas:用于处理 CSV 文件。
  • d2l:深度学习工具库,提供辅助函数。

3. 数据准备

竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public

3.1 数据集结构

假设数据集位于 classify-leaves 目录下,包含以下文件:

classify-leaves/
├── train.csv
├── test.csv
├── images/├── image1.jpg├── image2.jpg...
  • train.csv:包含训练图像的路径和标签。
  • test.csv:包含测试图像的路径。

3.2 数据加载与预处理

import os
import pandas as pd
import randomimgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i
  • num2name:获取所有类别标签,并按类别数量排序。
  • name2num:将类别名称映射到数字编号。

4. 自定义数据集类

为了加载数据,我们需要定义一个自定义数据集类 Leaf_data

from torch.utils.data import Dataset
from d2l import torch as d2lclass Leaf_data(Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)
  • __getitem__:根据索引返回一个样本,包括图像和标签。
  • __len__:返回数据集的长度。

5. 模型定义与初始化

我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:

import torch
import torchvision
from torch import nndef init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
  • init_weight:使用 Xavier 初始化方法初始化全连接层的权重。
  • net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。

6. 训练过程

6.1 优化器与损失函数

lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
  • trainer:使用 Adam 优化器,对全连接层使用更高的学习率。
  • LR_con:使用余弦退火学习率调度器。
  • loss:使用交叉熵损失函数。

6.2 训练函数


def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc
  • train:训练模型,记录损失和准确率,并在验证集上评估模型。

7. 测试与结果保存

在测试集上进行预测,并保存结果到 CSV 文件:

net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
  • test_dataloader:加载测试数据。
  • res:保存预测结果到 CSV 文件。

8. 总结

本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:

  1. 自定义数据集类的实现。
  2. 使用预训练模型进行迁移学习。
  3. 训练模型并保存最佳模型。
  4. 在测试集上进行预测并生成提交文件。

希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊

完整代码

import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i# GPU 检查
def try_gpu():if torch.cuda.device_count() > 0:return torch.device('cuda')return torch.device('cpu')# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')def save_model(net):torch.save(net.state_dict(), model_path)# 自定义数据集类
class Leaf_data(Data.Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc# 模型初始化
def init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(dataset=Leaf_data(imgpath, True, augs),lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)

参考链接:

  • PyTorch 官方文档
  • torchvision 官方文档
  • d2l 深度学习工具库

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/69757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytest测试专题 - 1.2 如何获得美观的测试报告

<< 返回目录 1 pytest测试专题 - 1.2 如何获得美观的测试报告 1.1 背景 虽然pytest命令的报文很详细&#xff0c;用例在执行调试时还算比较方便阅读和提取失败信息&#xff0c; 但对于大量测试用例运行时&#xff0c;可能会存在以下不足 报文被冲掉测试日志没法归档 …

压缩stl文件大小

1、MeshLab下载界面&#xff0c;从MeshLab下载适合自己系统的最新版本。 2、打开 MeshLab软件&#xff0c;将stl文件拖入其中。 3、 4、Percentage reduction参数即为缩放比例&#xff0c;根据自身想要将文件压缩到多大来。 然后点击apply 5、CtrlE弹出窗口保存文件后&…

FPGA 28 ,基于 Vivado Verilog 的呼吸灯效果设计与实现( 使用 Vivado Verilog 实现呼吸灯效果 )

目录 前言 一. 设计流程 1.1 需求分析 1.2 方案设计 1.3 PWM解析 二. 实现流程 2.1 确定时间单位和精度 2.2 定义参数和寄存器 2.3 实现计数器逻辑 2.4 控制 LED 状态 三. 整体流程 3.1 全部代码 3.2 代码逻辑 1. 参数定义 2. 分级计数 3. 状态切换 4. LED 输…

【云安全】云原生-K8S- API Server 未授权访问

API Server 是 Kubernetes 集群的核心管理接口&#xff0c;所有资源请求和操作都通过 kube-apiserver 提供的 API 进行处理。默认情况下&#xff0c;API Server 会监听两个端口&#xff1a;8080 和 6443。如果配置不当&#xff0c;可能会导致未授权访问的安全风险。 8080 端口…

DeepSeek的大模型介绍

文章目录 DeepSeek是什么DeepSeek平台使用DeepSeek的使用场景DeepSeek的本地部署 DeepSeek是什么 DeepSeek是一家2023/7月年成立的人工智能公司&#xff0c;致力于开发高效、高性能的生成式AI模型&#xff0c;在短短一年多的时间里推出了多款强大的开源模型&#xff0c;包括De…

问卷数据分析|SPSS实操之单因素方差分析

适用条件&#xff1a; 检验分类变量和定量变量之间的差异 分类变量数量要大于等于三 具体操作&#xff1a; 1.选择分析--比较平均值--单因素ANOVA检验 2. 下方填分类变量&#xff0c;上方为各个量表数据Z1-Y2 3. 点击选项&#xff0c;选择描述和方差齐性检验 4.此处为结果数…

信息收集-主机服务器系统识别IP资产反查技术端口扫描协议探针角色定性

知识点&#xff1a; 1、信息收集-服务器系统-操作系统&IP资产 2、信息收集-服务器系统-端口扫描&服务定性 一、演示案例-应用服务器-操作系统&IP资产 操作系统 1、Web大小写(windows不区分大小写&#xff0c;linux区分大小写) 2、端口服务特征(22就是linux上的服…

vmware安装win7

1、版本说明 vmware workstation 16 win7 X64 2、安装步骤 安装步骤有点独特&#xff0c;先配置虚拟机&#xff0c;然后再虚拟机的虚拟光驱里添加下载的win7。 配置完了之后&#xff0c;点击要运行的虚拟机&#xff0c;然后一直往下走就可以完成系统的安装。 3、配置系统以解…

【ThreeJS Basics 1-3】Hello ThreeJS,实现第一个场景

文章目录 环境创建一个项目安装依赖基础 Web 页面概念解释编写代码运行项目 环境 我的环境是 node version 22 创建一个项目 首先&#xff0c;新建一个空的文件夹&#xff0c;然后 npm init -y , 此时会快速生成好默认的 package.json 安装依赖 在新建的项目下用 npm 安装依…

Linux下的进程切换与调度

目录 1.进程的优先级 优先级是什么 Linux下优先级的具体做法 优先级的调整为什么要受限 2.Linux下的进程切换 3.Linux下进程的调度 1.进程的优先级 我们在使用计算机的时候&#xff0c;通常会启动多个程序&#xff0c;这些程序最后都会变成进程&#xff0c;但是我们的硬…

使用 EMQX 接入 LwM2M 协议设备

LwM2M 协议介绍 LwM2M 是一种轻量级的物联网设备管理协议&#xff0c;由 OMA&#xff08;Open Mobile Alliance&#xff09;组织制定。它基于 CoAP &#xff08;Constrained Application Protocol&#xff09;协议&#xff0c;专门针对资源受限的物联网设备设计&#xff0c;例…

【Java 面试 八股文】Redis篇

Redis 1. 什么是缓存穿透&#xff1f;怎么解决&#xff1f;2. 你能介绍一下布隆过滤器吗&#xff1f;3. 什么是缓存击穿&#xff1f;怎么解决&#xff1f;4. 什么是缓存雪崩&#xff1f;怎么解决&#xff1f;5. redis做为缓存&#xff0c;mysql的数据如何与redis进行同步呢&…

第二天:工具的使用

每天上午9点左右更新一到两篇文章到专栏《Python爬虫训练营》中&#xff0c;对于爬虫有兴趣的伙伴可以订阅专栏一起学习&#xff0c;完全免费。 键盘为桨&#xff0c;代码作帆。这趟为期30天左右的Python爬虫特训即将启航&#xff0c;每日解锁新海域&#xff1a;从Requests库的…

MySQL8.0 innodb Cluster 高可用集群部署(MySQL、MySQL Shell、MySQL Router安装)

简介 MySQL InnoDB集群&#xff08;Cluster&#xff09;提供了一个集成的&#xff0c;本地的&#xff0c;HA解决方案。Mysq Innodb Cluster是利用组复制的 pxos 协议&#xff0c;保障数据一致性&#xff0c;组复制支持单主模式和多主模式。 InnoDB Cluster组件&#xff1a; …

Unity-Mirror网络框架-从入门到精通之LagCompensation示例

文章目录 前言什么是滞后补偿Lag Compensation示例延迟补偿原理ServerCubeClientCubeCapture2DSnapshot3D补充LagCompensation.cs 独立算法滞后补偿器组件注意:算法最小示例前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。本系列文章将为读者提供对Mir…

初窥强大,AI识别技术实现图像转文字(OCR技术)

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ &#x1f434;作者&#xff1a;秋无之地 &#x1f434;简介&#xff1a;CSDN爬虫、后端、大数据、人工智能领域创作者。目前从事python全栈、爬虫和人工智能等相关工作&#xff0c;主要擅长领域有&#xff1a;python…

不小心删除服务[null]后,git bash出现错误

不小心删除服务[null]后&#xff0c;git bash出现错误&#xff0c;如何解决&#xff1f; 错误描述&#xff1a;打开 git bash、msys2都会出现错误「bash: /dev/null: No such device or address」 问题定位&#xff1a; 1.使用搜索引擎搜索「bash: /dev/null: No such device o…

k8s部署logstash

1. 编写logstash.yaml配置文件 --- apiVersion: v1 kind: Service metadata:name: logstash spec:type: ClusterIPclusterIP: Noneports:- name: logstash-tcpport: 5000targetPort: 5000- name: logstash-beatsport: 5044targetPort: 5044- name: logstash-apiport: 9600targ…

【JVM详解一】类加载过程与内存区域划分

一、简介 1.1 概述 JVM是Java Virtual Machine&#xff08;Java虚拟机&#xff09;的缩写&#xff0c;是通过在实际的计算机上仿真模拟各种计算机功能来实现的。由一套字节码指令集、一组寄存器、一个栈、一个垃圾回收堆和一个存储方法域等组成。JVM屏蔽了与操作系统平台相关…

25考研电子信息复试面试常见核心问题真题汇总,电子信息考研复试没有项目怎么办?电子信息考研复试到底该如何准备?

你是不是在为电子信息考研复试焦虑&#xff1f;害怕被老师问到刁钻问题、担心专业面答不上来&#xff1f;别慌&#xff01;作为复试面试92分逆袭上岸的学姐&#xff0c;今天手把手教你拆解电子信息类复试通关密码&#xff01;看完这篇&#xff0c;让你面试现场直接开大&#xf…