Pytorch训练RCAN QAT超分模型

Pytorch训练RCAN QAT超分模型

  • 版本信息
  • 测试步骤
    • 准备数据集
    • 创建容器
    • 生成文件列表
      • 创建文件列表的代码
      • 执行脚本,生成文件列表
    • 训练RCAN模型
      • 准备工作
      • 修改开源代码
      • 编写训练代码
      • 执行训练脚本
    • 可视化

本文以RCAN超分模型为例,演示了QAT的训练过程,步骤如下:

  • 先训练FP32模型
  • 再加载FP32训练的权值,进行QAT训练
  • 连续5次loss没有下降则停止训练
  • 为了加快演示,当psnr大于33.0时就停止训练
  • 采用tensorboard观察Loss曲线

版本信息

属性
训练环境搭建步骤
GPU型号NVIDIA GeForce RTX 3080 12GB
数据集下载链接http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip
开源模型结构https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

测试步骤

准备数据集

wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip

创建容器

按https://editor.csdn.net/md/?articleId=136176989的步骤构建镜像

docker stop rcan_dev
docker rm rcan_dev
nvidia-docker run -ti -e NVIDIA_VISIBLE_DEVICES=all --privileged \--net=host -p 6006:6006 -v $PWD:/home -w /home  \-v /mnt/disk/RCAN/:/RCAN --name rcan_dev  cuda_dev_image:v1.0 /bin/bash
conda activate ai_dev

生成文件列表

创建文件列表的代码

# generate_datalist.pyimport os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdmtrain_HR_path = './DIV2K_train_HR'
train_LR_path = './DIV2K_train_LR_bicubic/X2'
valid_HR_path = './DIV2K_valid_HR'
valid_LR_path = './DIV2K_valid_LR_bicubic/X2'train_file = 'datalist_div2k_train.txt'
valid_file = 'datalist_div2k_valid.txt'def get_images(input_path, format='png'):names = [os.path.splitext(fname)[0]for fname in os.listdir(input_path)if fname.endswith(format)]names.sort()return namesdef get_folders(input_path):names = [directory for directory in os.listdir(input_path)if os.path.isdir(os.path.join(input_path, directory))]names.sort()return namesthe_train_file = open(train_file, 'w')
image_names = get_images(train_HR_path)
for image_name in image_names:the_train_file.write('DIV2K_train_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_train_HR/' + image_name + '.png' + '\n')
the_train_file.close()the_valid_file = open(valid_file, 'w')
image_names = get_images(valid_HR_path)
for image_name in image_names: the_valid_file.write('DIV2K_valid_LR_bicubic/X2/' + image_name + 'x2.png' + ' ' + 'DIV2K_valid_HR/' + image_name + '.png' + '\n')
the_valid_file.close()

执行脚本,生成文件列表

cd /RCAN/
unzip DIV2K_train_HR.zip
unzip DIV2K_valid_HR.zip
unzip DIV2K_train_LR_bicubic_X2.zip
unzip DIV2K_valid_LR_bicubic_X2.zip
python generate_datalist.py

训练RCAN模型

准备工作

# 安装依赖
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple# 设置环境变量
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python# 下载开源模型源码
cd /RCAN/
mkdir model
curl -L -o model/rcan.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/rcan.py
curl -L -o model/option.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/option.py
curl -L -o model/common.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/model/common.py
curl -L -o template.py https://raw.githubusercontent.com/RussellEven/Multi-frame-RCAN/master/code/template.py

修改开源代码

  • model/rcan.py

image-20240220142852491

image-20240220144639916

  • model/common.py

    image-20240220143210588

编写训练代码

# train.pyimport os
import torch
import torch.nn as nn
import torch.optim as optim
import json
import copy
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.quantization.quantize_fx import prepare_qat_fx,convert_fx
from torch.ao.quantization import qconfig
from torch.ao.quantization.fake_quantize import *
from torch.ao.quantization.observer import *
from torch.utils import tensorboard
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage.color import rgb2hsv, hsv2rgb
import imageio
import random
import numpy as npdef _apply(func, x):if isinstance(x, (list, tuple)):return [_apply(func, x_i) for x_i in x]elif isinstance(x, dict):y = {}for key, value in x.items():y[key] = _apply(func, value)return yelse:return func(x)def get_patch(*args, patch_size=96, scale=2, input_large=False):ih, iw = args[0].shape[:2]if not input_large:p = scaletp = p * patch_sizeip = tp // scaleelse:tp = patch_sizeip = patch_sizeix = random.randrange(0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704058.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

量子计算学习经验

推荐B站冉仕举老师视频(老师讲的详细又耐心,张量网络做量子计算,不过有些基础概念都是通用的) StringCNU的个人空间-StringCNU个人主页-哔哩哔哩视频 2《量子计算与量子信息》是经典的教材书的,但是大部分同学第一次看…

【随笔】固态硬盘数据删除无法恢复(开启TRIM),注意数据备份

文章目录 一、序二、机械硬盘和固态硬盘的物理结构与工作原理2.1 机械硬盘2.11 基本结构2.12 工作原理 2.2 固态硬盘2.21 基本结构2.22 工作原理 三、机械硬盘和固态硬盘的垃圾回收机制3.1 机械硬盘GC3.2 固态硬盘GC3.3 TRIM指令开启和关闭 四、做好数据备份 一、序 周末电脑突…

数据库设计过程中的各种模式

在数据库设计过程中,有几种常见的模式,它们有助于组织和管理数据。以下是这几种模式的简介: 主扩展模式(也称为主从模式):这种模式适用于多个表具有相似结构的情况。这些表共享某些基本属性(也…

备战蓝桥之二分

二分题目: B3880 [信息与未来 2015] 买木头 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.security.PublicKey; impor…

【Qt学习】QLineEdit 控件 属性与实例(登录界面,验证密码,正则表达式)

文章目录 1. 介绍2. 实例使用2.1 登录界面2.2 对比两次密码是否相同2.3 通过按钮显示当前输入的密码(并对2.2进行优化)2.4 结语 3. 正则表达式3.1 QRegExp3.2 验证输入内容 4. 资源代码 1. 介绍 关于 QLineEdit 的详细介绍,可以去查阅官方文…

[计算机网络]--IP协议

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、IP协议…

202432读书笔记|《泰戈尔的诗》——什么事让你大笑,我生命的小蓓蕾

202432读书笔记|《泰戈尔的诗》——什么事让你大笑,我生命的小蓓蕾 《泰戈尔写给孩子的诗(中英双语版)》作者拉宾德拉纳特泰戈尔文 张王哲图,图文并茂的一本书,文字与图画都很美,相得益彰!很值得…

【Memory协议栈】EEPROM Abstraction模块详细介绍

目录 前言 正文 1.功能简介 2.关键概念 3.功能详解 3.1 Addressing scheme and segmentation 3.2 Address calculation 3.3 Limitation of erase / write cycles 3.4 Handling of “immediate” data 3.5 Managing block consistency information 4.关键API定义 4.…

学习磁盘管理

文章目录 一、磁盘接口类型二、磁盘设备的命名三、fdisk分区四、自动挂载五、扩容swap六、GPT分区七、逻辑卷管理八、磁盘配额九、RAID十、软硬链接 一、磁盘接口类型 IDE、SATA、SCSI、SAS、FC(光纤通道) IDE, 该接口是并口。SATA, 该接口是串口。SCS…

Linux笔记--文件内容的查阅与统计

一、文件内容的查阅 1.cat指令 concatenate,连接文件并打印到标准输出设备上(查看文件) (1) #cat文件的路径 常用选项: -n列出行号 (2)#tac 含义:倒序显示(应用:查看日志) 2. head指令 查看一个文件的前n行…

golang学习2,golang开发配置国内镜像

go env -w GO111MODULEon go env -w GOPROXYhttps://goproxy.cn,direct

npm已经配置淘宝源仍然无法使用

使用npm命令安装Taro框架的时候,尽管已经设置淘宝源但是仍然无法下载,提示错误 >npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED npm ERR! request to https://registry.npm.taobao.org/cnpm failed, reason: certificate h…

K8S部署Java项目(Gitlab CI/CD自动化部署终极版)

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

websocket入门及应用

websocket When to use a HTTP call instead of a WebSocket (or HTTP 2.0) WebSocket 是基于TCP/IP协议,独立于HTTP协议的通信协议。WebSocket 是双向通讯,有状态,客户端一(多)个与服务端一(多&#xff09…

代码随想录刷题第43天

第一题是最后一块石头的重量IIhttps://leetcode.cn/problems/last-stone-weight-ii/,没啥思路,直接上题解了。本题可以看作将一堆石头尽可能分成两份重量相似的石头,于是问题转化为如何合理取石头,使其装满容量为石头总重量一半的…

【AI Agent系列】【MetaGPT多智能体学习】0. 环境准备 - 升级MetaGPT 0.7.2版本及遇到的坑

之前跟着《MetaGPT智能体开发入门课程》学了一些MetaGPT的知识和实践,主要关注在MetaGPT入门和单智能体部分(系列文章附在文末,感兴趣的可以看下)。现在新的教程来了,新教程主要关注多智能体部分。 本系列文章跟随《M…

五种主流数据库:常用字符函数

SQL 字符函数用于字符数据的处理,例如字符串的拼接、大小写转换、子串的查找和替换等。 本文比较五种主流数据库常用数值函数的实现和差异,包括 MySQL、Oracle、SQL Server、PostgreSQL 以及 SQLite。 字符函数函数功能MySQLOracleSQL ServerPostgreSQ…

Wagtail安装运行并结合内网穿透实现公网访问本地网站界面

文章目录 前言1. 安装并运行Wagtail1.1 创建并激活虚拟环境 2. 安装cpolar内网穿透工具3. 实现Wagtail公网访问4. 固定的Wagtail公网地址 正文开始前给大家推荐个网站,前些天发现了一个巨牛的 人工智能学习网站, 通俗易懂,风趣幽默&#xf…

C++Lambda表达式介绍

C11中引入了Lambda表达式,Lambda表达式是一种匿名函数,它可以在需要函数的地方直接定义和使用,而无需显式地定义一个函数。 lambda表达式 Lambda表达式语法定义 [capture-list](parameters) -> return-type { statement } capture-lis…

cmake build

cmake -H. -Bbuild 这是使用 CMake 的命令行工具来配置项目的命令。具体来说: cmake 是 CMake 的命令行工具。-H. 表示 CMakeLists.txt 文件所在的源代码目录是当前目录 (.)。这个选项指定了 CMakeLists.txt 所在的路径,这样 CMake 就知道在哪里找到项目…