动手学深度学习(Pytorch版)代码实践 -计算机视觉-39实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

39实战Kaggle比赛:狗的品种识别(ImageNet Dogs

比赛链接:Dog Breed Identification | Kaggle

1.导入包
import torch
from torch import nn
import collections
import math
import os
import shutil
import torchvision
from d2l import torch as d2l
import matplotlib.pyplot as plt
import liliPytorch as lp
2.数据集处理
# 精简数据集
# file_path = '../data/kaggle_dog_tiny/'
# 原数据集
file_path = '../data/dog-breed-identification/'# 整理数据集
# 从原始训练集中拆分验证集,然后将图像移动到按标签分组的子文件夹中。
#@save
def read_csv_labels(fname):"""读取CSV文件中的标签,它返回一个字典,该字典将文件名中不带扩展名的部分映射到其标签"""with open(fname, 'r') as f:# 跳过文件头行(列名)lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name, label) for name, label in tokens))# labels = read_csv_labels(os.path.join(file_path, 'labels.csv'))
# print(labels) # {'0097c6242c6f3071762d9f85c3ef1b2f': 'bedlington_terrier', '00a338a92e4e7bf543340dc849230e75': 'dingo'}
# print('训练样本 :', len(labels)) # 训练样本 : 1000
# print('类别 :', len(set(labels.values()))) # 类别 : 120# 定义reorg_train_valid函数来将验证集从原始的训练集中拆分出来
#@save
def copyfile(filename, target_dir):"""将文件复制到目标目录"""os.makedirs(target_dir, exist_ok=True)shutil.copy(filename, target_dir)#@save
def reorg_train_valid(data_dir, labels, valid_ratio):"""将验证集从原始的训练集中拆分出来"""# 训练数据集中样本最少的类别中的样本数n = collections.Counter(labels.values()).most_common()[-1][1]# 验证集中每个类别的样本数n_valid_per_label = max(1, math.floor(n * valid_ratio))label_count = {}for train_file in os.listdir(os.path.join(data_dir, 'train')): # 遍历训练集文件夹中的所有文件。label = labels[train_file.split('.')[0]] # 获取文件名(去掉扩展名)fname = os.path.join(data_dir, 'train', train_file) # 构建完整的文件路径copyfile(fname, os.path.join(data_dir, 'train_valid_test','train_valid', label))if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname, os.path.join(data_dir, 'train_valid_test','valid', label))label_count[label] = label_count.get(label, 0) + 1else:copyfile(fname, os.path.join(data_dir, 'train_valid_test','train', label))return n_valid_per_label# reorg_test函数用来在预测期间整理测试集
#@save
def reorg_test(data_dir):"""在预测期间整理测试集,以方便读取"""for test_file in os.listdir(os.path.join(data_dir, 'test')):copyfile(os.path.join(data_dir, 'test', test_file),os.path.join(data_dir, 'train_valid_test', 'test','unknown'))def reorg_dog_data(data_dir, valid_ratio):labels = read_csv_labels(os.path.join(data_dir, 'labels.csv'))reorg_train_valid(data_dir, labels, valid_ratio)reorg_test(data_dir)reorg_dog_data(file_path, valid_ratio = 0.1)
3.数据集加载
# 数据图像增广
# 训练
transform_train = torchvision.transforms.Compose([# 随机裁剪图像,所得图像为原始面积的0.08~1之间,高宽比在3/4和4/3之间。# 然后,缩放图像以创建224x224的新图像torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0),ratio=(3.0/4.0, 4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),# 随机更改亮度,对比度和饱和度torchvision.transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),# 添加随机噪声torchvision.transforms.ToTensor(),# 标准化图像的每个通道torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
# 测试
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),# 从图像中心裁切224x224大小的图片torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])# 读取数据集
# 创建数据集对象
# 通常用于定义数据源及其预处理方法。
train_dataset, train_valid_dataset = [# ImageFolder 创建数据集时,它会遍历指定目录下的所有子文件夹,# 并将每个子文件夹的名称作为一个类别标签。然后,它会按字母顺序给每个类别分配一个索引torchvision.datasets.ImageFolder(os.path.join(file_path, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']]valid_dataset, test_dataset = [torchvision.datasets.ImageFolder(os.path.join(file_path, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']]# 显示每个类别名称和对应的索引
# print(train_dataset.class_to_idx) 4
# {'affenpinscher': 0, 'afghan_hound': 1, 'african_hunting_dog': 2}batch_size = 128
# 创建数据加载器
# 通常用于训练过程中按批次提供数据,具有更高效的数据加载和处理能力。
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_dataset, train_valid_dataset)]valid_iter = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False,drop_last=True)test_iter = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False,drop_last=False)
4.预训练模型resnet34
# 用于创建和配置训练模型
def get_net(devices):# 创建一个空的 nn.Sequential 容器finetune_net = nn.Sequential()# 加载预训练的 ResNet-34 模型,并将其特征层(features)部分添加到 finetune_net 中finetune_net.features = torchvision.models.resnet34(pretrained=True)# 定义一个新的输出网络finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))# 将模型参数分配到指定的设备(如 GPU 或 CPU)finetune_net = finetune_net.to(devices[0])# 冻结预训练的特征层参数,以避免在训练过程中更新这些参数for param in finetune_net.features.parameters():param.requires_grad = False# 返回配置好的模型return finetune_net
5.模型训练
def train_batch(net, X, y, loss, trainer, devices):"""使用多GPU训练一个小批量数据。参数:net: 神经网络模型。X: 输入数据,张量或张量列表。y: 标签数据。loss: 损失函数。trainer: 优化器。devices: GPU设备列表。返回:train_loss_sum: 当前批次的训练损失和。train_acc_sum: 当前批次的训练准确度和。"""# 如果输入数据X是列表类型if isinstance(X, list):# 将列表中的每个张量移动到第一个GPU设备X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备net.train() # 设置网络为训练模式trainer.zero_grad()# 梯度清零pred = net(X) # 前向传播,计算预测值l = loss(pred, y) # 计算损失l.sum().backward()# 反向传播,计算梯度trainer.step() # 更新模型参数train_loss_sum = l.sum()# 计算当前批次的总损失train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):trainer = torch.optim.SGD(# net.parameters():返回模型 net 中所有参数。# if param.requires_grad:仅选择那些 requires_grad 为 True 的参数。# 这些参数是需要进行梯度更新的(即未冻结的参数)(param for param in net.parameters()if param.requires_grad), # momentum用于加速 SGD 的收敛速度,通过在更新参数时考虑之前的更新方向,减少震荡# weight_decay权重衰减用于防止过拟合lr=lr,momentum=0.9, weight_decay=wd)# trainer = torch.optim.Adam(net.parameters(), lr=lr,weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)loss = nn.CrossEntropyLoss(reduction="none")num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = lp.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch(net, features, labels,loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()# train_l = metric[0] / metric[2] # 计算训练损失# train_acc = metric[1] / metric[2] # 计算训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[2],None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch + 1, (None, None, valid_acc))scheduler.step()# print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '#       f'valid_acc {valid_acc:.3f}')measures = (f'train loss {metric[0] / metric[2]:.3f}, 'f'train acc {metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'f' examples/sec on {str(devices)}')
6.模型预测
def predict(file_path_module):# 预测net = get_net(d2l.try_all_gpus())net.load_state_dict(torch.load(file_path_module + 'imageNet_Dogs.params'))# 初始化一个空列表preds用于存储预测结果preds = []# 遍历测试集中的每一个数据和标签for data, label in test_iter:# 使用神经网络(net)对数据进行预测,并使用softmax函数将输出转化为概率分布output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=1)# 将预测结果从GPU中取出,转换为NumPy数组后,添加到preds列表中preds.extend(output.cpu().detach().numpy())# 获取测试数据文件夹中所有文件的id,并按字典顺序排序ids = sorted(os.listdir(os.path.join(file_path, 'train_valid_test', 'test', 'unknown')))# 打开一个新的CSV文件submission.csv用于写入预测结果with open(file_path + 'submission.csv', 'w') as f:# 写入CSV文件的表头,包含'id'和所有类别标签f.write('id,' + ','.join(train_valid_dataset.classes) + '\n')# 遍历文件id和对应的预测结果for i, output in zip(ids, preds):# 写入每个文件的id和对应的预测概率f.write(i.split('.')[0] + ',' + ','.join([str(num) for num in output]) + '\n')
7.定义超参数并保存训练参数
# 定义模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 1e-4, 1e-4
lr_period, lr_decay, net = 10, 0.1, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 4, 0.9 (简略数据集)
# train loss 0.750, train acc 0.814, valid acc 0.646
# 647.4 examples/sec on [device(type='cuda', index=0)]# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 10, 0.1 (原数据集)
# train loss 0.863, train acc 0.759, valid acc 0.844
# 830.8 examples/sec on [device(type='cuda', index=0)]
plt.show()net = get_net(devices)
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,lr_decay)
# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 4, 0.9 (简略数据集)
# train loss 0.721, train acc 0.815
# 704.9 examples/sec on [device(type='cuda', index=0)]# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 10, 0.1 (原数据集)
# train loss 0.865, train acc 0.758
# 845.4 examples/sec on [device(type='cuda', index=0)]plt.show()
# 保存模型参数
file_path_module = '../limuPytorch/module/'
torch.save(net.state_dict(), file_path_module + 'imageNet_Dogs.params')

简略数据集:
在这里插入图片描述
在这里插入图片描述

原始数据集:
在这里插入图片描述
在这里插入图片描述

8.预测提交kaggle
predict(file_path_module)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/36024.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IaaS,PaaS,SaaS理解

目前主流的IaaS&#xff0c;PaaS&#xff0c;SaaS产品 一、简述应用方案 这里借用汽车的例子对IaaS、PaaS、SaaS的解释进一步阐述三者的区别。 假设你需要出去外出使用交通工具&#xff0c;我们有四种的方案&#xff1a; On-premise&#xff08;本地部署服务&#xff09; 自己…

【AI绘画】关于AI绘画做副业,你需要知道的事

前言 AI绘画是一种新兴的艺术形式&#xff0c;它利用人工智能技术来创造出各种各样的艺术作品。随着人工智能技术的不断发展&#xff0c;AI绘画已经成为了一种非常有前途的副业&#xff0c;可以帮助人们赚取额外的收入。下面是一些利用AI绘画副业方法。 1、利用AI绘画技术创作…

Java基础知识-线程

Java基础知识-线程 1、在 Java 中要想实现多线程代码有几种手段&#xff1f; 1. 一种是继承 Thread 类 2. 另一种就是实现 Runnable 接口 3. 最后一种就是实现 Callable 接口 4. 第四种也是实现 callable 接口&#xff0c;只不过有返回值而已 2、Thread 类中的 start() 和 …

JAVA课设必备环境配置 教程 JDK Tomcat配置 IDEA开发环境配置 项目部署参考视频 若依框架 链接数据库格式注意事项

JAVA环境配置 https://blog.csdn.net/xhmico/article/details/122390181 JAVA环境配置 前置条件&#xff1a;JDK安装 在开始配置Java环境之前&#xff0c;确保已经下载并安装了Java Development Kit (JDK)。JDK包含了Java编译器、Java虚拟机&#xff08;JVM&#xff09;以及…

微信公众号写作时必备的AI提示词(也称为指令或Prompt)

猫头虎 &#x1f42f; 微信公众号写作时必备的AI提示词&#xff08;也称为指令或Prompt&#xff09; &#x1f389; 大家好&#xff0c;我是猫头虎&#xff0c;科技自媒体博主。今天&#xff0c;我们来聊聊如何利用AI提示词&#xff0c;打造出爆款的微信公众号文章。&#x1…

Win10扩充C盘(把其他盘存储空间分给C盘)

C盘虽然没有安装任何软件&#xff0c;但无奈安装某些软件&#xff08;例如VS&#xff0c;QuarC等&#xff09;总会占用C盘容量&#xff0c;且C盘内存很小&#xff08;只有60G左右&#xff09;&#xff0c;看着D盘的三四十空闲内存&#xff0c;决定把D盘内存分给C盘30G&#xff…

css持续学习

一、样式层叠 当一个css样式发生冲突时&#xff0c;比如多处给一个字体设置了不同的颜色&#xff0c;这个时候就需要样式层叠了&#xff0c;它会进行三种比较 比较重要性 重要性从高到低&#xff1a; 1.带有 important 的作者样式&#xff08;作者样式就是开发者写的样式&…

【Red Hat 7.9---详细安装Oracle 11g】---图形化界面方式

原文&#xff1a;https://blog.csdn.net/qq_41840843/article/details/131198718?spm1001.2014.3001.5501 &#x1f53b; 一、安装前规划 规划项(本环境)描述操作系统版本Red Hat Enterprise Linux Server release 7.9 (Maipo)主机名db-oracle数据库版本Oracle 11gIp规划192.…

【毛毛虫案例-重力 Objective-C语言】

一、接下来,我们给这个毛毛虫,添加一下重力 1.把我们之前的代码,复制粘贴一份儿,改个名字,叫做:17-毛毛虫案例-重力, 重力的话,实际上,就比较简单了啊,那我们重力的话,去添加的时候,我也要在外面,去添加, 重力的话,叫做啥,UIGravityBehavior,啊, UIGravity…

Thinkphp/Laravel高校竞赛管理系统的设计与实现_9pi7u

高校竞赛管理&#xff0c;其工作流程繁杂、多样、管理复杂与设备维护繁琐。而计算机已完全能够胜任高校竞赛管理工作&#xff0c;而且更加准确、方便、快捷、高效、清晰、透明&#xff0c;它完全可以克服以上所述的不足之处。这将给查询信息和管理带来很大的方便&#xff0c;从…

时序约束(一):时钟的约束

目录 一、时钟约束的目的 二、约束工程项目 三、主时钟和生成时钟 四、主时钟约束 五、生成钟约束 一、时钟约束的目的 之前的文章对时序分析的基本原理做了介绍&#xff0c;我们会发现时序分析离不开时钟信号。对于时序分析工具来说同样如此&#xff0c;分析工具需要我…

【漏洞复现】用友GRP-U8——SQL注入

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 用友GRP-U8是一款企业管理软件&#xff0c;其系统dialog_moreUs…

财务RPA案例研究——分析成功的财务RPA实施案例

现代社会正加速向数字时代转型&#xff0c;数字技术以崭新的模式全面融入各行业领域。为顺应新一轮科技革命和产业变革趋势&#xff0c;越来越多的企业不断深化应用大数据、云计算、人工智能等新一代信息技术&#xff0c;积极迎接数字化转型&#xff0c;而RPA技术由于能够以自动…

常用组件详解(二):torch.nn.Flatten、torch.flatten()

文章目录 torch.nn.Flattentorch.flatten() 官方API文档&#xff1a;点击跳转。torch.nn.Flatten是Pytorch提供的类&#xff0c;常用于将输入数据进行展平&#xff0c;而torch.flatten()函数与之功能相同。 torch.nn.Flatten 类初始化方式&#xff1a; torch.nn.Flatten(star…

算法基础详解

大O记法 为了统一描述&#xff0c;大O不关注算法所用的时间&#xff0c;只关注其所用的步数。 比如数组不论多大&#xff0c;读取都只需1步。用大O记法来表示&#xff0c;就是&#xff1a;O(1)很多人将其读作“大O1”&#xff0c;也有些人读成“1数量级”。一般读成“O1”。虽…

友力科技广州数据中心搬迁

搬迁工作内容 1.搬迁技术工作 1)确定机房搬迁的负责人以及负责人的联系方式&#xff0c;保证在搬迁的过程中统一指挥管理。 2)确定服务器的数量&#xff0c;服务器的型号&#xff0c;服务器的配置等&#xff0c;如有需要&#xff0c;联系相关服务器的供货商或者厂家提供技术支持…

【极速入门版】编程小白也能轻松上手Comate AI编程插件

文章目录 概念使用错误检测与修复能力API生成代码生成json格式做开发测试 在目前的百模大战中&#xff0c;AI编程助手是程序员必不可少的东西&#xff0c;市面上琳琅满目的产品有没有好用一点的&#xff0c;方便一点的呢&#xff1f;今天工程师令狐向大家介绍一款极易入门的国产…

mysql中in参数过多优化

优化方式概述 未优化前 SELECT * FROM rb_product rb where sku in(1022044,1009786)方案2示例 public static void main(String[] args) {//往list里面设置3000个值List<String> list new ArrayList<>();for (int i 0; i < 3000; i) {list.add(""…

python-docx 获取页面大小、设置页面大小(纸张大小)

本文目录 前言一、docx纸张大小介绍1、document.xml① 关于 document.xml 的一些知识点② 纸张大小在哪里③ 纸张大小都有啥④ EMU对应的尺寸列表二、获取docx纸张大小1、完整代码2、运行效果图三、python为docx设置纸张大小1、完整代码2、效果图前言 今天的这边文章,我们来说…

项目实训-vue(八)

项目实训-vue&#xff08;八&#xff09; 文章目录 项目实训-vue&#xff08;八&#xff09;1.概述2.医院动态图像轮播3.页面背景板4.总结 1.概述 除了系统首页的轮播图展示之外&#xff0c;还需要在医院的首页展示医院动态部分的信息&#xff0c;展示医院动态是为了确保患者、…