pytorch03:transforms常见数据增强操作

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:def __init__(self):print("Initializing the CallableClass")def __call__(self, *args, **kwargs):print("Calling the CallableClass with arguments:", args, kwargs)# 实例化对象
obj = CallableClass()# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):"""增加椒盐噪声Args:snr (float): Signal Noise Rate 信噪比p (float): 概率值,依概率执行该操作Attributes:snr (float): 信噪比p (float): 操作执行的概率"""def __init__(self, snr, p=0.9):# 确保传入的snr和p是float类型assert isinstance(snr, float) and isinstance(p, float)self.snr = snrself.p = pdef __call__(self, img):"""对图像应用椒盐噪声操作。Args:img (PIL Image): PIL Image对象Returns:PIL Image: 处理后的PIL Image对象"""# 根据概率决定是否执行噪声操作if random.uniform(0, 1) < self.p:img_ = np.array(img).copy()h, w, c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)# 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声mask = np.random.choice((0, 1, 2), size=(h, w, 1),p=[signal_pct, noise_pct / 2., noise_pct / 2.])mask = np.repeat(mask, c, axis=2)# 根据噪声类型修改图像像素值img_[mask == 1] = 255  # 盐噪声img_[mask == 2] = 0    # 椒噪声# 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGBreturn Image.fromarray(img_.astype('uint8')).convert('RGB')else:return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),  #图片灰度化transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invertdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss_val)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()# ============================ inference ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)rmb = 1 if predicted.numpy()[0] == 0 else 100img_tensor = inputs[0, ...]  # C H Wimg = transform_invert(img_tensor, train_transform)plt.imshow(img)plt.title("LeNet got {} Yuan".format(rmb))plt.show()plt.pause(0.5)plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS4.0系统性深入开发09卡片使用动效能力

卡片使用动效能力 ArkTS卡片开放了使用动画效果的能力&#xff0c;支持显式动画、属性动画、组件内转场能力。需要注意的是&#xff0c;ArkTS卡片使用动画效果时具有以下限制&#xff1a; 名称参数说明限制描述duration动画播放时长限制最长的动效播放时长为1秒&#xff0c;当…

SpringMVC源码解析——DispatcherServlet的逻辑处理

DispatcherServlet类相关的结构图如下&#xff1a; 其中jakarta.servlet.http.HttpServlet的父类是jakarta.servlet.GenericServlet&#xff0c;实现接口jakarta.servlet.Servlet。我们先看一下jakarta.servlet.Servlet接口的源码如下&#xff1a; /*** 定义所有servlet必须实…

【网络安全】网络隔离设备

一、网络和终端隔离产品 网络和终端隔离产品分为终端隔离产品和网络隔离产品两大类。终端隔离产品一般指隔离卡或者隔离计算机。网络隔离产品根据产品形态和功能上的不同&#xff0c;该类产品可以分为协议转换产品、网闸和网络单向导入产品三种。 图1为终端隔离产品的一个典型…

dash 中的模式匹配回调函数Pattern-Matching Callbacks 8

模式匹配 模式匹配回调选择器 MATCH、ALL 和 ALLSMALLER 允许您编写可以响应或更新任意或动态数量组件的回调函数。 此示例呈现任意数量的 dcc. Dropdown 元素&#xff0c;并且只要任何 dcc. Dropdown 元素发生更改&#xff0c;就会触发回调。尝试添加几个下拉菜单并选择它们的…

Grafana增加仪表盘

1.Grafana介绍 grafana 是一款采用Go语言编写的开源应用&#xff0c;主要用于大规模指标数据的可视化展现&#xff0c;是网络架构和应用分析中最流行的时序数据展示工具&#xff0c;目前已经支持绝大部分常用的时序数据库。 Grafana下载地址&#xff1a;https://grafana.com/g…

burpsuite的安装与介绍

安装(挑一个你喜欢的版本安装就行) 编程环境安装指南:Java、Python 和 Burp Suite抓包工具_burpsuite和java-CSDN博客 简介 Burp Suite是一个用于攻击Web应用程序的集成平台。它集成了多种渗透测试组件,能够帮助我们更好地完成对Web应用的渗透测试和攻击,无论是自动化还…

基于CNN神经网络的手写字符识别实验报告

作业要求 具体实验内容根据实际情况自拟&#xff0c;可以是传统的BP神经网络&#xff0c;Hopfield神经网络&#xff0c;也可以是深度学习相关内容。 数据集自选&#xff0c;可以是自建数据集&#xff0c;或MNIST&#xff0c;CIFAR10等公开数据集。 实验报告内容包括但不限于&am…

[C#]opencvsharp进行图像拼接普通拼接stitch算法拼接

介绍&#xff1a; opencvsharp进行图像拼一般有2种方式&#xff1a;一种是传统方法将2个图片上下或者左右拼接&#xff0c;还有一个方法就是融合拼接&#xff0c;stitch拼接就是一种非常好的算法。opencv里面已经有stitch拼接算法因此我们很容易进行拼接。 效果&#xff1a; …

PayPal账户被封是因为什么?如何解决?

Paypal作为跨境出海玩家最常用的付款工具之一&#xff0c;同时也是最容易出现冻结封号现象。保障PP账号安全非常重要&#xff0c;只有支付渠道安全&#xff0c;才不会“白费力气”&#xff0c;那么最重要的就是要了解它的封号原因以做好规避。 一、Paypal账号被封原因 1、账号…

FreeRTOS列表与列表项相关知识总结以及列表项的插入与删除实战

1.列表与列表项概念及结构体介绍 1.1列表项简介 列表相当于链表&#xff0c;列表项相当于节点&#xff0c;FreeRTOS 中的列表是一个双向环形链表 1.2 列表、列表项、迷你列表项结构体 1&#xff09;列表结构体 typedef struct xLIST { listFIRST_LIST_INTEGRITY_CHECK_VAL…

List常见方法和遍历操作

List集合的特点 有序&#xff1a; 存和取的元素顺序一致有索引&#xff1a;可以通过索引操作元素可重复&#xff1a;存储的元素可以重复 List集合的特有方法 Collection的方法List都继承了List集合因为有索引&#xff0c;所以有了很多操作索引的方法 ublic static void main…

挑战 ChatGPT 和 Google Bard 的防御

到目前为止&#xff0c;科学家已经创建了基于人工智能的聊天机器人&#xff0c;可以帮助内容生成。我们还看到人工智能被用来创建像 WormGPT 这样的恶意软件&#xff0c;尽管地下社区对此并不满意。但现在正在创建聊天机器人&#xff0c;可以使用生成人工智能通过即时注入活动来…

OpenHarmony之分布式软总线

背景概述 从之前的文档(OpenHarmony之内核层)可知 分布式软总线是多设备终端的统一基座&#xff0c;为设备间的无缝互联提供了统一的分布式通信能力&#xff0c;能够快速发现并连接设备&#xff0c;高效地传输任务和数据。 分布式软总线实现近场设备间统一的分布式通信管理能…

代码随想录刷题第三十四天| 1005.K次取反后最大化的数组和 ● 134. 加油站 ● 135. 分发糖果

代码随想录刷题第三十四天 K次取反后最大化的数组和 (LC 1005) 题目思路&#xff1a; 代码实现&#xff1a; class Solution:def largestSumAfterKNegations(self, nums: List[int], k: int) -> int:nums.sort(keylambda x: abs(x), reverseTrue)for i in range(len(nums…

8868体育助力意甲罗马俱乐部 迪巴拉有望付出

8868体育助力意甲罗马俱乐部 迪巴拉有望付出 意甲罗马俱乐部是8868体育合作球队之一&#xff0c;本赛季&#xff0c;在意甲第14轮的比赛中&#xff0c;罗马客场2-1战胜萨索洛&#xff0c;积分上升到意甲第4位。 有报道称&#xff0c;迪巴拉在对阵佛罗伦萨的比赛中受伤&#xff…

网络Ping不通故障定位思路

故障分析 Ping不通是指Ping报文在网络中传输&#xff0c;由于各种原因&#xff08;如链路故障、ARP学习失败等&#xff09;而接收不到所有Ping应答报文的现象。 如图1-1所示&#xff0c;以一个Ping不通的尝试示例&#xff0c;介绍Ping不通故障的定位思路。 图1-1 Ping不通故…

在机器学习训练测试集中,如何切分出一份验证集

文章目录 1.读取数据&#xff1a;2.绘图查看target数量情况&#xff1a;3.特征拓展&#xff1a;4.构建X&#xff0c;y&#xff1a;5.拆分训练集和测试集&#xff0c;特征做缩放处理&#xff1a;6.从训练集里再切一次出验证集&#xff0c;特征做缩放处理&#xff1a;7.测试集训练…

Shell基本命令 Mkdir创建 cp 复制 ls-R递归打印 文件权限

mkdir 是在命令行界面下创建目录&#xff08;文件夹&#xff09;的命令。它是 “make directory” 的缩写。 以下是 mkdir 命令的基本语法&#xff1a; mkdir 目录路径其中&#xff0c;目录路径 是要创建的目录的路径和名称。 例如&#xff0c;要在当前目录下创建一个名为 m…

Java介绍

Java 是一门纯粹的面向对象编程语言&#xff0c;它吸收了C的各种优点&#xff0c;还努力摒弃了C里难以理解的多继承、指针等概念&#xff0c;真正地实现了面向对象理论&#xff0c;因而具有功能强大和简单易用两个特征。 除了基础语法之外&#xff0c;Java还有许多必须弄懂的特…

OpenCV-Python(22):直方图的计算绘制与分析

目标 了解直方图的原理及应用使用OpenCV 或Numpy 函数计算直方图使用Opencv 或者Matplotlib 函数绘制直方图学习函数cv2.calcHist()、np.histogram()等 原理及应用 直方图是一种统计图形&#xff0c;是对图像的另一种解释&#xff0c;用于表示图像中各个像素值的频次分布。直…