扩散模型实战(九):使用CLIP模型引导和控制扩散模型

推荐阅读列表:

 扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

扩散模型实战(七):Diffusers蝴蝶图像生成实战

扩散模型实战(八):微调扩散模型

上篇文章中介绍了如何微调扩散模型,有时候微调的效果仍然不能满足需求,比如图片编辑,3D模型输出等都需要对生成的内容进行控制,本文将初步探索一下如何控制扩散模型的输出。

我们将使用在LSUM bedrooms数据集上训练并在WikiArt数据集上微调的模型,首先加载模型来查看一下模型的生成效果:

!pip install -qq diffusers datasets accelerate wandb open-clip-torch

 

import numpy as npimport torchimport torch.nn.functional as Fimport torchvisionfrom datasets import load_datasetfrom diffusers import DDIMScheduler, DDPMPipelinefrom matplotlib import pyplot as pltfrom PIL import Imagefrom torchvision import transformsfrom tqdm.auto import tqdm

 

device = (    "mps"    if torch.backends.mps.is_available()    else "cuda"    if torch.cuda.is_available()    else "cpu")

 

# 载入一个预训练过的管线pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) # 使用DDIM调度器,仅用40步生成一些图片scheduler = DDIMScheduler.from_pretrained(pipeline_name)scheduler.set_timesteps(num_inference_steps=40) # 将随机噪声作为出发点x = torch.randn(8, 3, 256, 256).to(device) # 使用一个最简单的采样循环for i, t in tqdm(enumerate(scheduler.timesteps)):    model_input = scheduler.scale_model_input(x, t)    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]    x = scheduler.step(noise_pred, t, x).prev_sample # 查看生成结果,如图5-10所示grid = torchvision.utils.make_grid(x, nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

图片

       正如上图所示,模型可以生成一些图片,那么如何进行控制输出呢?下面我们以控制图片生成绿色风格为例介绍AIGC模型控制:

       思路是:定义一个均方误差损失函数,让生成的图片像素值尽量接近目标颜色;

 

def color_loss(images, target_color=(0.1, 0.9, 0.5)):    """给定一个RGB值,返回一个损失值,用于衡量图片的像素值与目标颜色相差多少; 这里的目标颜色是一种浅蓝绿色,对应的RGB值为(0.1, 0.9, 0.5)"""    target = (        torch.tensor(target_color).to(images.device) * 2 - 1    )  # 首先对target_color进行归一化,使它的取值区间为(-1, 1)     target = target[        None, :, None, None    ]  # 将所生成目标张量的形状改为(b, c, h, w),以适配输入图像images的 # 张量形状    error = torch.abs(        images - target    ).mean()  # 计算图片的像素值以及目标颜色的均方误差    return error

接下来,需要修改采样循环操作,具体操作步骤如下:

  1. 创建输入图像X,并设置requires_grad设置为True;

  2. 计算“去噪”后的图像X0;

  3. 将“去噪”后的图像X0传递给损失函数;

  4. 计算损失函数对输入图像X的梯度;

  5. 在使用调度器之前,先用计算出来的梯度修改输入图像X,使输入图像X朝着减少损失值的方向改进

实现上述步骤有两种方法:

方法一:从UNet中获取噪声预测,并将输入图像X的requires_grad属性设置为True,这样可以充分利用内存(因为不需要通过扩散模型追踪梯度),但是这会导致梯度的精度降低;

方法二:先将输入图像X的requires_grad属性设置为True,然后传递给UNet并计算“去噪”后的图像X0;

下面分别看一下这两种方法的效果:

 

# 第一种方法 # guidance_loss_scale用于决定引导的强度有多大guidance_loss_scale = 40  # 可设定为5~100的任意数字 x = torch.randn(8, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)):     # 准备模型输入    model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     # 设置x.requires_grad为True    x = x.detach().requires_grad_()     # 得到“去噪”后的图像    x0 = scheduler.step(noise_pred, t, x).pred_original_sample     # 计算损失值    loss = color_loss(x0) * guidance_loss_scale    if i % 10 == 0:        print(i, "loss:", loss.item())     # 获取梯度    cond_grad = -torch.autograd.grad(loss, x)[0]     # 使用梯度更新x    x = x.detach() + cond_grad     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample# 查看结果grid = torchvision.utils.make_grid(x, nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出0 loss: 29.3701839447021510 loss: 12.11665058135986320 loss: 11.64170455932617230 loss: 11.78276252746582

图片

 

# 第二种方法:在模型预测前设置好x.requires_gradguidance_loss_scale = 40x = torch.randn(4, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)):     # 首先设置好requires_grad    x = x.detach().requires_grad_()    model_input = scheduler.scale_model_input(x, t)     # 预测    noise_pred = image_pipe.unet(model_input, t)["sample"]     # 得到“去噪”后的图像    x0 = scheduler.step(noise_pred, t, x).pred_original_sample     # 计算损失值    loss = color_loss(x0) * guidance_loss_scale    if i % 10 == 0:        print(i, "loss:", loss.item())     # 获取梯度    cond_grad = -torch.autograd.grad(loss, x)[0]     # 根据梯度修改x    x = x.detach() + cond_grad     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample  grid = torchvision.utils.make_grid(x, nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出0 loss: 27.6226882934570310 loss: 16.84250640869140620 loss: 15.5464210510253930 loss: 15.545379638671875

图片

       从上图看出,第二种方法效果略差,但是第二种方法的输出更接近训练模型所使用的数据,也可以通过修改guidance_loss_scale参数来增强颜色的迁移效果。

CLIP控制图像生成

       虽然上述方式可以引导和控制图像生成某种颜色,但现在LLM更主流的方式是通过Prompt(仅仅打几行字描述需求)来得到自己想要的图像,那么CLIP是一个不错的选择。CLIP是有OpenAI开发的图文匹配大模型,由于这个过程是可微分的,所以可以将其作为损失函数来引导扩散模型。

使用CLIP控制图像生成的基本流程如下

  1. 使用CLIP模型对Prompt表示为512embedding向量;

  2. 在扩散模型的生成过程中需要多次执行如下步骤:

    1)生成多个“去噪”图像;

    2)对生成的每个“去噪”图像用CLIP模型进行embedding,并对Prompt embedding和图像的embedding进行对比;

    3)计算Prompt和“去噪”后图像的梯度,使用这个梯度先更新输入图像X,然后再使用调度器更新X;

    加载CLIP模型

 

import open_clip clip_model, _, preprocess = open_clip.create_model_and_transforms(    "ViT-B-32", pretrained="openai")clip_model.to(device) # 图像变换:用于修改图像尺寸和增广数据,同时归一化数据,以使数据能够适配CLIP模型 tfms = torchvision.transforms.Compose(    [        torchvision.transforms.RandomResizedCrop(224),# 随机裁剪        torchvision.transforms.RandomAffine(5),       # 随机扭曲图片        torchvision.transforms.RandomHorizontalFlip(),# 随机左右镜像, # 你也可以使用其他增广方法        torchvision.transforms.Normalize(            mean=(0.48145466, 0.4578275, 0.40821073),            std=(0.26862954, 0.26130258, 0.27577711),        ),    ]) # 定义一个损失函数,用于获取图片的特征,然后与提示文字的特征进行对比def clip_loss(image, text_features):    image_features = clip_model.encode_image(        tfms(image)    )  # 注意施加上面定义好的变换    input_normed = torch.nn.functional.normalize(image_features.       unsqueeze(1), dim=2)    embed_normed = torch.nn.functional.normalize(text_features.       unsqueeze(0), dim=2)    dists = (        input_normed.sub(embed_normed).norm(dim=2).div(2).           arcsin().pow(2).mul(2)    )  # 使用Squared Great Circle Distance计算距离    return dists.mean()

      下面是引导模型生成图像的过程,步骤与上述类似,只需要把color_loss()替换成CLIP的损失函数

 

prompt = "Red Rose (still life), red flower painting" # 读者可以探索一下这些超参数的影响guidance_scale = 8n_cuts = 4 # 这里使用稍微多一些的步数scheduler.set_timesteps(50) # 使用CLIP从提示文字中提取特征text = open_clip.tokenize([prompt]).to(device)with torch.no_grad(), torch.cuda.amp.autocast():    text_features = clip_model.encode_text(text) x = torch.randn(4, 3, 256, 256).to(    device)  for i, t in tqdm(enumerate(scheduler.timesteps)):     model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     cond_grad = 0     for cut in range(n_cuts):         # 设置输入图像的requires_grad属性为True        x = x.detach().requires_grad_()         # 获得“去噪”后的图像        x0 = scheduler.step(noise_pred, t, x).pred_original_sample         # 计算损失值        loss = clip_loss(x0, text_features) * guidance_scale         # 获取梯度并使用n_cuts进行平均        cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts     if i % 25 == 0:        print("Step:", i, ", Guidance loss:", loss.item())     # 根据这个梯度更新x    alpha_bar = scheduler.alphas_cumprod[i]    x = (        x.detach() + cond_grad * alpha_bar.sqrt()    )  # 注意这里的缩放因子     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample  grid = torchvision.utils.make_grid(x.detach(), nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出Step: 0 , Guidance loss: 7.418107986450195Step: 25 , Guidance loss: 7.085518836975098

图片

       上述生成的图像虽然不够完美,但可以调整一些超参数,比如梯度缩放因子alpha_bar.sqrt(),虽然理论上存在所谓的正确的缩放这些梯度方法,但在实践中仍需要实验来检验,下面介绍一些常用的方案:

 

plt.plot([1 for a in scheduler.alphas_cumprod], label="no scaling")plt.plot([a for a in scheduler.alphas_cumprod], label="alpha_bar")plt.plot([a.sqrt() for a in scheduler.alphas_cumprod],     label="alpha_bar.sqrt()")plt.plot(    [(1 - a).sqrt() for a in scheduler.alphas_cumprod], label="(1-     alpha_bar).sqrt()")plt.legend()

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/145017.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口自动化测试面试题

前言 前面总结了一篇关于接口测试的常规面试题,现在接口自动化测试用的比较多,也是被很多公司看好。那么想做接口自动化测试需要具备哪些能力呢? 也就是面试的过程中,面试官会考哪些问题,知道你是不是真的做过接口自动…

京东数据挖掘(京东运营数据分析):2023年宠物行业数据分析报告

随着社会经济的发展,人均收入水平逐渐提高,使得宠物成为越来越多家庭的成员,宠物数量不断增长。伴随养宠人群的增多,宠物相关产业的发展也不断升温,宠物经济规模持续增长。 根据鲸参谋平台的数据显示,在宠物…

传统游戏难产 育碧瞄向Web3

出品过《刺客信条》的游戏大厂育碧(Ubisoft)又在Web3游戏领域有了新动作。 首次试水NFT无功而返后,育碧(Ubisoft)战略创新实验室与Web3游戏网络Immutable达成合作,将利用Immutable 开发游戏的经验和及生态…

C++ 基础二

文章目录 四、流程控制语句4.1 选择结构4.1.1 if语句 4.1.2 三目运算符4.1.3 switch语句注意事项 4.1.4 if和switch的区别【CHAT】4.2 循环结构4.2.1 while循环语句4.2.2 do...while循环语句 4.2.3 for循环语句九九乘法表 4.3 跳转语句4.3.1 break语句4.3.2 continue语句4.3.3 …

python趣味编程-5分钟实现一个测验应用程序(含源码、步骤讲解)

Python测验是用 Python 编程语言编写的,这个关于 Python 编程的简单测验是一个简单的项目,用于测试一个人在给定主题考试中的知识能力。 Python 中的 Quiz项目仅包含用户端。用户必须先登录或注册才能开始Python 测验。 此外,还规定了解决问题的时间。用户应在时间结束前解…

计算机网络:网络层ARP协议

在实现IP通信时使用了两个地址:IP地址(网络层地址)和MAC地址(数据链路层地址) 问题:已知一个机器(主机或路由器)的IP地址,如何找到相应的MAC地址? 为了解决…

力扣刷题篇之数与位3

系列文章目录 目录 系列文章目录 前言 数学问题 总结 前言 本系列是个人力扣刷题汇总,本文是数与位。刷题顺序按照[力扣刷题攻略] Re:从零开始的力扣刷题生活 - 力扣(LeetCode) 数学问题 204. 计数质数 - 力扣(Le…

C# chatGPT API调用示例

# C# API现在需要Verify your phone number to create an API key using Newtonsoft.Json; using System.Text;class Program {static readonly HttpClient client new HttpClient();static async Task Main(){try{// 设置 API 密钥string apiKey "your api";clie…

Elasticsearch 8.9 Bulk批量给索引增加数据源码

一、相关API的handler二、RestBulkAction,组装bulkRequest调用TransportBulkAction三、TransportBulkAction 会把数据分发到数据该到的数据节点1、把数据按分片分组,按分片分组数据再发送到指定的数据节点(1) 计算此文档发往哪个分片1)根据索引是否是分区…

【Linux】vscode远程连接ubuntu失败

VSCode远程连接ubuntu服务器 这部分网上有很多,都烂大街了,自己搜吧。给个参考连接:VSCode远程连接ubuntu服务器 注意,这里我提前设置了免密登录。至于怎么设置远程免密登录,可以看其它帖子,比如这个。 …

Ps:锁定图层

使用“图层”面板上的锁定图层 Lock Layer功能可以完全或部分锁定图层以保护其内容。 比如,在完成某个图层后希望它不再被修改(包括不透明度和图层样式等),可将其完全锁定。 如果不想更改图像,但对其摆放位置还在犹豫不…

SpringBoot整合Quartz示例

数据表 加不加无所谓,如果需要重启服务器后重新执行所有JOB就把sql加上 如果不加表 将application.properties中的quartz数据库配置去掉 自己执行自己的逻辑来就好,大不了每次启动之后重新加载自己的逻辑 链接:https://pan.baidu.com/s/1KqOPYMfI4eHcEMxt5Bmt…

Unittest框架--自动化

Python中方法的传递 参数化 pip install parameterized -i https://pypi.douban.com/simple需求:验证Tpshop登录 # 断言参数化 import time import unittest from parameterized import parameterized from selenium import webdriver from selenium.webdriver.co…

小程序游戏、App游戏与H5游戏:三种不同的游戏开发与体验方式

在当今数字化的时代,游戏开发者面临着多种选择,以满足不同用户群体的需求。小程序游戏、App游戏和H5游戏是三种流行的游戏开发和发布方式,它们各自具有独特的特点和适用场景。 小程序游戏:轻巧便捷的社交体验 小程序游戏是近年来…

OpenCV快速入门:基本操作

文章目录 1. 像素操作1.1 像素统计1.2 两个图像之间的操作1.2.1 图像加法操作1.2.3 图像加权混合 1.3 二值化1.4 LUT(查找表)1.4.1 查找表原理1.4.2 代码演示 2 图像变换2.1 旋转操作2.1.1 旋转的基本原理2.1.2 代码实现 2.2 缩放操作2.3 平移操作2.3.1 …

第四代智能井盖传感器:万宾科技智能井盖位移监测方式一览

现在城市化水平不断提高,每个城市的井盖遍布在城市的街道上,是否能够实现常态化和系统化的管理,反映了一个城市治理现代化水平。而且近些年来住建部曾多次要求全国各个城市加强相关的井盖管理工作,作为基础设施重要的一个组成部分…

TP-LINK联洲面试题

文章目录 1.说一下微服务架构?2.微服务优缺点3.负载均衡的实现算法4.Redis集群部署方式?5.MySQL主从复制?5.1 配置流程5.2 优缺点分析6.口头手撕快排7.队列实现栈和栈实现队列7.1 队列实现栈7.2 栈实现队列8.进程有几种状态?9.Spring Boot Actuator?10.外键、主键和索引?…

python科研绘图:面积图

目录 1、面积图 2、堆积面积图 1、面积图 面积图是一种数据可视化图表,用于展示数据随时间或其他有序类别的变化趋势。它与折线图相似,但在展示数据变化的同时,面积图还强调了各个数据点之间的累积关系。这种图表通常通过在折线下方填充颜…

2023最新最全【虚幻4引擎】下载安装零基础教程

1、创建Epic Games账户 我们先打开浏览器,输入以下网址:unrealengine.com 随后点击【立即开始】 选择许可证类型,此处提供三种选项,分别是【游戏】、【非游戏】以及【私人定制】 第一类许可证适用于游戏和商业互动产品&#xff…

STM32笔记—USART

课外知识插入:STM32单片机extern全局变量_stm32全局变量-CSDN博客 如果你把temple定义在A中,然后让A.h和B.h包含在includes.h中,然后把includes.h放在A.c和B.c中单个编译是没有问题的,但是链接的时候会出现问题, “S…