gdip-yolo项目解读:gdip模块 |mdgip模块 |GDIP regularizer模块的使用分析

gdip-yolo是2022年提出了一个端到端的图像自适应目标检测框架,其论文中的效果展示了良好的图像增强效果。其提出了gdip模块 |mdgip模块 |GDIP regularizer模块等模块,并表明这是效果提升的关键。为此对gdip-yolo的项目进行深入分析。
gdip-yolo的论文可以查阅:https://hpg123.blog.csdn.net/article/details/135658906

在这里插入图片描述

1、整体分析

gdip-yolo项目基于yolov3项目改进所实现,与原始代码相比,仅是删除了训练代码。这里的代码与核心部分gdip功能关联不是很强,其配置文件为常规yolov3训练配置文件。
在这里插入图片描述

1.1 配置文件

这里所展露的是gdip-yolo项目中基于py的方式编写配置文件,与新一代的配置文件格式yaml|yml相比存在一定不足。

# coding=utf-8
# projectDATA_PATH = "/scratch/data"
PROJECT_PATH = "/scratch/"
WEIGHT_PATH="/scratch/data/weights/darknet53_448.weights"DATA = {"CLASSES":['person','bicycle','car','bus','motorbike'],"NUM":5}
#DATA = {"CLASSES":['bicycle','boat','bottle','bus','car','cat','chair','dog','motorbike','person'],
#        "NUM":10}
# model
MODEL = {"ANCHORS":[[(1.25, 1.625), (2.0, 3.75), (4.125, 2.875)],  # Anchors for small obj[(1.875, 3.8125), (3.875, 2.8125), (3.6875, 7.4375)],  # Anchors for medium obj[(3.625, 2.8125), (4.875, 6.1875), (11.65625, 10.1875)]] ,# Anchors for big obj"STRIDES":[8, 16, 32],"ANCHORS_PER_SCLAE":3}# train
TRAIN = {"TRAIN_IMG_SIZE":448,"AUGMENT":True,"BATCH_SIZE":8,"MULTI_SCALE_TRAIN":False,"IOU_THRESHOLD_LOSS":0.5,"EPOCHS":80,"NUMBER_WORKERS":5,"MOMENTUM":0.9,"WEIGHT_DECAY":0.0005,"LR_INIT":1e-4,"LR_END":1e-6,"WARMUP_EPOCHS":2  # or None}# test
TEST = {"TEST_IMG_SIZE":448,"BATCH_SIZE":1,"NUMBER_WORKERS":0,"CONF_THRESH":0.01,"NMS_THRESH":0.5,"MULTI_SCALE_TEST":False,"FLIP_TEST":False,"DATASET_PATH":"/scratch/data/RTTS","DATASET_DIRECTORY":"JPEGImages"}

1.2 推理与测试代码

推理代码核心为eval目录下的各类evaluator文件,在项目外面的推理代码仅为for循环调用。通过对eval相关代码进行对比分析,发现针对于各类数据的测试代码处数据预处理部分有差异外其余结构都完全一致。
在这里插入图片描述

from torch.utils.data import DataLoader
import utils.gpu as gpu
from model.yolov3_multilevel_gdip import Yolov3
from tqdm import tqdm
from utils.tools import *
from eval.evaluator_RTTS_GDIP import Evaluator
import argparse
import os
import config.yolov3_config_RTTS as cfg
from utils.visualize import *
from tqdm import tqdm# import os
# os.environ["CUDA_VISIBLE_DEVICES"]='0'class Tester(object):def __init__(self,weight_path=None,gpu_id=0,img_size=544,visiual=None,eval=False):self.img_size = img_sizeself.__num_class = cfg.DATA["NUM"]self.__conf_threshold = cfg.TEST["CONF_THRESH"]self.__nms_threshold = cfg.TEST["NMS_THRESH"]self.__device = gpu.select_device(gpu_id)self.__multi_scale_test = cfg.TEST["MULTI_SCALE_TEST"]self.__flip_test = cfg.TEST["FLIP_TEST"]self.__visiual = visiualself.__eval = evalself.__classes = cfg.DATA["CLASSES"]self.__model = Yolov3(cfg).to(self.__device)self.__load_model_weights(weight_path)self.__evalter = Evaluator(self.__model, visiual=False)def __load_model_weights(self, weight_path):print("loading weight file from : {}".format(weight_path))weight = os.path.join(weight_path)chkpt = torch.load(weight, map_location=self.__device)self.__model.load_state_dict(chkpt)# self.__model.load_state_dict(chkpt['model'])print("loading weight file is done")del chkptdef test(self):if self.__visiual:imgs = os.listdir(self.__visiual)for v in tqdm(imgs):path = os.path.join(self.__visiual, v)# print("test images : {}".format(path))img = cv2.imread(path)assert img is not Nonebboxes_prd = self.__evalter.get_bbox(img)if bboxes_prd.shape[0] != 0:boxes = bboxes_prd[..., :4]class_inds = bboxes_prd[..., 5].astype(np.int32)scores = bboxes_prd[..., 4]visualize_boxes(image=img, boxes=boxes, labels=class_inds, probs=scores, class_labels=self.__classes)path = os.path.join(cfg.PROJECT_PATH, "results/rtts/{}".format(v))cv2.imwrite(path, img)if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('--weight_path', type=str, default='best.pt', help='weight file path')parser.add_argument('--visiual', type=str, default='path/to/images', help='test data path or None')parser.add_argument('--eval', action='store_true', default=True, help='eval the mAP or not')parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')opt = parser.parse_args()Tester( weight_path=opt.weight_path,gpu_id=opt.gpu_id,eval=opt.eval,visiual=opt.visiual).test()

2、数据加载器分析

在gdip-yolo论文中提到训练时无需专属loss,理论上参照原始的dataloader即可,但是在分析代码时发现针对带雾数据、低亮度数据都有单独的加载器。为此进行源码分析。

2.1 IA_datasets_foggy.py

代码在utils\IA_datasets_foggy.py中,其关键代码如下所示,非foggy相关代码部分被博主删除了。可以看到IA_datasets_foggy中返回了img 与adv_img 图像,adv_img 为img的带雾副本图像(使用getFog函数实现)。这里奇怪的是adv_img 在论文中没有利用,却被返回了。带雾图像的概率为0.5。这里可以看出gdip-yolo使用的是在线数据增强的策略。ia-yolo:使用ASM来生成10个不同级别的雾,以包括在我们的综合训练集中的方差。我们以类似的方式从PascalVOC 2007测试集准备4952张图像(称为V_F_Ts)合成测试集。我们采用了一种混合策略,即混合使用雾和清晰的图像(以2:1的比例),即带雾概率为0.66

class VocDataset(Dataset):def __getitem__(self, item):img_org,adv_img_org, bboxes_org = self.__parse_annotation(self.__annotations[item])img_org = img_org.transpose(2, 0, 1)  # HWC->CHWadv_img_org = adv_img_org.transpose(2, 0, 1)  # HWC->CHWimg,adv_img, bboxes = dataAug.Mixup()(img_org,adv_img_org, bboxes_org)del img_org, bboxes_org,adv_img_orglabel_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.__creat_label(bboxes)img = torch.from_numpy(img).float()adv_img = torch.from_numpy(adv_img).float()label_sbbox = torch.from_numpy(label_sbbox).float()label_mbbox = torch.from_numpy(label_mbbox).float()label_lbbox = torch.from_numpy(label_lbbox).float()sbboxes = torch.from_numpy(sbboxes).float()mbboxes = torch.from_numpy(mbboxes).float()lbboxes = torch.from_numpy(lbboxes).float()return img,adv_img, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxesdef __parse_annotation(self, annotation):"""Data augument.:param annotation: Image' path and bboxes' coordinates, categories.ex. [image_path xmin,ymin,xmax,ymax,class_ind xmin,ymin,xmax,ymax,class_ind ...]:return: Return the enhanced image and bboxes. bbox'shape is [xmin, ymin, xmax, ymax, class_ind]"""anno = annotation.strip().split(' ')img_path = anno[0]img = cv2.imread(img_path)  # H*W*C and C=BGRassert img is not None, 'File Not Found ' + img_pathbboxes = np.array([list(map(float, box.split(','))) for box in anno[1:]])img, bboxes = dataAug.RandomHorizontalFilp()(np.copy(img), np.copy(bboxes))img, bboxes = dataAug.RandomCrop()(np.copy(img), np.copy(bboxes))img, bboxes = dataAug.RandomAffine()(np.copy(img), np.copy(bboxes))adv_img = img.copy()  # H*W*C and C=BGRif random.randint(0,2) > 0:adv_img = normalize(adv_img)fog_img = getFog(adv_img.copy())fog_img = fog_img.astype(np.uint8)adv_img = fog_img.copy()# assert adv_img is not None, 'File Not Found ' + adv_img_pathimg, bboxes = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(img), np.copy(bboxes))adv_img,_ = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(adv_img), np.copy(bboxes))return img,adv_img, bboxes

2.2 IA_datasets_lightning.py

IA_datasets_lightning的代码在utils\IA_datasets_lightning.py中,同样按照惯例删除非关键代码进行分析。IA_datasets_lightning的实现风格与上一份一样,多了一个adv_img,为原始图像的低亮度副本,基于getLightning函数实现(伽玛从1.5到5的范围内均匀采样)。图像低亮度的概率为0.5。这里可以看出gdip-yolo使用的是在线数据增强的策略。ia-yolo: 我们从ExDark中选择具有对象的PascalVOC中的图像,并应用伽玛变化来模拟低光照条件,伽玛从1.5到5的范围内均匀采样。在训练过程中,我们采用混合策略(类似于雾设置),使用黑暗和清晰图像的混合。即0.66的低亮度概率

class VocDataset(Dataset):def __getitem__(self, item):img_org,adv_img_org, bboxes_org = self.__parse_annotation(self.__annotations[item])img_org = img_org.transpose(2, 0, 1)  # HWC->CHWadv_img_org = adv_img_org.transpose(2, 0, 1)  # HWC->CHWimg,adv_img, bboxes = dataAug.Mixup()(img_org,adv_img_org, bboxes_org)del img_org, bboxes_org,adv_img_orglabel_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.__creat_label(bboxes)img = torch.from_numpy(img).float()adv_img = torch.from_numpy(adv_img).float()label_sbbox = torch.from_numpy(label_sbbox).float()label_mbbox = torch.from_numpy(label_mbbox).float()label_lbbox = torch.from_numpy(label_lbbox).float()sbboxes = torch.from_numpy(sbboxes).float()mbboxes = torch.from_numpy(mbboxes).float()lbboxes = torch.from_numpy(lbboxes).float()return img,adv_img, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxesdef __parse_annotation(self, annotation):"""Data augument.:param annotation: Image' path and bboxes' coordinates, categories.ex. [image_path xmin,ymin,xmax,ymax,class_ind xmin,ymin,xmax,ymax,class_ind ...]:return: Return the enhanced image and bboxes. bbox'shape is [xmin, ymin, xmax, ymax, class_ind]"""anno = annotation.strip().split(' ')img_path = anno[0]img = cv2.imread(img_path)  # H*W*C and C=BGRassert img is not None, 'File Not Found ' + img_pathbboxes = np.array([list(map(float, box.split(','))) for box in anno[1:]])img, bboxes = dataAug.RandomHorizontalFilp()(np.copy(img), np.copy(bboxes))img, bboxes = dataAug.RandomCrop()(np.copy(img), np.copy(bboxes))img, bboxes = dataAug.RandomAffine()(np.copy(img), np.copy(bboxes))adv_img = img.copy()  # H*W*C and C=BGRif random.randint(0,2) > 0:adv_img = normalize(adv_img)l_img = getLightning(adv_img.copy())l_img = l_img.astype(np.uint8)adv_img = l_img.copy()# assert adv_img is not None, 'File Not Found ' + adv_img_pathimg, bboxes = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(img), np.copy(bboxes))adv_img,_ = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(adv_img), np.copy(bboxes))return img,adv_img, bboxes

2.3 getFog与getLightning函数

getFog函数实现如下所示,相比于ia-yolo的实现代码行数更多

def getFog(img):h,w,c = img.shapex = np.linspace(0,w-1,w)y = np.linspace(0,h-1,h)xx,yy = np.meshgrid(x,y)x_c , y_c = w//2 , h//2transmission_map = np.zeros((h,w,1))c = np.random.uniform(0,9)beta = 0.01*c+0.05A = 0.5d = -0.04 * np.sqrt((yy-y_c)**2+(xx-x_c)**2)+np.sqrt(np.maximum(h,w))transmission_map[:,:,0] = np.exp(-beta*d)fog_img = img*transmission_map + (1-transmission_map)* A# fog_img = normalize(fog_img)fog_img = fog_img*255.fog_img = np.clip(fog_img,0,255)return fog_img

getLightning的实现如下,基于gamma变化实现

def getLightning(img):gamma = np.random.uniform(1.5,5)img = img**gammaimg = img*255.img = np.clip(img,0,255)return img

3、GDIP-yolo关键模块

在GDIP-yolo论文中描述到,没有额外使用loss,故此所开源的loss代码与原始yolov3 loss一模一样。但是在GDIP regularizer模块中需要额外loss(与原始图像计算l1 loss与 mae loss作为正则项),但是没有找到相应实现。

3.1 GatedDIP

在gidp-yolo项目中有多个GatedDIP模块,这里以符合论文中描述的代码为参考。通过代码注释可以看到GatedDIP可以使用vgg16做编码器。这里的GatedDIP内带VisionEncoder。

import math
import torch
import torchvision
from model.vision_encoder import VisionEncoderclass GatedDIP(torch.nn.Module):"""_summary_Args:torch (_type_): _description_"""def __init__(self,encoder_output_dim : int = 256,num_of_gates : int = 7):"""_summary_Args:encoder_output_dim (int, optional): _description_. Defaults to 256.num_of_gates (int, optional): _description_. Defaults to 7."""super(GatedDIP,self).__init__()print("GatedDIP with custom Encoder!!")# Encoder Model# self.encoder = torchvision.models.vgg16(pretrained=False)self.encoder = VisionEncoder(encoder_output_dim=encoder_output_dim)# Gating Moduleself.gate_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,num_of_gates,bias=True))# White-Balance Moduleself.wb_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,3,bias=True))# Gamma Moduleself.gamma_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))# Sharpning Moduleself.gaussian_blur = torchvision.transforms.GaussianBlur(13, sigma=(0.1, 5.0))self.sharpning_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))# De-Fogging Moduleself.defogging_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))# Contrast Moduleself.contrast_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))# Contrast Moduleself.tone_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,8,bias=True))def rgb2lum(self,img: torch.tensor):"""_summary_Args:img (torch.tensor): _description_Returns:_type_: _description_"""img = 0.27 * img[:, 0, :, :] + 0.67 * img[:, 1, :,:] + 0.06 * img[:, 2, :, :]return imgdef lerp(self ,a : int , b : int , l : torch.tensor):return (1 - l.unsqueeze(2).unsqueeze(3)) * a + l.unsqueeze(2).unsqueeze(3) * bdef dark_channel(self,x : torch.tensor):"""_summary_Args:x (torch.tensor): _description_Returns:_type_: _description_"""z = x.min(dim=1)[0].unsqueeze(1)return zdef atmospheric_light(self,x : torch.tensor,dark : torch.tensor ,top_k : int=1000):"""_summary_Args:x (torch.tensor): _description_top_k (int, optional): _description_. Defaults to 1000.Returns:_type_: _description_"""h,w = x.shape[2],x.shape[3]imsz = h * wnumpx = int(max(math.floor(imsz/top_k),1))darkvec = dark.reshape(x.shape[0],imsz,1)imvec = x.reshape(x.shape[0],3,imsz).transpose(1,2)indices = darkvec.argsort(1)indices = indices[:,imsz-numpx:imsz]atmsum = torch.zeros([x.shape[0],1,3]).cuda()# print(imvec[:,indices[0,0]].shape)for b in range(x.shape[0]):for ind in range(1,numpx):atmsum[b,:,:] = atmsum[b,:,:] + imvec[b,indices[b,ind],:]a = atmsum/numpxa = a.squeeze(1).unsqueeze(2).unsqueeze(3)return adef blur(self,x : torch.tensor):"""_summary_Args:x (torch.tensor): _description_Returns:_type_: _description_"""return self.gaussian_blur(x)def defog(self,x:torch.tensor ,latent_out : torch.tensor ,fog_gate : torch.tensor):"""Defogging module is used for removing the fog from the image using ASM(Atmospheric Scattering Model).I(X) = (1-T(X)) * J(X) + T(X) * A(X)I(X) => image containing the fog.T(X) => Transmission map of the image.J(X) => True image Radiance.A(X) => Atmospheric scattering factor.Args:x (torch.tensor): Input image I(X)latent_out (torch.tensor): Feature representation from DIP Module.fog_gate (torch.tensor): Gate value raning from (0. - 1.) which enables defog module.Returns:torch.tensor : Returns defogged image with true image radiance."""omega = self.defogging_module(latent_out).unsqueeze(2).unsqueeze(3)omega = self.tanh_range(omega,torch.tensor(0.1),torch.tensor(1.))dark_i = self.dark_channel(x)a = self.atmospheric_light(x,dark_i)i = x/ai = self.dark_channel(i)t = 1. - (omega*i)j = ((x-a)/(torch.maximum(t,torch.tensor(0.01))))+aj = (j - j.min())/(j.max()-j.min())# j = j* fog_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)return jdef white_balance(self,x : torch.tensor,latent_out : torch.tensor ,wb_gate: torch.tensor):""" White balance of the image is predicted using latent output of an encoder.Args:x (torch.tensor): Input RGB image.latent_out (torch.tensor): Output from the last layer of an encoder.wb_gate (torch.tensor): White-balance gate used to change the influence of color scaled image.Returns:torch.tensor: returns White-Balanced image."""log_wb_range = 0.5wb = self.wb_module(latent_out)wb = torch.exp(self.tanh_range(wb,-log_wb_range,log_wb_range))color_scaling = 1./(1e-5 + 0.27 * wb[:, 0] + 0.67 * wb[:, 1] +0.06 * wb[:, 2])wb = color_scaling.unsqueeze(1)*wbwb_out = wb.unsqueeze(2).unsqueeze(3)*xwb_out = (wb_out-wb_out.min())/(wb_out.max()-wb_out.min())# wb_out = wb_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)*wb_outreturn wb_outdef tanh01(self,x : torch.tensor):"""_summary_Args:x (torch.tensor): _description_Returns:_type_: _description_"""return torch.tanh(x)*0.5+0.5def tanh_range(self,x : torch.tensor,left : float,right : float):"""_summary_Args:x (torch.tensor): _description_left (float): _description_right (float): _description_Returns:_type_: _description_"""return self.tanh01(x)*(right-left)+ leftdef gamma_balance(self,x : torch.tensor,latent_out : torch.tensor,gamma_gate : torch.tensor):"""_summary_Args:x (torch.tensor): _description_latent_out (torch.tensor): _description_gamma_gate (torch.tensor): _description_Returns:_type_: _description_"""log_gamma = torch.log(torch.tensor(2.5))gamma = self.gamma_module(latent_out).unsqueeze(2).unsqueeze(3)gamma = torch.exp(self.tanh_range(gamma,-log_gamma,log_gamma))g = torch.pow(torch.maximum(x,torch.tensor(1e-4)),gamma)g = (g-g.min())/(g.max()-g.min())# g = g*gamma_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)return gdef sharpning(self,x : torch.tensor,latent_out: torch.tensor,sharpning_gate : torch.tensor):"""_summary_Args:x (torch.tensor): _description_latent_out (torch.tensor): _description_sharpning_gate (torch.tensor): _description_Returns:_type_: _description_"""out_x = self.blur(x)y = self.sharpning_module(latent_out).unsqueeze(2).unsqueeze(3)y = self.tanh_range(y,torch.tensor(0.1),torch.tensor(1.))s = x + (y*(x-out_x))s = (s-s.min())/(s.max()-s.min())# s = s * (sharpning_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3))return sdef identity(self,x : torch.tensor,identity_gate : torch.tensor):"""_summary_Args:x (torch.tensor): _description_identity_gate (torch.tensor): _description_Returns:_type_: _description_"""# x = x*identity_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)return xdef contrast(self,x : torch.tensor,latent_out : torch.tensor,contrast_gate : torch.tensor):"""_summary_Args:x (torch.tensor): _description_latent_out (torch.tensor): _description_contrast_gate (torch.tensor): _description_Returns:_type_: _description_"""alpha = torch.tanh(self.contrast_module(latent_out))luminance = torch.minimum(torch.maximum(self.rgb2lum(x), torch.tensor(0.0)), torch.tensor(1.0)).unsqueeze(1)contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5contrast_image = x / (luminance + 1e-6) * contrast_lumcontrast_image = self.lerp(x, contrast_image, alpha)contrast_image = (contrast_image-contrast_image.min())/(contrast_image.max()-contrast_image.min())# contrast_image = contrast_image * contrast_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)return contrast_imagedef tone(self,x : torch.tensor,latent_out : torch.tensor,tone_gate : torch.tensor):"""_summary_Args:x (torch.tensor): _description_latent_out (torch.tensor): _description_tone_gate (torch.tensor): _description_Returns:_type_: _description_"""curve_steps = 8tone_curve = self.tone_module(latent_out).reshape(-1,1,curve_steps)tone_curve = self.tanh_range(tone_curve,0.5, 2)tone_curve_sum = torch.sum(tone_curve, dim=2) + 1e-30total_image = x * 0for i in range(curve_steps):total_image += torch.clamp(x - 1.0 * i /curve_steps, 0, 1.0 /curve_steps) \* tone_curve[:,:,i].unsqueeze(2).unsqueeze(3)total_image *= curve_steps / tone_curve_sum.unsqueeze(2).unsqueeze(3)total_image = (total_image-total_image.min())/(total_image.max()-total_image.min())# total_image = total_image * tone_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)return total_imagedef forward(self, x : torch.Tensor):"""_summary_Args:x (torch.Tensor): _description_Returns:_type_: _description_"""# latent_out = torch.nn.functional.relu_(self.encoder(x))latent_out = self.encoder(x)gate = self.tanh_range(self.gate_module(latent_out),0.01,1.0)out_idx = gate.argmax(dim=1)if out_idx == 0:wb_out = self.white_balance(x,latent_out,gate[:,0])return wb_out,gateelif out_idx == 1:gamma_out = self.gamma_balance(x,latent_out,gate[:,1])return gamma_out,gateelif out_idx == 2:identity_out = self.identity(x,gate[:,2])return identity_out,gateelif out_idx == 3:sharpning_out = self.sharpning(x,latent_out,gate[:,3])return sharpning_out, gateelif out_idx == 4:fog_out = self.defog(x,latent_out,gate[:,4])return fog_out,gateelif out_idx == 5:contrast_out = self.contrast(x,latent_out,gate[:,5])return contrast_out, gateelse:tone_out = self.tone(x,latent_out,gate[:,6])return tone_out,gateif __name__ == '__main__':batch_size = 2encoder_out_dim = 256x = torch.randn(batch_size,3,448,448)x = (x-x.min())/(x.max()-x.min())model = GatedDIP(encoder_output_dim = encoder_out_dim)print(model)out,gate= model(x)print('out shape:',out.shape)print('gate shape:',gate.shape)

作者论文中所提出的视觉编码器实现如下所示:

import torch class VisionEncoder(torch.nn.Module):def __init__(self,encoder_output_dim=256):super(VisionEncoder,self).__init__()# conv_1self.conv_1 = torch.nn.Sequential(torch.nn.Conv2d(3,64,kernel_size = 3 , stride = 1),torch.nn.ReLU(True))self.max_pool_1 = torch.nn.AvgPool2d((3,3),(2,2))# conv_2self.conv_2 = torch.nn.Sequential(torch.nn.Conv2d(64,128,kernel_size = 3 , stride = 1),torch.nn.ReLU(True))self.max_pool_2 = torch.nn.AvgPool2d((3,3),(2,2))# conv_3self.conv_3 = torch.nn.Sequential(torch.nn.Conv2d(128,256,kernel_size = 3 , stride = 1),torch.nn.ReLU(True))self.max_pool_3 = torch.nn.AvgPool2d((3,3),(2,2))# conv_4self.conv_4 = torch.nn.Sequential(torch.nn.Conv2d(256,512,kernel_size = 3 , stride = 1),torch.nn.ReLU(True))self.max_pool_4 = torch.nn.AvgPool2d((3,3),(2,2))# conv_5self.conv_5 = torch.nn.Sequential(torch.nn.Conv2d(512,1024,kernel_size = 3 , stride = 1),torch.nn.ReLU(True))self.adp_pool_5 = torch.nn.AdaptiveAvgPool2d((1,1))self.linear_proj_5 = torch.nn.Sequential(torch.nn.Linear(1024,encoder_output_dim),torch.nn.ReLU(True))def forward(self,x):out_x = self.conv_1(x)max_pool_1 = self.max_pool_1(out_x)out_x = self.conv_2(max_pool_1)max_pool_2 = self.max_pool_2(out_x)out_x = self.conv_3(max_pool_2)max_pool_3 = self.max_pool_3(out_x)out_x = self.conv_4(max_pool_3)max_pool_4 = self.max_pool_4(out_x)out_x = self.conv_5(max_pool_4)adp_pool_5 = self.adp_pool_5(out_x)linear_proj_5 = self.linear_proj_5(adp_pool_5.view(adp_pool_5.shape[0],-1))return linear_proj_5if __name__ == '__main__':img = torch.randn(4,3,448,448).cuda()encoder = VisionEncoder(encoder_output_dim=256).cuda()print('output shape:',encoder(img).shape) # output should be [4,256]

gdip模块的用法如下所示,可以看到兼容原始yolo框架代码,但多返回了一个增强后的图像与各个DIP操作的权重。

import torch 
from model.gdip_model import GatedDIP
from model.yolov3 import Yolov3class Yolov3GatedDIP(torch.nn.Module):def __init__(self):super(Yolov3GatedDIP,self).__init__()self.gated_dip = GatedDIP(256)self.yolov3 = Yolov3()#self.yolov3.load_darknet_weights(weights_path)def forward(self,x):out_x,gates = self.gated_dip(x)p,p_d = self.yolov3(out_x)return out_x,gates,p,p_d

3.2 MultiLevelGDIP

代码在model\mgdip.py中。在MultiLevelGDIP中又单独实现了gdip,这里可以看到gdip-yolo项目代码比较混乱。在这里的GatedDIP中,没有内置视觉编码器,而是在MultiLevelGDIP中内置视觉编码器,将GatedDIP作为MultiLevelGDIP中的一个部件。同时,mgdip相关的VisionEncoder与gdip中的返回值不一样,为了实现多尺度VisionEncoder返回的是一个dict,其中包含了各个尺度的特征图。

class GatedDIP(torch.nn.Module):'''这里删除了与上一份代码类似的部分'''def forward(self,x,linear_proj):gate = self.tanh_range(self.gate_module(linear_proj),0.01,1.0)wb_out = self.white_balance(x,linear_proj,gate[:,0])gamma_out = self.gamma_balance(x,linear_proj,gate[:,1])identity_out = self.identity(x,gate[:,2])sharpning_out = self.sharpning(x,linear_proj,gate[:,3])fog_out = self.defog(x,linear_proj,gate[:,4])contrast_out = self.contrast(x,linear_proj,gate[:,5])tone_out = self.tone(x,linear_proj,gate[:,6])x = wb_out + gamma_out   + fog_out + sharpning_out + contrast_out + tone_out + identity_outx = (x-x.min())/(x.max()-x.min())return x,gateclass MultiLevelGDIP(torch.nn.Module):def __init__(self,encoder_output_dim : int = 256,num_of_gates : int = 7):super(MultiLevelGDIP,self).__init__()self.vision_encoder = VisionEncoder(encoder_output_dim,base_channel=32)self.gdip1 = GatedDIP(encoder_output_dim,num_of_gates)self.gdip2 = GatedDIP(encoder_output_dim,num_of_gates)self.gdip3 = GatedDIP(encoder_output_dim,num_of_gates)self.gdip4 = GatedDIP(encoder_output_dim,num_of_gates)self.gdip5 = GatedDIP(encoder_output_dim,num_of_gates)self.gdip6 = GatedDIP(encoder_output_dim,num_of_gates)def forward(self, x : torch.Tensor):"""_summary_Args:x (torch.Tensor): _description_Returns:_type_: _description_"""out_image = list()gates_list = list()output_dict = self.vision_encoder(x)x,gate_6 = self.gdip6(x,output_dict['linear_proj_6'])out_image.append(x)gates_list.append(gate_6)x,gate_5 = self.gdip5(x,output_dict['linear_proj_5'])out_image.append(x)gates_list.append(gate_5)x,gate_4 = self.gdip4(x,output_dict['linear_proj_4'])out_image.append(x)gates_list.append(gate_4)x,gate_3 = self.gdip3(x,output_dict['linear_proj_3'])out_image.append(x)gates_list.append(gate_3)x,gate_2 = self.gdip2(x,output_dict['linear_proj_2'])out_image.append(x)gates_list.append(gate_2)x,gate_1 = self.gdip1(x,output_dict['linear_proj_1'])out_image.append(x)gates_list.append(gate_1)return x,out_image,gates_list

其使用代码如下所示:

import torch 
from model.mgdip import MultiLevelGDIP
from model.yolov3 import Yolov3class Yolov3MGatedDIP(torch.nn.Module):def __init__(self):super(Yolov3MGatedDIP,self).__init__()self.mgdip = MultiLevelGDIP(256,7)self.yolov3 = Yolov3()def forward(self,x):out_x,_,gates_list = self.mgdip(x)p,p_d = self.yolov3(out_x)return out_x,gates_list,p,p_d

3.3 GDIP regularizer

代码在model\yolov3_multilevel_gdip.py,这里是mgdip regularizer模块。同样与上一份代码中的MultiLevelGDIP有所差别,这里的MultiLevelGDIP没有内置视觉编码器,而是获取Yolov3 backbone的3个尺度的输出+ 原始输入作为特征图输入MultiLevelGDIP(即使用Yolov3 backbone作为特征提取器)。其关键代码如下所示,同时发现MultiLevelGDIP没有使用预训练模型,而论文中提到MultiLevelGDIP可以作为正则化器。正则化器的loss通过训练后使Yolov3 backbone提取的特征与MultiLevelGDIP提取的一样。在这里MultiLevelGDIP参与训练,但其输出的值又不参与前向传播。

class Yolov3(nn.Module):"""Note : int the __init__(), to define the modules should be in order, because of the weight file is order"""def __init__(self, cfg, init_weights=True):super(Yolov3, self).__init__()self.__anchors = torch.FloatTensor(cfg.MODEL["ANCHORS"])self.__strides = torch.FloatTensor(cfg.MODEL["STRIDES"])self.__nC = cfg.DATA["NUM"]self.__out_channel = cfg.MODEL["ANCHORS_PER_SCLAE"] * (self.__nC + 5)self.__backnone = Darknet53()self.__fpn = FPN_YOLOV3(fileters_in=[1024, 512, 256],fileters_out=[self.__out_channel, self.__out_channel, self.__out_channel])# smallself.__head_s = Yolo_head(nC=self.__nC, anchors=self.__anchors[0], stride=self.__strides[0])# mediumself.__head_m = Yolo_head(nC=self.__nC, anchors=self.__anchors[1], stride=self.__strides[1])# largeself.__head_l = Yolo_head(nC=self.__nC, anchors=self.__anchors[2], stride=self.__strides[2])# multilevel gdipself.__multilevel_gdip = MultiLevelGDIP()if init_weights:self.__init_weights()def forward(self, x):out = []x_s, x_m, x_l = self.__backnone(x)out_x,img_list,gates_list = self.__multilevel_gdip(x,x_s,x_m,x_l)x_s, x_m, x_l = self.__fpn(x_l, x_m, x_s)out.append(self.__head_s(x_s))out.append(self.__head_m(x_m))out.append(self.__head_l(x_l))if self.training:p, p_d = list(zip(*out))return out_x,gates_list[-1],p, p_d  # smalll, medium, largeelse:p, p_d = list(zip(*out))return out_x,gates_list[-1],p, torch.cat(p_d, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/640303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第15届蓝桥杯嵌入式省赛准备第三天总结笔记(使用STM32cubeMX创建hal库工程+串口接收发送)

因为我是自己搞得板子,原本的下程序和串口1有问题,所以我用的是串口2,用的PA2和PA3 一,使用CubeMX配置串口 选择A开头的这个是异步通信。 配置串口参数,往届的题基本用的9600波特率,所以我这里设置为9600…

C++——结构体

1,结构体基本概念 结构体属于用户自定义的数据类型,允许用户存储不同的数据类型。像int(整型),浮点型,bool型,字符串型等都是属于系统内置的数据类型。而今天要学习的结构体则是属于我们自定义…

导出 MySQL 数据库表结构、数据字典word设计文档

一、第一种 :利用sql语句查询 需要说明的是该方法应该适用很多工具,博主用的是navicat SELECT TABLE_NAME 表名,( i : i 1 ) AS 序号,COLUMN_NAME 列名, COLUMN_TYPE 数据类型, DATA_TYPE 字段类型, CHARACTER_MAXIMUM_LENGTH 长度, IS_NULLABLE…

Linux: dev: glibc: 里面有很多的关于系统调用的函数

其实都没有实体源代码klogctl.c,而是通过编译时构造出来的源代码实体,比如klogctl这个函数,glibc的反汇编如下: 直接是0x67这个系统调用:103: Reading symbols from /usr/lib64/libc-2.28.so... (No debugg…

【C++】入门(一)

前言&#xff1a; 本篇博客将带大家认识C&#xff0c;熟悉基本语法 文章目录 认识CC的诞生与发展C 在行业中的运用 一、命名空间1.1 命名空间的定义1.2 命名空间的使用1.3 命名空间的访问 二、C输入&输出输出操作符 <<输入操作符 >>换行符和刷新输出缓冲区关键…

跟着我学Python进阶篇:03. 面向对象(下)

往期文章 跟着我学Python基础篇&#xff1a;01.初露端倪 跟着我学Python基础篇&#xff1a;02.数字与字符串编程 跟着我学Python基础篇&#xff1a;03.选择结构 跟着我学Python基础篇&#xff1a;04.循环 跟着我学Python基础篇&#xff1a;05.函数 跟着我学Python基础篇&#…

java实现将集合数据导入excel表格之竖向

这里使用的是apache.poi&#xff0c;当然如果使用easyExcel也可以实现 括号是EasyExcel官网&#xff08;EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel&#xff09; 找到填充&#xff0c;都会有相对应的demo实列 下面是我写的apache.poi实现代码 public static vo…

初学python系列: pandas操作excel

媳妇工作中经常用到excel处理&#xff0c;想用python处理excel更高效&#xff0c;所以自学了python&#xff0c;觉得python比Java还是简单多了&#xff0c;没有变量类型声明&#xff0c;比Java也就多了元组&#xff0c;各种库很丰富。 需求是&#xff1a; 汇总两个excel中 列&…

2024年 复习 HTML5+CSS3+移动web 笔记 之CSS遍

28-第三天课程介绍_哔哩哔哩_bilibili CSS 第一天 1.1 引入方式 1.2 选择器 1.3 画盒子 div 独占一行 1.4 文字控制 div 外层 1.5 调试工具 F12 1.5 综合案例 一 新闻详情 1.6 综合案例 二 CSS 简介 CSS 第二天 2.1 复合选择器 2.2 伪类选择器 2.3 CSS 特性 2.4 Emmet 写法 2.5…

JavaEE 网络原理

JavaEE 网络原理 文章目录 JavaEE 网络原理1. 网络互连1.1 局域网LAN1.2 广域网WAN 2. 网络通信基础2.1 IP地址2.2 端口号 3. 网络协议3.1 概念3.2 五元组3.3 协议分层3.4 TCP/IP 五层模型3.5 封装和分用 1. 网络互连 随着时代的发展&#xff0c;需要多个计算机协同工作来完成…

Pix2Pix理论与实战

本文为&#x1f517;365天深度学习训练营 中的学习记录博客 原作者&#xff1a;K同学啊|接辅导、项目定制 我的环境&#xff1a; 1.语言&#xff1a;python3.7 2.编译器&#xff1a;pycharm 3.深度学习框架Pytorch 1.8.0cu111 一、引入 在之前的学习中&#xff0c;我们知道…

Haar小波下采样模块

论文原址&#xff1a;Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect 原文代码&#xff1a;HWD/HWD.py at main apple1986/HWD (github.com) 介绍 深度卷积神经网络 &#xff08;DCNN&#xff09; 通…

springboot中一些注解

springboot中一些注解 1:项目启动时会去扫描启动的注解&#xff0c;一般是启动时就想要被加载的方法&#xff1a; 2:springBoot中MSApplication启动类的一些其他注解&#xff1a; EnableAsync&#xff1a;这是一个Spring框架的注解&#xff0c;它用于开启方法异步调用的功能。当…

2017年认证杯SPSSPRO杯数学建模B题(第一阶段)岁月的印记全过程文档及程序

2017年认证杯SPSSPRO杯数学建模 跨年龄人脸识别模型的建立与分析 B题 岁月的印记 原题再现&#xff1a; 对同一个人来说&#xff0c;如果没有过改变面容的疾病、面部外伤或外科手术等经历&#xff0c;年轻和年老时的面容总有很大的相似性。人们在生活中也往往能够分辨出来两…

2.【SpringBoot3】用户模块接口开发

文章目录 开发模式和环境搭建开发模式环境搭建 1. 用户注册1.1 注册接口基本代码编写1.2 注册接口参数校验 2. 用户登录2.1 登录接口基本代码编写2.2 登录认证2.2.1 登录认证引入2.2.2 JWT 简介2.2.3 登录功能集成 JWT2.2.4 拦截器 3. 获取用户详细信息3.1 获取用户详细信息基本…

一周时间,开发了一款封面图生成工具

介绍 这是一款封面图的制作工具&#xff0c;根据简单的配置即可生成一张好看的封面图&#xff0c;目前已有七款主题可以选择。做这个工具的初衷来自平时写文章&#xff0c;都为封面图发愁&#xff0c;去图片 网站上搜索很难找到满意的&#xff0c;而且当你要的图如果要搭配上文…

【JavaEE进阶】 关于⽇志框架(SLF4J)

文章目录 &#x1f333;SLF4j&#x1f332;⻔⾯模式(外观模式)&#x1f6a9;⻔⾯模式的定义&#x1f6a9;⻔⾯模式的优点 &#x1f343;关于SLF4J框架&#x1f6a9;不引⼊⽇志⻔⾯&#x1f6a9;引⼊⽇志⻔⾯ ⭕总结 &#x1f333;SLF4j SLF4J不同于其他⽇志框架,它不是⼀个真正…

构建高效外卖系统:技术实践与代码示例

外卖系统在现代社会中扮演着重要的角色&#xff0c;为用户提供了便捷的用餐解决方案。在这篇文章中&#xff0c;我们将探讨构建高效外卖系统的技术实践&#xff0c;同时提供一些基础的代码示例&#xff0c;帮助开发者更好地理解和应用这些技术。 1. 技术栈选择 构建外卖系统…

BP蓝图映射到C++笔记1

教程链接&#xff1a;示例1&#xff1a;CompleteQuest - 将蓝图转换为C (epicgames.com) 1.常用的引用需要记住&#xff0c;如图所示。 2.蓝图中可以调用C函数&#xff0c;也可以实现C函数 BlueprintImplementableEvent:C只创建&#xff0c;不实现&#xff0c;在蓝图中实现 B…

C++提高编程---模板---类模板

目录 一、类模板 1.模板 2.类模板的作用 3.语法 4.声明 二、类模板和函数模板的区别 三、类模板中成员函数的创建时机 四、类模板对象做函数参数 五、类模板与继承 六、类模板成员函数类外实现 七、类模板分文件编写 八、类模板与友元 九、类模板案例 一、类模板 …