pytorch03:transforms常见数据增强操作

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:def __init__(self):print("Initializing the CallableClass")def __call__(self, *args, **kwargs):print("Calling the CallableClass with arguments:", args, kwargs)# 实例化对象
obj = CallableClass()# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):"""增加椒盐噪声Args:snr (float): Signal Noise Rate 信噪比p (float): 概率值,依概率执行该操作Attributes:snr (float): 信噪比p (float): 操作执行的概率"""def __init__(self, snr, p=0.9):# 确保传入的snr和p是float类型assert isinstance(snr, float) and isinstance(p, float)self.snr = snrself.p = pdef __call__(self, img):"""对图像应用椒盐噪声操作。Args:img (PIL Image): PIL Image对象Returns:PIL Image: 处理后的PIL Image对象"""# 根据概率决定是否执行噪声操作if random.uniform(0, 1) < self.p:img_ = np.array(img).copy()h, w, c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)# 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声mask = np.random.choice((0, 1, 2), size=(h, w, 1),p=[signal_pct, noise_pct / 2., noise_pct / 2.])mask = np.repeat(mask, c, axis=2)# 根据噪声类型修改图像像素值img_[mask == 1] = 255  # 盐噪声img_[mask == 2] = 0    # 椒噪声# 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGBreturn Image.fromarray(img_.astype('uint8')).convert('RGB')else:return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),  #图片灰度化transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invertdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss_val)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()# ============================ inference ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)rmb = 1 if predicted.numpy()[0] == 0 else 100img_tensor = inputs[0, ...]  # C H Wimg = transform_invert(img_tensor, train_transform)plt.imshow(img)plt.title("LeNet got {} Yuan".format(rmb))plt.show()plt.pause(0.5)plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS4.0系统性深入开发09卡片使用动效能力

卡片使用动效能力 ArkTS卡片开放了使用动画效果的能力&#xff0c;支持显式动画、属性动画、组件内转场能力。需要注意的是&#xff0c;ArkTS卡片使用动画效果时具有以下限制&#xff1a; 名称参数说明限制描述duration动画播放时长限制最长的动效播放时长为1秒&#xff0c;当…

【华为数据之道学习笔记】8-3异常数据监控

质量控制是通过监控质量形成过程&#xff0c;消除全过程中引起不合格或不满意效果的因素&#xff0c;以达到质量要求而采用的各种质量作业技术和活动。要保证最终交付质量&#xff0c;必须对过程进行质量控制&#xff0c;通常是在过程中设置关键质量控制点。例如&#xff0c;可…

Dockerfile学习文档

Dockerfile详解 Dockerfile是一个组合映像命令的文本&#xff1b;可以使用在命令行中调用任何命令&#xff1b;Docker通过dockerfile中的指令自动生成镜像。 通过docker build -t repository:tag ./ 即可构建&#xff0c;要求&#xff1a;./下存在Dockerfile文件 之前我们聊的…

SpringMVC源码解析——DispatcherServlet的逻辑处理

DispatcherServlet类相关的结构图如下&#xff1a; 其中jakarta.servlet.http.HttpServlet的父类是jakarta.servlet.GenericServlet&#xff0c;实现接口jakarta.servlet.Servlet。我们先看一下jakarta.servlet.Servlet接口的源码如下&#xff1a; /*** 定义所有servlet必须实…

PHP8的JIT(Just-In-Time)编译器是什么?

PHP8的JIT&#xff08;Just-In-Time&#xff09;编译器是什么&#xff1f; PHP8是最新的PHP版本&#xff0c;引入了JIT&#xff08;Just-In-Time&#xff09;编译器&#xff0c;以进一步提高性能和执行速度。 JIT编译器是一种在运行时将解释性语言转化为机器码的技术。在过去…

【网络安全】网络隔离设备

一、网络和终端隔离产品 网络和终端隔离产品分为终端隔离产品和网络隔离产品两大类。终端隔离产品一般指隔离卡或者隔离计算机。网络隔离产品根据产品形态和功能上的不同&#xff0c;该类产品可以分为协议转换产品、网闸和网络单向导入产品三种。 图1为终端隔离产品的一个典型…

2022-2023年度广东省职业院校学生专业技能大赛“软件测试”赛项性能测试题目-Jmeter

性能测试-JM 1、脚本添加: 脚本文件名称:SuppAndComp,测试计划名称:SuppAndComp。测试计划下添加两个线程组: (1)线程组一操作内容:系统管理员登录、进行新增供应商操作。 线程组名称SuppAdd。具体要求如下: 登录操作存放到仅一次控制器中,供应商名称前4位为固定…

dash 中的模式匹配回调函数Pattern-Matching Callbacks 8

模式匹配 模式匹配回调选择器 MATCH、ALL 和 ALLSMALLER 允许您编写可以响应或更新任意或动态数量组件的回调函数。 此示例呈现任意数量的 dcc. Dropdown 元素&#xff0c;并且只要任何 dcc. Dropdown 元素发生更改&#xff0c;就会触发回调。尝试添加几个下拉菜单并选择它们的…

Grafana增加仪表盘

1.Grafana介绍 grafana 是一款采用Go语言编写的开源应用&#xff0c;主要用于大规模指标数据的可视化展现&#xff0c;是网络架构和应用分析中最流行的时序数据展示工具&#xff0c;目前已经支持绝大部分常用的时序数据库。 Grafana下载地址&#xff1a;https://grafana.com/g…

burpsuite的安装与介绍

安装(挑一个你喜欢的版本安装就行) 编程环境安装指南:Java、Python 和 Burp Suite抓包工具_burpsuite和java-CSDN博客 简介 Burp Suite是一个用于攻击Web应用程序的集成平台。它集成了多种渗透测试组件,能够帮助我们更好地完成对Web应用的渗透测试和攻击,无论是自动化还…

基于CNN神经网络的手写字符识别实验报告

作业要求 具体实验内容根据实际情况自拟&#xff0c;可以是传统的BP神经网络&#xff0c;Hopfield神经网络&#xff0c;也可以是深度学习相关内容。 数据集自选&#xff0c;可以是自建数据集&#xff0c;或MNIST&#xff0c;CIFAR10等公开数据集。 实验报告内容包括但不限于&am…

nodejs+vue+微信小程序+python+PHP的会议管理系统-计算机毕业设计推荐

会议管理系统可以为公司领导提供会议记录管理功能&#xff0c;公司领导也就是系统的管理员&#xff0c;具有员工管理、公告管理、会议室管理、会议资料管理、会议投票管理、意见收集管理等管理的权限&#xff0c;添加或者删除用户基本信息。管理员需要先进行登录&#xff0c;获…

[C#]opencvsharp进行图像拼接普通拼接stitch算法拼接

介绍&#xff1a; opencvsharp进行图像拼一般有2种方式&#xff1a;一种是传统方法将2个图片上下或者左右拼接&#xff0c;还有一个方法就是融合拼接&#xff0c;stitch拼接就是一种非常好的算法。opencv里面已经有stitch拼接算法因此我们很容易进行拼接。 效果&#xff1a; …

PayPal账户被封是因为什么?如何解决?

Paypal作为跨境出海玩家最常用的付款工具之一&#xff0c;同时也是最容易出现冻结封号现象。保障PP账号安全非常重要&#xff0c;只有支付渠道安全&#xff0c;才不会“白费力气”&#xff0c;那么最重要的就是要了解它的封号原因以做好规避。 一、Paypal账号被封原因 1、账号…

FreeRTOS列表与列表项相关知识总结以及列表项的插入与删除实战

1.列表与列表项概念及结构体介绍 1.1列表项简介 列表相当于链表&#xff0c;列表项相当于节点&#xff0c;FreeRTOS 中的列表是一个双向环形链表 1.2 列表、列表项、迷你列表项结构体 1&#xff09;列表结构体 typedef struct xLIST { listFIRST_LIST_INTEGRITY_CHECK_VAL…

List常见方法和遍历操作

List集合的特点 有序&#xff1a; 存和取的元素顺序一致有索引&#xff1a;可以通过索引操作元素可重复&#xff1a;存储的元素可以重复 List集合的特有方法 Collection的方法List都继承了List集合因为有索引&#xff0c;所以有了很多操作索引的方法 ublic static void main…

SpringBoot如何优雅的处理免登录接口

在项目开发过程中&#xff0c;会有很多API接口不需要登录就能直接访问&#xff0c;比如公开数据查询之类的 ~ 常规处理方法基本是 使用拦截器或过滤器&#xff0c;拦截需要认证的请求路径。在拦截器中判断session或token信息&#xff0c;如果存在则放行&#xff0c;否则跳转到…

挑战 ChatGPT 和 Google Bard 的防御

到目前为止&#xff0c;科学家已经创建了基于人工智能的聊天机器人&#xff0c;可以帮助内容生成。我们还看到人工智能被用来创建像 WormGPT 这样的恶意软件&#xff0c;尽管地下社区对此并不满意。但现在正在创建聊天机器人&#xff0c;可以使用生成人工智能通过即时注入活动来…

编程笔记 html5cssjs 014 网页布局框架

编程笔记 html5&css&js 014 网页布局框架 一、Bootstrap简介二、使用Bootstrap布局 网页布局不只用HTML&#xff0c;还要用CSS和JAVASCRIPT等技术完成,这里暂时简单了解一下Bootstrap。 一、Bootstrap简介 这是一个开源的前端框架&#xff0c;由Twitter的前端工程师Ma…

OpenHarmony之分布式软总线

背景概述 从之前的文档(OpenHarmony之内核层)可知 分布式软总线是多设备终端的统一基座&#xff0c;为设备间的无缝互联提供了统一的分布式通信能力&#xff0c;能够快速发现并连接设备&#xff0c;高效地传输任务和数据。 分布式软总线实现近场设备间统一的分布式通信管理能…