【扩散模型(五)】IP-Adapter 源码详解3-推理代码

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。

文章目录

  • 系列文章目录
  • 前言
  • 一、输入处理
  • 二、过 Unet
  • 三、Unet 中被替换的 CA


前言

这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb 中最基础的以图生图(Image Variations)为例:

SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:

  1. 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
  2. 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
  3. Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。

其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
在这里插入图片描述

*图中的 bs 代表 batch size

一、输入处理

IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py 文件的 IPAdapter 类的 generate() 函数中。

在这里插入图片描述

  1. 输入1: image prompt
    • 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
    • 以及训练好的 image_proj_model(ImageProjModel)
  2. 输入1对应的输出1有:
    • image_prompt_embeds
    • uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
  1. 输入2: text prompt、negative_prompt(默认的 ['monochrome, lowres, bad anatomy, worst quality, low quality']

    • text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
      • encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
      • 检查是否超过 clip 的长度
      • 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
    • negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
  2. 输入2 对应的输出2有:

    • prompt_embeds_
    • negative_prompt_embeds_
  3. 输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上 torch.cat 后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。

with torch.inference_mode():prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(prompt,device=self.device,num_images_per_prompt=num_samples,do_classifier_free_guidance=True,negative_prompt=negative_prompt,)prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

二、过 Unet

  1. 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
    • 得到进一步的 prompt_embeds 和 negative_prompt_embeds
  2. prompt_embeds 和 negative_prompt_embeds 做 torch.cat 是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
 # For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passesif self.do_classifier_free_guidance:prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1. 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
    在这里插入图片描述

三、Unet 中被替换的 CA

该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47678.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OAuth2系列】集成微信小程序登录到 Spring Security OAuth 2.0

作者:后端小肥肠 创作不易,未经允许严禁转载。 姊妹篇: 【Spring Security系列】权限之旅:SpringSecurity小程序登录深度探索_spring security 微信小程序登录-CSDN博客 目录 1. 前言 2. 总体登录流程 3. 数据表设计 3.1. sys…

Python测试服务器连接的实战代码

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

windows server——4.安装DNS管理器

windows server——4.安装DNS管理器 一、准备二、安装DNS管理器1.打开服务器管理器2.添加dns服务器 三、验证 一、准备 windows server电脑(已安装IIS) 静态网站数据包 二、安装DNS管理器 1.打开服务器管理器 2.添加dns服务器 点击管理——添加角色和…

Java语言程序设计基础篇_编程练习题*15.16(两个可移动的顶点以及它们间的距离)

*15.16(两个可移动的顶点以及它们间的距离) 请编写一个程序,显示两个分别位于(40,40)和(120,150) 的半径为10的圆,并用一条直线连接两个圆,如图15-28b所示。圆之间的距离显示在直线上。 用户可以拖动圆&am…

指标平台新书发布:智能驱动,数据管研用一体化新革命

在当下数字化经营的市场环境中,企业面临着前所未有的挑战和机遇。随着业务的不断扩展和市场的日益复杂,数据作为企业的核心资产,其重要性愈发凸显。然而“数据孤岛和数据不清晰”这一问题却成为了制约企业数字化进程和竞争力的关键因素。为了…

Windows下载、安装、部署Redis服务的详细流程

本文介绍在Windows电脑中,下载、安装、部署并运行Redis数据库服务的方法。 Redis(Remote Dictionary Server)是一个开源、高性能的键值存储系统,最初由Salvatore Sanfilippo在2009年发布,并由Redis Labs维护。Redis因其…

软考高级第四版备考--第27天(项目工作绩效域)

核心概念: 项目工作可使团队保持专注,并使项目活动顺序进行实现的预期目标主要包含:高效且有效的项目绩效;适合项目和环境的项目过程;干系人适当的沟通和参与;对实物资源进行了有效的管理;对采购进行有效管…

创建React应用的2种方式

一、使用create-react-app创建 1、全局安装脚手架库: create-react-app npm i -g create-react-app 2、创建项目:create-react-app (my-app)项目名称; create-react-app my-app 3、进入项目文件夹 cd my-app 4、运行项目 npm start 二、使用vite创建&…

<数据集>水果识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:10012张 标注数量(xml文件个数):10012 标注数量(txt文件个数):10012 标注类别数:7 标注类别名称:[Watermelon, Orange, Grape, Apple, peach, Banana, Pineapple] 序…

自建网站统计工具 Umami 替代 Google Analytics

本文首发于只抄博客,欢迎点击原文链接了解更多内容。 前言 Umami 是一款开源的网站统计工具,与 Google Analytics 相比更加的轻量,且不会收集网站用户的个人信息。同时,Umami 的仪表盘界面简洁,UI 精美,方便我们查看网站的历史统计数据。 Umami 使用方式也与 Google Ana…

n7.Nginx 第三方模块

Nginx 第三方模块 第三模块是对nginx 的功能扩展,第三方模块需要在编译安装Nginx 的时候使用参数–add-modulePATH指定路径添加,有的模块是由公司的开发人员针对业务需求定制开发的,有的模块是开 源爱好者开发好之后上传到github进行开源的模…

《0基础》学习Python——第二十四讲__爬虫/<7>深度爬取

一、深度爬取 深度爬取是指在网络爬虫中,获取网页上的所有链接并递归地访问这些链接,以获取更深层次的页面数据。 通常,一个简单的爬虫只会获取到初始页面上的链接,并不会进一步访问这些链接上的其他页面。而深度爬取则会不断地获…

计数排序(桶排序思想)

这段代码是一个计数排序算法的实现。计数排序是一种非比较排序算法,适用于整数数组,其时间复杂度为O(nk),其中n是数组长度,k是数组中的最大值。以下是该算法的步骤: 首先检查输入数组是否为空或长度小于2,…

python os库使用教程

os库使用教程 1.创建文件夹os.path.exists()检查文件是否存在os.listdir查看文件夹下的所有文件filename.endswith()查看文件列表的png或者txt结尾的所有文件shutil.move移动目标到文件夹 1.创建文件夹 先在盘符里创建一个文件用来演示,我这里…

前端JS特效第48集:terseBanner焦点图轮播插件

terseBanner焦点图轮播插件&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatibl…

word转pdf图变得模糊(解决)

日常小记 目录问题解决方案 结语 目录 问题 word转pdf图变得模糊后图变得不清晰 解决方案 首先在ppt中进行画图其次复制该图的所有元素直接复制到word&#xff0c;在粘贴中选中选择性粘贴&#xff0c;增强性图形即可解决&#xff01;&#xff01;&#xff01; 其余方案 可以…

Go语言 流程控制和循环语句

本文主要内容为Go语言中流程控制语句和循环语句介绍及示例。 目录 流程控制语句 If条件语句 If使用规则 表达式语句 Switch语句 使用fallthrough 判断表达式结果 从命令行获取参数 For循环语句 简单循环 省略循环条件 For无限循环 For循环中的continue 新型for循…

java中处理stream.filter()

在Java中&#xff0c;stream.filter方法用于对流中的元素进行筛选。filter方法接受一个Predicate&#xff08;一个返回布尔值的函数&#xff09;&#xff0c;然后返回一个包含所有匹配元素的新流。 使用场景 假设有一个包含多个元素的集合&#xff0c;需要对其中的元素进行筛…

HTTPServer改进思路1

Nginx源码思考项目改进 架构模式 事件驱动架构(EDA&#xff09;用于处理大量并发连接和IO操作 优点&#xff1a;高效处理大量并发请求&#xff0c;减少线程切换和阻塞调用技术实现&#xff1a;直接使用EPOLL&#xff0c;参考Node.js的http服务器 网络通信 协议&#xff1a;HTT…

Java 随笔记: 集合与泛型

文章目录 1. 集合框架概述2. 集合接口2.1 Collection 接口2.2 List 接口2.3 Set 接口2.4 Map 接口 3. 集合的常用操作3.1 添加元素3.2 删除元素3.3 遍历元素3.4 判断大小3.5 判断是否为空 4. 迭代器4.1 迭代器的作用4.2 迭代器的使用4.3 迭代器与增强 for 循环4.4 迭代器的注意…