使用PyTorch实现图像增广与模型训练实战

本文通过完整代码示例演示如何利用PyTorch和torchvision实现常用图像增广方法,并在CIFAR-10数据集上训练ResNet-18模型。我们将从基础图像变换到复杂数据增强策略逐步讲解,最终实现一个完整的训练流程。


一、图像增广基础操作

1.1 准备工作

#matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2ld2l.set_figsize()
img = d2l.Image.open('/workspace/data/cat.png')
d2l.plt.imshow(img)

1.2 图像变换工具函数

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5, titles=None):y = [aug(img) for _ in range(num_rows*num_cols)]d2l.show_images(y, num_rows, num_cols, titles, scale)

二、常用图像增广方法

2.1 水平/垂直翻转

# 水平翻转
apply(img, torchvision.transforms.RandomHorizontalFlip())# 垂直翻转
apply(img, torchvision.transforms.RandomVerticalFlip())

2.2 随机裁剪

shape_aug = torchvision.transforms.RandomResizedCrop((200,200), scale=(0.1,1), ratio=(0.5,2))
apply(img, shape_aug)

2.3 颜色调整

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.3, hue=0.5)
apply(img, color_aug)

2.4 组合增广策略

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug
])
apply(img, augs)

三、CIFAR-10数据增强实战

3.1 数据加载与可视化

all_images = torchvision.datasets.CIFAR10(train=True, root='/workspace/data', download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)

3.2 数据预处理配置

train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])test_augs = torchvision.transforms.ToTensor()

3.3 数据加载函数

def load_cifar10(is_train, augs, batch_size):dataset = torchvision.datasets.CIFAR10(root='../data', train=is_train,transform=augs, download=True)return torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=is_train, num_workers=4)

四、模型训练实现

4.1 训练核心函数

def train_batch_ch13(net, X, y, loss, trainer, devices):if isinstance(X, list):X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])y = y.to(devices[0])net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum

4.2 模型初始化

batch_size = 1024
devices = d2l.try_all_gpus()
net = d2l.resnet18(10, 3)def init_weights(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)

4.3 训练入口函数

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)loss = nn.CrossEntropyLoss(reduction='none')optimizer = torch.optim.Adam(net.parameters(), lr=lr)d2l.train_ch13(net, train_iter, test_iter, loss, optimizer, 10, devices)# 启动训练
train_with_data_aug(train_augs, test_augs, net)

五、训练结果分析

执行训练后可以看到类似如下输出:

train loss 0.018, train acc 0.895
test acc 0.856

典型训练过程特征:

  1. 训练损失持续下降

  2. 验证准确率稳步提升

  3. 最终测试准确率可达85%以上


六、关键知识点总结

  1. 图像增广作用:通过随机变换增加数据多样性,提升模型泛化能力

  2. 组合策略:合理组合几何变换与颜色变换可以达到最佳效果

  3. 训练技巧

    • 使用Xavier初始化保证参数合理分布

    • Adam优化器自动调整学习率

    • 多GPU并行加速训练


七、扩展改进方向

1.尝试更多增广组合:

advanced_augs = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(15),torchvision.transforms.RandomPerspective(),torchvision.transforms.RandomGrayscale(p=0.1)
])

2.调整网络结构:

net = d2l.resnet50(10, 3)  # 使用更深层的ResNet-50

3.优化参数:

optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

 完整代码已通过测试,可直接复制到Jupyter Notebook中运行。实际效果可能因硬件配置有所差异,建议使用GPU环境进行训练。如果遇到数据集下载问题,请检查root参数指定的路径是否正确。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/80347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决Mac 安装 PyICU 依赖失败

失败日志: 解决办法 1、使用 homebrew 安装相关依赖 brew install icu4c 安装完成后,设置环境变量 echo export PATH"/opt/homebrew/opt/icu4c77/bin:$PATH" >> ~/.zshrcecho export PATH"/opt/homebrew/opt/icu4c77/sbin:$PATH…

Springboot后端查询参数接收

1.实现方式 假设前端发送的接口: /users?nameJohn&age30 后端怎么接收里面的name和age呢?以及再发别的参数后端怎么接收呢? 1.比较简单的方式 当控制器方法的参数类型是简单类型(如 String、Integer、Long 等&#xff09…

桌面应用中VUE使用新浏览器窗口打开页面

1、浏览器应用忽略此方式,可任意方式打开。针对桌面应用设置 newWindowClick(){try {this.fileUrl "";this.params.year ""this.params.date ""axios({method: post,url: /url/pdf/preview,data: this.params,}).then(res> {t…

华为手机怎么进行音频降噪?音频降噪技巧分享:提升听觉体验

在当今数字化时代,音频质量对于提升用户体验至关重要,无论是在通话、视频录制还是音频文件播放中,清晰的音频都能带来更佳的听觉享受。 而华为手机凭借其强大的音频处理技术,为用户提供了多种音频降噪功能,帮助用户在…

【数据可视化-22】脱发因素探索的可视化分析

🧑 博主简介:曾任某智慧城市类企业算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN人工智能领域的优质创作者,提供AI相关的技术咨询、项目开发和个…

青少年编程与数学 02-018 C++数据结构与算法 06课题、树

青少年编程与数学 02-018 C数据结构与算法 06课题、树 一、树(Tree)1. 树的定义2. 树的基本术语3. 常见的树类型4. 树的主要操作5. 树的应用 二、二叉树(Binary Tree)1. 二叉树的定义2. 二叉树的基本术语3. 二叉树的常见类型4. 二叉树的主要操作5. 二叉树的实现代码说明输出示例…

【论文阅读】Visual Instruction Tuning

文章目录 导言1、论文简介2、论文主要方法3、论文针对的问题4、论文创新点总结 导言 本论文介绍了一个新兴的多模态模型——LLaVA(Large Language and Vision Assistant),旨在通过指令调优提升大型语言模型(LLM)在视觉…

【学习笔记】Cadence电子设计全流程(三)Capture CIS 原理图绘制(下)

【学习笔记】Cadence电子设计全流程(三)Capture CIS 原理图绘制(下) 3.16 原理图中元件的编辑与更新3.17 原理图元件跳转与查找3.18 原理图常见错误设置于编译检查3.19 低版本原理图文件输出3.20 原理图文件的锁定与解锁3.21 Orca…

js使用IntersectionObserver实现目标元素可见度的交互

文章目录 1、前言2、代码实现3、使用场景4、兼容性5、成熟的Hooks推荐 1、前言 IntersectionObserver 是浏览器原生提供的一个Api。可以"观察"我们的元素是否可见,原理是判断目标元素与可见区域的交叉比例,所以也被称为"交叉观察器"…

linux 中断子系统 层级中断编程

虚拟中断控制器代码&#xff1a; #include<linux/kernel.h> #include<linux/module.h> #include<linux/clk.h> #include<linux/err.h> #include<linux/init.h> #include<linux/interrupt.h> #include<linux/io.h> #include<linu…

虾皮(Shopee)商品详情 API 接口概述及 JSON 数据返回参考

前言 一、接口概述 Shopee 商品详情 API 接口是 Shopee 平台为开发者提供的&#xff0c;用于获取商品详细信息的接口服务。通过该接口&#xff0c;开发者可以获取商品的标题、价格、库存、描述、图片、规格参数、销量、评价等详细信息。这些数据为电商数据分析、商品比价工具…

three.js中的instancedMesh类优化渲染多个同网格材质的模型

three.js小白的学习之路。 在上上一篇博客中&#xff0c;简单验证了一下three.js中的网格共享。写的时候就有一些想法&#xff0c;如果说某个场景中有一万棵树&#xff0c;这些树共享一个geometry和material&#xff0c;有没有好的办法将其进行一定程度上的渲染优化&#xff0…

MySQL-自定义函数

自定义函数 函数的作用 mysql数据库中已经提供了内置的函数&#xff0c;比如&#xff1a;sum&#xff0c;avg&#xff0c;concat等等&#xff0c;方便我们日常的使用&#xff0c;当需要时mysql支持定义自定义的函数&#xff0c;方便与我们对于需用复用的功能进行封装。 基本…

ESP32上C语言实现JSON对象的创建和解析

在ESP32上使用C语言实现JSON对象的创建和解析&#xff0c;同样可以借助cJSON库。ESP-IDF&#xff08;Espressif IoT Development Framework&#xff09;本身已经集成了cJSON库&#xff0c;你可以直接使用。以下是详细的步骤和示例代码。 1. 创建一个新的ESP-IDF项目 首先&…

【FAQ】PCoIP 会话后物理工作站本地显示器黑屏

# 问题 工作人员从家里建立了到办公室工作站的 PCoIP 连接&#xff0c;该工作站安装了 HP Anyware Graphics Agent&#xff0c;并且还连接了本地显示器。然后&#xff0c;远程用户决定去办公室进行本地工作&#xff0c;工作站显示器显示黑屏&#xff08;有时没有信号&#xff…

el-table 目录树列表本地实现模糊查询

table目录树结构实现模糊查询 <el-form :model"queryParams" ref"queryForm" size"small" :inline"true" v-show"showSearch"><el-form-item label"名称:" prop"Name"><el-input v-mode…

力扣hot100 LeetCode 热题 100 Java 哈希篇

两数之和 1. 两数之和 - 力扣&#xff08;LeetCode&#xff09; 直接暴力 class Solution {public int[] twoSum(int[] nums, int target) {for(int i0;i<nums.length;i){for(int ji1;j<nums.length;j){long ans nums[i]nums[j];if(ans>target)continue;if(anstarg…

前后端部署

#在学习JavaWeb之后&#xff0c;进行了苍穹外卖的学习。在进行苍穹外卖的部署的时候&#xff0c;作者遇到了下面的问题# 1.前端工程nginx无法启动&#xff1a; 当我双击已经部署好的nginx工程中nginx.exe文件的时候&#xff0c;在服务中&#xff0c;并没有找到ngnix成功运行。…

基于 EFISH-SBC-RK3588 的无人机环境感知与数据采集方案

一、核心硬件架构设计‌ ‌高性能算力引擎&#xff08;RK3588 处理器&#xff09;‌ ‌异构计算架构‌&#xff1a;集成 8 核 CPU&#xff08;4Cortex-A762.4GHz 4Cortex-A551.8GHz&#xff09;&#xff0c;支持动态调频与多任务并行处理&#xff0c;单线程性能较传统四核方案…

什么是Maven

Maven的概念 Maven是一个一键式的自动化的构建工具。Maven 是 Apache 软件基金会组织维护的一款自动化构建工具&#xff0c;专注服务于Java 平台的项目构建和依赖管理。Maven 这个单词的本意是&#xff1a;专家&#xff0c;内行。Maven 是目前最流行的自动化构建工具&#xff0…