【目标检测】“复制-粘贴 copy-paste” 数据增强实现

文章目录

  • 前言
  • 1. 效果展示
  • 代码说明
  • 3. 参考文档
  • 4. 不合适点


前言

本文来源论文《Simple Copy-Paste is a Strong Data Augmentation Method
for Instance Segmentation》(CVPR2020),对其数据增强方式进行实现。

论文地址:https://arxiv.org/abs/2012.07177

解读:https://mp.weixin.qq.com/s/nKC3bEe3m1eqPDI0LpVTIA

主要思想:
在这里插入图片描述

本文参考该数据增强的语义分割实现[1],相应修改为对应目标检测的实现,坐标变换的写法参考[2]。

其中,对应的标注信息为txt格式,如果自己的数据集是VOC或COCO格式,可自行修改,也可先转换成txt格式再使用下述代码。


1. 效果展示

数据来源CCPD2019数据集,下图分别为img_main和img_src:

在这里插入图片描述
将img_src的车牌目标“复制-粘贴”到img_main的结果:
在这里插入图片描述
新生成的图片大小与img_main一致,空白的部分会补灰边。

代码说明

'''
Descripttion: Data Augment for Object Detection.
version: 1.0.0
Author: lakuite
Date: 2021-08-06 13:37:38
Copyright: Copyright(c) 2021 lakuite. All Rights Reserved
'''import numpy as np
import cv2
import os
import tqdm
import argparse
from skimage.draw import polygon
import randomdef random_flip_horizontal(img, box, p=0.5):'''对img和mask随机进行水平翻转。box为二维np.array。https://blog.csdn.net/weixin_41735859/article/details/106468551img[:,:,::-1] gbr-->bgr、img[:,::-1,:] 水平翻转、img[::-1,:,:] 上下翻转'''if np.random.random() < p:w = img.shape[1]img = img[:, ::-1, :]box[:, [0, 2, 4, 6]] = w - box[:, [2, 0, 6, 4]] # 仅针对4个点变换return img, boxdef Large_Scale_Jittering(img, box, min_scale=0.1, max_scale=2.0):'''对img和box进行0.1-2.0的大尺度抖动,并变回h*w的大小。'''rescale_ratio = np.random.uniform(min_scale, max_scale)h, w, _ = img.shape# rescaleh_new, w_new = int(h * rescale_ratio), int(w * rescale_ratio)img = cv2.resize(img, (w_new, h_new), interpolation=cv2.INTER_LINEAR)# crop or padding# x,y是随机选择左上角的一个点,让小图片在这个位置,或者让大图片从这个位置开始裁剪x, y = int(np.random.uniform(0, abs(w_new - w))), int(np.random.uniform(0, abs(h_new - h)))# 如果图像缩小了,那么其余部分要填充为像素168大小if rescale_ratio <= 1.0:  # paddingimg_pad = np.ones((h, w, 3), dtype=np.uint8) * 168img_pad[y:y + h_new, x:x + w_new, :] = imgbox[:, [0, 2, 4, 6]] = box[:, [0, 2, 4, 6]] * w_new/w + x # x坐标box[:, [1, 3, 5, 7]] = box[:, [1, 3, 5, 7]] * h_new/h + y # y坐标return img_pad, box# 如果图像放大了,那么要裁剪成h*w的大小else:  # cropimg_crop = img[y:y + h, x:x + w, :]box[:, [0, 2, 4, 6]] = box[:, [0, 2, 4, 6]] * w_new/w - xbox[:, [1, 3, 5, 7]] = box[:, [1, 3, 5, 7]] * h_new/h - yreturn img_crop, boxdef img_add(img_src, img_main, mask_src, box_src):'''将src加到main图像中,结果图还是main图像的大小。'''if len(img_main.shape) == 3:h, w, c = img_main.shapeelif len(img_main.shape) == 2:h, w = img_main.shapesrc_h, src_w = img_src.shape[0], img_src.shape[1]mask = np.asarray(mask_src, dtype=np.uint8)# mask是二值图片,对src进行局部遮挡,即只露出目标物体的像素。sub_img01 = cv2.add(img_src, np.zeros(np.shape(img_src), dtype=np.uint8), mask=mask) # 报错深度不一致mask_02 = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)mask_02 = np.asarray(mask_02, dtype=np.uint8)sub_img02 = cv2.add(img_main, np.zeros(np.shape(img_main), dtype=np.uint8),mask=mask_02) # 在main图像上对应位置挖了一块# main图像减去要粘贴的部分的图,然后加上复制过来的图img_main = img_main - sub_img02 + cv2.resize(sub_img01, (w, h),interpolation=cv2.INTER_NEAREST)box_src[:, [0, 2, 4, 6]] = box_src[:, [0, 2, 4, 6]] * w/src_wbox_src[:, [1, 3, 5, 7]] = box_src[:, [1, 3, 5, 7]] * h/src_hreturn img_main, box_srcdef normal_(jpg_path, txt_path="", box=None):"""根据txt获得box或者根据box获得mask。:param jpg_path: 图片路径:param txt_path: x1,y1,x2,y2 x3,y3,x4,y4...:param box: 如果有box,则为根据box生成mask:return: 图像,box 或 掩码"""if isinstance(jpg_path, str): # 如果是路径就读取图片jpg_path = cv2.imread(jpg_path)img = jpg_path.copy()if box is None: # 一定有txt_pathlines = open(txt_path).readlines()box = []for line in lines:ceils = line.strip().split(',')xy = []for ceil in ceils:xy.append(round(float(ceil)))box.append(np.array(xy))return np.array(img), np.array(box)else: # 获得maskh, w = img.shape[:2]mask = np.zeros((h, w), dtype=np.float32)for xy in box: # 对每个框xy = np.array(xy).reshape(-1, 2)cv2.fillPoly(mask, [xy.astype(np.int32)], 1)return np.array(mask)def is_coincide(polygon_1, polygon_2):'''判断2个四边形是否重合:param polygon_1: [x1, y1,...,x4, y4]:param polygon_2::return:  bool,1表示重合'''rr1, cc1 = polygon([polygon_1[i] for i in range(0, len(polygon_1), 2)],[polygon_1[i] for i in range(1, len(polygon_1), 2)])rr2, cc2 = polygon([polygon_2[i] for i in range(0, len(polygon_2), 2)],[polygon_2[i] for i in range(1, len(polygon_2), 2)])try: # 能包含2个四边形的最小矩形长宽r_max = max(rr1.max(), rr2.max()) + 1c_max = max(cc1.max(), cc2.max()) + 1except:return 0# 相当于canvas是包含了2个多边形的一个画布,有2个多边形的位置像素为1,重合位置像素为2canvas = np.zeros((r_max, c_max))canvas[rr1, cc1] += 1canvas[rr2, cc2] += 1intersection = np.sum(canvas == 2)return 1 if intersection!=0 else 0def copy_paste(img_main_path, img_src_path, txt_main_path, txt_src_path, coincide=False, muti_obj=True):'''整个复制粘贴操作,输入2张图的图片和坐标路径,返回其融合后的图像和坐标结果。1. 传入随机选择的main图像和src图像的img和txt路径;2. 对其进行随机水平翻转;3. 对其进行随机抖动;4. 获得src变换完后对应的mask;5. 将src的结果加到main中,返回对应main_new的img和src图的box.'''# 读取图像和坐标img_main, box_main = normal_(img_main_path, txt_main_path)img_src, box_src = normal_(img_src_path, txt_src_path)# 随机水平翻转img_main, box_main = random_flip_horizontal(img_main, box_main)img_src, box_src = random_flip_horizontal(img_src, box_src)# LSJ, Large_Scale_Jittering 大尺度抖动,并变回h*w大小img_main, box_main = Large_Scale_Jittering(img_main, box_main)img_src, box_src = Large_Scale_Jittering(img_src, box_src)if not muti_obj or box_src.ndim==1: # 只复制粘贴一个目标id = random.randint(0, len(box_src)-1)box_src = box_src[id]box_src = box_src[np.newaxis, :] # 增加一维# 获得一系列变换后的img_src的maskmask_src = normal_(img_src_path, box=box_src)# 将src结果加到main图像中,返回main图像的大小的叠加图img, box_src = img_add(img_src, img_main, mask_src, box_src)# 判断融合后的区域是否重合if not coincide:for point_main in box_main:for point_src in box_src:if is_coincide(point_main, point_src):return None, Nonebox = np.vstack((box_main, box_src))return img, boxdef save_res(img, img_path, box, txt_path):'''保存图片和txt坐标结果。'''cv2.imwrite(img_path, img)h, w = img.shape[:2]with open(txt_path, 'w+') as ftxt:for point in box: # [x1,y1,...x4,,y4]strxy = ""for i, p in enumerate(point):if i%2==0: # x坐标p = np.clip(p, 0, w-1)else: # y坐标p = np.clip(p, 0, h-1)strxy = strxy +  str(p) + ','strxy = strxy[:-1] # 去掉最后一个逗号ftxt.writelines(strxy + "\n")def main(args):# 图像和坐标txt文件输入路径JPEGs = os.path.join(args.input_dir, 'jpg')BOXes = os.path.join(args.input_dir, 'txt')# 输出路径os.makedirs(args.output_dir, exist_ok=True)os.makedirs(os.path.join(args.output_dir, 'cpAug_jpg'), exist_ok=True)os.makedirs(os.path.join(args.output_dir, 'cpAug_txt'), exist_ok=True)# 参与数据增强的图片名称,不含后缀imgs_list = open(args.aug_txt, 'r').read().splitlines()flag = '.jpg' # 图像的后缀名 .jpg ,pngtbar = tqdm.tqdm(imgs_list, ncols=100)  # 进度条显示for src_name in tbar:# src图像img_src_path = os.path.join(JPEGs, src_name+flag)txt_src_path = os.path.join(BOXes, src_name+'.txt')# 随机选择main图像main_name = np.random.choice(imgs_list)img_main_path = os.path.join(JPEGs, main_name+flag)txt_main_path = os.path.join(BOXes, main_name+'.txt')# 数据增强img, box = copy_paste(img_main_path, img_src_path, txt_main_path, txt_src_path,args.coincide, args.muti_obj)if img is None:continue# 保存结果img_name = "copy_" + src_name + "_paste_" + main_namesave_res(img, os.path.join(args.output_dir, 'cpAug_jpg', img_name+flag),box, os.path.join(args.output_dir, 'cpAug_txt', img_name+'.txt'))def get_args():parser = argparse.ArgumentParser()parser.add_argument("--input_dir", default="./input_dir", type=str,help="要进行数据增强的图像路径,路径结构下应有jpg和txt文件夹")parser.add_argument("--output_dir", default="./output_dir", type=str,help="保存数据增强结果的路径")parser.add_argument("--aug_txt", default="./input_dir/test.txt",type=str, help="要进行数据增强的图像的名字,不包含后缀")parser.add_argument("--coincide", default=False, type=bool,help="True表示允许数据增强后的图像目标出现重合,默认不允许重合")parser.add_argument("--muti_obj", default=False, type=bool,help="True表示将src图上的所有目标都复制粘贴,False表示只随机粘贴一个目标")return parser.parse_args()if __name__ == "__main__":args = get_args()main(args)
  1. 图像路径:
    在这里插入图片描述
    input_dir存放要数据增强的图片和其对应的txt,其中图片和txt名称应相同,图片后缀可修改 flag,默认为.jpg。output_dir输出数据增强后的图片,无需创建。

  2. 需进行增强的图片列表test.txt,不含后缀:

生成test.txt代码[3]:
在这里插入图片描述

# 获取验证集训练集划分的txt文件,划分仅保存名字,不包含后缀import os
import randomrandom.seed(0)xmlfilepath = './input_dir/txt' # 标签路径
saveBasePath = "./input_dir" # 保存的位置trainval_percent = 0.9 # 训练+验证集的比例,不为1说明有测试集
train_percent = 1 # 训练集在训练+验证集中占的比例,如果代码是从训练集分出的验证集,那就不用改temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:if xml.endswith(".txt"):total_xml.append(xml)num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)print("train and val size", tv)
print("traub suze", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')for i in list:name = total_xml[i][:-4] + '\n'if i in trainval:ftrainval.write(name)if i in train:ftrain.write(name)else:fval.write(name)else:ftest.write(name)ftrainval.close()
ftrain.close()
fval.close()
ftest.close()

运行后可在input_dir下生成4个.txt,其中test.txt仅包含10% input_dir中的图片。

3.标签txt格式:
在这里插入图片描述

3. 参考文档

参考文档
[1] 代码复现:Copy-Paste 数据增强for 语义分割 https://blog.csdn.net/oyezhou/article/details/111696577

[2] 目标检测中的数据增强方法(附详细代码讲解)https://www.cnblogs.com/xiamuzi/p/13471386.html

4. 不合适点

以上是人家的代码,但用在我这边不合适,是因为:它的车牌不会有交叉覆盖,我的是烟火识别,
烟和火是两个目标,有覆盖。 所以不合适。

import globimport cv2
import numpy as np
import randomdef crop_image(image, x, y, width, height):cropped_image = image[y:y + height, x:x + width]return cropped_imagedef convert_to_absolute(label, image_width, image_height):class_id, relative_x_center, relative_y_center, relative_width, relative_height = label# 计算边界框的绝对坐标absolute_x_center = relative_x_center * image_widthabsolute_y_center = relative_y_center * image_heightabsolute_width = relative_width * image_widthabsolute_height = relative_height * image_height# 计算边界框的左上角和右下角坐标left = absolute_x_center - absolute_width / 2top = absolute_y_center - absolute_height / 2right = absolute_x_center + absolute_width / 2bottom = absolute_y_center + absolute_height / 2# 返回绝对坐标形式的边界框return [class_id, left, top, right, bottom]def convert_to_yolo_format(class_id, left, top, right, bottom, image_width, image_height):# 计算目标框的中心点坐标和宽高x = (left + right) / 2y = (top + bottom) / 2width = right - leftheight = bottom - top# 将坐标和尺寸归一化到[0, 1]之间x /= image_widthy /= image_heightwidth /= image_widthheight /= image_height# 返回Yolo格式的标注return f"{class_id} {x} {y} {width} {height}"def get_src():img_list = glob.glob(r"E:\Dataset\zhongwaiyun\data_fire(1w)\data_fire(1w)\scr_copy_paste\images\*.jpg")random.shuffle(img_list)img_path = img_list[0]txt_path = img_list[0].replace("images", "txt").replace(".jpg", ".txt")return img_path, txt_pathimg_list = glob.glob(r"E:\Dataset\zhongwaiyun\zwy_make_background\*.jpg")
for img_b_path in img_list:img_a_path, img_a_txt = get_src()image_a = cv2.imread(img_a_path)image_height, image_width, _ = image_a.shapeimg_b_txt = img_b_path.replace(".jpg", ".txt").replace("zwy_make_background", "zwy_make_fire_and_smoke")img_b_path_new = img_b_path.replace("zwy_make_background", "zwy_make_fire_and_smoke")src_location_map = []with open(img_a_txt) as f:for line_str in f:line_info = line_str.strip().split(" ")label = [int(line_info[0]), float(line_info[1]), float(line_info[2]), float(line_info[3]),float(line_info[4])]class_id, left, top, right, bottom = convert_to_absolute(label, image_width, image_height)src_location_map.append([class_id, left, top, right, bottom])image_b = cv2.imread(img_b_path)res_list = []for row in src_location_map:class_id, left, top, right, bottom = rowif left or top or right or bottom:try:# 目标可以出现在空白图片的任何位置,只要没有超过限制即可x = int(left)  # 指定区域的起始横坐标y = int(top)  # 指定区域的起始纵坐标width = int(right - left)  # 指定区域的宽度height = int(bottom - top)  # 指定区域的高度cropped_image_a = crop_image(image_a, int(x), int(y), int(width), int(height))image_b_height, image_b_width, _ = image_b.shapeb_x = random.randint(0, int(image_b_width - width - 5))b_y = random.randint(0, int(image_b_height - height - 5))image_b[b_y:b_y + height, b_x:b_x + width] = cropped_image_ares = convert_to_yolo_format(class_id, b_x, b_y, b_x + width, b_y + height, image_b_width, image_b_height)print("--==", img_b_txt)with open(img_b_txt, "a") as f:f.write(res)cv2.imwrite(img_b_path_new, image_b)breakexcept:break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/53359.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[matlab]matlab配置mingw64编译器

第一步&#xff1a;下载官方绿色版本mingw64编译器然后解压放到一个非中文空格路径下面 比如我mingw64-win是我随便改的文件名&#xff0c;然后添加环境变量&#xff0c;选择用户或者系统环境变量添加下面的变量 变量名&#xff1a; MW_MINGW64_LOC 变量值&#xff1a;自己的m…

原生小案例:如何使用HTML5 Canvas构建画板应用程序

使用HTML5 Canvas构建绘图应用是在Web浏览器中创建交互式和动态绘图体验的绝佳方式。HTML5 Canvas元素提供了一个绘图表面&#xff0c;允许您操作像素并以编程方式创建各种形状和图形。本文将为您提供使用HTML5 Canvas创建绘图应用的概述和指导。此外&#xff0c;它还将通过解释…

解决无法远程连接MySQL服务的问题

① 设置MySQL中root用户的权限&#xff1a; [rootnginx-dev etc]# mysql -uroot -pRoot123 mysql> use mysql; mysql> GRANT ALL PRIVILEGES ON *.* TO root% IDENTIFIED BY Root123 WITH GRANT OPTION; mysql> select host,user,authentication_string from user; -…

【es6】中的Generator

Generator 一、Generator 是什么&#xff1f;1.1 与普通函数写法不一样&#xff0c;有两个不同 二、Generator 使用2.1 书写方法 三、yield语句3.1 yield和return3.2 注意事项3.3 yield*语句3.4 yield*应用 四、next方法4.1参数4.2 运行逻辑 五、异步解决方案六、Generator相关…

Java“牵手”根据关键词搜索(分类搜索)lazada商品列表页面数据获取方法,lazadaAPI实现批量商品数据抓取示例

lazada商城是一个网上购物平台&#xff0c;售卖各类商品&#xff0c;包括服装、鞋类、家居用品、美妆产品、电子产品等。要获取lazada商品列表和商品详情页面数据&#xff0c;您可以通过开放平台的接口或者直接访问lazada商城的网页来获取商品详情信息。以下是两种常用方法的介…

量子非凡暴风去广告接口

>>>https://videos.centos.chat/lzffbf.php/?url 免费提供综合去广告接口&#xff0c;各位请友好调用

Android 使用模拟器模拟Linux操作系统

1. 简介 在Android手机上使用模拟器模拟ubuntu等操作系统&#xff0c;便于测试 2. 软件准备 Termux&#xff1a;是一款 Android 终端模拟器和 Linux 环境应用程序&#xff0c;无需 root 或设置即可直接运行。虽然酷安和谷歌菜市场都能下载&#xff0c;但这些渠道都很久没更新…

集成学习:Bagging, Boosting,Stacking

目录 集成学习 一、bagging 二、boosting Bagging VS Boosting 1.1 集成学习是什么&#xff1f; Bagging Boosting Stacking 总结 集成学习 好比人做出一个决策时&#xff0c;会从不同方面&#xff0c;不同角度&#xff0c;不同层次去思考&#xff08;多个自我&am…

保姆级别0-10级完整的在线企业帮助中心、官网博客搭建教程及工具

在今天的数字时代&#xff0c;拥有一个完善的在线企业帮助中心和官网博客是非常重要的。这些平台不仅可以提供有关产品和服务的信息&#xff0c;还可以增加客户互动&#xff0c;建立品牌声誉。在这个教程中&#xff0c;我们将从零开始创建一个帮助中心和官网博客&#xff0c;包…

Config:服务端连接Git配置

创建子模块 Pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org…

飞天使-k8s基础组件分析-安全

文章目录 名称空间解释访问kubernetes API的控制RBAC的介绍 kubeconfig用户的创建集群默认角色 给组创建授权针对pod配置服务账户参考文档 名称空间解释 名字是啥&#xff1f; 答&#xff1a;集群中每个对象的名称对于该类型的资源都是唯一的。并且每一个对象在整个集群中也有…

微服务基础知识

文章目录 微服务基础知识一、系统架构的演变1、单体应用架构2、垂直应用架构3、分布式SOA架构&#xff08;1&#xff09;什么是SOA&#xff08;2&#xff09;SOA架构 4、微服务架构5、SOA和微服务的关系&#xff08;1&#xff09;SOA&#xff08;2&#xff09;微服务架构 二、分…

【QT】progressBar的使用(13)

progressBar多用于记录程序运行的时间、文件下载的时间等等&#xff0c;今天就来看一下&#xff0c;如何熟练运用progressBar。 一.环境配置 1.python 3.7.8 可直接进入官网下载安装&#xff1a;Download Python | Python.org 2.QT Designer 官方下载路径&#xff1a;Qt…

数据处理 | Python实现基于DFCP张量分解结合贝叶斯优化的缺失数据填补

数据处理 | Python实现基于DFCP张量分解结合贝叶斯优化的缺失数据填补 目录 数据处理 | Python实现基于DFCP张量分解结合贝叶斯优化的缺失数据填补实践过程基本介绍研究背景程序设计参考资料实践过程 基本介绍 数据处理 | Python实现基于DFCP张量分解结合贝叶斯优化的缺失数据填…

【LeetCode-中等题】48. 旋转图像

文章目录 题目方法一&#xff1a;使用辅助数组矩阵 行列的规律方法二&#xff1a;原地修改 递推公式 题目 方法一&#xff1a;使用辅助数组矩阵 行列的规律 public void rotate(int[][] matrix) {int n matrix.length;int[][] matrix_new new int[n][n];for(int i 0 ; i<…

深入浅出AXI协议(2)——通道及信号

一、前言 在之前的文章中&#xff0c;我们主要介绍了什么是AXI协议&#xff0c;AXI协议的特点与优点&#xff0c;然后对于AXI协议非常重要的五通道结构进行了介绍&#xff0c;了解了5个通道各自的作用。本文我们继续AXI协议的学习&#xff0c;我们将讨论5个通道的具体内容和相对…

uni、js——点击与禁用(不可点击)、动态样式class

案例 没约满的时间可以点击进行选择&#xff0c;约满的就不能选择了。选择完之后变色变字。 核心思想就是创建一个第三方变量存起来&#xff0c;点击谁就存到第三方&#xff0c;在根据这个进行判断。 代码 <template><view class"content"><view cl…

JavaScript:基本语法(变量与函数的定义与使用)

文章目录 script 标签srcdefer 延迟加载 基本语法定义变量 与 使用变量基本类型typeof 查看变量类型复合类型数组类型定义对象类型定义 函数定义函数使用函数 script 标签 src 和scc一样可以内嵌也可以外src外引。 一般是推荐外引。 <script src"idx.js">&l…

opencv 进阶15-检测DoG特征并提取SIFT描述符cv2.SIFT_create()

前面我们已经了解了Harris函数来进行角点检测&#xff0c;因为角点的特性&#xff0c;这些角点在图像旋转的时候也可以被检测到。但是&#xff0c;如果我们放大或缩小图像时&#xff0c;就可能会丢失图像的某些部分&#xff0c;甚至有可能增加角点的质量。这种损失的现象需要一…

ant design自定义展开折叠查看子项和点击行查看详情

实现思路&#xff1a;通过配置rowSelection&#xff0c;列表项是否可选择来实现。 页面内容&#xff1a; <a-table :dataSource"integrationBonds" :columns"columns" :customRow"customintegrationBondsRow":pagination"{hideOnSingle…