BasicSR项目(通用图像超分、修复、增强工具库)介绍

项目地址:https://github.com/XPixelGroup/BasicSR
文档地址:https://github.com/XPixelGroup/BasicSR-docs/releases
在这里插入图片描述

BasicSR 是一个开源项目,旨在提供一个方便易用的图像、视频的超分、复原、增强的工具箱。BasicSR 代码库从2018年4月20日开始第一个提交,然后随着做研究、打比赛、发论文,逐渐发展与完善起来。它从最开始的针对超分辨率算法到后来拓展到其他更多复原增强相关的算法,
因此,BasicSR 中 SR 的涵义也从 Super-Resolution 延拓到 Super-Restoration。2022年5月9日,BasicSR 迎来新的里程碑,它加入到 XPixel 大家庭中,和更多的小伙伴们一起致力于把 BasicSR 建设得更好!

1、基本说明

1.1 支持的模型

在BasicSR项目中支持的模型如下所示,虽然数量不多,但是可以轻易的添加其他模型结构到BasicSR项目中。很多最新的图像超分论文也是基于BasicSR项目完成,但代码更新没有同步到BasicSR库中。

1.2 运行环境要求

Python 和 Python 库 (对于 Python 库,我们提供了相应的安装脚本):
a) Python >= 3.7 (推荐使用Anaconda或者Miniconda)
b) PyTorch >= 1.7:目前深度学习领域广泛使用的深度学习框架

1.3 项目安装

打开https://github.com/XPixelGroup/BasicSR?tab=readme-ov-file,下载项目,然后在命令行里执行: pip install -e .
在这里插入图片描述
在上图可以看到,特殊原因导致安装失败。参考:https://blog.csdn.net/weixin_46455141/article/details/131353266 ,执行命令,pip config set global.index-url https://mirrors.aliyun.com/pypi/simple 更换源,然后重新执行安装命令 pip install -e . 可以看到成功安装。在这里插入图片描述

1.4 特殊算子支持

通过1.3方式安装的库不支持DCN(可变形卷积)、StyleGAN 中的特定的算子,比如:upfirdn2d, fused_act。安装时需要附加额外信息,支持可变形卷积。若无特殊需求,可以忽略。
需编译特殊算子的安装命令如下,可以看到是多了一个环境变量参数BASICSR_EXT=True

BASICSR_EXT=True pip install -e .

作者也提到,可以在执行代码时加入BASICSR_JIT=True参数,即时加加加载载载 (JIT) PyTorch C++ 编译算子

BASICSR_JIT=True python inference/inference_stylegan2.py

二者对比如下:
在这里插入图片描述

2、项目代码结构

2.1 基本结构

红色 表示和跑实验直接相关的文件,即我们平时打交道最多的文件;
蓝色 表示其他与 BasicSR 存在相关的代码文件;
通常只需要了解红色的部分即可。
在这里插入图片描述
basicsr目录下是该库的核心代码,其目录结构如下所示。可以看到关于模型有archs与models,archs才是模型网络结构与forward的定义。
在这里插入图片描述

2.2 models详情

models目录下详情如下,包含多种模型结构,主要以SRModel与SRRANModel为基类,最原始的基类是BaseModel。
在这里插入图片描述
快速过base_model.py,可以发现有函数model_ema,用于实现模型参数的指数更新。

对比SRModel与SRRANModel,可以发现SRRANModel多了一个net_d相关的参数(应该是鉴别器),对net_d进行检索,可以发现,SRRANModel,同样多了cri_gan(loss函数)的使用。与之对应,配置文件里当有相应的配置项。

        self.optimizer_d.zero_grad()# realreal_d_pred = self.net_d(self.gt)l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)loss_dict['l_d_real'] = l_d_realloss_dict['out_d_real'] = torch.mean(real_d_pred.detach())l_d_real.backward()# fakefake_d_pred = self.net_d(self.output.detach())l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)loss_dict['l_d_fake'] = l_d_fakeloss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())l_d_fake.backward()self.optimizer_d.step()

在观察RealESRNetModel与RealESRGANModel,可以发现其在feed_data函数上与原始基类不一样,代码含量极大。具体如下所示,可以看出其是包含了在线下采样策略。对其入参data进行分析,可以看到多了kernel1,kernel2,sinc_kernel等属性项,这表明这两类模型使用的dataload与原始基类模型不一样。

    @torch.no_grad()def feed_data(self, data):"""Accept data from dataloader, and then add two-order degradations to obtain LQ images."""if self.is_train and self.opt.get('high_order_degradation', True):# training data synthesisself.gt = data['gt'].to(self.device)# USM sharpen the GT imagesif self.opt['gt_usm'] is True:self.gt = self.usm_sharpener(self.gt)self.kernel1 = data['kernel1'].to(self.device)self.kernel2 = data['kernel2'].to(self.device)self.sinc_kernel = data['sinc_kernel'].to(self.device)ori_h, ori_w = self.gt.size()[2:4]# ----------------------- The first degradation process ----------------------- ## blurout = filter2D(self.gt, self.kernel1)# random resizeupdown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]if updown_type == 'up':scale = np.random.uniform(1, self.opt['resize_range'][1])elif updown_type == 'down':scale = np.random.uniform(self.opt['resize_range'][0], 1)else:scale = 1mode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, scale_factor=scale, mode=mode)# add noisegray_noise_prob = self.opt['gray_noise_prob']if np.random.uniform() < self.opt['gaussian_noise_prob']:out = random_add_gaussian_noise_pt(out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:out = random_add_poisson_noise_pt(out,scale_range=self.opt['poisson_scale_range'],gray_prob=gray_noise_prob,clip=True,rounds=False)# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifactsout = self.jpeger(out, quality=jpeg_p)# ----------------------- The second degradation process ----------------------- ## blurif np.random.uniform() < self.opt['second_blur_prob']:out = filter2D(out, self.kernel2)# random resizeupdown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]if updown_type == 'up':scale = np.random.uniform(1, self.opt['resize_range2'][1])elif updown_type == 'down':scale = np.random.uniform(self.opt['resize_range2'][0], 1)else:scale = 1mode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)# add noisegray_noise_prob = self.opt['gray_noise_prob2']if np.random.uniform() < self.opt['gaussian_noise_prob2']:out = random_add_gaussian_noise_pt(out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:out = random_add_poisson_noise_pt(out,scale_range=self.opt['poisson_scale_range2'],gray_prob=gray_noise_prob,clip=True,rounds=False)# JPEG compression + the final sinc filter# We also need to resize images to desired sizes. We group [resize back + sinc filter] together# as one operation.# We consider two orders:#   1. [resize back + sinc filter] + JPEG compression#   2. JPEG compression + [resize back + sinc filter]# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.if np.random.uniform() < 0.5:# resize back + the final sinc filtermode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)out = filter2D(out, self.sinc_kernel)# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])out = torch.clamp(out, 0, 1)out = self.jpeger(out, quality=jpeg_p)else:# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])out = torch.clamp(out, 0, 1)out = self.jpeger(out, quality=jpeg_p)# resize back + the final sinc filtermode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)out = filter2D(out, self.sinc_kernel)# clamp and roundself.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.# random cropgt_size = self.opt['gt_size']self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])# training pair poolself._dequeue_and_enqueue()self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contractelse:# for paired training or validationself.lq = data['lq'].to(self.device)if 'gt' in data:self.gt = data['gt'].to(self.device)self.gt_usm = self.usm_sharpener(self.gt)

2.3 archs详情

arch下是具体的超分模型或者是gan超分模型中的生成器
在这里插入图片描述
任意打开一个文件,如ridnet_arch.py,可以发现关键代码如下, 只要在模型类上添加@ARCH_REGISTRY.register()修饰,即可注册为BasicSR库中的模型。然后,模型只要能正常forward即可。

import torch
import torch.nn as nnfrom basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, make_layer
@ARCH_REGISTRY.register()
class RIDNet(nn.Module):def __init__(self,in_channels,mid_channels,out_channels,num_block=4,img_range=255.,rgb_mean=(0.4488, 0.4371, 0.4040),rgb_std=(1.0, 1.0, 1.0)):super(RIDNet, self).__init__()self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)self.body = make_layer(EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):res = self.sub_mean(x)res = self.tail(self.body(self.relu(self.head(res))))res = self.add_mean(res)out = x + resreturn out

在通过对basicsr\archs_init_.py进行分析,可以看到只会将以‘_arch.py’结尾的模型文件添加到库中。

import importlib
from copy import deepcopy
from os import path as ospfrom basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY__all__ = ['build_network']# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]def build_network(opt):opt = deepcopy(opt)network_type = opt.pop('type')net = ARCH_REGISTRY.get(network_type)(**opt)logger = get_root_logger()logger.info(f'Network [{net.__class__.__name__}] is created.')return net

2.4 losses详情

losses目录下的文件到比较少,但通过对__init__.py进行分析可以发现只会将以‘_loss.py’结尾的文件注册到系统中。

在这里插入图片描述
通过对basic_loss.py进行查看,只要5种loss。但需要注意的是,PerceptualLoss是感知损失,需要依赖vgg模型对y_true与y_pred进行推理然后计算中间层特征的差异。
在这里插入图片描述
WeightedTVLoss是一种不需要y_true的loss,其主要目的是使梯度信息最少,然后使超分后的数据更加平滑。其实现代码如下所示:

@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):def __init__(self, loss_weight=1.0, reduction='mean'):if reduction not in ['mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)def forward(self, pred, weight=None):if weight is None:y_weight = Nonex_weight = Noneelse:y_weight = weight[:, :, :-1, :]x_weight = weight[:, :, :, :-1]y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)loss = x_diff + y_diffreturn loss

CharbonnierLoss的核心是charbonnier_loss函数,其对charbonnier_loss输出的结果进行加权。可以看到charbonnier_loss与rmse loss类似,但有一个eps充当类似的正则化参数。

@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):return torch.sqrt((pred - target)**2 + eps)

2.5 dataloader详情

dataloader对应着basicsr\data目录下的文件,通过对__init__.py进行分析可以发现只会将以‘_dataset.py’结尾的文件注册到系统中。可以看到一共有7个dataset文件,表明其支持7种存储结构下的数据集。
在这里插入图片描述
FFHQDataset与SingleImageDataset 通过对代码中__getitem__函数分析,可以看到FFHQDataset与支持数据中只有一种图片。SingleImageDataset与FFHQDataset类似,也是支持只有一种图片的数据集。

@DATASET_REGISTRY.register()
class FFHQDataset(data.Dataset):def __getitem__(self, index):if self.file_client is None:self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)# load gt imagegt_path = self.paths[index]# avoid errors caused by high latency in reading filesretry = 3while retry > 0:try:img_bytes = self.file_client.get(gt_path)except Exception as e:logger = get_root_logger()logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')# change another file to readindex = random.randint(0, self.__len__())gt_path = self.paths[index]time.sleep(1)  # sleep 1s for occasional server congestionelse:breakfinally:retry -= 1img_gt = imfrombytes(img_bytes, float32=True)# random horizontal flipimg_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)# BGR to RGB, HWC to CHW, numpy to tensorimg_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)# normalizenormalize(img_gt, self.mean, self.std, inplace=True)return {'gt': img_gt, 'gt_path': gt_path}

使用代码如下,可以看到支持一种ffhq_256.lmdb的数据结构。

datasets:train:name: FFHQtype: FFHQDatasetdataroot_gt: datasets/ffhq/ffhq_256.lmdbio_backend:type: lmdbuse_hflip: truemean: [0.5, 0.5, 0.5]std: [0.5, 0.5, 0.5]

PairedImageDataset 通过对代码中__getitem__函数分析,可以看到PairedImageDataset支持数据是需要gt_path与lq_path。使用代码如下,可以看到需要配置dataroot_gt与dataroot_lq;meta_info_file与filename_tmpl只是可选参数。

datasets:train:name: DIV2Ktype: PairedImageDatasetdataroot_gt: datasets/DF2K/DIV2K_train_HR_subdataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_submeta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt# (for lmdb)# dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb# dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdbfilename_tmpl: '{}'io_backend:type: disk# (for lmdb)# type: lmdbgt_size: 128use_hflip: trueuse_rot: true# data loadernum_worker_per_gpu: 6batch_size_per_gpu: 16dataset_enlarge_ratio: 100prefetch_mode: ~val:name: Set5type: PairedImageDatasetdataroot_gt: datasets/Set5/GTmod12dataroot_lq: datasets/Set5/LRbicx4io_backend:type: disk

其返回数据结构为

{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

RealESRGANDataset 是用于RealESRGAN模型的数据加载器。格式数据的使用如下所示,可以看到需要设置各种数据在线下采样的参数配置。但最为关键的是dataroot_gt与meta_info,但从使用中可以看出在meta_info的txt中只是存储了图像的相对文件名。

# Each line in the meta_info describes the relative path to an imagewith open(self.opt['meta_info']) as fin:paths = [line.strip().split(' ')[0] for line in fin]self.paths = [os.path.join(self.gt_folder, v) for v in paths]
datasets:train:name: DF2K+OSTtype: RealESRGANDatasetdataroot_gt: datasets/DF2Kmeta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txtio_backend:type: diskblur_kernel_size: 21kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob: 0.1blur_sigma: [0.2, 3]betag_range: [0.5, 4]betap_range: [1, 2]blur_kernel_size2: 21kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob2: 0.1blur_sigma2: [0.2, 1.5]betag_range2: [0.5, 4]betap_range2: [1, 2]final_sinc_prob: 0.8gt_size: 256use_hflip: Trueuse_rot: False

其返回数据结构为

{'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}

RealESRGANPairedDataset 也是用于RealESRGAN的数据加载器,但其参数结构、返回数据结构与PairedImageDataset是一样的。返回结构为:

{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

**VideoTestDataset ** 是针对图像序列(视频)的数据加载,其同样需要设置’dataroot_gt’,‘dataroot_lq’,每一个路径要求的数据格式为如下所示。

        dataroot├── subfolder1├── frame000├── frame001├── ...├── subfolder2├── frame000├── frame001├── ...├── ...

返回的数据格式为:

return {'lq': imgs_lq,  # (t, c, h, w)'gt': img_gt,  # (c, h, w)'folder': folder,  # folder name'idx': self.data_info['idx'][index],  # e.g., 0/99'border': border,  # 1 for border, 0 for non-border'lq_path': lq_path  # center frame}

REDSDatasetVimeo90KDataset 是特定数据的加载方法

2.6 metrics详情

metrics目录下的是评价指标,目前通__init__.py文件可以看到只支持’calculate_psnr’, ‘calculate_ssim’, ‘calculate_niqe’ 三种。其中niqe是一种无参考的评价指标,我们可以将自行将其他

对应配置文件中的使用代码如下:

metrics:psnr: # metric name, can be arbitrarytype: calculate_psnrcrop_border: 0test_y_channel: falsessim:type: calculate_ssimcrop_border: 0test_y_channel: falseniqe:type: calculate_niqecrop_border: 4num_thread: 8

2.7 options详情(配置文件)

options目录与核心代码目录平级,在项目根路径下。主要包含train与test两个分支,里面存储的是对应的使用配置(含模型结构、数据加载器配置、loss配置等)。
以options\train\RealESRGAN\train_realesrgan_x2plus.yml为例

# general settings
name: train_RealESRGANx2plus_400k_B12G4
model_type: RealESRGANModel  #指定模型结构
scale: 2
num_gpu: auto  # auto: can infer from your visible devices automatically. official: 4 GPUs  #指定gpu数量
manual_seed: 0# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
# USM the ground-truth
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False# the first degradation process
resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]# the second degradation process
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]gt_size: 256
queue_size: 180# dataset and data loader settings
datasets:train:name: DF2K+OSTtype: RealESRGANDatasetdataroot_gt: datasets/DF2Kmeta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txtio_backend:type: diskblur_kernel_size: 21kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob: 0.1blur_sigma: [0.2, 3]betag_range: [0.5, 4]betap_range: [1, 2]blur_kernel_size2: 21kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob2: 0.1blur_sigma2: [0.2, 1.5]betag_range2: [0.5, 4]betap_range2: [1, 2]final_sinc_prob: 0.8gt_size: 256use_hflip: Trueuse_rot: False# data loadernum_worker_per_gpu: 5batch_size_per_gpu: 12dataset_enlarge_ratio: 1prefetch_mode: ~# Uncomment these for validation# val:#   name: validation#   type: PairedImageDataset#   dataroot_gt: path_to_gt#   dataroot_lq: path_to_lq#   io_backend:#     type: disk# network structures
network_g:type: RRDBNetnum_in_ch: 3num_out_ch: 3num_feat: 64num_block: 23num_grow_ch: 32scale: 2network_d:type: UNetDiscriminatorSNnum_in_ch: 3num_feat: 64skip_connection: True# path
path:# use the pre-trained Real-ESRNet modelpretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pthparam_key_g: params_emastrict_load_g: trueresume_state: ~# training settings
train:ema_decay: 0.999optim_g:type: Adamlr: !!float 1e-4weight_decay: 0betas: [0.9, 0.99]optim_d:type: Adamlr: !!float 1e-4weight_decay: 0betas: [0.9, 0.99]scheduler:type: MultiStepLRmilestones: [400000]gamma: 0.5total_iter: 400000warmup_iter: -1  # no warm up# lossespixel_opt:type: L1Lossloss_weight: 1.0reduction: mean# perceptual loss (content and style losses)perceptual_opt:type: PerceptualLosslayer_weights:# before relu'conv1_2': 0.1'conv2_2': 0.1'conv3_4': 1'conv4_4': 1'conv5_4': 1vgg_type: vgg19use_input_norm: trueperceptual_weight: !!float 1.0style_weight: 0range_norm: falsecriterion: l1# gan lossgan_opt:type: GANLossgan_type: vanillareal_label_val: 1.0fake_label_val: 0.0loss_weight: !!float 1e-1net_d_iters: 1net_d_init_iters: 0# Uncomment these for validation
# validation settings
# val:
#   val_freq: !!float 5e3
#   save_img: True#   metrics:
#     psnr: # metric name
#       type: calculate_psnr
#       crop_border: 4
#       test_y_channel: false# logging settings
logger:print_freq: 100save_checkpoint_freq: !!float 5e3use_tb_logger: truewandb:project: ~resume_id: ~# dist training settings
dist_params:backend: ncclport: 29500

3、其他关键信息

3.1 训练与测试

训练命令如下,主要是-opt对应的yml文件,该文件即为2.7中对应的配置项

python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml

完整的命令行参数如下:

  • -opt,配置文件的路径,一般采用这个命令配置训练或者测试的 yml 文件。
  • – laucher,用于指定 distibuted training 的,比如 pytorch 或者 slurm。默认是 none,
    即单卡非 distributed training。
  • – auto_resume,是否自动 resume,即自动查找最近的 checkpoint ,然后 resume。
  • – debug,能够快速帮助 debug。
  • – local_rank,这个不用管,是 distributed training 中程序自动会传入。
  • – force_yml,方便在命令行中修改 yml 中的配置文件。

测试命令如下:

python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml

3.2 模型保存与训练状态恢复

训练的时候, checkpoints 会保存两个文件:

  1. 网 络 参 数 .pth 文 件 。 在 每 个 实 验 的 models 文 件 夹 中 , 文 件 名 诸
    如:net_g_5000.pth、net_g_10000.pth
  2. 包含 optimizer 和 scheduler 信息的 .state 文件。在每个实验的 training_states 文件夹中,
    文件名诸如:5000.state、10000.state
    根据这两个文件,就可以 resume 了。

对应的参数配置如下,对应pretrain_network_g与resume_state的配置

path:# use the pre-trained Real-ESRNet modelpretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pthparam_key_g: params_emastrict_load_g: trueresume_state: True # 默认为 ~, 表示删除该参数

也可以在命令行中加入 ‘–auto_resume’,程序就会找到保存的最近的模型
参数和状态,并加载进来,接着训练啦。

分布式训练 单机多卡 8 GPU训练命令

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch

单机多卡 4 GPU训练命令

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch

Slurm训练 4 GPU训练命令

GLOG_vmodule=MemcachedClient=-1 
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"

3.3 模型EMA

EMA (Exponential Moving Average),指数移动平均。 它是用来“平均”一个变量在历史上的值。使用怎样的权重平均呢?如名字所示,随着时间,越是过往的时间,以一个指数衰减的权重来平均。

在 BasicSR 里面,EMA 一般作用在模型的参数上。它的效果一般是:
• 稳定训练效果。GAN 训练的结果一般瑕疵更少,视觉效果更好
• 对于以 PSNR 为目的的模型,其 PSNR 一般会更高一些

由于开启 EMA的代价几乎可以不计,所以我们推荐开启 EMA。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/46776.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Memcached】Memcached的工作原理

目录 ​编辑 第2章&#xff1a;Memcached工作原理 2.1 数据存储与访问 2.2 分布式架构 2.3 数据过期机制 第2章&#xff1a;Memcached工作原理 2.1 数据存储与访问 Memcached是一种键值存储系统&#xff0c;其中数据以键值对的形式存储。键是用于定位数据的唯一标识符&am…

libyaml库的交叉编译

目录 1.Ubuntu环境中安装libyaml库 2.交叉编译 3.success 1.Ubuntu环境中安装libyaml库 官方地址&#xff1a;https://pyyaml.org/wiki/LibYAML 下载路径&#xff1a;http://pyyaml.org/download/libyaml/yaml-0.2.5.tar.gz 2.交叉编译 官方的下载路径为/usr/local下&am…

【unity实战】使用unity制作一个红点系统

前言 注意&#xff0c;本文是本人的学习笔记记录&#xff0c;这里先记录基本的代码&#xff0c;后面用到了再回来进行实现和整理 素材 https://assetstore.unity.com/packages/2d/gui/icons/2d-simple-ui-pack-218050 框架&#xff1a; RedPointSystem.cs using System.…

PHP全功能微信投票迷你平台系统小程序源码

&#x1f525;让决策变得超简单&#xff01;&#x1f389; &#x1f680;【一键创建&#xff0c;秒速启动】 嘿小伙伴们&#xff0c;你还在为组织投票而手忙脚乱吗&#xff1f;来试试这款全功能投票迷你微信小程序吧&#xff01;只需轻轻一点&#xff0c;无论是班级选举、社团…

【postgresql】pg_dump备份数据库

pg_dump 介绍 pg_dump 是一个用于备份 PostgreSQL 数据库的实用工具。它可以将数据库的内容导出为一个 SQL 脚本文件或其他格式的文件&#xff0c;以便在需要时进行恢复或迁移。 基本用法 pg_dump [选项] [数据库名] 命令选项 -h 或 --host&#xff1a;指定数据库服务器的主…

2024年大数据高频面试题(上篇)

文章目录 HDFS读流程和写流程HDFS读数据流程NameNode和Secondary NameNode工作机制FsimageEdits文件Seen_txidnamenode工作机制HA NameNode如何工作ZKFCHealthMonitorActiveStandbyElectorJouranlNode集群DataNode工作机制DataNode数据损坏压缩MapReduce工作流程MapTask工作流R…

Visual Studio远程调试工具

路径&#xff1a;Visual Studio安装路径/Common7/IDE/Remote Debugger/平台/msvsmon.exe。 平台有x86、x64&#xff0c;x64即可调试x86进程也可调试x64进程。 将平台路径下的所有文件拷贝至其他PC&#xff0c;运行msvsmon.exe。 工具栏选择“工具&#xff08;T&#xff09;”…

Ubuntu18.04安装ROS

1.添加ROS软件源 sudo sh -c echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.listcurl -s https://raw.githubusercontent.com/ros/rosdistro/master/ros.asc输入指令&#xff1a;curl -s https:…

Python中EMD的安装教程

第一步&#xff1a;首先安装两个包 pip install pyemd pip install EMD-signal第二步&#xff1a;然后&#xff0c;进行改名 安装完之后&#xff0c;找到包所在的位置&#xff0c;然后要将原来pyemd的文件夹名称改为PyEMD&#xff1a;

cleanshot Mac 上的截图工具

笔者闲来无事&#xff0c;最近在找一些mac上好用的工具其中一款就是cleanShot。为什么不用原有的mac自带的呢。因为相对来说编辑功能不算全面&#xff0c;不支持长截图。那有没有一款软件支持关于截图的好用工具呢。 所以笔者找了这款。安装包是直接安装就可使用的。请大家点赞…

Golang | Leetcode Golang题解之第237题删除链表中的节点

题目&#xff1a; 题解&#xff1a; func deleteNode(node *ListNode) {node.Val node.Next.Valnode.Next node.Next.Next }

Python中的UnboundLocalError是什么错误?如何解决?

如果代码报错UnboundLocalError, 大概率犯了以下错误&#xff1a; money 10000 # 当前存款def add_money(value):money valueif __name__ __main__:print(当前存款:, money)add_money(1000)print(当前存款:, money)其中&#xff0c;变量money表示当前存款&#xff1b;函数…

DialogFragment 开发手游sdk代替透明的activity

前言 各位同学大家好 有段时间没有给各位更新文章了,最近在写新的项目 之前的手游sdk 都是用透明的activity 效果有缺陷,现在我改成用这个dialogfragment 来实现 , 废话不多说我们正式开始 效果图 : 为什么要使用dialogfragment: 之前开发手游sdk的时候 我这边都是使用透…

【区块链 + 智慧政务】区块链 +ETC 下一代公路联网收费关键技术优化项目 | FISCO BCOS应用案例

2020 年&#xff0c;我国取消省界收费站项目完成后&#xff0c;随着收费模式与收费方式的变化&#xff0c;形成了以门架为计费单元的新收 费体系&#xff1a;按照车辆通行门架数&#xff0c;RSU 天线读取 ETC 卡、电子标签 OBU 或 CPC 卡内标识的车型信息&#xff0c;车型门架计…

ALlegro批量替换封装?

1&#xff0c;此种情况批量修改同名封装&#xff0c;即改前改后的封装名相同 2&#xff0c;首先将改好后的封装放于库路径下 3&#xff0c;place ----update symbols —package symbols ----选择修改的封装名 4&#xff0c;refresh 完成

开源PS2模拟器 PCSX2 2.0版发布 性能与功能全面升级

时隔多年之后&#xff0c;备受玩家喜爱的PS2模拟器PCSX2迎来了重大更新&#xff0c;2.0版本正式发布&#xff01;此次更新包含了大量改进&#xff0c;几乎涵盖了模拟器各个方面&#xff0c;为玩家带来更流畅、更便捷的游戏体验。 下载地址&#xff1a; https://pcsx2.net/ 界…

Hadoop-29 ZooKeeper集群 Watcher机制 工作原理 与 ZK基本命令 测试集群效果 3台公网云服务器

章节内容 上节我们完成了&#xff1a; ZNode的基本介绍ZNode节点类型的介绍事务ID的介绍ZNode实机测试效果 背景介绍 这里是三台公网云服务器&#xff0c;每台 2C4G&#xff0c;搭建一个Hadoop的学习环境&#xff0c;供我学习。 之前已经在 VM 虚拟机上搭建过一次&#xff…

品牌形象的智能塑造:Kompas.ai如何构建品牌视觉识别

品牌形象是企业在消费者心中构建的独特印象&#xff0c;它对于品牌识别和记忆度至关重要。一个一致且具有辨识度的品牌形象能够帮助企业在激烈的市场竞争中脱颖而出。Kompas.ai&#xff0c;作为一款智能设计工具&#xff0c;正帮助品牌塑造和维护其独特的视觉识别系统。 一致的…

JMeter进行HTTP接口测试的技术要点

参数化 用户定义的变量 用的时候 ${名字} 用户参数 在参数列表中传递 并且也是${} csv数据文件设置 false 不忽略首行 要首行 从第一行读取 true 忽略首行 从第二行开始 请求时的参数设置&#xff1a; 这里的名称是看其接口需要的请求参数的名称 这里的变量名称就是为csv里面…

帮助中心如何提高用户粘性和活跃度?

帮助中心&#xff08;Help Center&#xff09;是在产品网站或者产品内部设立的一个功能模块&#xff0c;用于将产品使用上遇到的问题&#xff0c;或者关于产品的所有问题进行汇总&#xff0c;并通过Q&A&#xff08;问题与解答&#xff09;的形式展现给用户&#xff0c;帮助…