(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

5. 读取数据

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, device):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(device[0]), labels.to(device[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None :valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/158843.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于野马算法优化概率神经网络PNN的分类预测 - 附代码

基于野马算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于野马算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于野马优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神经网络的光滑…

【FPGA】Verilog:计数器 | 异步计数器 | 同步计数器 | 2位二进制计数器的实现 | 4位十进制计数器的实现

目录 Ⅰ. 实践说明 0x00 计数器(Counter) 0x01 异步计数器(Asynchronous Counter)

深度学习中的图像融合:图像融合论文阅读与实战

个人博客:Sekyoro的博客小屋 个人网站:Proanimer的个人网站 abs 介绍图像融合概念,回顾sota模型,其中包括数字摄像图像融合,多模态图像融合, 接着评估一些代表方法 介绍一些常见应用,比如RGBT目标跟踪,…

鸿蒙4.0开发笔记之DevEco Studio启动时不直接打开原项目(二)

1、想要在DevEco Studio启动时不直接打开关闭前的那个项目,可以在设置中进行。 有两个位置可以进入“设置”,一个是左上角的File>Settings,二是右上方的设置图标。 2、进入Settings界面以后,选择Appearance&Behavior下面…

Java Stream中的API你都用过了吗?

公众号「架构成长指南」,专注于生产实践、云原生、分布式系统、大数据技术分享。 在本教程中,您将通过大量示例来学习 Java 8 Stream API。 Java 在 Java 8 中提供了一个新的附加包,称为 java.util.stream。该包由类、接口和枚举组成&#x…

AIGC创作系统ChatGPT网站系统源码,支持最新GPT-4-Turbo模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如…

万宾科技智能井盖的效果怎么样?

日常出行过程中,人们最不想看到交通拥堵或者道路维修等现象,因为这代表出行受到影响甚至会导致不能按时赴约等。所以城市路面的安全和稳定,是市民朋友非常关心的话题。骑行在路上的时候,如果经过井盖时发出异常声响,骑…

Apache配置文件详解

引言: Apache是一种功能强大的Web服务器软件,通过配置文件可以对其行为进行高度定制。对于初学者来说,理解和正确配置Apache的配置文件是非常重要的。本文将详细解释Apache配置文件的各个方面,并给出一些入门指南,帮助读者快速上手。 1、主配置文件(httpd.conf): 主…

jQuery【菜单功能、淡入淡出轮播图(上)、淡入淡出轮播图(下)、折叠面板】(五)-全面详解(学习总结---从入门到深化)

目录 菜单功能 淡入淡出轮播图(上) 淡入淡出轮播图(下) 折叠面板 菜单功能 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><…

这是大学生就业网站最基础的布局。

<!DOCTYPE html> <html> <head> <title>大学生就业网站</title> <style> /* Reset default margin and padding */ *, *::before, *::after { margin: 0; padding: 0; box-s…

SpingBoot原理

目录 配置优先级Bean管理 (掌握)Bean的获取 ApplicationContext.getBeanBean的作用域 Scope("prototype") Lazy第三方Bean Bean Configuration SpringBoot底层原理 起步依赖与自动配置(无需手撸但面试高频知识点)自动配置引入第三方依赖常见方案方案1&#xff1a;Com…

Java引用和内部类

引用 引用变量 引用相当于一个 “别名”, 也可以理解成一个指针. 创建一个引用只是相当于创建了一个很小的变量, 这个变量保存了一个整数, 这个整数表示内存中的一个地址. new 出来的数组肯定是在堆上开辟的空间,那么在栈中存放的就是引用,引用存放的的就是一个对象的地址,代表…

格式化名称节点,启动Hadoop

1.循环删除hadoop目录下的tmp文件&#xff0c;记住在hadoop目录下进行 rm tmp -rf 使用上述命令&#xff0c;hadoop目录下为&#xff1a; 2.格式化名称节点 # 格式化名称节点 ./bin/hdfs namenode -format 3.启动所有节点 ./sbin/start-all.sh 效果图&#xff1a; 4.查看节…

HCIP-五、OSPF-1 邻居状态机和 DR 选举

五、OSPF-1 邻居状态机和 DR 选举 实验拓扑实验需求及解法1.如图所示&#xff0c;配置设备 IP 地址。2.配置 OSPF3. 按照以下步骤观察 R1 与 R2 的邻居关系建立过程 实验拓扑 实验需求及解法 通过本次实验&#xff0c;验证 OSPF 邻居状态机变化过程&#xff0c;以及 DR 选举过…

今日定音,博通以610亿美元成功收购VMware | 百能云芯

博通&#xff08;Broadcom&#xff09;日前宣布&#xff0c;已获得中国监管机构的批准&#xff0c;将于今日完成对云计算公司VMware的收购交易。这意味着&#xff0c;610亿美元的收购案正式收关。 据悉&#xff0c;中国市场监管总局在11月21日晚发布了有关附加限制性条件批准博…

Python-大数据分析之常用库

Python-大数据分析之常用库 1. 数据采集与第三方数据接入 1-1. Beautiful Soup ​ Beautiful Soup 是一个用于解析HTML和XML文档的库&#xff0c;非常适用于网页爬虫和数据抓取。可以提取所需信息&#xff0c;无需手动分析网页源代码&#xff0c;简化了从网页中提取数据的过…

排序算法--冒泡排序

实现逻辑 ① 比较相邻的元素。如果第一个比第二个大&#xff0c;就交换他们两个。 ②对每一对相邻元素作同样的工作&#xff0c;从开始第一对到结尾的最后一对。在这一点&#xff0c;最后的元素应该会是最大的数。 ③针对所有的元素重复以上的步骤&#xff0c;除了最后一个。 ④…

centos7 网卡聚合bond0模式配置

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、什么是网卡bond二、网卡bond的模式三、配置bond0 一、什么是网卡bond 所谓bond&#xff0c;就是把多个物理网卡绑定成一个逻辑上的网卡&#xff0c;使用同一个…

威胁攻击层出不穷,信息化负责人该如何增强对抗与防御能力?

11月15日&#xff0c;以“加快推进智慧校园建设 赋能为党育才为党献策”为主题的2023年华东地区党校&#xff08;行政学院&#xff09;信息化和图书馆工作高质量发展专题研讨班顺利举办。 作为国内云原生安全领导厂商&#xff0c;安全狗受邀出席活动。 厦门服云信息科技有限公…

腾讯云服务器99元一年是真的吗?假的!

腾讯云服务器99元一年是真的吗&#xff1f;假的&#xff0c;不用99元&#xff0c;只要88元即可购买一台2核2G3M带宽的轻量应用服务器&#xff0c;99元太多了&#xff0c;88元就够了&#xff0c;腾讯云百科活动 txybk.com/go/txy 活动打开如下图&#xff1a; 腾讯云轻量服务器 腾…