详细解读个性化定制大杀器IP-Adapter代码

Diffusion models代码解读:入门与实战

前言:IP-Adapter作为Diffusion Models最成功的技术之一,已经在诸多互联网应用中落地。介绍IP-Adapter原理和应用的博客有很多,但是逐行详细解读代码的博客很少。这篇博客从细节出发,结合原理详细解读个性化定制大杀器IP-Adapter代码。

目录

原理概述

代码详解

冻结模型

Image Projection

注意力替换

新注意力计算

打包IP-Adapter

推理

源码地址


原理概述

一句话概括:原有的Cross Attention计算是用text condition计算的;IP-Adapter在原有Cross Attention计算上加上了image condition。其中Q共用,K V重新计算。

在整个训练过程中,冻结了最初的 UNet 模型,所以在上述解耦交叉注意力中,只有W′k,W′v 两个参数是可训练的。而且W′k,W′v 的权重从原来对应的Cross Attention初始化。

代码详解

冻结模型

这张图中蓝色部分全部冻结:

    unet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)

红色部分可训练:

params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())

Image Projection

对应于下图中的这个部分,由一个Linear层和一个Norm组成:

    image_proj_model = ImageProjModel(cross_attention_dim=unet.config.cross_attention_dim,clip_embeddings_dim=image_encoder.config.projection_dim,clip_extra_context_tokens=4,)
class ImageProjModel(torch.nn.Module):"""Projection Model"""def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):super().__init__()self.generator = Noneself.cross_attention_dim = cross_attention_dimself.clip_extra_context_tokens = clip_extra_context_tokensself.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)self.norm = torch.nn.LayerNorm(cross_attention_dim)def forward(self, image_embeds):embeds = image_embedsclip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)clip_extra_context_tokens = self.norm(clip_extra_context_tokens)return clip_extra_context_tokens

经过这个模型得到Image Feature,它的维度需要与计算Attention的维度对齐。

注意力替换

在原始的Unet模型中,需要把Cross Attention的计算替换成新的,但是self- attention不变。

其中attn1.processor 是 self- attention,这部分不变;attn2.processor代表交叉注意力层,这部分替换!

W′k,W′v 的权重从原来对应的Cross Attention初始化!也就是代码中的:

        if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)attn_procs[name].load_state_dict(weights)

这部分的完整代码如下:

    attn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

新注意力计算

现在已经在Unet中把原来的Cross注意力计算都替换并且组装好了。

下一步我们看看被替换的注意力是如何计算的,也就是看看IPAttnProcessor 如何实现。

class IPAttnProcessor(nn.Module):r"""Attention processor for IP-Adapater.Args:hidden_size (`int`):The hidden size of the attention layer.cross_attention_dim (`int`):The number of channels in the `encoder_hidden_states`.scale (`float`, defaults to 1.0):the weight scale of image prompt.num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):The context length of the image features."""def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):super().__init__()self.hidden_size = hidden_sizeself.cross_attention_dim = cross_attention_dimself.scale = scaleself.num_tokens = num_tokensself.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)def __call__(self,attn,hidden_states,encoder_hidden_states=None,attention_mask=None,temb=None,*args,**kwargs,):residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states)if encoder_hidden_states is None:encoder_hidden_states = hidden_stateselse:# get encoder_hidden_states, ip_hidden_statesend_pos = encoder_hidden_states.shape[1] - self.num_tokensencoder_hidden_states, ip_hidden_states = (encoder_hidden_states[:, :end_pos, :],encoder_hidden_states[:, end_pos:, :],)if attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)attention_probs = attn.get_attention_scores(query, key, attention_mask)hidden_states = torch.bmm(attention_probs, value)hidden_states = attn.batch_to_head_dim(hidden_states)# for ip-adapterip_key = self.to_k_ip(ip_hidden_states)ip_value = self.to_v_ip(ip_hidden_states)ip_key = attn.head_to_batch_dim(ip_key)ip_value = attn.head_to_batch_dim(ip_value)ip_attention_probs = attn.get_attention_scores(query, ip_key, None)self.attn_map = ip_attention_probsip_hidden_states = torch.bmm(ip_attention_probs, ip_value)ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)hidden_states = hidden_states + self.scale * ip_hidden_states# linear projhidden_states = attn.to_out[0](hidden_states)# dropouthidden_states = attn.to_out[1](hidden_states)if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

核心是计算出ip_hidden_states,然后与原来的hidden_state相加:

        ip_key = self.to_k_ip(ip_hidden_states)ip_value = self.to_v_ip(ip_hidden_states)ip_key = attn.head_to_batch_dim(ip_key)ip_value = attn.head_to_batch_dim(ip_value)ip_attention_probs = attn.get_attention_scores(query, ip_key, None)self.attn_map = ip_attention_probsip_hidden_states = torch.bmm(ip_attention_probs, ip_value)ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

相加的时候有一个scale,这个scale是一个固定的参数1:

hidden_states = hidden_states + self.scale * ip_hidden_states

打包IP-Adapter

因为是用accelerate训练的,所以把可训练的部分都打包成一个单独的类 ip_adapter,用accelerator包装一下,就可以很方便更新梯度和保存权重。

ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

推理

累了,下次再写...

    def generate(self,pil_image=None,clip_image_embeds=None,prompt=None,negative_prompt=None,scale=1.0,num_samples=4,seed=None,guidance_scale=7.5,num_inference_steps=30,**kwargs,):self.set_scale(scale)if pil_image is not None:num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)else:num_prompts = clip_image_embeds.size(0)if prompt is None:prompt = "best quality, high quality"if negative_prompt is None:negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"if not isinstance(prompt, List):prompt = [prompt] * num_promptsif not isinstance(negative_prompt, List):negative_prompt = [negative_prompt] * num_promptsimage_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=clip_image_embeds)bs_embed, seq_len, _ = image_prompt_embeds.shapeimage_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)with torch.inference_mode():prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(prompt,device=self.device,num_images_per_prompt=num_samples,do_classifier_free_guidance=True,negative_prompt=negative_prompt,)prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)generator = get_generator(seed, self.device)images = self.pipe(prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,guidance_scale=guidance_scale,num_inference_steps=num_inference_steps,generator=generator,**kwargs,).imagesreturn images

源码地址

IP-Adapter/tutorial_train.py at main · tencent-ailab/IP-Adapter · GitHub

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/58382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP JSON 教程

PHP JSON 教程 PHP 是一种广泛使用的开源服务器端脚本语言,而 JSON(JavaScript Object Notation)是一种轻量级的数据交换格式。PHP 提供了多种函数和库来处理 JSON 数据,使得在 PHP 应用程序中解析和生成 JSON 数据变得非常容易。本教程将详细介绍 PHP 中 JSON 的使用方法…

数据采集之scrapy框架2

本博文使用自动化爬虫框架完成微信开放社区文档信息的爬取(重点理解 scrapy 框架自动化爬 虫构建过程,能够分析 LinkExtractor 和 Rule 规则的基本用法) 包结构目录如下图所示: 主要代码: ( items.p…

WAPI认证过程如何实现?

WAPI(WLAN Authentication and Privacy Infrastructure)认证过程是通过一系列步骤来实现的,以确保无线局域网(WLAN)中设备的合法性和数据传输的安全性。以下是WAPI认证过程的详细实现步骤: 一、认证前的准…

从零开始的LeetCode刷题日记:746. 使用最小花费爬楼梯

一.相关链接 题目链接:746. 使用最小花费爬楼梯 二.心得体会 这道题还是动规五部曲。 1.首先是dp数组及其下标的含义,dp记录了每层楼梯对应的爬的方法,每个下标存储每个对应楼层。 2.然后是递归公式,这里的递归公式就不是简单…

深⼊理解指针(2)

目录 1. const修饰指针及变量 2. 野指针 3. assert断⾔ 4. 指针的传址调⽤ 一 const修饰指针及变量(const是场属性——不能改变的属性) 1 const修饰变量 那怎么证明被const修饰的变量本质还是变量呢? 上面我们绕过n,使…

每日科技资讯:2024年11月06日【龙】农历十月初六 ---文末送书

目录 1.OpenAI因算力瓶颈暂缓GPT-5发布 合作芯片开发寻求突破2.现在,𝕏 允许被你屏蔽的人继续查看你的帖子3.硬刚Intel与AMD!NVIDIA明年推出PC芯片4.苹果停止签署 iOS 18.0.1,不再允许从 18.1 降级5.Nvidia 加入道琼斯指数成份股 …

swoole扩展安装--入门篇

对于php来说,swoole是个强大的补充扩展。这是我第3次写swoole扩展安装,这次基于opencloudos8系统,php使用8.2。 安装swoole扩展首先想到的是用宝塔来安装,毕竟安装方便,还能统一管理。虽然获得swoole版本不是最新的&am…

layui xm-select的使用

一、文档 xm-select 二、使用 <div id"js-form-tags{$ke}{$index}" val"{$ke}"></div> <input type"hidden" class"selectkey" name"selectkey[]" value"{$ke}" /> function initSelect(id…

【大模型开发指南】llamaindex配置deepseek、jina embedding及chromadb实现本地RAG及知识库(win系统、CPU适配)

说一些坑&#xff0c;本来之前准备用milvus&#xff0c;但是发现win搞不了&#xff08;docker都配好了&#xff09;。然后转头搞chromadb。这里面还有就是embedding一般都是本地部署&#xff0c;但我电脑是cpu的没法玩&#xff0c;我就选了jina的embedding性能较优&#xff08;…

vue前端sku实现

this.value.skuStockList [];let skuList this.value.skuStockList;//只有一个属性时if (this.selectProductAttr.length 1) {let attr this.selectProductAttr[0];for (let i 0; i < attr.values.length; i) {skuList.push({spData: JSON.stringify([{key:attr.name,v…

pyspark基础准备

1.前言介绍 学习目标&#xff1a;了解什么是Speak、PySpark&#xff0c;了解为什么学习PySpark&#xff0c;了解课程是如何和大数据开发方向进行衔接 使用pyspark库所写出来的代码&#xff0c;既可以在电脑上简单运行&#xff0c;进行数据分析处理&#xff0c;又可以把代码无缝…

数据库基础(4) . 数据库结构

2.基础结构 2.1.结构及名称 数据库 database 表空间 tablespaces(Oracle) 表格 table 字段 column 记录 record 值 value 2.2.数据库 database 在配置文件中指定存放位置 # 设置mysql数据库的数据的存放目录 datadirD:\MySQL\mysql-8.0.16-winx64\data每个数据库对应…

Meme 币生态全景图分析:如何获得超额收益?

近期&#xff0c;BTC 再次突破 7 万美元大关&#xff0c;市场上贪婪指数再次达到 80&#xff0c;而 Meme 币往往是每次牛市冲锋的号角&#xff0c;比如 $GOAT 5 天内价格一度上涨超 1 万倍。通过对当前市场 TOP 25 Meme 币的交易数据分析&#xff0c;我们发现了几个值得关注的市…

数据结构之二叉树——堆 详解(含代码实现)

1.堆 如果有一个关键码的集合 K { &#xff0c; &#xff0c; &#xff0c; … &#xff0c;}&#xff0c;把它的所有元素按完全二叉树的顺序存储方式存储 在一个一维数组中&#xff0c;则称为小堆( 或大堆 ) 。将根节点最大的堆叫做最大堆或大根堆&#xff0c;根节点最小的…

高级 <HarmonyOS主题课>构建华为支付服务的课后习题

五色令人目盲&#xff1b; 五音令人耳聋&#xff1b; 五味令人口爽&#xff1b; 驰骋畋猎&#xff0c;令人心发狂&#xff1b; 难得之货&#xff0c;令人行妨&#xff1b; 是以圣人为腹不为目&#xff0c;故去彼取此。 本篇内容主要来自&#xff1a;<HarmonyOS主题课>构建…

酒店民宿小程序,探索行业数字化管理发展

在数字化发展时代&#xff0c;各行各业都开始向数字化转型发展&#xff0c;酒店民宿作为热门行业也逐渐趋向数字、智能化发展。 对于酒店民宿来说&#xff0c;如何将酒店特色服务优势等更加快速运营推广是重中之重。酒店民宿小程序作为一款集结预约、房源管理、客户订单管理等…

猎板PCB2到10层数的科技进阶与应用解析

1. 单层板&#xff08;Single-sided PCB&#xff09; 定义&#xff1a;单层板是最基本的PCB类型&#xff0c;导线只出现在其中一面&#xff0c;因此被称为单面板。限制&#xff1a;由于只有一面可以布线&#xff0c;设计线路上有许多限制&#xff0c;不适合复杂电路。应用&…

Python网络爬虫入门篇!

预备知识 学习者需要预先掌握Python的数字类型、字符串类型、分支、循环、函数、列表类型、字典类型、文件和第三方库使用等概念和编程方法。 2. Python爬虫基本流程 a. 发送请求 使用http库向目标站点发起请求&#xff0c;即发送一个Request&#xff0c;Request包含&#xf…

python数据结构基础(5)

本章学习的是栈.它最突出的特点是后进先出,与队列恰好相反,但其在实现过程中与队列异曲同工. 栈的基本结构: 栈是按照有序的后进先出规则运行的一种结构,其插入和删除操作均在栈项进行这一点区别于队列的队尾进队,队头出队. 栈一般包括入栈和出栈操作,且有一个顶指针(top)用于…

gerrit 搭建遇到的问题

1、启动Apache&#xff0c;端口被占用 : AH00072: make sock: could not bind to address (0S 10048)通常每个套接字地址(协议/网络地址/端口)只允许使用一次。: AH00072: make sock: could not bind to address 0.0.0.:443 a AH00451: no listening sockets available, shutti…