Pytorch训练RCAN QAT超分模型

Pytorch训练RCAN QAT超分模型

  • 版本信息
  • 测试步骤
    • 准备数据集
    • 创建容器
    • 生成文件列表
      • 创建文件列表的代码
      • 执行脚本,生成文件列表
    • 训练RCAN模型
      • 准备工作
      • 修改开源代码
      • 编写训练代码
      • 执行训练脚本
    • 可视化

本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:

  • 先训练FP32模型
  • 再加载FP32训练的权值,进行QAT训练
  • 连续5次loss没有下降则停止训练
  • 为了加快演示,当psnr大于33.0时就停止训练
  • 采用tensorboard观察Loss曲线

版本信息

属性
训练环境搭建步骤
GPU型号NVIDIA GeForce RTX 3080 12GB
数据集下载链接http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
开源模型结构https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

测试步骤

准备数据集

wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip

创建容器

按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像

docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \--net=host -p 6006:6006 -v $PWD:/home -w /home  \-v /mnt/disk/RCAN/:/RCAN --name rcan_dev  cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev

生成文件列表

创建文件列表的代码

# generate_datalist.pyimport os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdmtrain_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'def get_images(input_path, format='png'):names = [os.path.splitext(fname)[0]for fname in os.listdir(input_path)if fname.endswith(format)]names.sort()return namesdef get_folders(input_path):names = [directory for directory in os.listdir(input_path)if os.path.isdir(os.path.join(input_path, directory))]names.sort()return namesthe_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names: the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()

执行脚本,生成文件列表

cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py

训练RCAN模型

准备工作

# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

修改开源代码

  • model/rcan.py

image-20240220142852491

image-20240220144639916

  • model/common.py

    image-20240220143210588

编写训练代码

# train.pyimport os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as npdef _apply(func, x):if isinstance(x, (list, tuple)):return [_apply(func, x_i) for x_i in x]elif isinstance(x, dict):y = {}for key, value in x.items():y[key] = _apply(func, value)return yelse:return func(x)def get_patch(*args, patch_size=96, scale=2, input_large=False):ih, iw = args[0].shape[:2]if not input_large:p = scaletp = p * patch_sizeip = tp // scaleelse:tp = patch_sizeip = patch_sizeix = random.randrange(0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704058.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【随笔】固态硬盘数据删除无法恢复(开启TRIM),注意数据备份

文章目录 一、序二、机械硬盘和固态硬盘的物理结构与工作原理2.1 机械硬盘2.11 基本结构2.12 工作原理 2.2 固态硬盘2.21 基本结构2.22 工作原理 三、机械硬盘和固态硬盘的垃圾回收机制3.1 机械硬盘GC3.2 固态硬盘GC3.3 TRIM指令开启和关闭 四、做好数据备份 一、序 周末电脑突…

【Qt学习】QLineEdit 控件 属性与实例(登录界面,验证密码,正则表达式)

文章目录 1. 介绍2. 实例使用2.1 登录界面2.2 对比两次密码是否相同2.3 通过按钮显示当前输入的密码(并对2.2进行优化)2.4 结语 3. 正则表达式3.1 QRegExp3.2 验证输入内容 4. 资源代码 1. 介绍 关于 QLineEdit 的详细介绍,可以去查阅官方文…

[计算机网络]--IP协议

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、IP协议…

202432读书笔记|《泰戈尔的诗》——什么事让你大笑,我生命的小蓓蕾

202432读书笔记|《泰戈尔的诗》——什么事让你大笑,我生命的小蓓蕾 《泰戈尔写给孩子的诗(中英双语版)》作者拉宾德拉纳特泰戈尔文 张王哲图,图文并茂的一本书,文字与图画都很美,相得益彰!很值得…

【Memory协议栈】EEPROM Abstraction模块详细介绍

目录 前言 正文 1.功能简介 2.关键概念 3.功能详解 3.1 Addressing scheme and segmentation 3.2 Address calculation 3.3 Limitation of erase / write cycles 3.4 Handling of “immediate” data 3.5 Managing block consistency information 4.关键API定义 4.…

学习磁盘管理

文章目录 一、磁盘接口类型二、磁盘设备的命名三、fdisk分区四、自动挂载五、扩容swap六、GPT分区七、逻辑卷管理八、磁盘配额九、RAID十、软硬链接 一、磁盘接口类型 IDE、SATA、SCSI、SAS、FC(光纤通道) IDE, 该接口是并口。SATA, 该接口是串口。SCS…

golang学习2,golang开发配置国内镜像

go env -w GO111MODULEon go env -w GOPROXYhttps://goproxy.cn,direct

K8S部署Java项目(Gitlab CI/CD自动化部署终极版)

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

websocket入门及应用

websocket When to use a HTTP call instead of a WebSocket (or HTTP 2.0) WebSocket 是基于TCP/IP协议,独立于HTTP协议的通信协议。WebSocket 是双向通讯,有状态,客户端一(多)个与服务端一(多&#xff09…

代码随想录刷题第43天

第一题是最后一块石头的重量IIhttps://leetcode.cn/problems/last-stone-weight-ii/,没啥思路,直接上题解了。本题可以看作将一堆石头尽可能分成两份重量相似的石头,于是问题转化为如何合理取石头,使其装满容量为石头总重量一半的…

【AI Agent系列】【MetaGPT多智能体学习】0. 环境准备 - 升级MetaGPT 0.7.2版本及遇到的坑

之前跟着《MetaGPT智能体开发入门课程》学了一些MetaGPT的知识和实践,主要关注在MetaGPT入门和单智能体部分(系列文章附在文末,感兴趣的可以看下)。现在新的教程来了,新教程主要关注多智能体部分。 本系列文章跟随《M…

Wagtail安装运行并结合内网穿透实现公网访问本地网站界面

文章目录 前言1. 安装并运行Wagtail1.1 创建并激活虚拟环境 2. 安装cpolar内网穿透工具3. 实现Wagtail公网访问4. 固定的Wagtail公网地址 正文开始前给大家推荐个网站,前些天发现了一个巨牛的 人工智能学习网站, 通俗易懂,风趣幽默&#xf…

C++Lambda表达式介绍

C11中引入了Lambda表达式,Lambda表达式是一种匿名函数,它可以在需要函数的地方直接定义和使用,而无需显式地定义一个函数。 lambda表达式 Lambda表达式语法定义 [capture-list](parameters) -> return-type { statement } capture-lis…

SQL Developer 小贴士:PL/SQL语法分析

对于SQL或PL/SQL中的语法错误和警告,SQL Developer可以用不同颜色的下划波浪线显示。 启用语法分析,可以用菜单Tool>Preferences>Code Editor>Completion Insight>Enable Semantic Analysis Info Tips 例如,以下的代码中&…

blender bvh显示关节名称

导入bvh,菜单选择布局,右边出现属性窗口, 在下图红色框依次点击选中,就可以查看bvh关节名称了。

自考《计算机网络原理》考前冲刺

常考选择填空 1、计算机网络的定义:计算机网络是互连的、自治的计算机的集合。 2、协议的定义:协议是网络通信实体之间在数据交换过程中需要遵循的规则或约定 3、协议的3个要素 (1) 语法:定义实体之间交换信息的格式与结构,或…

设计模式六:策略模式

1、策略模式 策略模式定义了一系列的算法,并将每一个算法封装起来,使每个算法可以相互替代,使算法本身和使用算法的客户端分割开来,相互独立。 策略模式的角色: 策略接口角色IStrategy:用来约束一系列具体…

第一次开机开机动画结束后闪白屏

开机动画结束会闪下白屏,再进入launcher 思路 : 分析下从开机动画结束到launcher起来之间的流程步骤 从ZygoteInit.java开始分析 : SystemServer起来后会启动一些核心服务 attachApplication方法中主要创建了Application和Activity 接下里RootActivityC…

快速搭建网站原型!8款网站原型软件推荐

现在,基于云的软件已经逐渐成为主流,网站原型设计工具也不例外。与桌面版本相比,在线原型工具具有独特的优势,无论您使用Linux,Mac 或者Windows,都不需要安装就可以使用这些工具。下面小编就为大家推荐8款非…

c++入门学习⑧——模板

目录 前言 基本介绍 什么是模板? 作用 特点 分类 函数模板 语法 使用方式 注意事项 函数模板和普通函数区别 普通函数和函数模板的调用规则 局限性 类模板 语法 类模板的成员函数创建时机 类模板实例化对象 类模板实例化对象做函数参数 类模板成…