[pytorch、学习] - 9.1 图像增广

参考

9.1 图像增广

在5.6节(深度卷积神经网络)里我们提过,大规模数据集是成功应用神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不相同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Imageimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.1.1 常用的图像增广方法

d2l.set_figsize()
img = Image.open('../img/cat2.jpg')
d2l.plt.imshow(img)

在这里插入图片描述
下面定义绘图函数show_images

# 传入图片, 行、列 和规模
def show_images(imgs, num_rows, num_cols, scale=2):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)for i in range(num_rows):for j in range(num_cols):axes[i][j].imshow(imgs[i * num_cols + j])axes[i][j].axes.get_xaxis().set_visible(False)axes[i][j].axes.get_yaxis().set_visible(False)return axes

大部分图像增广都有一定的随机性。为了方便观察图像增广效果,接下来我们定义一个辅助函数apply。这个函数对输入图像img多次运行图像增广方法aug并展示所有的结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for  _ in range(num_rows * num_cols)]show_images(Y, num_rows, num_cols, scale)

9.1.1.1 翻转和裁剪

左右翻转图像通常不改变物体的类别。它是最早也是最广泛使用的一种图像增广方法。下面我们通过torchvision.transforms模块创建RandomHorizontalFlip实例来实现一半概率的图像水平(左右)翻转。

apply(img, torchvision.transforms.RandomHorizontalFlip(),num_rows= 2, scale=3)

在这里插入图片描述
上下翻转不如左右翻转通用。但是至少对于样例图像,上下翻转不会造成识别障碍。下面我们创建RandomVerticalFllip实例来实现一半概率的图像垂直(上下翻转)

apply(img, torchvision.transforms.RandomVerticalFlip(), scale=2.7)

在这里插入图片描述
在我们使用的样例图像里,猫在图像正中间,但一般情况下可能不是这样。在5.4节(池化层)里我们解释了池化层能降低卷积层对目标位置的敏感度。除此之外,我们还可以通过对图像随机裁剪来让物体以不同的比例出现在图像的不同位置,这样能够降低模型对目标位置的敏感性。

在下面的代码里,我们每次随机裁剪出一块面积为原面积10% ~ 100%的区域,且该区域的宽高和高之比随机取自 0.5 ~ 2, 然后再将该区域的宽和高分别缩放到200像素。若无特殊说明, 本节中 a 和 b之间的随机数指的是从区间[a,b]中随机均匀采样所得的连续值.

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio = (0.5, 2))
apply(img, shape_aug, scale =3)

在这里插入图片描述

9.1.1.2 变化颜色

另一类增广方法是变化颜色。我们可以从4个方面改变图像的颜色: 亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。再下面的例子里,我们将图像的亮度随机变化为原亮度的50% (1 - 0.5) ~ 150% (1 + 0.5)

apply(img, torchvision.transforms.ColorJitter(brightness = 0.5), scale = 2.5)

在这里插入图片描述
我们也可以随机色调

apply(img, torchvision.transforms.ColorJitter(hue = 0.5), scale = 3)

在这里插入图片描述
类似地,我们也可以随机变化图像的对比度。

apply(img, torchvision.transforms.ColorJitter(contrast = 0.5), scale = 3)

在这里插入图片描述
我们也可以同时设置如何随机变化图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue = 0.5
)apply(img, color_aug, num_cols=3, num_rows=2, scale =2.7)

在这里插入图片描述

9.1.1.3 叠加多个图像增广方法

实际应用中我们将会将多个图像增广方法叠加使用。我们可以通过Compose实例将上面定义的多个图像增广方法叠加起来,再应用到每张图像上mm

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])apply(img, augs)

在这里插入图片描述

9.1.2 使用图像增广训练模型

下面我们来看一个将图像增广应用在实际训练中的例子。这里我们使用CIFAR-10数据集,而不是之前我们一直使用的Fashion-MNIST数据集。这是因为Fashion-MNIST数据集中物体的位置和尺寸都已经经过归一化处理,而CIFAR-10数据集中物体的颜色和大小更加显著。下面展示了CIFAR-10数据集中前32张训练图像。

all_images = torchvision.datasets.CIFAR10(train= True, root="~/Datasets/CIFAR", download=True)
# all_images的每一个元素都是(image, label)

在这里插入图片描述

注: 此处根据下载的位置,用迅雷下载会比较快

show_images([all_images[i][0] for i in range(8)], 2, 4, scale=3)

在这里插入图片描述
为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。在这里我们只使用最简单的随机左右翻转。此外,我们使用ToTensor将小批量图像转成PyTorch需要的格式,即形状为(batch_size, channels, height, width)、值域在0~1之间且类型为32位浮点数

flip_aug = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])no_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])

接下来,我们定义一个辅助函数来方便读取图像并应用图像增广.

num_workers = 0 if sys.platform.startswith('win32') else 4def load_cifar10(is_train, augs, batch_size, root = "~/Datasets/CIFAR"):datasets = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download =True)return DataLoader(datasets, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

使用图像增广训练模型

# 定义train函数使用GPU训练并评价模型def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):net = net.to(device)print("training on", device)batch_count =0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = d2l.evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

然后可以定义train_with_data_aug函数使用图像增广来训练模型了。该函数使用Adam算法作为训练使用的优化算法,然后将图像增广应用于训练数据集之上,最好调用刚才定义的train函数训练并评价模型。

def train_with_data_aug(train_augs, test_augs, lr = 0.001):batch_size, net = 256, d2l.resnet18(10)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = torch.nn.CrossEntropyLoss()train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=20)

下面使用随机左右翻转的图像增广来训练模型

train_with_data_aug(flip_aug, no_aug)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250113.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql绿色版安装

导读:MySQL是一款关系型数据库产品,官网给出了两种安装包格式:MSI和ZIP。MSI格式是图形界面安装方式,基本只需下一步即可,这篇文章主要介绍ZIP格式的安装过程。ZIP Archive版是免安装的。只要解压就行了。 一、首先下…

[pytorch、学习] - 9.2 微调

参考 9.2 微调 在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。 假设我们想从图…

关于mac机抓包的几点基础知识

1. 我使用的抓包工具为WireShark,以下操作按我当前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的区别。 2. 将mac设置为热点:打开系统偏好设置,点击共享: 然后点击WIFI选项,设置WIFI名…

SpringBoot启动如何加载application.yml配置文件

一、前言 在spring时代配置文件的加载都是通过web.xml配置加载的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多数都是通过指定路径的文件名的形式去告诉spring该加载哪个文件&#xff1b; <context-param><param-name>contextConfigLocat…

阿里云服务器端口开放对外访问权限

登陆阿里云管理控制台 点击自己的实例 点击安全组配置 点击配置规则 点击添加安全组规则 配置出入放心&#xff0c;和开放的端口号&#xff0c;以及那些网段可以访问&#xff0c;这里设置所有网段都可以访问 转自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

数据分页功能是我们软件系统中必备的功能&#xff0c;在持久层使用mybatis的情况下&#xff0c;pageHelper来实现后台分页则是我们常用的一个选择&#xff0c;所以本文专门类介绍下。 PageHelper原理 相关依赖 <dependency><groupId>org.mybatis</groupId>&…

10-多写一个@Autowired导致程序崩了

再是javaweb实验六中&#xff0c;是让我们改代码&#xff0c;让它跑起来&#xff0c;结果我少注释了一个&#xff0c;导致一直报错&#xff0c;检查许久没有找到&#xff0c;最后通过代码替换逐步查找&#xff0c;才发现问题。 转载于:https://www.cnblogs.com/zhumengdexiaoba…

springboot---整合redis

pom.xml新增 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>代码结构如下 其中redis.yml是连接redis的配置文件&#xff0c;RedisConfig.java是java配置…

[Head First Java] - Swing做一个简单的客户端

参考 - P487 1. vscode配置java的格式 点击左下角齿轮 -> 设置 -> 打开任意的setting.json输入如下代码 {code-runner.executorMap": {"java": "cd $dir && javac -encoding utf-8 $fileName && java $fileNameWithoutExt"},…

计算机网络知识总结

一 OSI与TCP/IP各层的结构与功能&#xff0c;都有哪些协议 OSI的七层体系结构概念清楚&#xff0c;理论也很完整&#xff0c;但是它比较复杂而且不实用。在这里顺带提一下之前一直被一些大公司甚至一些国家政府支持的OSI失败的原因&#xff1a; OSI的专家缺乏实际经验&#xff…

[Head First Java] - 给线程命名

参考 - P503 public class RunThreads implements Runnable {public static void main (String[] args) {RunThreads runner new RunThreads();Thread alpha new Thread(runner);Thread beta new Thread(runner);alpha.setName("Alpha thread");beta.setName(&qu…

快速排序的C++版

int Partition(int a[], int low, int high) {int x a[high];//将输入数组的最后一个数作为主元&#xff0c;用它来对数组进行划分int i low - 1;//i是最后一个小于主元的数的下标for (int j low; j < high; j)//遍历下标由low到high-1的数{if (a[j] < x)//如果数小于…

asp.net中提交表单数据时提示从客户端(。。。)中检测到有潜在危险的 Request.Form 值...

看到这个图是不是很亲切熟悉哈&#xff0c;做过。net的肯定都见过哈 已经 将近4年没碰。net了&#xff0c;今天正好朋友的程序有几个bug,让我帮忙修复下&#xff0c;于是我就抱着试试看的心情改了改&#xff0c;改到最后一个问题的时候也就是上面的这个问题&#xff0c;我一看&…

Shiro表结构设计

表设计 开发用户-角色-权限管理系统&#xff0c;首先我们需要知道用户-角色-权限管理系统的表结构设计。 在用户-角色-权限管理系统找那个一般会涉及5张表&#xff0c;分别为&#xff1a; 1.sys_users用户表 2.sys_roles角色表 3.sys_permissions权限表&#xff08;或资源表&…

[Java核心技术(卷I)] - 简易的日历

参考 - P102~P103 1. 目标 生成一个日历,格式如下图所示。 ps: 当前的天数需要标记为* 2. 核心 对日历的变量 import java.time.*; public class CalendarTest{public static void main(String[] args) {LocalDate date LocalDate.now(); // 获取当前日期int month date…

个人作业——福大微信公众号使用评测

案例分析&#xff1a;在福州大学公众号上&#xff0c;我们可以即时使用手机关注福大新闻&#xff0c;查看自身课表、成绩等。公众号可能存在一些小bug影响同学们的用户体验。本次作业中&#xff0c;作为一个用户——福大的学生&#xff0c;将切身体验该公众号的功能&#xff0c…

在winform中使用wpf窗体

在winform项目&#xff0c;通过引用dll可以添加WPF窗体&#xff0c;如下 但是如果直接在winform的项目中添加wpf窗体还是有部分问题&#xff0c;图片的显示。 直接在XAML界面中用Source属性设置图片会出现错误。必须通过后台代码的方式来实现。 image1.Source GetImageIcon(gl…

shiro---注解

RequiresAuthentication 验证用户是否登录&#xff0c;等同于方法subject.isAuthenticated() 结果为true时。 RequiresUser 验证用户是否被记忆&#xff0c;user有两种含义&#xff1a; 一种是成功登录的&#xff08;subject.isAuthenticated() 结果为true&#xff09;&…

[Java核心技术(卷I)] - Java中的参数能做什么和不能做什么

1. 参考 - P123 ~ P126 2. 你将学到 Java中对方法参数能做什么和不能做什么 方法不能修改基本数据类型的参数(数值型或布尔型)方法可以改变对象参数的状态方法不能让一个对象参数引用一个新的对象 3. 代码证明 public class ParamTest {public static void main(String[] ar…

软件构造 第五章第一节 可复用性的度量、形态和外部观察

第五章第一节 可复用性的度量、形态和外部观察 面向复用编程(programming for reuse)&#xff1a;开发出可复用的软件 基于复用编程(programming with reuse)&#xff1a;利用已有的可复用软件搭建应用系统 代码复用的类型&#xff1a; 白盒复用&#xff1a;源代码可见&#x…