动手学深度学习(Pytorch版)代码实践 -计算机视觉-38实战Kaggle比赛:图像分类 (CIFAR-10)

38实战Kaggle比赛:图像分类 (CIFAR-10)

比赛链接:CIFAR-10 - Object Recognition in Images | Kaggle

导入包
import os
import glob
import pandas as pd
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp
import csv
预处理:数据集分析
# 获取精简数据集
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True
if demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/cifar-10/'train_path = '../data/kaggle_cifar10_tiny/train.csv'
file_path = '../data/kaggle_cifar10_tiny/'# 读取数据
train_data = pd.read_csv(train_path)
# 查看数据
print(train_data['label'].value_counts())
# """
# label
# automobile    112
# frog          107
# truck         103
# horse         102
# airplane      102
# deer           99
# bird           99
# ship           99
# cat            92
# dog            85
# """
1.数据处理与加载
train_path = '../data/kaggle_cifar10_tiny/train.csv'
test_path = '../data/kaggle_cifar10_tiny/test.csv'
file_path = '../data/kaggle_cifar10_tiny/'# 统计label种类,并排序
cifar_labels = sorted(list(set(train_data['label'])))
# 将label对应编号
labels_to_num = dict(zip(cifar_labels, range(len(cifar_labels))))
# print(labels_to_num)
"""
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
"""
# 将编号对应label,用于后续预测
num_to_labels = {value : key for key, value in labels_to_num.items()}
# print(num_to_labels)
"""
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 
5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
"""def get_image_filenames(folder_path, extensions=['.png', '.jpg', '.jpeg']):# 获取指定文件夹中的所有图片文件image_files = []for ext in extensions:image_files.extend(glob.glob(os.path.join(folder_path, f'*{ext}')))# 返回图片文件名列表return [os.path.basename(image) for image in image_files]def save_filenames_to_csv(filenames, csv_path):with open(csv_path, mode='w', newline='', encoding='utf-8') as file:writer = csv.writer(file)# 写入CSV的第一行writer.writerow(['id'])# 写入每个文件名for filename in filenames:writer.writerow([filename])# 获取测试图片名
test_images_path = '../data/kaggle_cifar10_tiny/test'
image_filenames = get_image_filenames(test_images_path)
# 保存到CSV文件
save_filenames_to_csv(image_filenames, file_path + 'test.csv') class CifarDataset(Dataset):def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height=224, resize_width=224):"""初始化 LeavesDataset 对象。参数:csv_path (str): 包含图像路径和标签的 CSV 文件路径。file_path (str): 图像文件所在目录的路径。mode (str, optional): 数据集的模式。可以是 'train', 'valid' 或 'test'。默认值为 'train'。valid_ratio (float, optional): 用于验证的数据比例。默认值为 0.2。resize_height (int, optional): 调整图像高度的大小。默认值为 224。resize_width (int, optional): 调整图像宽度的大小。默认值为 224。"""# 存储图像调整大小的高度和宽度self.resize_height = resize_heightself.resize_width = resize_width# 存储图像文件路径和模式(train/valid/test)if mode == 'train' or mode == 'valid':self.file_path = file_path + 'train/'else:self.file_path = file_path + 'test/'self.mode = mode# 读取包含图像路径和标签的 CSV 文件self.data_info = pd.read_csv(csv_path, header=0)# 获取样本总数self.data_len = len(self.data_info.index)# 计算训练集样本数self.train_len = int(self.data_len * (1 - valid_ratio))# 根据模式处理数据if self.mode == 'train':# 训练模式下的图像和标签self.train_img = np.asarray(self.data_info.iloc[0:self.train_len, 0])self.train_label = np.asarray(self.data_info.iloc[0:self.train_len, 1])self.image_arr = self.train_imgself.label_arr = self.train_labelelif self.mode == 'valid':# 验证模式下的图像和标签self.valid_img = np.asarray(self.data_info.iloc[self.train_len:, 0])self.valid_label = np.asarray(self.data_info.iloc[self.train_len:, 1])self.image_arr = self.valid_imgself.label_arr = self.valid_labelelif self.mode == 'test':# 测试模式下的图像self.test_img = np.asarray(self.data_info.iloc[:, 0])self.image_arr = self.test_img# 获取图像数组的长度self.len_image = len(self.image_arr)print(f'扫描所有 {mode} 数据,共 {self.len_image} 张图像')def __getitem__(self, idx):"""获取指定索引的图像和标签。参数:idx (int): 标签文本对应编号的索引返回:如果是测试模式,返回图像张量;否则返回图像张量和标签。"""# 打开图像文件if self.mode == 'test':self.img = Image.open(self.file_path + str(self.image_arr[idx]))else :self.img = Image.open(self.file_path + str(self.image_arr[idx]) + '.png')if self.mode == 'train':# 训练模式下的数据增强trans =torchvision.transforms.Compose([torchvision.transforms.Resize((self.resize_height, self.resize_width)),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.RandomVerticalFlip(p=0.5),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),ratio=(1.0, 1.0)),torchvision.transforms.RandomRotation(degrees=30),# torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),# torchvision.transforms.RandomResizedCrop(size=self.resize_height, scale=(0.8, 1.0)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])self.img = trans(self.img)else:# 验证和测试模式下的简单处理trans = torchvision.transforms.Compose([torchvision.transforms.Resize((self.resize_height, self.resize_width)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])self.img = trans(self.img)if self.mode == 'test':return self.imgelse:# 获取标签文本对应的编号self.label = labels_to_num[self.label_arr[idx]]return self.img, self.labeldef __call__(self, idx):"""使对象可以像函数一样被调用。参数:idx (int):标签文本对应编号的索引返回:调用 __getitem__ 方法并返回结果。"""return self.__getitem__(idx)def __len__(self):"""获取数据集的长度。返回:数据集中图像的数量。"""return self.len_imagetrain_dataset = CifarDataset(train_path,file_path, mode='train', valid_ratio=0.1, resize_height=40, resize_width=40)
valid_dataset = CifarDataset(train_path, file_path, mode='valid',valid_ratio=0.1, resize_height=40, resize_width=40)
test_dataset = CifarDataset(test_path, file_path, mode='test',valid_ratio=0.1, resize_height=40, resize_width=40)batch_size = 32 
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
valid_iter = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0)
2.模型训练
def train_batch(net, X, y, loss, trainer, devices):"""使用多GPU训练一个小批量数据。参数:net: 神经网络模型。X: 输入数据,张量或张量列表。y: 标签数据。loss: 损失函数。trainer: 优化器。devices: GPU设备列表。返回:train_loss_sum: 当前批次的训练损失和。train_acc_sum: 当前批次的训练准确度和。"""# 如果输入数据X是列表类型if isinstance(X, list):# 将列表中的每个张量移动到第一个GPU设备X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备net.train() # 设置网络为训练模式trainer.zero_grad()# 梯度清零pred = net(X) # 前向传播,计算预测值l = loss(pred, y) # 计算损失l.sum().backward()# 反向传播,计算梯度trainer.step() # 更新模型参数train_loss_sum = l.sum()# 计算当前批次的总损失train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay,param_group=True):# trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,weight_decay=wd)trainer = torch.optim.Adam(net.parameters(), lr=lr,weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)loss = nn.CrossEntropyLoss(reduction="none")num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = lp.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch(net, features, labels,loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()train_l = metric[0] / metric[2] # 计算训练损失train_acc = metric[1] / metric[2] # 计算训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l , train_acc,None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch + 1, (None, None, valid_acc))scheduler.step()print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'valid_acc {valid_acc:.3f}')measures = (f'train loss {metric[0] / metric[2]:.3f}, 'f'train acc {metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'f' examples/sec on {str(devices)}')
3.定义超参数
# 定义模型
net = d2l.resnet18(len(cifar_labels),3)
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 100, 3e-4, 5e-4
lr_period, lr_decay = 4, 0.9
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
plt.show()
# train loss 0.153, train acc 0.955, valid acc 0.469
# 873.5 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

4.模型预测
# 针对测试集进行分类预测
def predict(net, data_loader, devices):"""使用模型进行预测参数:net (torch.nn.Module): 要进行预测的模型data_loader (torch.utils.data.DataLoader): 数据加载器,用于提供待预测的数据devices (list): 计算设备列表(CPU或GPU)返回:all_preds (list): 包含所有预测结果的列表"""all_preds = []  # 存储所有预测结果net.to(devices[0])  # 将模型移动到指定设备net.eval()  # 设置模型为评估模式with torch.no_grad():  # 在不需要计算梯度的上下文中进行for X in data_loader:  # 遍历数据加载器X = X.to(devices[0])  # 将数据移动到指定设备outputs = net(X)  # 前向传播,计算模型输出_, preds = torch.max(outputs, 1)  # 获取预测结果all_preds.extend(preds.cpu().numpy())  # 将预测结果添加到列表中return all_preds  # 返回所有预测结果# 调用预测函数
predictions = predict(net, test_iter, devices)
# 映射预测结果到标签
mapped_predictions = [num_to_labels[int(i)] for i in predictions]
# 读取测试数据
test_data = pd.read_csv(test_path)
# 将预测结果添加到测试数据中
test_data['label'] = pd.Series(mapped_predictions)
# 创建提交文件
submission = pd.concat([test_data['id'], test_data['label']], axis=1)
# 保存提交文件
submission.to_csv(file_path + 'submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38608.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R语言数据分析案例39-合肥市AQI聚类和多元线性回归

一、研究背景 随着全球工业化和城市化的迅速发展,空气污染问题日益凸显,已成为影响人类健康和环境质量的重大挑战。空气污染不仅会引发呼吸系统、心血管系统等多种疾病,还会对生态系统造成不可逆转的损害。因此,空气质量的监测和…

MySQL高阶:事务和并发

事务和并发 1. 事务创建事务 2. 并发和锁定并发问题 3. 事务隔离等级3.1 读取未提交隔离级别3.2 读取已提交隔离级别3.3 重复读取隔离级别3.4 序列化隔离级别 4. 死锁 1. 事务 事务(trasaction)是完成一个完整事件的一系列SQL语句。这一组SQL语句是一条…

经典小游戏(一)C实现——三子棋

switch(input){case 1:printf("三子棋\n");//这里先测试是否会执行成功break;case 0:printf("退出游戏\n");break;default :printf("选择错误,请重新选择!\n");break;}}while(input);//直到输入的结果为假,循环才会结束} …

go Channel原理 (二)

Channel 设计原理 不要通过共享内存的方式进行通信,而是应该通过通信的方式共享内存。 在主流编程语言中,多个线程传递数据的方式一般都是共享内存。 Go 可以使用共享内存加互斥锁进行通信,同时也提供了一种不同的并发模型,即通…

error: Sandbox: rsync.samba in Xcode project

在Targets 的 Build Settings 搜索:User script sandboxing 设置为NO

python课程设计作业-TCP客户端-服务端通信

说明文档 目录 小组成员分工 作品功能介绍 使用的工具和方法 设计的步骤 课程设计中遇到的问题 结论 1. 小组成员分工 本次课程设计由以下小组成员完成: xxx 2. 作品功能介绍 本次课程设计的作品是一个简单的基于 TCP 协议的客户端-服务端通信示例。通过这个示…

【SpringBoot Web框架实战教程】06 SpringBoot 整合 Druid

不积跬步,无以至千里;不积小流,无以成江海。大家好,我是闲鹤,微信:xxh_1459,十多年开发、架构经验,先后在华为、迅雷服役过,也在高校从事教学3年;目前已创业了…

阿里云centos7.9 挂载数据盘到 www目录

一、让系统显示中文 参考:centos7 怎么让命令行显示中文(英文->中文)_如何在命令行中显示中文-CSDN博客 1、输入命令:locale -a |grep "zh_CN" 可以看到已经存在了中文包 2、输入命令:sudo vi…

AGPT•intelligence:带你领略全新量化交易的风采

随着金融科技的快速发展,量化交易已经成为了投资领域的热门话题。越来越多的投资者开始关注和使用量化交易软件来进行投资决策。在市场上有许多量化交易软件可供选择。 Delaek,是一位资深的金融科技专家,在 2020年成立一家专注于数字资产量化…

第一后裔延迟高怎么办?快速降低第一后裔延迟

第一后裔/The First Descendant一款射击游戏,融合了刷宝、角色扮演、团队合作、剧情等元素,让每个玩家都能在自己的角度上,找到切入点,并不断地成长,一步步解开后裔身上隐藏的秘密。近期该作正式上线,很多玩…

vue项目创建+eslint+Prettier+git提交规范(commitizen+hooks+husk)

# 步骤 1、使用 vue-cli 创建项目 这一小节我们需要创建一个 vue3 的项目,而创建项目的方式依然是通过 vue-cli 进行创建。 不过这里有一点大家需要注意,因为我们需要使用最新的模板,所以请保证你的 vue-cli 的版本在 4.5.13 以上&#xff…

Debian linux忘记root密码如何重置

重启电脑, 到下图再按 e 键 在页面中可以看到有个ro的行,在ro行的尾部,添加 rw init/bin/bas 3. ctrl X 启动系统,最后会进入命令行模式 4. 重设root密码,输入命令 passwd root,按照提示输入新密码并确认 5. 重启系…

基于Python的自动化测试框架-Pytest总结-第一弹基础

Pytest总结第一弹基础 入门知识点安装pytest运行pytest测试用例发现规则执行方式命令行执行参数 配置发现规则 如何编写测试Case基础案例断言语句的使用pytest.fail() 和 Exceptions自定义断言函数异常测试测试类形式 pytest的Fixture使用Fixture入门案例使用fixture的Setup、T…

昇思25天学习打卡营第8天|模型训练

昇思25天学习打卡营第8天|模型训练 前言模型训练构建数据集定义神经网络模型定义超参、损失函数和优化器超参损失函数优化器 训练与评估 个人任务打卡(读者请忽略)个人理解与总结 前言 非常感谢华为昇思大模型平台和CSDN邀请体验昇思大模型!从…

linux中如何启动python虚拟环境

找到python虚拟环境所在目录 执行下面的命令即可 source auth_python/bin/activate

【遇坑笔记】Node.js 开发环境与配置 Visual Studio Code

【遇坑笔记】Node.js 开发环境与配置 Visual Studio Code 前言node.js开发环境配置解决pnpm 不是内部或外部命令的问题(pnpm安装教程) 解决 pnpm : 无法加载文件 C:\Program Files\nodejs\pnpm.ps1,因为在此系统上禁止运行脚本。 前言 最近部…

【代码随想录】【算法训练营】【第49天】 [300]最长递增子序列 [674]最长连续递增序列 [718]最长重复子数组

前言 思路及算法思维,指路 代码随想录。 题目来自 LeetCode。 day 49,周二,坚持不了一点~ 题目详情 [300] 最长递增子序列 题目描述 300 最长递增子序列 解题思路 前提:最大递增子序列的长度 思路:动态规划 d…

基于X86+FPGA的精密加工检测设备解决方案

应用场景 随着我国高新技术的发展和国防现代化发展,航空、航天等领域需 要的大型光电子器件,微型电子机械、 光 电信息等领域需要的微型器件,还有一些复杂零件的加工需求日益增加,这些都需要借助精密甚至超精密的加工检测设备 客…

esp12实现的网络时钟校准

网络时间的获取是通过向第三方服务器发送GET请求获取并解析出来的。 在本篇博客中,网络时间的获取是一种自动的行为,当系统成功连接WiFi获取到网络天气后,系统将自动获取并解析得到时间和日期,为了减少误差每两分钟左右进行一次校…

web平台—apache

web平台—apache 1. 学apache前需要知道的知识点2. apache详解2.1 概述2.2 工作模式2.3 启动apache网站整体流程2.4 相关文件保存位置2.5 配置文件详解 3. apache配置实验实验1:设置apache的目录别名实验2:apache的用户认证实验3:虚拟主机 (重…