详细解读个性化定制大杀器IP-Adapter代码

Diffusion models代码解读:入门与实战

前言:IP-Adapter作为Diffusion Models最成功的技术之一,已经在诸多互联网应用中落地。介绍IP-Adapter原理和应用的博客有很多,但是逐行详细解读代码的博客很少。这篇博客从细节出发,结合原理详细解读个性化定制大杀器IP-Adapter代码。

目录

原理概述

代码详解

冻结模型

Image Projection

注意力替换

新注意力计算

打包IP-Adapter

推理

源码地址


原理概述

一句话概括:原有的Cross Attention计算是用text condition计算的;IP-Adapter在原有Cross Attention计算上加上了image condition。其中Q共用,K V重新计算。

在整个训练过程中,冻结了最初的 UNet 模型,所以在上述解耦交叉注意力中,只有W′k,W′v 两个参数是可训练的。而且W′k,W′v 的权重从原来对应的Cross Attention初始化。

代码详解

冻结模型

这张图中蓝色部分全部冻结:

    unet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)

红色部分可训练:

params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())

Image Projection

对应于下图中的这个部分,由一个Linear层和一个Norm组成:

    image_proj_model = ImageProjModel(cross_attention_dim=unet.config.cross_attention_dim,clip_embeddings_dim=image_encoder.config.projection_dim,clip_extra_context_tokens=4,)
class ImageProjModel(torch.nn.Module):"""Projection Model"""def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):super().__init__()self.generator = Noneself.cross_attention_dim = cross_attention_dimself.clip_extra_context_tokens = clip_extra_context_tokensself.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)self.norm = torch.nn.LayerNorm(cross_attention_dim)def forward(self, image_embeds):embeds = image_embedsclip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)clip_extra_context_tokens = self.norm(clip_extra_context_tokens)return clip_extra_context_tokens

经过这个模型得到Image Feature,它的维度需要与计算Attention的维度对齐。

注意力替换

在原始的Unet模型中,需要把Cross Attention的计算替换成新的,但是self- attention不变。

其中attn1.processor 是 self- attention,这部分不变;attn2.processor代表交叉注意力层,这部分替换!

W′k,W′v 的权重从原来对应的Cross Attention初始化!也就是代码中的:

        if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)attn_procs[name].load_state_dict(weights)

这部分的完整代码如下:

    attn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

新注意力计算

现在已经在Unet中把原来的Cross注意力计算都替换并且组装好了。

下一步我们看看被替换的注意力是如何计算的,也就是看看IPAttnProcessor 如何实现。

class IPAttnProcessor(nn.Module):r"""Attention processor for IP-Adapater.Args:hidden_size (`int`):The hidden size of the attention layer.cross_attention_dim (`int`):The number of channels in the `encoder_hidden_states`.scale (`float`, defaults to 1.0):the weight scale of image prompt.num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):The context length of the image features."""def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):super().__init__()self.hidden_size = hidden_sizeself.cross_attention_dim = cross_attention_dimself.scale = scaleself.num_tokens = num_tokensself.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselse:# get encoder_hidden_states, ip_hidden_statesend_pos = encoder_hidden_states.shape[1] - self.num_tokensencoder_hidden_states, ip_hidden_states = (encoder_hidden_states[:, :end_pos, :],encoder_hidden_states[:, end_pos:, :],)if attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# for ip-adapterip_key = self.to_k_ip(ip_hidden_states)ip_value = self.to_v_ip(ip_hidden_states)ip_key = attn.head_to_batch_dim(ip_key)ip_value = attn.head_to_batch_dim(ip_value)ip_attention_probs = attn.get_attention_scores(query, ip_key, None)self.attn_map = ip_attention_probsip_hidden_states = torch.bmm(ip_attention_probs, ip_value)ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)hidden_states = hidden_states + self.scale * ip_hidden_states# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

核心是计算出ip_hidden_states,然后与原来的hidden_state相加:

        ip_key = self.to_k_ip(ip_hidden_states)ip_value = self.to_v_ip(ip_hidden_states)ip_key = attn.head_to_batch_dim(ip_key)ip_value = attn.head_to_batch_dim(ip_value)ip_attention_probs = attn.get_attention_scores(query, ip_key, None)self.attn_map = ip_attention_probsip_hidden_states = torch.bmm(ip_attention_probs, ip_value)ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

相加的时候有一个scale,这个scale是一个固定的参数1:

hidden_states = hidden_states + self.scale * ip_hidden_states

打包IP-Adapter

因为是用accelerate训练的,所以把可训练的部分都打包成一个单独的类 ip_adapter,用accelerator包装一下,就可以很方便更新梯度和保存权重。

ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

推理

累了,下次再写...

    def generate(self,pil_image=None,clip_image_embeds=None,prompt=None,negative_prompt=None,scale=1.0,num_samples=4,seed=None,guidance_scale=7.5,num_inference_steps=30,**kwargs,):self.set_scale(scale)if pil_image is not None:num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)else:num_prompts = clip_image_embeds.size(0)if prompt is None:prompt = "best quality, high quality"if negative_prompt is None:negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"if not isinstance(prompt, List):prompt = [prompt] * num_promptsif not isinstance(negative_prompt, List):negative_prompt = [negative_prompt] * num_promptsimage_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)bs_embed, seq_len, _ = image_prompt_embeds.shapeimage_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)with torch.inference_mode():prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(prompt,device=self.device,num_images_per_prompt=num_samples,do_classifier_free_guidance=True,negative_prompt=negative_prompt,)prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)generator = get_generator(seed, self.device)images = self.pipe(prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,guidance_scale=guidance_scale,num_inference_steps=num_inference_steps,generator=generator,**kwargs,).imagesreturn images

源码地址

IP-Adapter/tutorial_train.py at main · tencent-ailab/IP-Adapter · GitHub

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/58382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据采集之scrapy框架2

本博文使用自动化爬虫框架完成微信开放社区文档信息的爬取(重点理解 scrapy 框架自动化爬 虫构建过程,能够分析 LinkExtractor 和 Rule 规则的基本用法) 包结构目录如下图所示: 主要代码: ( items.p…

深⼊理解指针(2)

目录 1. const修饰指针及变量 2. 野指针 3. assert断⾔ 4. 指针的传址调⽤ 一 const修饰指针及变量(const是场属性——不能改变的属性) 1 const修饰变量 那怎么证明被const修饰的变量本质还是变量呢? 上面我们绕过n,使…

每日科技资讯:2024年11月06日【龙】农历十月初六 ---文末送书

目录 1.OpenAI因算力瓶颈暂缓GPT-5发布 合作芯片开发寻求突破2.现在,𝕏 允许被你屏蔽的人继续查看你的帖子3.硬刚Intel与AMD!NVIDIA明年推出PC芯片4.苹果停止签署 iOS 18.0.1,不再允许从 18.1 降级5.Nvidia 加入道琼斯指数成份股 …

swoole扩展安装--入门篇

对于php来说,swoole是个强大的补充扩展。这是我第3次写swoole扩展安装,这次基于opencloudos8系统,php使用8.2。 安装swoole扩展首先想到的是用宝塔来安装,毕竟安装方便,还能统一管理。虽然获得swoole版本不是最新的&am…

【大模型开发指南】llamaindex配置deepseek、jina embedding及chromadb实现本地RAG及知识库(win系统、CPU适配)

说一些坑,本来之前准备用milvus,但是发现win搞不了(docker都配好了)。然后转头搞chromadb。这里面还有就是embedding一般都是本地部署,但我电脑是cpu的没法玩,我就选了jina的embedding性能较优(…

pyspark基础准备

1.前言介绍 学习目标:了解什么是Speak、PySpark,了解为什么学习PySpark,了解课程是如何和大数据开发方向进行衔接 使用pyspark库所写出来的代码,既可以在电脑上简单运行,进行数据分析处理,又可以把代码无缝…

数据库基础(4) . 数据库结构

2.基础结构 2.1.结构及名称 数据库 database 表空间 tablespaces(Oracle) 表格 table 字段 column 记录 record 值 value 2.2.数据库 database 在配置文件中指定存放位置 # 设置mysql数据库的数据的存放目录 datadirD:\MySQL\mysql-8.0.16-winx64\data每个数据库对应…

Meme 币生态全景图分析:如何获得超额收益?

近期,BTC 再次突破 7 万美元大关,市场上贪婪指数再次达到 80,而 Meme 币往往是每次牛市冲锋的号角,比如 $GOAT 5 天内价格一度上涨超 1 万倍。通过对当前市场 TOP 25 Meme 币的交易数据分析,我们发现了几个值得关注的市…

数据结构之二叉树——堆 详解(含代码实现)

1.堆 如果有一个关键码的集合 K { , , , … ,},把它的所有元素按完全二叉树的顺序存储方式存储 在一个一维数组中,则称为小堆( 或大堆 ) 。将根节点最大的堆叫做最大堆或大根堆,根节点最小的…

高级 <HarmonyOS主题课>构建华为支付服务的课后习题

五色令人目盲&#xff1b; 五音令人耳聋&#xff1b; 五味令人口爽&#xff1b; 驰骋畋猎&#xff0c;令人心发狂&#xff1b; 难得之货&#xff0c;令人行妨&#xff1b; 是以圣人为腹不为目&#xff0c;故去彼取此。 本篇内容主要来自&#xff1a;<HarmonyOS主题课>构建…

酒店民宿小程序,探索行业数字化管理发展

在数字化发展时代&#xff0c;各行各业都开始向数字化转型发展&#xff0c;酒店民宿作为热门行业也逐渐趋向数字、智能化发展。 对于酒店民宿来说&#xff0c;如何将酒店特色服务优势等更加快速运营推广是重中之重。酒店民宿小程序作为一款集结预约、房源管理、客户订单管理等…

猎板PCB2到10层数的科技进阶与应用解析

1. 单层板&#xff08;Single-sided PCB&#xff09; 定义&#xff1a;单层板是最基本的PCB类型&#xff0c;导线只出现在其中一面&#xff0c;因此被称为单面板。限制&#xff1a;由于只有一面可以布线&#xff0c;设计线路上有许多限制&#xff0c;不适合复杂电路。应用&…

Python网络爬虫入门篇!

预备知识 学习者需要预先掌握Python的数字类型、字符串类型、分支、循环、函数、列表类型、字典类型、文件和第三方库使用等概念和编程方法。 2. Python爬虫基本流程 a. 发送请求 使用http库向目标站点发起请求&#xff0c;即发送一个Request&#xff0c;Request包含&#xf…

gerrit 搭建遇到的问题

1、启动Apache&#xff0c;端口被占用 : AH00072: make sock: could not bind to address (0S 10048)通常每个套接字地址(协议/网络地址/端口)只允许使用一次。: AH00072: make sock: could not bind to address 0.0.0.:443 a AH00451: no listening sockets available, shutti…

提升安全上网体验:Windows 11 启用 DOH(阿里公共DNS)

文章目录 阿里公共 DNS 介绍免费开通云解析 DNS 服务Windows 编辑 DNS 设置配置 IPv4配置 IPv6 路由器配置 DNS 阿里公共 DNS 介绍 https://alidns.com/ 免费开通云解析 DNS 服务 https://dnsnext.console.aliyun.com/pubDNS 开通服务后&#xff0c;获取 DOH 模板&#xff0…

项目实战使用gitee

1.创建本地仓库 2.进行提交到本地仓库 创建仓库后在idea中会显示图标&#xff0c;点击绿色的√进行快速提交 3.绑定远程仓库 4.番外篇-创建gitee仓库 注意不要勾选其他

【大模型LLM面试合集】大语言模型架构_chatglm系列模型

chatglm系列模型 1.ChatGLM 1.1 背景 主流的预训练框架主要有三种&#xff1a; autoregressive自回归模型&#xff08;AR模型&#xff09;&#xff1a;代表作GPT。本质上是一个left-to-right的语言模型。通常用于生成式任务&#xff0c;在长文本生成方面取得了巨大的成功&a…

yolov8涨点系列之HiLo注意力机制引入

文章目录 HiLo 注意力介绍原理特点 yolov8增加CBAM具体步骤HiLo代码(1)在__init.py__conv.py文件的__all__内添加‘HiLo’(2)conv.py文件复制粘贴HiLo代码(3)修改task.py文件 yolov8.yaml文件增加HiLo注意力机制yolov8.yamlyolov8.yaml引入HiLo注意力机制 将 HiLo 注意力引入 Y…

ReactPress—基于React的免费开源博客CMS内容管理系统

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议&#xff0c;感谢Star。 ![ReactPress](https://i-blog.csdnimg.cn/direct/0720f155edaa4eadba796f4d96d394d7.png#pic_center ReactPress 是使用React开发的开源发布平台&…

金华迪加 现场大屏互动系统 mobile.do.php 任意文件上传漏洞复现

0x01 产品简介 金华迪加现场大屏互动系统是一种集成了先进技术和创意设计的互动展示解决方案,旨在通过大屏幕和多种交互方式,为观众提供沉浸式的互动体验。该系统广泛应用于各类活动、展览、会议等场合,能够显著提升现场氛围和参与者的体验感。 0x02 漏洞概述 金华迪加 现…