Pytorch训练RCAN QAT超分模型
- 版本信息
- 测试步骤
- 准备数据集
- 创建容器
- 生成文件列表
- 创建文件列表的代码
- 执行脚本,生成文件列表
- 训练RCAN模型
- 准备工作
- 修改开源代码
- 编写训练代码
- 执行训练脚本
- 可视化
本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:
- 先训练FP32模型
- 再加载FP32训练的权值,进行QAT训练
- 连续5次loss没有下降则停止训练
- 为了加快演示,当psnr大于33.0时就停止训练
- 采用tensorboard观察Loss曲线
版本信息
属性 | 值 |
---|---|
训练环境 | 搭建步骤 |
GPU型号 | NVIDIA GeForce RTX 3080 12GB |
数据集下载链接 | http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip |
开源模型结构 | https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py |
测试步骤
准备数据集
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
创建容器
按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像
docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \--net=host -p 6006:6006 -v $PWD:/home -w /home \-v /mnt/disk/RCAN/:/RCAN --name rcan_dev cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev
生成文件列表
创建文件列表的代码
# generate_datalist.pyimport os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdmtrain_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'def get_images(input_path, format='png'):names = [os.path.splitext(fname)[0]for fname in os.listdir(input_path)if fname.endswith(format)]names.sort()return namesdef get_folders(input_path):names = [directory for directory in os.listdir(input_path)if os.path.isdir(os.path.join(input_path, directory))]names.sort()return namesthe_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names: the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()
执行脚本,生成文件列表
cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py
训练RCAN模型
准备工作
# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py
修改开源代码
- model/rcan.py
-
model/common.py
编写训练代码
# train.pyimport os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as npdef _apply(func, x):if isinstance(x, (list, tuple)):return [_apply(func, x_i) for x_i in x]elif isinstance(x, dict):y = {}for key, value in x.items():y[key] = _apply(func, value)return yelse:return func(x)def get_patch(*args, patch_size=96, scale=2, input_large=False):ih, iw = args[0].shape[:2]if not input_large:p = scaletp = p * patch_sizeip = tp // scaleelse:tp = patch_sizeip = patch_sizeix = random.randrange(0