(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

  1. 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = Trueif demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
# 查看数据集
def read_csv_labels(fname):"""读取‘fname’来给标签字典返回一个文件名"""with open(fname, 'r') as f:lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)return dict((name, label) for name, label in tokens)labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

在这里插入图片描述
将验证集从原始的训练集钟拆分出来

# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):"""将文件复制到目标目录"""os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹def reorg_train_valid(data_dir, labels, valid_ratio):"""将验证集从原始训练集钟拆分出来"""# 训练数据集中样本数量最少的类别中的样本数# Counter: 计数器,返回一个字典,键为元素,值为元素个数;# .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序# [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数n = collections.Counter(labels.values()).most_common()[-1][1]# 验证集中每个类别的样本数n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整label_count = {}# 遍历原始训练集中的每个样本for train_file in os.listdir(os.path.join(data_dir, 'train')):label = labels[train_file.split('.')[0]]  # 从文件名中提取标签fname = os.path.join(data_dir, 'train', train_file)copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))# 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))label_count[label] = label_count.get(label, 0) + 1else:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))return n_valid_per_label# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):"""在预测期间整理测试集,以方便读取"""# 遍历测试集中的每个样本for test_file in os.listdir(os.path.join(data_dir, 'test')):# 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'copyfile(os.path.join(data_dir, 'test', test_file),os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))reorg_train_valid(data_dir, labels, valid_ratio)reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
transform_train = torchvision.transforms.Compose([# 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强torchvision.transforms.Resize(40),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),# 标准化图像的每个通道 : 消除评估结果中的随机性torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
  1. 加载数据集
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():num_classes = 10net = d2l.resnet18(num_classes, in_channels=3)return net
# 查看网络模型
get_net()

在这里插入图片描述

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = d2l.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0]/ metric[2], metric[1] / metric[2], None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch+1, (None, None, valid_acc))scheduler.step()measures = (f'train loss {metric[0] / metric[2]:.3f},'f'train acc{metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述
10. 对测试集进行分类并提交结果

net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:y_hat = net(X.to(devices[0]))preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152423.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ElasticSearch在Windows上的下载与安装

Elasticsearch是一个开源的分布式搜索和分析引擎&#xff0c;它可以帮助我们快速地搜索、分析和处理大量数据。Elasticsearch能够快速地处理结构化和非结构化数据&#xff0c;支持全文检索、地理位置搜索、自动补全、聚合分析等功能&#xff0c;能够承载各种类型的应用&#xf…

基于Qt中操作MySQL数据库示例

# 一、安装驱动 ##(1)安装 在Qt中操作MySQL数据库首先要安装mysql的驱动文件,将MySQL下的libmusql.dll文件复制到Qt的安装路径下的bin文件夹下即可。 直接将libmysql.dll文件粘贴到此文件夹中。 ## (2)验证驱动是否安装成功 复制成功之后来测试一下驱动程序是否安装成功…

◢Django 分页+搜索

1、搜索数据 从数据库中获取数据&#xff0c;并进行筛选&#xff0c;xx__contains q作为条件&#xff0c;查找的是xx列中有q的所有数据条 当有多个筛选条件时&#xff0c;将条件变成一个字典&#xff0c;传入 **字典 &#xff0c;ORM会自行翻译并查找。 筛选电话号码这一列…

Java获取指定日期到当前日期的差距

Java获取指定日期到当前日期的差距 一、指定日期到今天的y年m月d日 private JSONObject getYesrMonthDay(String dataParam){JSONObject res new JSONObject();/*只比较年月日&#xff0c;不要时间*/DateTimeFormatter df DateTimeFormatter.ofPattern("yyyy-MM-dd&quo…

在Go编程中调用外部命令的几种场景

1.摘要 在很多场合, 使用Go语言需要调用外部命令来完成一些特定的任务, 例如: 使用Go语言调用Linux命令来获取执行的结果,又或者调用第三方程序执行来完成额外的任务。在go的标准库中, 专门提供了os/exec包来对调用外部程序提供支持, 本文将对调用外部命令的几种使用方法进行总…

【机器学习】034_多层感知机Part.2_从零实现多层感知机

一、解决XOR问题 1. 回顾XOR问题&#xff1a; 如图&#xff0c;如何对XOR面进行分割以划分四个输入 对应的输出 呢&#xff1f; 思路&#xff1a;采用两个分类器分类&#xff0c;每次分出两个输入 &#xff0c;再借助这两个分类从而分出 。 即采用同或运算&#xff0c;当两…

通过easyexcel实现数据导入功能

上一篇文章通过easyexcel导出数据到excel表格已经实现了简单的数据导出功能&#xff0c;这篇文章也介绍一下怎么通过easyexcel从excel表格中导入数据。 目录 一、前端代码 index.html index.js 二、后端代码 controller service SongServiceImpl 三、功能预览 四、后端…

WordPress画廊插件Envira Gallery v1.9.7河蟹版下载

Envira Gallery是一款功能强大的WordPress画廊插件。通过使用这个插件&#xff0c;你可以在WordPress的前台页面上创建出令人赏心悦目的图片画廊展示形式。 拖放生成器&#xff1a;轻松创建精美照片和视频画廊 自定义主题&#xff0c;打造独特外观 使用预设模板&#xff0c;为…

分类预测 | Matlab实现基于PSO-SDAE粒子群优化算法优化堆叠去噪自编码器的数据分类预测

分类预测 | Matlab实现基于PSO-SDAE粒子群优化算法优化堆叠去噪自编码器的数据分类预测 目录 分类预测 | Matlab实现基于PSO-SDAE粒子群优化算法优化堆叠去噪自编码器的数据分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现基于PSO-SDAE粒子群优化算法…

使用树莓派学习Linux系统编程的 --- 库编程(面试重点)

在之前的Linux系统编程中&#xff0c;学习了文件的打开&#xff1b;关闭&#xff1b;读写&#xff1b;进程&#xff1b;线程等概念.... 本节补充“Linux库概念 & 相关编程”&#xff0c;这是一个面试的重点&#xff01; 分文件编程 在之前的学习中&#xff0c;面对较大的…

算法---相等行列对

题目 给你一个下标从 0 开始、大小为 n x n 的整数矩阵 grid &#xff0c;返回满足 Ri 行和 Cj 列相等的行列对 (Ri, Cj) 的数目。 如果行和列以相同的顺序包含相同的元素&#xff08;即相等的数组&#xff09;&#xff0c;则认为二者是相等的。 示例 1&#xff1a; 输入&…

原理Redis-Dict字典

Dict 1) Dict组成2) Dict的扩容3) Dict的收缩4) Dict的rehash5) 总结 1) Dict组成 Redis是一个键值型&#xff08;Key-Value Pair&#xff09;的数据库&#xff0c;可以根据键实现快速的增删改查。而键与值的映射关系正是通过Dict来实现的。 Dict由三部分组成&#xff0c;分别…

无重复最长字符串(最长无重复子字符串),剑指offer,力扣

目录 原题&#xff1a; 力扣地址&#xff1a; 我们直接看题解吧&#xff1a; 解题方法&#xff1a; 难度分析&#xff1a; 难度算中下吧&#xff0c;这个总体不算很难&#xff0c;而且滑动窗口&#xff0c;以及哈希都比较常见 审题目事例提示&#xff1a; 解题思路&#xff08;…

vue3 setup展示数据

效果图 1.创建数据 content.js import { reactive } from vueconst data reactive({color:red,title: 二十四节气,subTitle: 节气&#xff0c;是干支历中表示自然节律变化以及确立“十二月建”&#xff08;月令&#xff09;的特定节令。,list: [{name: "立春",con…

代码随想录-刷题第二天

977. 有序数组的平方 题目链接&#xff1a;977. 有序数组的平方 思路&#xff1a;双指针思想&#xff0c;数组是有序的且含有负数&#xff0c;其中元素的平方一定是两边最大。定义两个指针&#xff0c;从两端开始向中间靠近&#xff0c;每次比较两个指针的元素平方大小&#…

线上bug-接口速度慢

&#x1f47d;System.out.println(“&#x1f44b;&#x1f3fc;嗨&#xff0c;大家好&#xff0c;我是代码不会敲的小符&#xff0c;双非大四&#xff0c;Java实习中…”); &#x1f4da;System.out.println(“&#x1f388;如果文章中有错误的地方&#xff0c;恳请大家指正&a…

腾讯云助力港华能源上线“碳汭星云2.0”,推动能源行业绿色低碳转型

11月17日&#xff0c;港华能源与腾讯云联合打造的港华智慧能源生态平台“碳汭星云2.0”升级上线。依托双方的连接、大数据能力和行业深耕经验&#xff0c;该平台打破了园区“数据孤岛”&#xff0c;进一步提升了数据治理、应用集成和复制推广能力&#xff0c;未来有望以综合能源…

【小呆的力学笔记】有限元专题之循环对称结构有限元原理

文章目录 1. 循环对称问题的提出2. 循环对称条件2.1 节点位移的循环对称关系2.2 节点内力的循环对称关系 3. 在平衡方程中引入循环对称条件 1. 循环对称问题的提出 许多工程结构都是其中某一扇面的n次周向重复&#xff0c;也就是是周期循环对称结构。如果弹性体的几何形状、约…

【每日刷题——语音信号篇】

思考与练习 练习2.1 语音信号在产生的过程中&#xff0c;以及被感知的过程中&#xff0c;分别要经过人体的哪些器官&#xff1f; 1.产生过程&#xff1a; 肺部空气 → \rightarrow →冲击声带 → \rightarrow →通过声道&#xff08;可以调节&#xff09; → \rightarrow →…

IDEA自动注解设置(中文版)

IDEA自动注解设置 1、添加类自动注释 文件 - 设置 - 编辑器 - 文件和代码模板 - Include - File Header /** *description&#xff1a;TODO *author&#xff1a; ${USER} *create&#xff1a; ${DATE} ${TIME} */2、添加类方法自动注释 文件 - 设置 - 编辑器 - 实时模版 - …