pytorch一致数据增强

分割任务对 image 做(某些)transform 时,要对 label(segmentation mask)也做对应的 transform,如 Resize、RandomRotation 等。如果对 image、label 分别用 transform 处理一遍,则涉及随机操作的可能不一致,如 RandomRotation 将 image 转了 a 度、却将 label 转了 b 度。

MONAI 有个 ArrayDataset 实现了这功能,思路是每次 transform 前都重置一次 random seed 先。对 monai 订制 transform 的方法不熟,torchvision.transforms 的订制接口比较简单,考虑基于 pytorch 实现。要改两个东西:

  • 扩展 torchvison.transforms.Compose,使之支持多个输入(image、label);
  • 一个 wrapper,扩展 transform,使之支持多输入。

思路也是重置 random seed,参考 [1-4]。

Code

  • to_multi:将处理单幅图的 transform 扩展成可处理多幅;
  • MultiCompose:扩展 torchvision.transforms.Compose,可输入多幅图。内部调用 to_multi 扩展传入的 transforms。
import random, os
import numpy as np
import torchdef seed_everything(seed=42):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Truedef to_multi(trfm):"""wrap a transform to extend to multiple input with synchronised random seedInput:trfm: transformation function/object (custom or from torchvision.transforms)Output:_multi_transform: function"""# numpy.random.seed range error:#   ValueError: Seed must be between 0 and 2**32 - 1min_seed = 0 # - 0x8000_0000_0000_0000max_seed = min(2**32 - 1, 0xffff_ffff_ffff_ffff)def _multi_transform(*images):"""images: [C, H, W]"""if len(images) == 1:return trfm(images[0])_seed = random.randint(min_seed, max_seed)res = []for img in images:seed_everything(_seed)res.append(trfm(img))return tuple(res)return _multi_transformclass MultiCompose:"""Extension of torchvision.transforms.Compose that accepts multiple input.Usage is the same as torchvision.transforms.Compose. This class will wrap inputtransforms with `to_multi` to support simultaneous multiple transformation.This can be useful when simultaneously transforming images & segmentation masks."""def __init__(self, transforms):"""transforms should be wrapped by `to_multi`"""self.transforms = [to_multi(t) for t in transforms]def __call__(self, *images):for t in self.transforms:images = t(*images)return images

test

测试一致性,用到预处理过的 verse’19 数据集、一些工具函数、一个订制 transform:

  • verse’19 数据集及预处理见 iTomxy/data/verse;
  • digit_sort_key:数据文件排序用;
  • get_palettecolor_segblend_seg:可视化用;
  • MyDataset:看其中 __getitem__ 的 transform 用法,即同时传入 image 和 label;
  • ResizeZoomPad:一个订制的 transform;
import os, os.path as osp, random
from glob import glob
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as Fdef digit_sort_key(s, num_pattern=re.compile('([0-9]+)')):"""natural sort,数据排序用"""return [int(text) for text in num_pattern.split(s) if text.isdigit()]def get_palette(n_classes, pil_format=True):"""创建调色盘,可视化用"""n = n_classespalette = [0] * (n * 3)for j in range(0, n):lab = jpalette[j * 3 + 0] = 0palette[j * 3 + 1] = 0palette[j * 3 + 2] = 0i = 0while lab:palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))i += 1lab >>= 3if pil_format:return paletteres = []for i in range(0, len(palette), 3):res.append(tuple(palette[i: i+3]))return resdef color_seg(label, n_classes=0):"""segmentation mask 上色,可视化用"""if n_classes < 1:n_classes = math.ceil(np.max(label)) + 1label_rgb = Image.fromarray(label.astype(np.int32)).convert("L")label_rgb.putpalette(get_palette(n_classes))return label_rgb.convert("RGB")def blend_seg(image, label, n_classes=0, alpha=0.7, rescale=False, transparent_bg=True, save_file=""):"""融合 image 和其 segmentation mask,可视化用"""if rescale:denom = image.max() - image.min()if 0 != denom:image = (image - image.min()) / denom * 255image = np.clip(image, 0, 255).astype(np.uint8)img_pil = Image.fromarray(image).convert("RGB")lab_pil = color_seg(label, n_classes)blended_image = Image.blend(img_pil, lab_pil, alpha)if transparent_bg:blended_image = Image.fromarray(np.where((0 == label)[:, :, np.newaxis],np.asarray(img_pil),np.asarray(blended_image)))if save_file:blended_image.save(save_file)return blended_imageclass MyDataset(torch.utils.data.Dataset):"""订制 dataset,看 __getitem__ 处 transform 的调法"""def __init__(self, image_list, label_list, transform=None):assert len(image_list) == len(label_list)self.image_list = image_listself.label_list = label_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, index):img = np.load(self.image_list[index]) # [h, w]lab = np.load(self.label_list[index])img = torch.from_numpy(img).unsqueeze(0).float() # -> [c=1, h, w]lab = torch.from_numpy(lab).unsqueeze(0).int()if self.transform is not None:img, lab = self.transform(img, lab) # 同时传入 image、labelreturn img, labclass ResizeZoomPad:"""订制 resize"""def __init__(self, size, interpolation="bilinear"):if isinstance(size, int):assert size > 0self.size = [size, size]elif isinstance(size, (tuple, list)):assert len(size) == 2 and size[0] > 0 and size[1] > 0self.size = sizeif isinstance(interpolation, str):assert interpolation.lower() in {"nearest", "bilinear", "bicubic", "box", "hamming", "lanczos"}interpolation = {"nearest": F.InterpolationMode.NEAREST,"bilinear": F.InterpolationMode.BILINEAR,"bicubic": F.InterpolationMode.BICUBIC,"box": F.InterpolationMode.BOX,"hamming": F.InterpolationMode.HAMMING,"lanczos": F.InterpolationMode.LANCZOS}[interpolation.lower()]self.interpolation = interpolationdef __call__(self, image):"""image: [C, H, W]"""scale_h, scale_w = float(self.size[0]) / image.size(1), float(self.size[1]) / image.size(2)scale = min(scale_h, scale_w)tmp_size = [ # clipping to ensure sizemin(int(image.size(1) * scale), self.size[0]),min(int(image.size(2) * scale), self.size[1])]image = F.resize(image, tmp_size, self.interpolation)assert image.size(1) <= self.size[0] and image.size(2) <= self.size[1]pad_h, pad_w = self.size[0] - image.size(1), self.size[1] - image.size(2)if pad_h > 0 or pad_w > 0:pad_left, pad_right = pad_w // 2, (pad_w + 1) // 2pad_top, pad_bottom = pad_h // 2, (pad_h + 1) // 2image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom))return image# 读数据文件
data_path = os.path.expanduser("~/data/verse/processed-verse19-npy-horizontal")
train_images, train_labels, val_images, val_labels = [], [], [], []
for d in os.listdir(osp.join(data_path, "training")):if d.endswith("_ct"):img_p = osp.join(data_path, "training", d)lab_p = osp.join(data_path, "training", d[:-3]+"_seg-vert_msk")assert osp.isdir(lab_p)train_labels.extend(glob(os.path.join(lab_p, "*.npy")))train_images.extend(glob(os.path.join(img_p, "*.npy")))
for d in os.listdir(osp.join(data_path, "validation")):if d.endswith("_ct"):img_p = osp.join(data_path, "validation", d)lab_p = osp.join(data_path, "validation", d[:-3]+"_seg-vert_msk")assert osp.isdir(lab_p)val_labels.extend(glob(os.path.join(lab_p, "*.npy")))val_images.extend(glob(os.path.join(img_p, "*.npy")))# 数据文件名排序
train_images = sorted(train_images, key=lambda f: digit_sort_key(os.path.basename(f)))
train_labels = sorted(train_labels, key=lambda f: digit_sort_key(os.path.basename(f)))
val_images = sorted(val_images, key=lambda f: digit_sort_key(os.path.basename(f)))
val_labels = sorted(val_labels, key=lambda f: digit_sort_key(os.path.basename(f)))# transform
# 用 MultiCompose,其内部调用 to_multi 将 transforms wrap 成支持多输入的
train_trans = MultiCompose([ResizeZoomPad((224, 256)),transforms.RandomRotation(30),
])# 测试:读数据,可试化 image 和 label
check_ds = MyDataset(train_images, train_labels, train_trans)
check_loader = torch.utils.data.DataLoader(check_ds, batch_size=10, shuffle=True)
for images, labels in check_loader:print(images.size(), labels.size())for i in range(images.size(0)):# print(i, end='\r')img = images[i][0].numpy()lab = labels[i][0].numpy()print(np.unique(lab))seg_img = blend_seg(img, lab)img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8)img = np.asarray(Image.fromarray(img).convert("RGB"))lab = np.asarray(color_seg(lab))comb = np.concatenate([img, lab, seg_img], axis=1)Image.fromarray(comb).save(f"test-dataset-{i}.png")break

效果:
test-dataset-7.png
可见,image 和 label 转了同一个随机角度。

Limits

有些 augmentations 是只对 image 做而不对 label 做的,如 ColorJitter,这里没有考虑怎么处理。

References

  1. How to Set Random Seeds in PyTorch and Tensorflow
  2. ihoromi4/seed_everything.py
  3. Reproducibility
  4. What is the max seed you can set up?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/220897.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络网络层(期末、考研)

计算机网络总复习链接&#x1f517; 目录 路由算法静态路由与动态路由距离-向量算法链路状态路由算法层次路由 IPv4&#xff08;这个必考&#xff09;IPv4分组IPv4地址与NAT子网划分与子网掩码、CIDRARP、DHCP与ICMP地址解析协议ARP动态主机配置协议DHCP IPv6IPv6特点 路由协议…

android studio 创建按钮项目

1&#xff09;、新建一个empty activity项目&#xff0c;切换到project视图&#xff1a; 2&#xff09;、修改app\src\main\res\layout\activity_main.xml文件&#xff0c;修改后如下&#xff1a; <?xml version"1.0" encoding"utf-8"?> <andr…

html基础知识

1、文字阴影代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <meta http-eq…

Vue 工作开发小技巧

一、汇总 ​ 本博客&#xff0c;记录了一些Vue在日常开发工作中比较实用的小技巧&#xff0c;后续会陆续添加更新。 ​ 1、利用Sass的:global定义全局样式。 ​ 2、在<style>内部使用v-bind给CSS属性绑定属性值。 ​ 3、父子组件传值时&#xff0c;使用.sync修饰符后…

cgteamwork与shotgrid对比

最近有项目接触使用并二开cgteamwork&#xff0c; 也重新认识了cgteamwork&#xff0c;感受到国产软件的强大&#xff0c;国内中小CG公司的首选&#xff0c;原因&#xff1a; 1 上手容易&#xff0c;不会的有售前工程师教&#xff0c;他们全国各地城市到处跑。 感概业务的强大…

智能优化算法应用:基于生物地理学算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于生物地理学算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于生物地理学算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.生物地理学算法4.实验参数设定5.算法…

Visual studio+Qt开发环境搭建以及注意事项和打开qt的.pro项目

下载qt-然后安装5.14.2_msvc2017 不知道安装那个就全选5.14.2的父级按钮 https://download.qt.io/archive/qt/5.14/5.14.2/ 安装Visual studio,下载直接下一步就行 配置Visual studio的qt环境 在线安装-重启Visual studio会自动安装 离线安装-关闭Visual studio点击安装 关闭…

桂电|《操作系统》实验一:UNIX/LINUX及其使用环境(实验报告)

桂林电子科技大学2023-2024学年 第 一 学期 操作系统A 实验报告 实验名称 实验一 UNIX/LINUX及其使用环境 实验指导老师&#xff1a; 成绩 院 系 计算机与信息安全学院 专业 计算机科学与技术(卓越工程) 学 号 姓名 课内序…

Spring Boot+FreeMarker=打造高效Web应用

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于Spring BootFreeMarker的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一. FreeMarker是什么 二…

本地连锁门店经营可以借助系统实现哪些功能?

不少的连锁门店目前还是很基础的ERPPOS收银&#xff0c;其他的还是走传统的手工管理&#xff0c;大多连锁老板知道借助信息化系统可以帮助门店实现精细化管理&#xff0c;提高运营效率&#xff0c;降低成本&#xff0c;增强竞争力&#xff0c;但不知道怎么去做&#xff0c;能做…

每日汇评:黄金需要突破2050美元的供应区域才能延续复苏

周四早间&#xff0c;金价接近每盎司2,030美元&#xff0c;创下6天来的最高水平&#xff1b; 美联储确认鸽派政策转向&#xff0c;美元和美国国债收益率双双下挫&#xff1b; 英国央行和欧洲央行2023年的最终政策公告可能会进一步推高金价&#xff1b; 随着投资者重新评估美联储…

2020年第九届数学建模国际赛小美赛C题亚马逊野火解题全过程文档及程序

2020年第九届数学建模国际赛小美赛 C题 亚马逊野火 原题再现&#xff1a; 野火是指发生在乡村或荒野地区的可燃植被中的任何不受控制的火灾。这样的环境过程对人类生活有着重大的影响。因此&#xff0c;对这一现象进行建模&#xff0c;特别是对其空间发生和扩展进行建模&…

0x13 链表与邻接表

0x13 链表与邻接表 数组是一种支持随机访问&#xff0c;但不支持在任意位置插入和删除元素的数据结构。与之相对应&#xff0c;链表支持在任意位置插入或删除元素&#xff0c;但只能按顺序依次访问其中元素。我们可以使用一个struct来表示链表的节点&#xff0c;其中可以存储任…

《师兄啊师兄》第二季开播 李长寿渡劫归来扬名四海

看新国风&#xff0c;上优酷动漫&#xff01;由优酷出品&#xff0c;玄机科技制作&#xff0c;改编自阅文集团旗下起点读书小说《我师兄实在太稳健了》&#xff08;作者&#xff1a;言归正传&#xff09;的修仙喜剧动画《师兄啊师兄》第二季《海神扬名篇》于今日10:00正式回归。…

如何性能测试中进行业务验证?

在性能测试过程中&#xff0c;验证HTTP code和响应业务code码是比较基础的&#xff0c;但是在一些业务中&#xff0c;这些参数并不能保证接口正常响应了&#xff0c;很可能返回了错误信息&#xff0c;所以这个时候对接口进行业务验证就尤其重要。下面分享一个对某个资源进行业务…

Python多线程threading的使用方法

前言 有时候&#xff0c;我们在编写Python程序时&#xff0c;会遇到比较耗时的函数方法&#xff0c;我们的需求是等这个耗时的函数执行完毕之后&#xff0c;在执行后面的程序&#xff0c;这时候就需要用到多进程。 下面我们来举一个使用多进程threading的例子 例子 import t…

Unity | AVpro的最基础使用方法(视频播放插件)

一、 AVpro的使用方法 (一)准备播放器MediaPlayer 1. AVpro的播放器是MediaPlayer&#xff0c;在Heirarchy面板里创建 2.播放器里放视频 a.把视频放到StreamingAssets文件夹下 b.你就可以在MediaPlayer里面找到这个视频 c.选中以后&#xff0c;就会变成 这里点击播放可以播放…

FET偏置控制器电路的卫星接收器LNB电路

都具有FET偏置控制器电路的卫星接收器LNB电路 芯片的描述&#xff1a;D3211是一-块用于卫星接收LNBs的专用电路&#xff0c;具有极化电压检测切换、22KHz脉冲检测切换和提供高放、本振级GaAs或HEMT FET晶体管工作点偏置等功能。D321 1内部的22K检测及切换控制由22K有源滤波器、…

RT-DETR改进《目标对象计数》多任务实验:深度集成版来了!支持自定义数据集训练自定义模型

💡该教程为改进RT-DETR专栏,属于《芒果书》📚系列,包含大量的原创改进方式🚀 💡🚀🚀🚀内含改进源代码 按步骤操作运行改进后的代码即可💡更方便的统计更多实验数据,方便写作 RT-DETR改进《目标对象计数》多任务实验:深度集成版来了!支持自定义数据集训练…

车联网助力自动驾驶发展

单车智能决策难点 芯片&#xff0c;成为自动驾驶的最大瓶颈 自动驾驶对芯片算力要求极高。要求自动驾驶处理器在每秒能够处理数百万亿次的计算&#xff1b; 自动驾驶对计算的实时性要求极高。任何一点时延&#xff0c;都有可能造成车毁人亡&#xff1b; 对低能耗有极大的…