项目地址:https://github.com/XPixelGroup/BasicSR
文档地址:https://github.com/XPixelGroup/BasicSR-docs/releases
BasicSR 是一个开源项目,旨在提供一个方便易用的图像、视频的超分、复原、增强的工具箱。BasicSR 代码库从2018年4月20日开始第一个提交,然后随着做研究、打比赛、发论文,逐渐发展与完善起来。它从最开始的针对超分辨率算法到后来拓展到其他更多复原增强相关的算法,
因此,BasicSR 中 SR 的涵义也从 Super-Resolution 延拓到 Super-Restoration。2022年5月9日,BasicSR 迎来新的里程碑,它加入到 XPixel 大家庭中,和更多的小伙伴们一起致力于把 BasicSR 建设得更好!
1、基本说明
1.1 支持的模型
在BasicSR项目中支持的模型如下所示,虽然数量不多,但是可以轻易的添加其他模型结构到BasicSR项目中。很多最新的图像超分论文也是基于BasicSR项目完成,但代码更新没有同步到BasicSR库中。
1.2 运行环境要求
Python 和 Python 库 (对于 Python 库,我们提供了相应的安装脚本):
a) Python >= 3.7 (推荐使用Anaconda或者Miniconda)
b) PyTorch >= 1.7:目前深度学习领域广泛使用的深度学习框架
1.3 项目安装
打开https://github.com/XPixelGroup/BasicSR?tab=readme-ov-file,下载项目,然后在命令行里执行: pip install -e .
在上图可以看到,特殊原因导致安装失败。参考:https://blog.csdn.net/weixin_46455141/article/details/131353266 ,执行命令,pip config set global.index-url https://mirrors.aliyun.com/pypi/simple
更换源,然后重新执行安装命令 pip install -e .
可以看到成功安装。
1.4 特殊算子支持
通过1.3方式安装的库不支持DCN(可变形卷积)、StyleGAN 中的特定的算子,比如:upfirdn2d, fused_act。安装时需要附加额外信息,支持可变形卷积。若无特殊需求,可以忽略。
需编译特殊算子的安装命令如下,可以看到是多了一个环境变量参数BASICSR_EXT=True
BASICSR_EXT=True pip install -e .
作者也提到,可以在执行代码时加入BASICSR_JIT=True参数,即时加加加载载载 (JIT) PyTorch C++ 编译算子
BASICSR_JIT=True python inference/inference_stylegan2.py
二者对比如下:
2、项目代码结构
2.1 基本结构
红色 表示和跑实验直接相关的文件,即我们平时打交道最多的文件;
蓝色 表示其他与 BasicSR 存在相关的代码文件;
通常只需要了解红色的部分即可。
basicsr目录下是该库的核心代码,其目录结构如下所示。可以看到关于模型有archs与models,archs才是模型网络结构与forward的定义。
2.2 models详情
models目录下详情如下,包含多种模型结构,主要以SRModel与SRRANModel为基类,最原始的基类是BaseModel。
快速过base_model.py,可以发现有函数model_ema,用于实现模型参数的指数更新。
对比SRModel与SRRANModel,可以发现SRRANModel多了一个net_d相关的参数(应该是鉴别器),对net_d进行检索,可以发现,SRRANModel,同样多了cri_gan(loss函数)的使用。与之对应,配置文件里当有相应的配置项。
self.optimizer_d.zero_grad()# realreal_d_pred = self.net_d(self.gt)l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)loss_dict['l_d_real'] = l_d_realloss_dict['out_d_real'] = torch.mean(real_d_pred.detach())l_d_real.backward()# fakefake_d_pred = self.net_d(self.output.detach())l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)loss_dict['l_d_fake'] = l_d_fakeloss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())l_d_fake.backward()self.optimizer_d.step()
在观察RealESRNetModel与RealESRGANModel,可以发现其在feed_data函数上与原始基类不一样,代码含量极大。具体如下所示,可以看出其是包含了在线下采样策略。对其入参data进行分析,可以看到多了kernel1
,kernel2
,sinc_kernel
等属性项,这表明这两类模型使用的dataload与原始基类模型不一样。
@torch.no_grad()def feed_data(self, data):"""Accept data from dataloader, and then add two-order degradations to obtain LQ images."""if self.is_train and self.opt.get('high_order_degradation', True):# training data synthesisself.gt = data['gt'].to(self.device)# USM sharpen the GT imagesif self.opt['gt_usm'] is True:self.gt = self.usm_sharpener(self.gt)self.kernel1 = data['kernel1'].to(self.device)self.kernel2 = data['kernel2'].to(self.device)self.sinc_kernel = data['sinc_kernel'].to(self.device)ori_h, ori_w = self.gt.size()[2:4]# ----------------------- The first degradation process ----------------------- ## blurout = filter2D(self.gt, self.kernel1)# random resizeupdown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]if updown_type == 'up':scale = np.random.uniform(1, self.opt['resize_range'][1])elif updown_type == 'down':scale = np.random.uniform(self.opt['resize_range'][0], 1)else:scale = 1mode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, scale_factor=scale, mode=mode)# add noisegray_noise_prob = self.opt['gray_noise_prob']if np.random.uniform() < self.opt['gaussian_noise_prob']:out = random_add_gaussian_noise_pt(out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:out = random_add_poisson_noise_pt(out,scale_range=self.opt['poisson_scale_range'],gray_prob=gray_noise_prob,clip=True,rounds=False)# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifactsout = self.jpeger(out, quality=jpeg_p)# ----------------------- The second degradation process ----------------------- ## blurif np.random.uniform() < self.opt['second_blur_prob']:out = filter2D(out, self.kernel2)# random resizeupdown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]if updown_type == 'up':scale = np.random.uniform(1, self.opt['resize_range2'][1])elif updown_type == 'down':scale = np.random.uniform(self.opt['resize_range2'][0], 1)else:scale = 1mode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)# add noisegray_noise_prob = self.opt['gray_noise_prob2']if np.random.uniform() < self.opt['gaussian_noise_prob2']:out = random_add_gaussian_noise_pt(out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)else:out = random_add_poisson_noise_pt(out,scale_range=self.opt['poisson_scale_range2'],gray_prob=gray_noise_prob,clip=True,rounds=False)# JPEG compression + the final sinc filter# We also need to resize images to desired sizes. We group [resize back + sinc filter] together# as one operation.# We consider two orders:# 1. [resize back + sinc filter] + JPEG compression# 2. JPEG compression + [resize back + sinc filter]# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.if np.random.uniform() < 0.5:# resize back + the final sinc filtermode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)out = filter2D(out, self.sinc_kernel)# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])out = torch.clamp(out, 0, 1)out = self.jpeger(out, quality=jpeg_p)else:# JPEG compressionjpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])out = torch.clamp(out, 0, 1)out = self.jpeger(out, quality=jpeg_p)# resize back + the final sinc filtermode = random.choice(['area', 'bilinear', 'bicubic'])out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)out = filter2D(out, self.sinc_kernel)# clamp and roundself.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.# random cropgt_size = self.opt['gt_size']self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])# training pair poolself._dequeue_and_enqueue()self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contractelse:# for paired training or validationself.lq = data['lq'].to(self.device)if 'gt' in data:self.gt = data['gt'].to(self.device)self.gt_usm = self.usm_sharpener(self.gt)
2.3 archs详情
arch下是具体的超分模型或者是gan超分模型中的生成器
任意打开一个文件,如ridnet_arch.py,可以发现关键代码如下, 只要在模型类上添加@ARCH_REGISTRY.register()
修饰,即可注册为BasicSR库中的模型。然后,模型只要能正常forward即可。
import torch
import torch.nn as nnfrom basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, make_layer
@ARCH_REGISTRY.register()
class RIDNet(nn.Module):def __init__(self,in_channels,mid_channels,out_channels,num_block=4,img_range=255.,rgb_mean=(0.4488, 0.4371, 0.4040),rgb_std=(1.0, 1.0, 1.0)):super(RIDNet, self).__init__()self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)self.body = make_layer(EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)self.relu = nn.ReLU(inplace=True)def forward(self, x):res = self.sub_mean(x)res = self.tail(self.body(self.relu(self.head(res))))res = self.add_mean(res)out = x + resreturn out
在通过对basicsr\archs_init_.py进行分析,可以看到只会将以‘_arch.py’结尾的模型文件添加到库中。
import importlib
from copy import deepcopy
from os import path as ospfrom basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY__all__ = ['build_network']# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]def build_network(opt):opt = deepcopy(opt)network_type = opt.pop('type')net = ARCH_REGISTRY.get(network_type)(**opt)logger = get_root_logger()logger.info(f'Network [{net.__class__.__name__}] is created.')return net
2.4 losses详情
losses目录下的文件到比较少,但通过对__init__.py进行分析可以发现只会将以‘_loss.py’结尾的文件注册到系统中。
通过对basic_loss.py进行查看,只要5种loss。但需要注意的是,PerceptualLoss是感知损失,需要依赖vgg模型对y_true与y_pred进行推理然后计算中间层特征的差异。
WeightedTVLoss是一种不需要y_true的loss,其主要目的是使梯度信息最少,然后使超分后的数据更加平滑。其实现代码如下所示:
@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):def __init__(self, loss_weight=1.0, reduction='mean'):if reduction not in ['mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)def forward(self, pred, weight=None):if weight is None:y_weight = Nonex_weight = Noneelse:y_weight = weight[:, :, :-1, :]x_weight = weight[:, :, :, :-1]y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)loss = x_diff + y_diffreturn loss
CharbonnierLoss的核心是charbonnier_loss函数,其对charbonnier_loss输出的结果进行加权。可以看到charbonnier_loss与rmse loss类似,但有一个eps充当类似的正则化参数。
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):return torch.sqrt((pred - target)**2 + eps)
2.5 dataloader详情
dataloader对应着basicsr\data目录下的文件,通过对__init__.py进行分析可以发现只会将以‘_dataset.py’结尾的文件注册到系统中。可以看到一共有7个dataset文件,表明其支持7种存储结构下的数据集。
FFHQDataset与SingleImageDataset 通过对代码中__getitem__函数分析,可以看到FFHQDataset与支持数据中只有一种图片。SingleImageDataset与FFHQDataset类似,也是支持只有一种图片的数据集。
@DATASET_REGISTRY.register()
class FFHQDataset(data.Dataset):def __getitem__(self, index):if self.file_client is None:self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)# load gt imagegt_path = self.paths[index]# avoid errors caused by high latency in reading filesretry = 3while retry > 0:try:img_bytes = self.file_client.get(gt_path)except Exception as e:logger = get_root_logger()logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')# change another file to readindex = random.randint(0, self.__len__())gt_path = self.paths[index]time.sleep(1) # sleep 1s for occasional server congestionelse:breakfinally:retry -= 1img_gt = imfrombytes(img_bytes, float32=True)# random horizontal flipimg_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)# BGR to RGB, HWC to CHW, numpy to tensorimg_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)# normalizenormalize(img_gt, self.mean, self.std, inplace=True)return {'gt': img_gt, 'gt_path': gt_path}
使用代码如下,可以看到支持一种ffhq_256.lmdb的数据结构。
datasets:train:name: FFHQtype: FFHQDatasetdataroot_gt: datasets/ffhq/ffhq_256.lmdbio_backend:type: lmdbuse_hflip: truemean: [0.5, 0.5, 0.5]std: [0.5, 0.5, 0.5]
PairedImageDataset 通过对代码中__getitem__函数分析,可以看到PairedImageDataset支持数据是需要gt_path与lq_path。使用代码如下,可以看到需要配置dataroot_gt与dataroot_lq;meta_info_file与filename_tmpl只是可选参数。
datasets:train:name: DIV2Ktype: PairedImageDatasetdataroot_gt: datasets/DF2K/DIV2K_train_HR_subdataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_submeta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt# (for lmdb)# dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb# dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdbfilename_tmpl: '{}'io_backend:type: disk# (for lmdb)# type: lmdbgt_size: 128use_hflip: trueuse_rot: true# data loadernum_worker_per_gpu: 6batch_size_per_gpu: 16dataset_enlarge_ratio: 100prefetch_mode: ~val:name: Set5type: PairedImageDatasetdataroot_gt: datasets/Set5/GTmod12dataroot_lq: datasets/Set5/LRbicx4io_backend:type: disk
其返回数据结构为
{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
RealESRGANDataset 是用于RealESRGAN模型的数据加载器。格式数据的使用如下所示,可以看到需要设置各种数据在线下采样的参数配置。但最为关键的是dataroot_gt与meta_info,但从使用中可以看出在meta_info的txt中只是存储了图像的相对文件名。
# Each line in the meta_info describes the relative path to an imagewith open(self.opt['meta_info']) as fin:paths = [line.strip().split(' ')[0] for line in fin]self.paths = [os.path.join(self.gt_folder, v) for v in paths]
datasets:train:name: DF2K+OSTtype: RealESRGANDatasetdataroot_gt: datasets/DF2Kmeta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txtio_backend:type: diskblur_kernel_size: 21kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob: 0.1blur_sigma: [0.2, 3]betag_range: [0.5, 4]betap_range: [1, 2]blur_kernel_size2: 21kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob2: 0.1blur_sigma2: [0.2, 1.5]betag_range2: [0.5, 4]betap_range2: [1, 2]final_sinc_prob: 0.8gt_size: 256use_hflip: Trueuse_rot: False
其返回数据结构为
{'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
RealESRGANPairedDataset 也是用于RealESRGAN的数据加载器,但其参数结构、返回数据结构与PairedImageDataset是一样的。返回结构为:
{'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
**VideoTestDataset ** 是针对图像序列(视频)的数据加载,其同样需要设置’dataroot_gt’,‘dataroot_lq’,每一个路径要求的数据格式为如下所示。
dataroot├── subfolder1├── frame000├── frame001├── ...├── subfolder2├── frame000├── frame001├── ...├── ...
返回的数据格式为:
return {'lq': imgs_lq, # (t, c, h, w)'gt': img_gt, # (c, h, w)'folder': folder, # folder name'idx': self.data_info['idx'][index], # e.g., 0/99'border': border, # 1 for border, 0 for non-border'lq_path': lq_path # center frame}
REDSDataset与Vimeo90KDataset 是特定数据的加载方法
2.6 metrics详情
metrics目录下的是评价指标,目前通__init__.py文件可以看到只支持’calculate_psnr’, ‘calculate_ssim’, ‘calculate_niqe’ 三种。其中niqe是一种无参考的评价指标,我们可以将自行将其他
对应配置文件中的使用代码如下:
metrics:psnr: # metric name, can be arbitrarytype: calculate_psnrcrop_border: 0test_y_channel: falsessim:type: calculate_ssimcrop_border: 0test_y_channel: falseniqe:type: calculate_niqecrop_border: 4num_thread: 8
2.7 options详情(配置文件)
options目录与核心代码目录平级,在项目根路径下。主要包含train与test两个分支,里面存储的是对应的使用配置(含模型结构、数据加载器配置、loss配置等)。
以options\train\RealESRGAN\train_realesrgan_x2plus.yml为例
# general settings
name: train_RealESRGANx2plus_400k_B12G4
model_type: RealESRGANModel #指定模型结构
scale: 2
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs #指定gpu数量
manual_seed: 0# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
# USM the ground-truth
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]# the second degradation process
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]gt_size: 256
queue_size: 180# dataset and data loader settings
datasets:train:name: DF2K+OSTtype: RealESRGANDatasetdataroot_gt: datasets/DF2Kmeta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txtio_backend:type: diskblur_kernel_size: 21kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob: 0.1blur_sigma: [0.2, 3]betag_range: [0.5, 4]betap_range: [1, 2]blur_kernel_size2: 21kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]sinc_prob2: 0.1blur_sigma2: [0.2, 1.5]betag_range2: [0.5, 4]betap_range2: [1, 2]final_sinc_prob: 0.8gt_size: 256use_hflip: Trueuse_rot: False# data loadernum_worker_per_gpu: 5batch_size_per_gpu: 12dataset_enlarge_ratio: 1prefetch_mode: ~# Uncomment these for validation# val:# name: validation# type: PairedImageDataset# dataroot_gt: path_to_gt# dataroot_lq: path_to_lq# io_backend:# type: disk# network structures
network_g:type: RRDBNetnum_in_ch: 3num_out_ch: 3num_feat: 64num_block: 23num_grow_ch: 32scale: 2network_d:type: UNetDiscriminatorSNnum_in_ch: 3num_feat: 64skip_connection: True# path
path:# use the pre-trained Real-ESRNet modelpretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pthparam_key_g: params_emastrict_load_g: trueresume_state: ~# training settings
train:ema_decay: 0.999optim_g:type: Adamlr: !!float 1e-4weight_decay: 0betas: [0.9, 0.99]optim_d:type: Adamlr: !!float 1e-4weight_decay: 0betas: [0.9, 0.99]scheduler:type: MultiStepLRmilestones: [400000]gamma: 0.5total_iter: 400000warmup_iter: -1 # no warm up# lossespixel_opt:type: L1Lossloss_weight: 1.0reduction: mean# perceptual loss (content and style losses)perceptual_opt:type: PerceptualLosslayer_weights:# before relu'conv1_2': 0.1'conv2_2': 0.1'conv3_4': 1'conv4_4': 1'conv5_4': 1vgg_type: vgg19use_input_norm: trueperceptual_weight: !!float 1.0style_weight: 0range_norm: falsecriterion: l1# gan lossgan_opt:type: GANLossgan_type: vanillareal_label_val: 1.0fake_label_val: 0.0loss_weight: !!float 1e-1net_d_iters: 1net_d_init_iters: 0# Uncomment these for validation
# validation settings
# val:
# val_freq: !!float 5e3
# save_img: True# metrics:
# psnr: # metric name
# type: calculate_psnr
# crop_border: 4
# test_y_channel: false# logging settings
logger:print_freq: 100save_checkpoint_freq: !!float 5e3use_tb_logger: truewandb:project: ~resume_id: ~# dist training settings
dist_params:backend: ncclport: 29500
3、其他关键信息
3.1 训练与测试
训练命令如下,主要是-opt对应的yml文件,该文件即为2.7中对应的配置项
python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
完整的命令行参数如下:
- -opt,配置文件的路径,一般采用这个命令配置训练或者测试的 yml 文件。
- – laucher,用于指定 distibuted training 的,比如 pytorch 或者 slurm。默认是 none,
即单卡非 distributed training。 - – auto_resume,是否自动 resume,即自动查找最近的 checkpoint ,然后 resume。
- – debug,能够快速帮助 debug。
- – local_rank,这个不用管,是 distributed training 中程序自动会传入。
- – force_yml,方便在命令行中修改 yml 中的配置文件。
测试命令如下:
python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
3.2 模型保存与训练状态恢复
训练的时候, checkpoints 会保存两个文件:
- 网 络 参 数 .pth 文 件 。 在 每 个 实 验 的 models 文 件 夹 中 , 文 件 名 诸
如:net_g_5000.pth、net_g_10000.pth - 包含 optimizer 和 scheduler 信息的 .state 文件。在每个实验的 training_states 文件夹中,
文件名诸如:5000.state、10000.state
根据这两个文件,就可以 resume 了。
对应的参数配置如下,对应pretrain_network_g与resume_state的配置
path:# use the pre-trained Real-ESRNet modelpretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pthparam_key_g: params_emastrict_load_g: trueresume_state: True # 默认为 ~, 表示删除该参数
也可以在命令行中加入 ‘–auto_resume’,程序就会找到保存的最近的模型
参数和状态,并加载进来,接着训练啦。
分布式训练 单机多卡 8 GPU训练命令
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
单机多卡 4 GPU训练命令
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
Slurm训练 4 GPU训练命令
GLOG_vmodule=MemcachedClient=-1
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"
3.3 模型EMA
EMA (Exponential Moving Average),指数移动平均。 它是用来“平均”一个变量在历史上的值。使用怎样的权重平均呢?如名字所示,随着时间,越是过往的时间,以一个指数衰减的权重来平均。
在 BasicSR 里面,EMA 一般作用在模型的参数上。它的效果一般是:
• 稳定训练效果。GAN 训练的结果一般瑕疵更少,视觉效果更好
• 对于以 PSNR 为目的的模型,其 PSNR 一般会更高一些
由于开启 EMA的代价几乎可以不计,所以我们推荐开启 EMA。