【深度学习】CodeFormer训练过程,如何训练人脸修复模型CodeFormer

文章目录

  • BasicSR介绍
  • 环境
  • 数据
  • 阶段 I - VQGAN
  • 阶段 II - CodeFormer (w=0)
  • 阶段 III - CodeFormer (w=1)

代码地址:https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

论文的一些简略介绍:
https://qq742971636.blog.csdn.net/article/details/134562550

BasicSR介绍

CodeFormer整个项目都沿袭BasicSR,了解一下BasicSR很有必要:

https://mp.csdn.net/mp_blog/creation/success/135674803

环境

# git clone this repository
git clone https://github.com/sczhou/CodeFormer
cd CodeFormer# create new anaconda env
conda create -n codeformer python=3.8 -y
conda activate codeformerconda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia# install python dependencies
pip3 install -r requirements.txt
python basicsr/setup.py developconda install -c conda-forge dlib (only for face detection or cropping with dlib)

数据

找一些高清人脸数据1024*1024。

人脸数据需要对齐,对齐方式为: https://qq742971636.blog.csdn.net/article/details/135521146

阶段 I - VQGAN

训练VQGAN:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch
CUDA_VISIBLE_DEVICES=0,2,3 python -m torch.distributed.launch --nproc_per_node=3 --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch # 指定三张显卡训练,对应VQGAN_512_ds32_nearest_stage1.yaml也是需要修改的

训练完VQGAN后,可以通过下面代码预先获得训练数据集的密码本序列,从而加速后面阶段的训练过程:

python scripts/generate_latent_gt.py

如果你不需要训练自己的VQGAN,可以在Release v0.1.0文档中找到预训练的VQGAN (vqgan_code1024.pth)和对应的密码本序列 (latent_gt_code1024.pth): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

打开日志查看训练过程:

tensorboard --logdir="/ssd/xiedong/CodeFormer/tb_logger/20240116_182107_VQGAN-512-ds32-nearest-stage1" --bind_all

在这里插入图片描述

VQGAN本身就是一个图生图的网络,在中间使用transformer将特征图转为embedding. 而 CodeFormer就是要利用这每张图的embedding来进行面部修复。

下面代码里用vqgan_code1024.pth获取训练数据的密码本,vqgan_code1024.pth的encoder输出的是2563232的特征图,由embedding给到1*1024,最终所有图保存为一个pytorch文件。

import argparse
import glob
import numpy as np
import os
import cv2
import torch
from torchvision.transforms.functional import normalize
from tqdm import tqdmfrom basicsr.utils import imwrite, img2tensor, tensor2imgfrom basicsr.utils.registry import ARCH_REGISTRYif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('-i', '--test_path', type=str, default='/ssd/xiedong/FFHQ/faces_hq_sr')parser.add_argument('-o', '--save_root', type=str, default='/ssd/xiedong/FFHQ/lt_output')parser.add_argument('--codebook_size', type=int, default=1024)parser.add_argument('--ckpt_path', type=str, default='/ssd/xiedong/CodeFormer/weights/vqgan/vqgan_code1024.pth')args = parser.parse_args()if args.save_root.endswith('/'):  # solve when path ends with /args.save_root = args.save_root[:-1]dir_name = os.path.abspath(args.save_root)os.makedirs(dir_name, exist_ok=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')test_path = args.test_pathsave_root = args.save_rootckpt_path = args.ckpt_pathcodebook_size = args.codebook_sizevqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',codebook_size=codebook_size).to(device)checkpoint = torch.load(ckpt_path)['params_ema']vqgan.load_state_dict(checkpoint)vqgan.eval()sum_latent = np.zeros((codebook_size)).astype('float64')size_latent = 32latent = {}latent['orig'] = {}latent['hflip'] = {}for i in ['orig', 'hflip']:# for i in ['hflip']:for img_path in tqdm(sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g')))):img_name = os.path.basename(img_path)img = cv2.imread(img_path)if i == 'hflip':cv2.flip(img, 1, img)img = img2tensor(img / 255., bgr2rgb=True, float32=True)normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)img = img.unsqueeze(0).to(device)with torch.no_grad():# output = net(img)[0]# x, feat_dict = vqgan.encoder(img, True)x = vqgan.encoder(img)x, _, log = vqgan.quantize(x)# del outputtorch.cuda.empty_cache()min_encoding_indices = log['min_encoding_indices']min_encoding_indices = min_encoding_indices.view(size_latent, size_latent)latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()print(img_name, latent[i][img_name[:-4]].shape)latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')torch.save(latent, latent_save_path)print(f'\nLatent GT code are saved in {save_root}')

阶段 II - CodeFormer (w=0)

w=0 是需要模型完全追求抽象美学,w=1 是需要模型完全追求与原图相似。

在第一个阶段,得到了每张图对应的embedding。

训练密码本训练预测模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch

预训练CodeFormer第二阶段模型 (codeformer_stage2.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

阶段 III - CodeFormer (w=1)

训练可调模块:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch

预训练CodeFormer模型 (codeformer.pth)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/649380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

链路追踪-调用链跟踪-Jaeger

文章目录 一、什么是链路跟踪二、OpenCensusOpenCensus 主要特点OpenTracing标准基本概念Span 三、典型服务端产品什么是OpenTracing?opentracing 使用介绍 四、JaegerJaeger 包含的模块Jaeger-client(客户端库) 五、Jaeger服务容器化部署过程问题整理 …

csdn黑色背景用法

在edge浏览器下,下载油猴脚本管理器 脚本下载 edge扩展 效果图如下:::

[ACM学习] 进制转换

进制的本质 本质是每一位的数位上的数字乘上这一位的权重 将任意进制转换为十进制 原来还很疑惑为什么从高位开始,原来从高位开始的,可以被滚动地乘很多遍。 将十进制转换为任意进制

适合深夜发朋友圈的心灵鸡汤(整理70句)

1、很多时候,我们赢得了口舌,却失去了感情。 2、失恋到极致的时候,我真的会用后退来保护自己。 3、全身心地去爱,你可能会受到伤害,但这是完整人生的唯一方式。 4、自由不是想干什么就干什么,而是不想干…

Linux中LVM实验

LVM实验: 1、分区 -L是大小的意思-n名称的意思 从vg0(卷组)分出来 2、格式化LV逻辑卷 LVM扩容 如果icdir空间不够了, 扩展空间lvextend -L 5G /dev/vg0/lv1 /dev/vg0/lv1(pp,vg,lv) 刷新文件系统xfs_growfs /lvdir VG扩容 …

php:规范小数位数,例:10.00展示为10,10.98展示为10.98

代码 <?php$value 10.98; // 原始的双精度类型值if ($value floor($value)) {$formattedValue number_format($value, 0); // 10.00 转换为 10echo $formattedValue;} else {$formattedValue number_format($value, 2); // 10.98 保持为 10.98echo $formattedValue;} …

Sublime Text 3配置 Java 开发环境

《开发工具系列》 《开发语言-Java》 Sublime Text 3配置 Java 开发环境 一、引言二、主要内容1. 初识 Sublime Text 32. 初识 Java3. 接入 Java3.1 JDK 下载3.2 安装和使用 java3.3 环境变量配置 4. 配置 Java 开发环境5. 编写 Java 代码6. 编译和运行 Java 代码7. 乱码问题 三…

服务器无法访问外网怎么办

目前是互联网时代&#xff0c;网络已经成为人们日常生活中不可或缺的一部分。我们通过网络获取信息、进行沟通、甚至进行工作&#xff0c;因此&#xff0c;保持网络的稳定和通畅是非常重要的。然而&#xff0c;有时候我们可能会遇到一些网络无法访问外网的问题&#xff0c;这给…

作者推荐 | 【深入浅出MySQL】「底层原理」探秘缓冲池的核心奥秘,揭示终极洞察

探秘缓冲池的核心奥秘&#xff0c;揭示终极洞察 缓存池BufferPool机制MySQL缓冲池缓冲池缓冲池的问题 缓冲池的原理数据预读程序的局部性原则&#xff08;集中读写原理&#xff09;时间局部性空间局部性 innodb的数据页查询InnoDB的数据页InnoDB缓冲池缓存数据页InnoDB缓存数据…

[DIOR | DIOR-R]旋转目标检测数据集——基于YOLOv8obb,map50已达81.8%

DIOR是一个用于光学遥感图像目标检测的大规模基准数据集。涵盖20个对象类。这20个对象类是飞机、机场、棒球场、篮球场、桥梁、烟囱、水坝、高速公路服务区、高速公路收费站、港口、高尔夫球场、地面田径场、天桥、船舶、体育场、储罐、网球场、火车站、车辆和风磨。 1. DIOR简…

常见の算法链表问题

时间复杂度 1.链表逆序 package class04;import java.util.ArrayList; import java.util.List;public class Code01_ReverseList {public static class Node {public int value;public Node next;public Node(int data) {value data;}}public static class DoubleNode {publi…

Java 字符串 05 练习-遍历字符串和统计字符个数

代码&#xff1a; import java.util.Scanner; public class practice{public static void main(String[] args) {//键盘录入一个字符串&#xff0c;并进行遍历&#xff1b;Scanner input new Scanner(System.in);System.out.println("输入一个字符串&#xff1a;")…

webassembly003 whisper.cpp的main项目-1

参数设置 /home/pdd/le/whisper.cpp-1.5.0/cmake-build-debug/bin/main options:-h, --help [default] show this help message and exit-t N, --threads N [4 ] number of threads to use during computation-p N, --processors …

Android App开发-简单控件(2)——视图基础

2.2 视图基础 本节介绍视图的几种基本概念及其用法&#xff0c;包括如何设置视图的宽度和高度&#xff0c;如何设置视图的外部间距和内部间距&#xff0c;如何设置视图的外部对齐方式和内部对齐方式等等。 2.2.1 设置视图的宽高 手机屏幕是块长方形区域&#xff0c;较短的那…

【星海随笔】unix 启动问题记录.

启动Ubuntu操作系统时&#xff0c;直接进入GRUB状态。 调试时候&#xff0c;曾显示 no bootable device no known filesystem detected 注意&#xff1a; 目前 GRUB 分成 GRUB legacy 和 GRUB 2。版本号是 0.9x 以及之前的版本都称为 GRUB Legacy &#xff0c;从 1.x 开始的就称…

NODE笔记 2 使用node操作飞书多维表格

前面简单介绍了node与简单的应用&#xff0c;本文通过结合飞书官方文档 使用node对飞书多维表格进行简单的操作&#xff08;获取token 查询多维表格recordid&#xff0c;删除多行数据&#xff0c;新增数据&#xff09; 文章目录 前言 前两篇文章对node做了简单的介绍&#xff…

eNSP学习——配置通过STelnet登陆系统

目录 背景 实验内容 实验目的 实验步骤 实验拓扑 详细配置过程 基础配置 配置SSH server 配置SSH client 配置SFTP server 与client 背景 由于Telnet缺少安全的认证方式&#xff0c;而且传输过程采用的是TCP进行明文传输。单纯的提供Telnet服务容易招致主机IP地址欺骗、路…

数据分析 - 图形化解释(后续添加)

图形化解释 作为数据分析师来说一个好的图形&#xff0c;就是自己的数据表达能力 简单文本 只有一两项数据需要分享的时候&#xff0c;简单的文本是最佳的沟通方法 下图的对比可以看出来文字的表达效果会好很多 散点图 散点图在展示两件事的关系时很有用&#xff0c;观察是否存…

【搞懂设计模式】命令模式:从遥控器到编程的妙用!

我们都熟悉电视遥控器&#xff0c;它有许多按钮&#xff0c;每个按钮都有确定的功能。你按下电源键电视就会打开&#xff0c;再按下一次电视就会关闭。编程世界里也有这种模式&#xff0c;这就是我们说的命令模式。 命令模式是一种设计模式&#xff0c;它把一个请求或操作封装…

以梦为码,CodeArts Snap 缩短我与算法的距离

背景 最近一直在体验华为云的 CodeArts Snap&#xff0c;逐渐掌握了使用方法&#xff0c;代码自动生成的准确程度大大提高了。 自从上次跟着 CodeArts Snap 学习用 Python 编程&#xff0c;逐渐喜欢上了 Python。 我还给 CodeArts Snap 起了一个花名&#xff1a; 最佳智能学…