Hunuan-DiT代码阅读

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock((norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(attn1): FlashSelfMHAModified((Wqkv): Linear(in_features=1408, out_features=4224, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashSelfAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=1408, out_features=6144, bias=True)(act): GELU(approximate='tanh')(drop1): Dropout(p=0, inplace=False)(norm): Identity()(fc2): Linear(in_features=6144, out_features=1408, bias=True)(drop2): Dropout(p=0, inplace=False))(default_modulation): Sequential((0): FP32_SiLU()(1): Linear(in_features=1408, out_features=1408, bias=True))(attn2): FlashCrossMHAModified((q_proj): Linear(in_features=1408, out_features=1408, bias=True)(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashCrossAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/55997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

netdata保姆级面板介绍

netdata保姆级面板介绍 基本介绍部署流程下载安装指令选择设置KSM为什么要启用 KSM?如何启用 KSM?验证 KSM 是否启用注意事项 检查端口启动状态 netdata和grafana的区别NetdataGrafananetdata各指标介绍总览system overview栏仪表盘1. CPU2. Load3. Disk…

3.使用条件语句编写存储过程(3/10)

引言 在现代数据库管理系统中,存储过程扮演着至关重要的角色。它们是一组为了执行特定任务而编写的SQL语句,这些语句被保存在数据库中,可以被重复调用。存储过程不仅可以提高数据库操作的效率,还可以增强数据的安全性和一致性。此…

RPA技术的定义与原理

RPA(Robotic Process Automation)即机器人流程自动化,是一种利用软件机器人或机器人工具来自动执行重复性、规则性和可预测性的业务流程的技术。以下是对RPA技术的详细介绍: 一、RPA技术的定义与原理 RPA技术通过模拟人工操作&a…

【redis-06】redis的stream流实现消息中间件

redis系列整体栏目 内容链接地址【一】redis基本数据类型和使用场景https://zhenghuisheng.blog.csdn.net/article/details/142406325【二】redis的持久化机制和原理https://zhenghuisheng.blog.csdn.net/article/details/142441756【三】redis缓存穿透、缓存击穿、缓存雪崩htt…

关于Linux查看系统及版本信息的命令lsb_release命令以及Centos7中将redis服务写入systemctl服务

一、关于Linux查看系统及版本信息的命令lsb_release命令 linux查看系统是centos还是ubuntu,之前一直使用uname -a以及cat /etc/issue。但在某个服务器上发些这些都不行。有一个更好用的命令:lsb_release -a。如执行时提示-bash: lsb_release: 未找到命令…

Vscode+Pycharm+Vue.js+WEUI+django火锅(三)理解Vue

新创建的Vue项目里面很多文件,对于新手,老老实实做一下了解。 1.框架逻辑 框架的逻辑都是相通的,花点时间理一下就清晰了。 2.文件目录及文件 创建好的vue项目下,主要的文件和文件夹要先认识一下,并与框架逻辑对应起…

计算机毕业设计 校内跑腿业务系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

dayu_widgets-简介

前言: 越来越多的人开始使用python来做GUI程序,市面上却很少有好的UI控件。即使有也是走的商业收费协议,不敢使用,一个不小心就收到法律传票。 一、原始开源项目: 偶然在GitHub上发现了这个博主的开源项目。https://github.com/phenom-films…

YOLO11改进|SPPF篇|引入YOLOv9提出的SPPELAN模块

目录 一、【SPPELAN】模块1.1【SPPELAN】模块介绍1.2【SPPELAN】核心代码 二、添加【SPPELAN】模块2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、【SPPELAN】模块 1.1【SPPELAN】模块介绍 下图是【SPPELAN】的结构图,让我们…

OOOPS:零样本实现360度开放全景分割,已开源 | ECCV‘24

全景图像捕捉360的视场(FoV),包含了对场景理解至关重要的全向空间信息。然而,获取足够的训练用密集标注全景图不仅成本高昂,而且在封闭词汇设置下训练模型时也受到应用限制。为了解决这个问题,论文定义了一…

环球资源网 海外 globalsource reese84 分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 有相关问题请第一时间头像私信联系我删…

Graph知识图谱融入向量数据库,带来RAG效果飞升

01. 前言 随着大型语言模型(LLMs)在各种应用中的广泛使用,如何提升其回答的准确性和相关性成为一个关键问题。检索增强生成(RAG)技术通过整合外部知识库,为LLMs提供了额外的背景信息,有效地改…

使用激光跟踪仪提升码垛机器人精度

标题1.背景 码垛机器人是一种用于工业自动化的机器人,专门设计用来将物品按照一定的顺序和结构堆叠起来,通常用于仓库、物流中心和生产线上,它们可以自动执行重复的、高强度的搬运和堆垛任务。 图1 码垛机器人 传统调整码垛机器人的方法&a…

【重学 MySQL】四十六、创建表的方式

【重学 MySQL】四十六、创建表的方式 使用CREATE TABLE语句创建表使用CREATE TABLE LIKE语句创建表使用CREATE TABLE AS SELECT语句创建表使用CREATE TABLE SELECT语句创建表并从另一个表中选取数据(与CREATE TABLE AS SELECT类似)使用CREATE TEMPORARY …

maven指定模块快速打包idea插件Quick Maven Package

问题背景描述 在实际开发项目中,我们的maven项目结构可能不是单一maven项目结构,项目一般会用parent方式将各个项目进行规范; 随着组件的数量增加,就会引入一个问题:我们只想打包某一个修改后的组件A时就变得很不方便…

企业数据安全防泄密要怎么做?七个措施杜绝泄密风险!

随着信息技术的快速发展,企业的核心数据已成为最具价值的资产之一。然而,数据泄露事件频发,不仅会给企业造成严重的经济损失,还会影响企业的声誉。因此,如何防止企业数据泄密已成为每个企业管理者关注的重点。以下是七…

利用特征点采样一致性改进icp算法点云配准方法

1、index、vector 2、kdtree和kdtreeflann 3、if kdtree.radiusSearch(。。。) > 0)

js拼接html代码在线工具

具体请前往:在线Html转Js--将Html代码转成javascript动态拼接代码并保持原有格式

年薪30W的Java程序员都要求熟悉JVM与性能调优!

一、JVM 内存区域划分 1.程序计数器(线程私有) 程序计数器(Program Counter Register),也有称作为 PC 寄存器。保存的是程序当前执行的指令的地址(也可以说保存下一条指令的所在存储单元的地址&#xff0…

在线教育的未来:SpringBoot技术实现

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理微服务在线教育系统的相关信息成为必然。开…