【扩散模型(五)】IP-Adapter 源码详解3-推理代码

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。

文章目录

  • 系列文章目录
  • 前言
  • 一、输入处理
  • 二、过 Unet
  • 三、Unet 中被替换的 CA


前言

这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb 中最基础的以图生图(Image Variations)为例:

SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:

  1. 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
  2. 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
  3. Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。

其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
在这里插入图片描述

*图中的 bs 代表 batch size

一、输入处理

IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py 文件的 IPAdapter 类的 generate() 函数中。

在这里插入图片描述

  1. 输入1: image prompt
    • 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
    • 以及训练好的 image_proj_model(ImageProjModel)
  2. 输入1对应的输出1有:
    • image_prompt_embeds
    • uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
  1. 输入2: text prompt、negative_prompt(默认的 ['monochrome, lowres, bad anatomy, worst quality, low quality']

    • text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
      • encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
      • 检查是否超过 clip 的长度
      • 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
    • negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
  2. 输入2 对应的输出2有:

    • prompt_embeds_
    • negative_prompt_embeds_
  3. 输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上 torch.cat 后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。

with torch.inference_mode():prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(prompt,device=self.device,num_images_per_prompt=num_samples,do_classifier_free_guidance=True,negative_prompt=negative_prompt,)prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

二、过 Unet

  1. 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
    • 得到进一步的 prompt_embeds 和 negative_prompt_embeds
  2. prompt_embeds 和 negative_prompt_embeds 做 torch.cat 是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
 # For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passesif self.do_classifier_free_guidance:prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1. 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
    在这里插入图片描述

三、Unet 中被替换的 CA

该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47678.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OAuth2系列】集成微信小程序登录到 Spring Security OAuth 2.0

作者:后端小肥肠 创作不易,未经允许严禁转载。 姊妹篇: 【Spring Security系列】权限之旅:SpringSecurity小程序登录深度探索_spring security 微信小程序登录-CSDN博客 目录 1. 前言 2. 总体登录流程 3. 数据表设计 3.1. sys…

Python测试服务器连接的实战代码

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

windows server——4.安装DNS管理器

windows server——4.安装DNS管理器 一、准备二、安装DNS管理器1.打开服务器管理器2.添加dns服务器 三、验证 一、准备 windows server电脑(已安装IIS) 静态网站数据包 二、安装DNS管理器 1.打开服务器管理器 2.添加dns服务器 点击管理——添加角色和…

Java语言程序设计基础篇_编程练习题*15.16(两个可移动的顶点以及它们间的距离)

*15.16(两个可移动的顶点以及它们间的距离) 请编写一个程序,显示两个分别位于(40,40)和(120,150) 的半径为10的圆,并用一条直线连接两个圆,如图15-28b所示。圆之间的距离显示在直线上。 用户可以拖动圆&am…

指标平台新书发布:智能驱动,数据管研用一体化新革命

在当下数字化经营的市场环境中,企业面临着前所未有的挑战和机遇。随着业务的不断扩展和市场的日益复杂,数据作为企业的核心资产,其重要性愈发凸显。然而“数据孤岛和数据不清晰”这一问题却成为了制约企业数字化进程和竞争力的关键因素。为了…

Windows下载、安装、部署Redis服务的详细流程

本文介绍在Windows电脑中,下载、安装、部署并运行Redis数据库服务的方法。 Redis(Remote Dictionary Server)是一个开源、高性能的键值存储系统,最初由Salvatore Sanfilippo在2009年发布,并由Redis Labs维护。Redis因其…

<数据集>水果识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:10012张 标注数量(xml文件个数):10012 标注数量(txt文件个数):10012 标注类别数:7 标注类别名称:[Watermelon, Orange, Grape, Apple, peach, Banana, Pineapple] 序…

自建网站统计工具 Umami 替代 Google Analytics

本文首发于只抄博客,欢迎点击原文链接了解更多内容。 前言 Umami 是一款开源的网站统计工具,与 Google Analytics 相比更加的轻量,且不会收集网站用户的个人信息。同时,Umami 的仪表盘界面简洁,UI 精美,方便我们查看网站的历史统计数据。 Umami 使用方式也与 Google Ana…

n7.Nginx 第三方模块

Nginx 第三方模块 第三模块是对nginx 的功能扩展,第三方模块需要在编译安装Nginx 的时候使用参数–add-modulePATH指定路径添加,有的模块是由公司的开发人员针对业务需求定制开发的,有的模块是开 源爱好者开发好之后上传到github进行开源的模…

《0基础》学习Python——第二十四讲__爬虫/<7>深度爬取

一、深度爬取 深度爬取是指在网络爬虫中,获取网页上的所有链接并递归地访问这些链接,以获取更深层次的页面数据。 通常,一个简单的爬虫只会获取到初始页面上的链接,并不会进一步访问这些链接上的其他页面。而深度爬取则会不断地获…

python os库使用教程

os库使用教程 1.创建文件夹os.path.exists()检查文件是否存在os.listdir查看文件夹下的所有文件filename.endswith()查看文件列表的png或者txt结尾的所有文件shutil.move移动目标到文件夹 1.创建文件夹 先在盘符里创建一个文件用来演示,我这里…

前端JS特效第48集:terseBanner焦点图轮播插件

terseBanner焦点图轮播插件&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatibl…

HTTPServer改进思路1

Nginx源码思考项目改进 架构模式 事件驱动架构(EDA&#xff09;用于处理大量并发连接和IO操作 优点&#xff1a;高效处理大量并发请求&#xff0c;减少线程切换和阻塞调用技术实现&#xff1a;直接使用EPOLL&#xff0c;参考Node.js的http服务器 网络通信 协议&#xff1a;HTT…

【LeetCode】对称二叉树

目录 一、题目二、解法完整代码 一、题目 给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 示例 1&#xff1a; 输入&#xff1a;root [1,2,2,3,4,4,3] 输出&#xff1a;true 示例 2&#xff1a; 输入&#xff1a;root [1,2,2,null,3,null,3] 输出&#…

友力科技数据中心搬迁方案

将当前运行机房中的所有设备、应用系统安全搬迁至新数据中心机房&#xff0c;实现平滑切换、平稳过渡&#xff0c;最大限度地降低搬迁工作对业务的影响。 为了确保企事业单位能够顺利完成数据中心机房搬迁工作&#xff0c;我们根据实际经验提供了4个基本原则&#xff0c;希望能…

异步电机矢量控制matlab simulink

1、内容简介 略 86-可以交流、咨询、答疑 异步电机、矢量控制 2、内容说明 略 3、仿真分析 略 4、参考论文 略

YOLOv2小白精讲

YOLOv2是一个集成了分类和检测任务的神经网络&#xff0c;它将目标检测和分类任务统一在一个单一的网络中进行处理。 本文在yolov1的基础上&#xff0c;对yolov2的网络结构和改进部分进行讲解。yolov1的知识点可以看我另外一篇博客&#xff08;yolov1基础精讲-CSDN博客&#xf…

设计模式-抽象工厂

抽象工厂属于创建型模式。 抽象工厂和工厂设计模式的区别&#xff1a; 工厂模式的是设计模式中最简单的一种设计模式&#xff0c;主要设计思想是&#xff0c;分离对象的创建和使用&#xff0c;在Java中&#xff0c;如果需要使用一个对象时&#xff0c;需要new Class()&#xff…

RAG-LLM Survey

大模型虽然厉害&#xff0c;但是存在着幻觉、知识陈旧等问题。检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;可以通过挂载外部知识库&#xff0c;来提升生成内容的准确性和可信度。了解一个研究方向的最快的方法&#xff0c;就是阅读相关的综述。今…

Python数据可视化------动态柱状图

一、基础柱状图 # 基础柱状图 # 导包 from pyecharts.charts import Bar from pyecharts.options import *# 构建柱状图 bar Bar() # 添加数据&#xff08;列表&#xff09; x_list ["张三", "李四", "王五", "赵六"] y_list [50,…