[pytorch、学习] - 9.1 图像增广

参考

9.1 图像增广

在5.6节(深度卷积神经网络)里我们提过,大规模数据集是成功应用神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不相同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Imageimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.1.1 常用的图像增广方法

d2l.set_figsize()
img = Image.open('../img/cat2.jpg')
d2l.plt.imshow(img)

在这里插入图片描述
下面定义绘图函数show_images

# 传入图片, 行、列 和规模
def show_images(imgs, num_rows, num_cols, scale=2):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)for i in range(num_rows):for j in range(num_cols):axes[i][j].imshow(imgs[i * num_cols + j])axes[i][j].axes.get_xaxis().set_visible(False)axes[i][j].axes.get_yaxis().set_visible(False)return axes

大部分图像增广都有一定的随机性。为了方便观察图像增广效果,接下来我们定义一个辅助函数apply。这个函数对输入图像img多次运行图像增广方法aug并展示所有的结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for  _ in range(num_rows * num_cols)]show_images(Y, num_rows, num_cols, scale)

9.1.1.1 翻转和裁剪

左右翻转图像通常不改变物体的类别。它是最早也是最广泛使用的一种图像增广方法。下面我们通过torchvision.transforms模块创建RandomHorizontalFlip实例来实现一半概率的图像水平(左右)翻转。

apply(img, torchvision.transforms.RandomHorizontalFlip(),num_rows= 2, scale=3)

在这里插入图片描述
上下翻转不如左右翻转通用。但是至少对于样例图像,上下翻转不会造成识别障碍。下面我们创建RandomVerticalFllip实例来实现一半概率的图像垂直(上下翻转)

apply(img, torchvision.transforms.RandomVerticalFlip(), scale=2.7)

在这里插入图片描述
在我们使用的样例图像里,猫在图像正中间,但一般情况下可能不是这样。在5.4节(池化层)里我们解释了池化层能降低卷积层对目标位置的敏感度。除此之外,我们还可以通过对图像随机裁剪来让物体以不同的比例出现在图像的不同位置,这样能够降低模型对目标位置的敏感性。

在下面的代码里,我们每次随机裁剪出一块面积为原面积10% ~ 100%的区域,且该区域的宽高和高之比随机取自 0.5 ~ 2, 然后再将该区域的宽和高分别缩放到200像素。若无特殊说明, 本节中 a 和 b之间的随机数指的是从区间[a,b]中随机均匀采样所得的连续值.

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio = (0.5, 2))
apply(img, shape_aug, scale =3)

在这里插入图片描述

9.1.1.2 变化颜色

另一类增广方法是变化颜色。我们可以从4个方面改变图像的颜色: 亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。再下面的例子里,我们将图像的亮度随机变化为原亮度的50% (1 - 0.5) ~ 150% (1 + 0.5)

apply(img, torchvision.transforms.ColorJitter(brightness = 0.5), scale = 2.5)

在这里插入图片描述
我们也可以随机色调

apply(img, torchvision.transforms.ColorJitter(hue = 0.5), scale = 3)

在这里插入图片描述
类似地,我们也可以随机变化图像的对比度。

apply(img, torchvision.transforms.ColorJitter(contrast = 0.5), scale = 3)

在这里插入图片描述
我们也可以同时设置如何随机变化图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue = 0.5
)apply(img, color_aug, num_cols=3, num_rows=2, scale =2.7)

在这里插入图片描述

9.1.1.3 叠加多个图像增广方法

实际应用中我们将会将多个图像增广方法叠加使用。我们可以通过Compose实例将上面定义的多个图像增广方法叠加起来,再应用到每张图像上mm

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])apply(img, augs)

在这里插入图片描述

9.1.2 使用图像增广训练模型

下面我们来看一个将图像增广应用在实际训练中的例子。这里我们使用CIFAR-10数据集,而不是之前我们一直使用的Fashion-MNIST数据集。这是因为Fashion-MNIST数据集中物体的位置和尺寸都已经经过归一化处理,而CIFAR-10数据集中物体的颜色和大小更加显著。下面展示了CIFAR-10数据集中前32张训练图像。

all_images = torchvision.datasets.CIFAR10(train= True, root="~/Datasets/CIFAR", download=True)
# all_images的每一个元素都是(image, label)

在这里插入图片描述

注: 此处根据下载的位置,用迅雷下载会比较快

show_images([all_images[i][0] for i in range(8)], 2, 4, scale=3)

在这里插入图片描述
为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。在这里我们只使用最简单的随机左右翻转。此外,我们使用ToTensor将小批量图像转成PyTorch需要的格式,即形状为(batch_size, channels, height, width)、值域在0~1之间且类型为32位浮点数

flip_aug = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])no_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])

接下来,我们定义一个辅助函数来方便读取图像并应用图像增广.

num_workers = 0 if sys.platform.startswith('win32') else 4def load_cifar10(is_train, augs, batch_size, root = "~/Datasets/CIFAR"):datasets = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download =True)return DataLoader(datasets, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

使用图像增广训练模型

# 定义train函数使用GPU训练并评价模型def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):net = net.to(device)print("training on", device)batch_count =0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = d2l.evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

然后可以定义train_with_data_aug函数使用图像增广来训练模型了。该函数使用Adam算法作为训练使用的优化算法,然后将图像增广应用于训练数据集之上,最好调用刚才定义的train函数训练并评价模型。

def train_with_data_aug(train_augs, test_augs, lr = 0.001):batch_size, net = 256, d2l.resnet18(10)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = torch.nn.CrossEntropyLoss()train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=20)

下面使用随机左右翻转的图像增广来训练模型

train_with_data_aug(flip_aug, no_aug)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250113.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql绿色版安装

导读:MySQL是一款关系型数据库产品,官网给出了两种安装包格式:MSI和ZIP。MSI格式是图形界面安装方式,基本只需下一步即可,这篇文章主要介绍ZIP格式的安装过程。ZIP Archive版是免安装的。只要解压就行了。 一、首先下…

在微信浏览器字体被调大导致页面错乱的解决办法

iOS的解决方案是覆盖掉微信的样式: body { /* IOS禁止微信调整字体大小 */-webkit-text-size-adjust: 100% !important; } 安卓的解决方案是通过 WeixinJSBridge 对象将网页的字体大小设置为默认大小,并且重写设置字体大小的方法,让用户不能在…

[pytorch、学习] - 9.2 微调

参考 9.2 微调 在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。 假设我们想从图…

Springboot默认加载application.yml原理

Springboot默认加载application.yml原理以及扩展 SpringApplication.run(…)默认会加载classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默认加载一套默认的配置文件demo.properties,让开发人员实现“零”配置开发,但…

java 集合(Set接口)

Set接口:无序集合,不允许有重复值,允许有null值 存入与取出的顺序有可能不一致 HashSet:具有set集合的基本特性,不允许重复值,允许null值 底层实现是哈希表结构 初始容量为16 保存自定义对象时,保证数据的唯…

关于mac机抓包的几点基础知识

1. 我使用的抓包工具为WireShark,以下操作按我当前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的区别。 2. 将mac设置为热点:打开系统偏好设置,点击共享: 然后点击WIFI选项,设置WIFI名…

SpringBoot启动如何加载application.yml配置文件

一、前言 在spring时代配置文件的加载都是通过web.xml配置加载的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多数都是通过指定路径的文件名的形式去告诉spring该加载哪个文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小结(分支拉取、版本回退)

1. 首次(fork项目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [远程仓库的分支名]克隆的是主干网络 2. 再次拉取代码 $ git pull [master下选择分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得进入主仓(不是自己的远程仓)然后…

MYSQL 查看最大连接数和修改最大连接数

MySQL查看最大连接数和修改最大连接数 1、查看最大连接数show variables like %max_connections%;2、修改最大连接数set GLOBAL max_connections 200; 以下的文章主要是向大家介绍的是MySQL最大连接数的修改&#xff0c;我们大家都知道MySQL最大连接数的默认值是100, 这个数值…

阿里云服务器端口开放对外访问权限

登陆阿里云管理控制台 点击自己的实例 点击安全组配置 点击配置规则 点击添加安全组规则 配置出入放心&#xff0c;和开放的端口号&#xff0c;以及那些网段可以访问&#xff0c;这里设置所有网段都可以访问 转自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

数据分页功能是我们软件系统中必备的功能&#xff0c;在持久层使用mybatis的情况下&#xff0c;pageHelper来实现后台分页则是我们常用的一个选择&#xff0c;所以本文专门类介绍下。 PageHelper原理 相关依赖 <dependency><groupId>org.mybatis</groupId>&…

10-多写一个@Autowired导致程序崩了

再是javaweb实验六中&#xff0c;是让我们改代码&#xff0c;让它跑起来&#xff0c;结果我少注释了一个&#xff0c;导致一直报错&#xff0c;检查许久没有找到&#xff0c;最后通过代码替换逐步查找&#xff0c;才发现问题。 转载于:https://www.cnblogs.com/zhumengdexiaoba…

Java class不分32位和64位

1、32位JDK编译的java class在32位系统和64位系统下都可以运行&#xff0c;64位系统兼容32位程序&#xff0c;可以理解。2、无论是Linux还是Windows平台下的JDK编译的java class在Linux、Windows平台下通用&#xff0c;Java跨平台特性。3、64位JDK编译的java class在32位的系统…

包装对象

原文地址&#xff1a;https://wangdoc.com/javascript/ 定义 对象是JavaScript语言最主要的数据类型&#xff0c;三种原始类型的值--数值、字符串、布尔值--在一定条件下&#xff0c;也会自动转为对象&#xff0c;也就是原始类型的包装对象。所谓包装对象&#xff0c;就是分别与…

[C++] 转义序列

参考 C Primer(第5版)P36 名称转义序列换行符\n横向制表符\t报警(响铃)符\a纵向制表符\v退格符\b双引号"反斜杠\问号?单引号’回车符\r进纸符\f

vue使用(二)

本节目标&#xff1a; 1.数据路径的三种方式 2.{{}}和v-html的区别 1.绑定图片的路径 方法一&#xff1a;直接写路径 <img src"http://pic.baike.soso.com/p/20140109/20140109142534-188809525.jpg"> 方法二&#xff1a;在data中写路径&#xff0c;在…

typedef 为类型取别名

#include <stdio.h> int main() {   typedef int myint; // 为int 类型取自己想要的名字   myint a 10;   printf("%d", a);   return 0;} 其他类型的用法也是一样的 typedef 类型 自己想要取得名字; 转载于:https://www.cnblogs.com/hello-dummy/p/9…

【C++】如何提高Cache的命中率,示例

参考链接 https://stackoverflow.com/questions/16699247/what-is-a-cache-friendly-code 只是堆积&#xff1a;缓存不友好与缓存友好代码的典型例子是矩阵乘法的“缓存阻塞”。 朴素矩阵乘法看起来像 for(i0;i<N;i) {for(j0;j<N;j) {dest[i][j] 0;for( k;k<N;i)…

springboot---整合redis

pom.xml新增 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>代码结构如下 其中redis.yml是连接redis的配置文件&#xff0c;RedisConfig.java是java配置…

[Head First Java] - 简单的建议程序

参考 - p481、p484 与我对接的业务层使用的是JAVA语言,因此花点时间入门java.下面几篇博客可能都是关于java的,我觉得在工作中可能会遇到的 简单的通信 DailyAdviceClient(客户端程序) import java.io.*; import java.net.*;public class DailyAdviceClient{public void go()…