用ControlNet+Inpaint实现stable diffusion模特换衣
- ControlNet 训练与架构详解
- ControlNet 的架构
- 用于文本到图像扩散的 ControlNet
- 训练过程
- Zero卷积层的作用解释
- inpaint
- Inpaint Anything 的重要性
- Inpaint Anything 的功能概述
在现代计算机视觉领域,稳定扩散(Stable Diffusion)技术已经成为图像修复的重要工具之一。然而,虽然稳定扩散能够有效地填补图像中的缺失区域,但是对于用户来说,对修复过程进行更精准的控制往往是一项挑战。
为了解决这一问题,我们引入了ControlNet,这是一种专门设计用于在大型预训练文本到图像扩散模型中引入空间调节控制的神经网络架构。通过结合ControlNet与稳定扩散技术,我们实现了一种全新的图像修复方法,使用户能够通过各种条件输入来精确控制修复过程,例如Canny边缘、霍夫线、用户涂鸦、人体关键点、分割图、形状法线和深度等。本研究不仅证明了ControlNet在小型和大型数据集上的稳健性,还展示了其在图像修复领域的巨大潜力,为更广泛的图像处理应用提供了全新的可能性。
ControlNet 是一种神经网络架构,可以通过空间局部化、特定于任务的图像条件增强大型预训练文本到图像扩散模型。我们首先介绍下ControlNet的基本结构,
然后后面描述如何将ControlNet应用到图像扩散模型Stable Diffusion,以及Inpaint的方法
ControlNet 训练与架构详解
ControlNet 的架构
ControlNet 将附加条件注入到神经网络的块中。具体来说,ControlNet 的设计目的是在预训练模型的基础上,添加可训练的副本,以便处理新的控制信息(如草图、边缘图等)。这种设计可以保留预训练模型的优点,同时增强模型的多样性和灵活性。
核心概念
- 网络块:指一组神经层的组合,形成神经网络的一个单元,例如 ResNet 块、Conv-BN-ReLU 块、多头注意力块等。
- 锁定参数:冻结预训练模型的参数,使其保持稳定,并将其复制为可训练的副本。
- 零卷积:使用 1×1 卷积层,其权重和偏置初始化为零,确保初始训练时无有害噪声影响。
控制模块的添加
- 预训练块:
假设 F ( ⋅ ; Θ ) F(·;\Theta) F(⋅;Θ)是一个具有参数 Θ \Theta Θ的预训练神经块,将输入特征图 x x x转换为输出特征图 y y y。
y = F ( x ; Θ ) y=F(x;\Theta) y=F(x;Θ)
- ControlNet 模块:
锁定原始块的参数 Θ \Theta Θ,并克隆为可训练副本 Θ c \Theta_c Θc
可训练副本接受外部条件向量 c c c作为输入
- 计算过程:
ControlNet 的完整计算如下:
y c = F ( x ; Θ ) + Z ( F ( x + Z ( c ; Θ z 1 ) ; Z c ) ; Θ z 2 ) y_c = F(x;\Theta)+\Zeta(F(x+\Zeta(c;\Theta_{z1});\Zeta_c);\Theta_{z2}) yc=F(x;Θ)+Z(F(x+Z(c;Θz1);Zc);Θz2)
其中, Z ( ⋅ ; ⋅ ) \Zeta(·;·) Z(⋅;⋅)是零卷积层, Θ z 1 \Theta_{z1} Θz1和 Θ z 2 \Theta _{z2} Θz2是其参数。
- 零卷积的初始状态:
初始训练步骤中,零卷积的权重和偏置均为零,因此计算结果也是零,使得:
y c = y y_c = y yc=y
用于文本到图像扩散的 ControlNet
架构
- Stable Diffusion:采用 U-Net 结构,包含编码器、中间块和解码器。
- 编码器和解码器:每个包含 12 个块,总共有 25 个块。
- ViT(视觉变换器):在主块中包含多个交叉注意力和自注意力机制。
ControlNet 的集成
- 位置:ControlNet 结构应用于 U-Net 的每个编码器级别,共创建 12 个编码块和 1 个中间块的可训练副本。
- 分辨率:12 个编码块有 4 种分辨率,分别为 64×64, 32×32, 16×16 和 8×8,每种分辨率重复 3 次。
效率
- 计算效率:冻结的参数无需计算梯度,因此训练更高效。
- 资源消耗:与不使用 ControlNet 优化的 Stable Diffusion 相比,使用 ControlNet 优化只增加约 23% 的 GPU 内存和 34% 的时间。
训练过程
训练步骤
- 数据准备:
将输入图像逐渐添加噪声,生成噪声图像 z t z_t zt
条件包括时间步长 t t t、文本提示 c t c_t ct以及特定任务的条件 c f c_f cf。
- 损失函数
目标是预测添加到噪声图像中的噪声 ϵ \epsilon ϵ。
L = E z 0 , t , c t , c f , ϵ ∼ N ( 0 , 1 ) [ ∣ ∣ ϵ − ϵ θ ( z t , t , c t , c f ) ∣ ∣ 2 2 ] L = E_{z_0,t,c_t,c_f,\epsilon \sim N(0,1)}[||\epsilon - \epsilon_\theta(z_t,t,c_t,c_f)||_2^2] L=Ez0,t,ct,cf,ϵ∼N(0,1)[∣∣ϵ−ϵθ(zt,t,ct,cf)∣∣22]
- 随机替换文本提示:
训练过程中,随机将 50% 的文本提示替换为空字符串,以增强 ControlNet 直接识别输入调节图像语义的能力。
Zero卷积层的作用解释
在ControlNet架构中,zero卷积层起到了关键作用。具体来说,零卷积层的权重和偏置均初始化为零,这在模型的训练过程中起到了保护和稳定作用。以下是对每个zero卷积层作用的详细解释:
- 初始稳定性和保护作用:
初始训练时刚开始的训练,前面迭代的时候,零卷积层确保新加入的可训练网络块不会干扰预训练模型的输出。这是通过将零卷积层的权重和偏置初始化为零实现的。在训练初期,由于这些层的计算结果为零,新的可训练副本不会对模型的输出产生任何影响,从而保护了模型的稳定性。
Z ( c ; Θ z 1 ) = 0 和 Z ( c ; Θ z 2 ) = 0 \Zeta(c;\Theta_{z1}) = 0 \space和\space\Zeta(c;\Theta_{z2}) = 0 Z(c;Θz1)=0 和 Z(c;Θz2)=0
因此,初始输出为:
y c = F ( x ; Θ ) y_c = F(x;\Theta) yc=F(x;Θ)
在初始阶段,输出仅依赖于原始预训练模型,而不会受到新条件的噪声干扰。
- 渐进学习和控制注入:
随着训练的进行,zero卷积层逐渐学习到新的条件信息,并将其注入到模型中。这使得可训练副本能够逐步学习并适应新的控制信息,如草图、边缘图等。
具体来说,两个zero卷积层 Z ( c ; Θ z 1 ) \Zeta(c;\Theta_{z1}) Z(c;Θz1)和 Z ( F ( x + Z ( c ; Θ z 1 ) ; Θ c ) ; Θ z 2 ) \Zeta(F(x+\Zeta(c;\Theta_{z1});\Theta_c);\Theta_{z2}) Z(F(x+Z(c;Θz1);Θc);Θz2)
分别在两个不同的阶段起作用:
- 第一阶段:
Z ( c ; Θ z 1 ) \Zeta(c;\Theta_{z1}) Z(c;Θz1)
这个零卷积层接受外部条件向量 c c c,并将其转换为一个中间表示。
- 第二阶段:
Z ( F ( x + Z ( c ; Θ z 1 ) ; Θ c ) ; Θ z 2 ) \Zeta(F(x+\Zeta(c;\Theta_{z1});\Theta_c);\Theta_{z2}) Z(F(x+Z(c;Θz1);Θc);Θz2)
这个零卷积层接受经过第一阶段处理的输出,并进一步处理,最终将新条件信息注入到模型输出中。
随着训练的进行,这些卷积层的参数会被更新,从而使模型逐渐学习到如何在输出中包含新的控制信息。
- 消除初始训练步骤中的噪声:
零卷积层的设计确保在训练的初始阶段,不会有随机噪声影响网络的隐藏状态。这样,初始模型能够完全依赖于预训练模型的稳定性,并在此基础上逐步学习新的控制信息。
整个过程是有监督的,因为模型在训练过程中使用了成对的条件输入和目标输出图像。通过监督学习,模型逐渐学会在给定特定条件下生成相应的图像。具体步骤如下:
- 输入-输出对:提供(条件输入 c i c_i ci ,目标输出 y t a r g e t y_ {target} ytarget)对。
- Forward Pass:通过ControlNet计算生成图像 y c y_c yc。
- Loss Calculation:计算生成图像与目标输出之间的损失。
- backpropagation:根据损失更新模型参数。
通过这种有监督学习方法,ControlNet能够逐步学会在给定各种条件(如草图、深度图等)下,生成符合条件的图像,从而实现文本到图像生成的精确控制。
inpaint
论文标题:Inpaint Anything: Segment Anything Meets Image Inpainting
论文地址:https://arxiv.org/abs/2304.06790
github地址:https://github.com/geekyutao/Inpaint-Anything/tree/main
inpaint的作用:用户可以通过点击来选择图像中的任何对象。借助强大的视觉模型,例如,SAM、LaMa和稳定扩散(SD),Inpaint Anything能够平滑地移除对象(即,删除任何东西)。此外,在用户输入文本的提示下,Insaint Anything可以用任何期望的内容填充对象(即,填充任何东西)或任意替换它的背景(即,替换任何东西)。
Inpaint Anything 的重要性
当前图像修复的进展
目前,最先进的图像修复方法,如 LaMa、Repaint、MAT、ZITS等,已经取得了显著的进步。这些方法能够成功修复大面积区域,处理复杂的重复结构,并能够很好地适应高分辨率图像。然而,这些方法通常需要对每个掩码进行精细注释,这对于训练和推理至关重要。
Segment Anything Model (SAM) 的应用潜力
Segment Anything Model (SAM) 是一个强大的分割基础模型,能够根据点或框等输入提示生成高质量的对象蒙版,并且可以为图像中的所有对象生成全面且准确的蒙版。然而,SAM 的掩模分割预测功能在图像修复领域尚未得到充分利用。
现有方法的局限
现有的修复方法只能使用上下文填充被移除的区域。人工智能生成内容(AIGC)模型为创作开辟了新的机会,有可能满足大量需求,帮助用户生成他们所需的新内容。
综合解决方案的优势
通过结合 SAM、最先进的图像修复器和AIGC 模型的优势,论文提供了一个强大且用户友好的管道,来解决更常见的修复相关问题,如对象移除、新内容填充、背景替换等。这一综合方法不仅提高了图像修复的效果,还大大简化了用户操作,使得图像修复变得更加高效和便捷。
Inpaint Anything 的功能概述
删除任何内容 | 填充任何内容 | 替换任何内容 | 删除任何3D | 删除任何视频 |
---|---|---|---|---|
- 点击一个对象 | - 点击一个对象 | - 点击一个对象 | - 单击源视图的第一个视图中的对象 | - 点击视频第一帧中的对象 |
- 分割模型SAM将对象分割出来 | - SAM将目标分割出来 | - SAM将目标分割出来 | - SAM将对象分割出来(使用三个可能的掩码) | - SAM将对象分割出来(使用三个可能的掩码) |
- 修复模型填补“窟窿” | - 输入文字提示 | - 输入文字提示 | - 选择一个掩码 | - 选择一个掩码 |
- 文本提示引导的修复模型根据文本填充“洞” | - 文本提示引导的修复模型根据文本替换背景 | - 利用OSTrack等跟踪模型对这些视图中的目标进行跟踪 | - 利用OSTrack等跟踪模型对视频中的目标进行跟踪 | |
- SAM根据跟踪结果在每个源视图中分割出目标 | ||||
- 利用LaMa等修补模型对每个源视图中的对象进行修补 | ||||
- 利用NeRF等新的视图合成模型合成出不含物体的场景的新视图 |
填充内容代码如下 :fill_anything.py
import cv2
import sys
import argparse
import numpy as np
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from typing import Any, Dict, Listfrom sam_segment import predict_masks_with_sam
from stable_diffusion_inpaint import fill_img_with_sd
from utils import load_img_to_array, save_array_to_img, dilate_mask, \show_mask, show_points, get_clicked_pointdef setup_args(parser):parser.add_argument("--input_img", type=str, required=True,help="Path to a single input img",)parser.add_argument("--coords_type", type=str, required=True,default="key_in", choices=["click", "key_in"], help="The way to select coords",)parser.add_argument("--point_coords", type=float, nargs='+', required=True,help="The coordinate of the point prompt, [coord_W coord_H].",)parser.add_argument("--point_labels", type=int, nargs='+', required=True,help="The labels of the point prompt, 1 or 0.",)parser.add_argument("--text_prompt", type=str, required=True,help="Text prompt",)parser.add_argument("--dilate_kernel_size", type=int, default=None,help="Dilate kernel size. Default: None",)parser.add_argument("--output_dir", type=str, required=True,help="Output path to the directory with results.",)parser.add_argument("--sam_model_type", type=str,default="vit_h", choices=['vit_h', 'vit_l', 'vit_b', 'vit_t'],help="The type of sam model to load. Default: 'vit_h")parser.add_argument("--sam_ckpt", type=str, required=True,help="The path to the SAM checkpoint to use for mask generation.",)parser.add_argument("--seed", type=int,help="Specify seed for reproducibility.",)parser.add_argument("--deterministic", action="store_true",help="Use deterministic algorithms for reproducibility.",)if __name__ == "__main__":"""Example usage:python fill_anything.py \--input_img FA_demo/FA1_dog.png \--coords_type key_in \--point_coords 750 500 \--point_labels 1 \--text_prompt "a teddy bear on a bench" \--dilate_kernel_size 15 \--output_dir ./results \--sam_model_type "vit_h" \--sam_ckpt sam_vit_h_4b8939.pth """parser = argparse.ArgumentParser()setup_args(parser)args = parser.parse_args(sys.argv[1:])device = "cuda" if torch.cuda.is_available() else "cpu"if args.coords_type == "click":latest_coords = get_clicked_point(args.input_img)elif args.coords_type == "key_in":latest_coords = args.point_coordsimg = load_img_to_array(args.input_img)masks, _, _ = predict_masks_with_sam(img,[latest_coords],args.point_labels,model_type=args.sam_model_type,ckpt_p=args.sam_ckpt,device=device,)masks = masks.astype(np.uint8) * 255# dilate mask to avoid unmasked edge effectif args.dilate_kernel_size is not None:masks = [dilate_mask(mask, args.dilate_kernel_size) for mask in masks]# visualize the segmentation resultsimg_stem = Path(args.input_img).stemout_dir = Path(args.output_dir) / img_stemout_dir.mkdir(parents=True, exist_ok=True)for idx, mask in enumerate(masks):# path to the resultsmask_p = out_dir / f"mask_{idx}.png"img_points_p = out_dir / f"with_points.png"img_mask_p = out_dir / f"with_{Path(mask_p).name}"# save the masksave_array_to_img(mask, mask_p)# save the pointed and masked imagedpi = plt.rcParams['figure.dpi']height, width = img.shape[:2]plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))plt.imshow(img)plt.axis('off')show_points(plt.gca(), [latest_coords], args.point_labels,size=(width*0.04)**2)plt.savefig(img_points_p, bbox_inches='tight', pad_inches=0)show_mask(plt.gca(), mask, random_color=False)plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)plt.close()# fill the masked imagefor idx, mask in enumerate(masks):if args.seed is not None:torch.manual_seed(args.seed)mask_p = out_dir / f"mask_{idx}.png"img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"img_filled = fill_img_with_sd(img, mask, args.text_prompt, device=device)save_array_to_img(img_filled, img_filled_p)
上面代码以SAM 模型生成高质量的对象分割掩码,stable Diffusion 模型利用这些掩码和文本提示生成新内容,填补图像中的空洞。具体步骤如下:
- 生成掩码:SAM 根据用户提供的点坐标和标签生成对象的分割掩码。
- 膨胀掩码:对生成的掩码进行膨胀处理,以确保边缘过渡平滑。
- 生成新内容:Stable Diffusion 模型根据膨胀后的掩码和用户提供的文本提示生成新内容,填补图像中的空洞。