RCG自条件是如何添加到 Pixel Generator上的?

在自条件的训练过程中,需要将图像经过Pretrained encoder的表征Rep输入进已有的Pixel Generator上,目前RCG是向四种Pixel Generator上加入了自条件,关于它是如何将rep加到Pixel Generator上的,我来总结一下:

一、Pixel Generator: MAGE

在MAGE中,是使用rep替换embedding的 fake class token做的:

  1. 得到CFG的混合表征
  2. 将混合表征替换embedding的 fake class token
  3. 输入进ViT block
# replace fake class token with repif self.use_rep:# cfg(class free guidance) by masking representationdrop_rep_mask = torch.rand(bsz) < self.rep_drop_probdrop_rep_mask = drop_rep_mask.unsqueeze(-1).cuda().float()# 这里相当于cfg, O = αU + (1-a)C, 最终输出是由条件生成C(rep)和无条件生成U(fake_latent)的线性外推获得rep = drop_rep_mask * self.fake_latent + (1 - drop_rep_mask) * reprep = self.latent_prior_proj(rep)# 将rep赋值给embedding的(将图像的rep替换seq的第0维度,相当于替换了seq的fake class token),其实并没有对原始的MAGE做什么改变,只是将原来可学习的fake token换为了rep,从而输入进encoder# input_embeddings_after_drop:(64,129,768) <-- rep:(64,768)input_embeddings_after_drop[:, 0] = rep# class-conditional MAGEif self.use_class_label:class_emb = self.class_emb(class_label)input_embeddings_after_drop[:, 0] = class_emb# apply Transformer blocksx = input_embeddings_after_dropfor blk in self.blocks:x = blk(x)x = self.norm(x)# print("Encoder representation shape:", x.shape)return x, gt_indices, token_drop_mask, token_all_mask

二、Pixel Generator: DiT

从forward函数中,可以看到,

  1. 先使用CFG得到rep的混合表征rep
  2. rep加到timestep中 (16,1024) + (16,1024) =(16,1024)维度得到c。
  3. 然后将这个c作为融合条件输入进去噪block(transformer block)中。
    def forward(self, x, t, y, rep=None):"""Forward pass of DiT.x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)t: (N,) tensor of diffusion timestepsy: (N,) tensor of class labels"""x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2t = self.t_embedder(t)                   # (N, D)y = self.y_embedder(y, self.training)    # (N, D)# rep condif rep is not None:# 1、get the CFG mixture repif self.training:drop_rep_mask = torch.rand(x.size(0)) < self.rep_dropout_probdrop_rep_mask = drop_rep_mask.unsqueeze(-1).cuda().float()rep = drop_rep_mask * self.fake_latent + (1 - drop_rep_mask) * reprep = self.rep_embedder(rep)# 2】直接将rep加到timestep t上从而作为下一步的输入 -->(16,1024)c = t + repelse:c = t + y  # (N, D)# 3、进一步处理for block in self.blocks:x = block(x, c)                      # (N, T, D)x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)x = self.unpatchify(x)                   # (N, out_channels, H, W)return x

三、Pixel Generator: ADM

这里和DiT的处理方式是一样的,直接将rep与timestep相加,然后输入进U-Net进行去噪

U-Net的forward(): 

model_output = model(x_t, self._scale_timesteps(t), rep=rep, **model_kwargs)
    def forward(self, x, timesteps, y=None, rep=None):"""Apply the model to an input batch.:param x: an [N x C x ...] Tensor of inputs.:param timesteps: a 1-D batch of timesteps.:param y: an [N] Tensor of labels, if class-conditional.:return: an [N x C x ...] Tensor of outputs."""assert (y is not None) == (self.num_classes is not None), "must specify y if and only if the model is class-conditional"assert (rep is not None) == self.rep_cond# 将timestep embeddinghs = []emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))if self.num_classes is not None:assert y.shape == (x.shape[0],)emb = emb + self.label_emb(y)# 将timestep的embedding和rep相加,然后输入进U-Netif self.rep_cond:emb = emb + self.rep_proj(rep)h = x.type(self.dtype)for module in self.input_blocks:h = module(h, emb)hs.append(h)h = self.middle_block(h, emb)for module in self.output_blocks:h = th.cat([h, hs.pop()], dim=1)h = module(h, emb)h = h.type(x.dtype)return self.out(h)

四、Pixel Generator: LDM

整体来说没有什么太多的问题,就是将LDM中的condition换成了包含了/未包含condition信息的 rep

  1. 得到encoder后的图像x(4,4,32,32),表征rep c(4,1,256)
  2. 将带有condition信息的rep替换DDPM的原始condition
  3. 将encoder后的图像x(4,4,32,32),表征rep c(4,1,256), timestep t(4)输入进DDPM的后向过程求loss
    def forward(self, x, c, batch=None, gen_img=False, *args, **kwargs):if gen_img:return self.gen_imgs()# 1、得到encoder后的图像x(4,4,32,32),表征rep c(4,1,256)if batch is not None:x, c = self.get_input(batch, self.first_stage_key)if self.rep_cond:rep = c['rep']c = {'class_label': c['class_label']}t = torch.randint(0, self.num_timesteps, (x.shape[0],)).cuda().long()if self.model.conditioning_key is not None:assert c is not None# 将图像的label变为可学习的if self.cond_stage_trainable:c = self.get_learned_conditioning(c)if self.shorten_cond_schedule:  # TODO: drop this optiontc = self.cond_ids[t].cuda()c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))# 2、将带有condition信息的rep替换DDPM的原始conditionif self.rep_cond:c = repc = c.unsqueeze(1)# 3、将encoder后的图像x(4,4,32,32),表征rep c(4,1,256), timestep t(4)输入进DDPM的后向过程求lossloss, loss_dict = self.p_losses(x, c, t, *args, **kwargs)if self.use_ema and batch is not None:self.model_ema(self.model)return loss, loss_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/783265.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【前端Vue】Vue从0基础完整教程第4篇:面经PC端 - Element (下)【附代码文档】

Vue从0基础到大神学习完整教程完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;vue基本概念&#xff0c;vue-cli的使用&#xff0c;vue的插值表达式&#xff0c;{{ gaga }}&#xff0c;{{ if (obj.age > 18 ) { } }}&#xff0c;vue指令&#xff0c;综合…

树的重心——树的结构

树的重心是指对于某个点&#xff0c;将其删除后&#xff0c;可以使得剩余联通块的最大值最小。也就等价于一某个点为根的树&#xff0c;将根删除后&#xff0c;剩余的若干棵子树的大小最小。 例如下图的树的重心就是2。 性质&#xff1a; 性质一&#xff1a;重心的若干棵子树打…

Vue使用el-statistic和el-card显示大屏中的统计数据

​ 一、页面内容&#xff1a; <el-row :gutter"20"><el-col :span"6"><el-card class"box-card"><div><el-statisticgroup-separator",":precision"2":value"value2":title"tit…

【娱乐】战双帕弥什游戏笔记攻略

文章目录 Part.I IntroductionChap.I Information Part.II 新手攻略Chap.I 角色和武器挑选Chap.II 新手意识推荐 Part.II 阵容搭配Chap.I 一拖二Chap.II 毕业队 Reference Part.I Introduction 2019年12月5日全平台公测。 偶然间入坑战双&#xff0c;玩了几天&#xff0c;觉得…

elasticsearch基础应用

1._cat接口 | _cat接口 | 说明 | | GET /_cat/nodes | 查看所有节点 | | GET /_cat/health | 查看ES健康状况 | | GET /_cat/master | 查看主节点 | | GET /_cat/indices | 查看所有索引信息 | es 中会默认提供上面的几个索引&#xff0c;表头…

Hotspot虚拟机对象问题(对象头...对象创建)

目录 对象头 实例数据 对齐填充 对象是如何创建 对象头 在Hotspot虚拟机中&#xff0c;Java对象在内存中的布局大致可以分为三部分:对象头、实例数据和填充对齐。因为synchronized用的锁是存在对象头里的&#xff0c;这里我们需要重点了解对象头。如果对象头是数组类型则对…

springboot汉服推广网站

摘 要 本论文主要论述了如何使用JAVA语言开发一个汉服推广网站 &#xff0c;本系统将严格按照软件开发流程进行各个阶段的工作&#xff0c;采用B/S架构&#xff0c;面向对象编程思想进行项目开发。在引言中&#xff0c;作者将论述汉服推广网站的当前背景以及系统开发的目的&am…

创建第一个Electron程序

前置准备 创建一个文件夹&#xff0c;如: electest进入文件夹&#xff0c;初始化npm npm init -y 安装electron依赖包 注&#xff0c;这里使用npm i -D electron会特别卡&#xff0c;哪怕换成淘宝源也不行。可以使用下面方式安装。 首先&#xff0c;安装yarn npm i -g yarn 随…

从打开电视的过程认识接口(Java)

目录 问题引入 接口 什么是接口 接口的特点 简单理解接口 接口的实现 TV接口及其实现 Controlor接口及其实现 注意 情景实现 代码 输出 问题引入 在面向对象的世界里&#xff0c;不妨设想&#xff0c;我们打开电视需要些什么呢&#xff1f;显然的&#xff0c;首…

Leetcode刷题记录面试基础题day1(备战秋招)

hello&#xff0c;你好鸭&#xff0c;我是康康&#xff0c;很高兴你能来阅读&#xff0c;昵称是希望自己能不断精进&#xff0c;向着优秀程序员前行!&#x1f4aa;&#x1f4aa;&#x1f4aa; 目前博客主要更新Java系列、数据库、项目案例、计算机基础等知识点。感谢你的阅读和…

C/C++游戏编程实例-飞翔的小鸟

飞翔的小鸟游戏设计 首先需要包含以下库&#xff1a; #include<stdio.h> #include<windows.h> #include<stdlib.h> //包含system #include<conio.h>设置窗口大小&#xff1a; #define WIDTH 50 #define HEIGHT 16设置鸟的结构&#xff1a; struct …

CentOs7.9中修改Mysql8.0.28默认的3306端口防止被端口扫描入侵

若你的服务器被入侵&#xff0c;可以从这些地方找到证据&#xff1a; 若有上述信息&#xff0c;300%是被入侵了&#xff0c;重装服务器系统以后再重装Mysql数据库&#xff0c;除了设置一个复杂的密码以外&#xff0c;还需要修改默认的Mysql访问端口&#xff0c;逃避常规端口扫描…

wps没保存关闭了恢复数据教程

有时候我们因为电脑问题会忘记保存就关闭wps导致数据丢失&#xff0c;不知道wps没保存关闭了怎么恢复数据&#xff0c;其实数据是无法恢复的。 wps没保存关闭了怎么恢复数据 1、wps没有数据恢复功能&#xff0c;不过可以开启自动备份。 2、我们可以先点击wps左上角的“文件”…

全国超市数据可视化仪表板制作

全国超市消费数据展示 指定 Top几 客户销费数据展示 指定 Top几 省份销费数据展示 省份销售额数据分析 完整结果

Taro+vue3 监听当前的页面滚动的距离

1.需求 想实现一个这样的效果 一开始这个城市组件 是透明的 在顶部 的固定定位 当屏幕滑动的时候到一定的距离 将这个固定的盒子 背景颜色变成白色 2.Taro中的滚动 Taro中的滚动 有固定的api 像生命周期一样 这个生命周期是 usePageScroll import Taro, { useDidShow, useP…

蓝桥杯day14刷题日记

P8707 [蓝桥杯 2020 省 AB1] 走方格 思路&#xff1a;很典型的动态规划问题&#xff0c;对于偶数格特判&#xff0c;其他的正常遍历一遍&#xff0c;现在所处的格子的方案数等于左边的格子的方案数加上上面格子的方案数之和 #include <iostream> using namespace std; …

Linux系统下NAT网卡出现问题,无法上网的解决办法

NTA连接无法上网&#xff0c;如果你试过网上所有教程&#xff0c;检测了Windows环境和Ubuntu环境没问题&#xff0c;且无法启动系统服务、ping网络失败、重置虚拟机网络配置器也无效等种种以下所列原因无法解决&#xff0c;可能在于没有获取IP地址&#xff0c;才不能上网 netw…

书生·浦语大模型实战营 | 第1次学习笔记

前言 书生浦语大模型应用实战营 第二期正在开营&#xff0c;欢迎大家来学习。&#xff08;参与链接&#xff1a;https://mp.weixin.qq.com/s/YYSr3re6IduLJCAh-jgZqghttps://mp.weixin.qq.com/s/YYSr3re6IduLJCAh-jgZqg&#xff09; 第一堂课的视频链接&#xff1a;https://m…

多系统使用ffmpeg读取麦克风数据

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、命令行1.Ubuntu1.alsa2.pulseaudio 2.Windows1.dshow 二、代码总结 前言 最近在搞一个项目需要用到麦克风读取数据并分析&#xff0c;我的开发环境是Ubunt…

Android 12系统源码_多窗口模式(一)和多窗口模式相关方法的调用顺序

前言 从 Android 7.0 开始&#xff0c;Google 推出了一个名为“多窗口模式”的新功能&#xff0c;允许在设备屏幕上同时显示多个应用&#xff0c;多窗口模式允许多个应用同时共享同一屏幕&#xff0c;多窗口模式&#xff08;Multi Window Supports&#xff09;目前支持以下三种…