Hunuan-DiT代码阅读

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock((norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(attn1): FlashSelfMHAModified((Wqkv): Linear(in_features=1408, out_features=4224, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashSelfAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)(mlp): Mlp((fc1): Linear(in_features=1408, out_features=6144, bias=True)(act): GELU(approximate='tanh')(drop1): Dropout(p=0, inplace=False)(norm): Identity()(fc2): Linear(in_features=6144, out_features=1408, bias=True)(drop2): Dropout(p=0, inplace=False))(default_modulation): Sequential((0): FP32_SiLU()(1): Linear(in_features=1408, out_features=1408, bias=True))(attn2): FlashCrossMHAModified((q_proj): Linear(in_features=1408, out_features=1408, bias=True)(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)(inner_attn): FlashCrossAttention((drop): Dropout(p=0.0, inplace=False))(out_proj): Linear(in_features=1408, out_features=1408, bias=True)(proj_drop): Dropout(p=0.0, inplace=False))(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/55997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

netdata保姆级面板介绍

netdata保姆级面板介绍 基本介绍部署流程下载安装指令选择设置KSM为什么要启用 KSM?如何启用 KSM?验证 KSM 是否启用注意事项 检查端口启动状态 netdata和grafana的区别NetdataGrafananetdata各指标介绍总览system overview栏仪表盘1. CPU2. Load3. Disk…

3.使用条件语句编写存储过程(3/10)

引言 在现代数据库管理系统中,存储过程扮演着至关重要的角色。它们是一组为了执行特定任务而编写的SQL语句,这些语句被保存在数据库中,可以被重复调用。存储过程不仅可以提高数据库操作的效率,还可以增强数据的安全性和一致性。此…

ChatGPT进行翻译

1.建立客户端 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" )2.建立相关函数 def get_openai_response(client, prompt, model"gpt-3.5-turbo"):response client.chat.completions.create(modelmodel,message…

自己写一个FTP客户端程序的过程

前言 以前在一个项目中遇到了内外网分离的问题,内网和外网不能直接通信,项目中外网的机器需要从内网拿数据,因此中间配置了一台FTP服务器,内网产生的数据形成文件,然后传到中间FTP服务器上,外网定时去FTP服…

RPA技术的定义与原理

RPA(Robotic Process Automation)即机器人流程自动化,是一种利用软件机器人或机器人工具来自动执行重复性、规则性和可预测性的业务流程的技术。以下是对RPA技术的详细介绍: 一、RPA技术的定义与原理 RPA技术通过模拟人工操作&a…

微信小程序15天

UniApp(Vue3组合式API)和微信小程序15天学习计划 第1天:开发环境配置和基础知识 UniApp和微信小程序概述及对比安装并配置HBuilderX(UniApp)和微信开发者工具创建第一个UniApp Vue3项目和微信小程序项目了解两个平台的项目结构差异配置外部浏览器和各种小程序模拟…

【redis-06】redis的stream流实现消息中间件

redis系列整体栏目 内容链接地址【一】redis基本数据类型和使用场景https://zhenghuisheng.blog.csdn.net/article/details/142406325【二】redis的持久化机制和原理https://zhenghuisheng.blog.csdn.net/article/details/142441756【三】redis缓存穿透、缓存击穿、缓存雪崩htt…

关于Linux查看系统及版本信息的命令lsb_release命令以及Centos7中将redis服务写入systemctl服务

一、关于Linux查看系统及版本信息的命令lsb_release命令 linux查看系统是centos还是ubuntu,之前一直使用uname -a以及cat /etc/issue。但在某个服务器上发些这些都不行。有一个更好用的命令:lsb_release -a。如执行时提示-bash: lsb_release: 未找到命令…

Vscode+Pycharm+Vue.js+WEUI+django火锅(三)理解Vue

新创建的Vue项目里面很多文件,对于新手,老老实实做一下了解。 1.框架逻辑 框架的逻辑都是相通的,花点时间理一下就清晰了。 2.文件目录及文件 创建好的vue项目下,主要的文件和文件夹要先认识一下,并与框架逻辑对应起…

云原生化 - 监控(简约版)

要在程序中暴露指标,并符合 Prometheus 和 Kubernetes 的规范,可以按照以下步骤进行: 1. 选择合适的库 根据你的编程语言选择适合的 Prometheus 客户端库。例如: Go: github.com/prometheus/client_golangJava: io.prometheus:…

计算机毕业设计 校内跑腿业务系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

dayu_widgets-简介

前言: 越来越多的人开始使用python来做GUI程序,市面上却很少有好的UI控件。即使有也是走的商业收费协议,不敢使用,一个不小心就收到法律传票。 一、原始开源项目: 偶然在GitHub上发现了这个博主的开源项目。https://github.com/phenom-films…

YOLO11改进|SPPF篇|引入YOLOv9提出的SPPELAN模块

目录 一、【SPPELAN】模块1.1【SPPELAN】模块介绍1.2【SPPELAN】核心代码 二、添加【SPPELAN】模块2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、【SPPELAN】模块 1.1【SPPELAN】模块介绍 下图是【SPPELAN】的结构图,让我们…

OOOPS:零样本实现360度开放全景分割,已开源 | ECCV‘24

全景图像捕捉360的视场(FoV),包含了对场景理解至关重要的全向空间信息。然而,获取足够的训练用密集标注全景图不仅成本高昂,而且在封闭词汇设置下训练模型时也受到应用限制。为了解决这个问题,论文定义了一…

软件测试面试题大全

什么是软件测试? 答案:软件测试是一系列活动,旨在评估软件产品的质量,并验证它是否满足规定的需求。它包括执行程序或系统以识别任何缺陷、问题或错误,并确保软件产品符合用户期望。 软件测试的目的是什么&#xff1f…

环球资源网 海外 globalsource reese84 分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 有相关问题请第一时间头像私信联系我删…

Graph知识图谱融入向量数据库,带来RAG效果飞升

01. 前言 随着大型语言模型(LLMs)在各种应用中的广泛使用,如何提升其回答的准确性和相关性成为一个关键问题。检索增强生成(RAG)技术通过整合外部知识库,为LLMs提供了额外的背景信息,有效地改…

使用激光跟踪仪提升码垛机器人精度

标题1.背景 码垛机器人是一种用于工业自动化的机器人,专门设计用来将物品按照一定的顺序和结构堆叠起来,通常用于仓库、物流中心和生产线上,它们可以自动执行重复的、高强度的搬运和堆垛任务。 图1 码垛机器人 传统调整码垛机器人的方法&a…

【重学 MySQL】四十六、创建表的方式

【重学 MySQL】四十六、创建表的方式 使用CREATE TABLE语句创建表使用CREATE TABLE LIKE语句创建表使用CREATE TABLE AS SELECT语句创建表使用CREATE TABLE SELECT语句创建表并从另一个表中选取数据(与CREATE TABLE AS SELECT类似)使用CREATE TEMPORARY …

maven指定模块快速打包idea插件Quick Maven Package

问题背景描述 在实际开发项目中,我们的maven项目结构可能不是单一maven项目结构,项目一般会用parent方式将各个项目进行规范; 随着组件的数量增加,就会引入一个问题:我们只想打包某一个修改后的组件A时就变得很不方便…