ColossalAI Open-Sora 1.1 项目技术报告 (视频生成)

项目信息

  • 项目地址:https://github.com/hpcaitech/Open-Sora
  • 技术报告:
    • Open-Sora 1:https://github.com/hpcaitech/Open-Sora/blob/main/docs/report_01.md
    • Open-Sora 1.1:https://github.com/hpcaitech/Open-Sora/blob/main/docs/report_02.md
  • 项目介绍:
    • Open-Sora 是潞晨科技 (ColossalAI) 团队实现的一个致力于高效生产高质量视频的开源项目,旨在让所有人都能够访问先进的视频生成技术。该项目遵循开源原则,不仅使视频生成技术的访问民主化,还提供了一个简化和用户友好的平台,以简化视频制作的复杂性。Open-Sora 的目标是在内容创作领域激发创新、创造力和包容性。
    • 目前发布了两个版本
      • Open-Sora 1.0:生成 512x512 的 2s 视频
      • Open-Sora 1.1:生成 2s~15s, 144p to 720p, any aspect ratio,支持 text-to-video, image-to-video, video-to-video, infinite time generation 等模式

Open-Sora 1.1 效果

请添加图片描述

Open-Sora 1.1 技术报告

在 Open-Sora 1.1 版本中,训练了一个700M的模型,使用了 10M 的数据(相较于Open-Sora 1.0使用的40万数据)以及更好的 STDiT 架构。实现了 sora 报告中提到的以下功能:

  • 可变的时长、分辨率、纵横比(采样灵活性、改进的框架和构图)
  • 用图像和视频提示(动画图像、扩展生成的视频、视频编辑、连接视频)
  • 图像生成能力

为实现这一目标,在预训练阶段使用了多任务学习。对于扩散模型,使用不同采样时间步的训练已经是一种多任务学习。进一步将这一理念扩展到多分辨率、纵横比、帧长、帧率以及不同的图像和视频条件生成的掩码策略。模型在 0 到 15 秒,144p 到 720p,各种纵横比的视频上进行训练。尽管由于训练 FLOP 的限制,时间一致性的质量不是很高,但仍能看到模型的潜力。

模型架构修改

对原始 ST-DiT 进行了以下修改,以提高训练稳定性和性能(ST-DiT-2):

  • 用于时间注意力的 Rope 嵌入:借鉴 LLM 的最佳实践,将正弦位置编码更改为 Rope 嵌入,用于时间注意力,因为它也是一种序列预测任务。
  • AdaIN 和 Layernorm 用于时间注意力:用AdaIN和Layernorm包装时间注意力,就像空间注意力一样,以稳定训练。
  • 带有 RMSNorm 的 QK 标准化:借鉴 SD3,将 QK 标准化应用于所有注意力,以提高半精度训练的稳定性。
  • 动态输入大小支持和视频信息条件:为了支持多分辨率、纵横比和帧率训练,使 ST-DiT-2 接受任何输入大小,并自动调整位置嵌入。参考 PixArt-alpha 的理念,根据视频的高度、宽度、纵横比、帧长和帧率进行条件设置。
  • 将 T5 的 tokens 从 120 扩展到 200:通常 caption 不超过 200 个 tokens,发现模型可以很好地处理较长的文本。

支持多时间/分辨率/纵横比/帧率训练

如 sora 报告中所述,使用原始视频的分辨率、纵横比和长度进行训练可以增加采样灵活性并改进框架和构图。找到三种实现这一目标的方法:

  • NaViT:通过掩码支持同一批次内的动态大小,效率损失较小。然而,系统实现有点复杂,可能无法从优化的内核(如 flash-attention)中受益。
  • 填充(FiT,Open-Sora-Plan):通过填充支持同一批次内的动态大小。然而,将不同分辨率填充到同一大小效率不高。
  • 桶(SDXL,PixArt):通过分桶支持不同批次内的动态大小,但同一批次内的大小必须相同,并且只能应用固定数量的大小。在同一批次中使用相同大小,不需要实现复杂的掩码或填充。

为了实现的简便性,选择了桶方法。预定义了一些固定的分辨率,并将不同的样本分配到不同的桶中。分桶的担忧如下,但在我们的案例中,这些担忧并不大。

  1. 桶大小的限制
    桶大小限制为固定数量:首先,在实际应用中,只有少数纵横比(如9:16,3:4)和分辨率(如 240p,1080p)常用。其次,发现训练好的模型可以很好地泛化到未见过的分辨率。
  2. 在每个批次中的大小相同,打破了独立同分布假设
    由于使用多台 GPU,不同 GPU 上的本地批次大小不同。没有观察到由于此问题导致的显著性能下降。
  3. 样本可能不足以填充每个桶,分布可能有偏差
    首先,数据集足够大,当本地批次大小不太大时,可以填充每个桶。其次,应该分析数据在不同大小上的分布,并据此定义桶大小。第三,不平衡的分布没有显著影响训练过程。
  4. 不同分辨率和帧长度可能具有不同的处理速度
    不同于 PixArt,仅处理相似分辨率(类似 token 数量)的纵横比,需要考虑不同分辨率和帧长度的处理速度。可以使用 bucket_config 定义每个桶的批次大小,以确保处理速度相似。

桶分配策略
如图所示,桶是(分辨率、帧数、纵横比)的三元组。为不同的分辨率提供了预定义的纵横比,涵盖了大多数常见的视频纵横比。在每个训练周期之前,打乱数据集并将样本分配到不同的桶中,如图所示。将一个样本放入比视频小的最大分辨率和帧长的桶中。

考虑到计算资源有限,为每个(分辨率、帧数)引入了两个属性:keep_prob 和 batch_size,以减少计算成本并实现多阶段训练。具体来说,高分辨率视频将以 1-keep_prob 的概率降采样到较低分辨率,每个桶的批量大小为 batch_size。通过这种方式,可以控制不同桶中的样本数量,并通过搜索合适的批量大小来平衡 GPU 负载。

Masked DiT 作为图像/视频生成模型

Transformer 可以轻松扩展以支持图像到图像视频到视频的任务。提出了一种掩码策略来支持图像和视频的条件生成。掩码策略如下图所示。
掩码策略

通常情况下,对于图像/视频生成条件,去除作为条件的帧的掩码。在 ST-DiT 前向过程中,未掩码的帧将具有时间步 0,而其他帧保持不变(t)。发现直接将该策略应用于已训练的模型会产生较差的结果,因为扩散模型在训练过程中没有学习在一个样本中处理不同的时间步。

受 UL2 启发,在训练过程中引入随机掩码策略。具体来说,在训练过程中随机去除帧的掩码,包括去除第一帧、前 k 帧、最后一帧、最后 k 帧、前后 k 帧、随机帧等。基于 Open-Sora 1.0,使用 50% 的概率应用掩码,发现模型在 10k 步内可以学习处理图像条件(30%的概率效果较差),同时文本到视频性能略有下降。因此,对于 Open-Sora 1.1,从头开始预训练模型并应用掩码策略。

下面提供了一个推理中使用的掩码策略配置示例。一个五元组提供了定义掩码策略的极大灵活性。通过对生成的帧进行条件化,可以自回归地生成无限帧(尽管误差会传播)。
通过 5 元组实现上述的多种掩码策略

数据收集与处理管道

在 Open-Sora 1.0 中发现,数据数量和质量对于训练一个好的模型至关重要,因此致力于扩大数据集。首先,参考 SVD 创建了一个自动化管道,包括场景切割、字幕生成、各种评分和过滤,以及数据集管理脚本和规范。
数据处理管道

计划使用 panda-70M 和其他数据来训练模型,约为 3000万+ 数据。然而,发现磁盘 IO 在同时进行训练和数据处理时是一个瓶颈。因此,只能准备1000万数据集,并未经过我们构建的所有处理管道。最终使用了 970 万视频 + 260 万图像进行预训练,560k 视频+160 万图像进行微调。预训练数据集统计如下。

  • 图像文本 token 长度统计(使用 T5 分词器)
    图像文本token(使用T5分词器)

  • 视频文本 token 长度统计(使用 T5 分词器)。直接使用 panda 的短字幕进行训练,并为其他数据集生成字幕。生成的字幕通常少于 200 个 token。
    视频文本 token 长度统计

  • 视频时长:
    在这里插入图片描述

训练细节

由于计算资源有限,需要仔细监控训练过程,如果推测模型学习效果不佳,则更改训练策略,因为没有计算消融研究的资源。因此,Open-Sora 1.1 的训练包括多次更改,因此未应用 ema。

  • 首先,从 Pixart-alpha-1024 检查点开始,使用不同分辨率的图像进行 6000 步微调。发现模型很容易适应生成不同分辨率的图像。使用 SpeeDiT(iddpm-speed) 加速扩散训练。
  • [阶段1] 然后,在 64 台 H800 GPU 上预训练模型 24000 步,耗时 4 天。尽管模型看到的样本数量相同,但发现模型相比于较小的批量大小学习得更慢。推测在早期阶段,步数对训练更重要。大多数视频分辨率为240p,配置与 stage2.py 相似。视频效果良好,但模型对时间知识了解不多。使用 10% 的掩码率。
  • [阶段1] 为增加步数,切换到较小的批量大小,不使用梯度检查点 (gradient-checkpointing)。此时还增加了 fps 条件。训练了 40000 步,耗时 2 天。大多数视频分辨率为144p,配置文件为 stage1.py。使用较低分辨率,因为在 Open-Sora 1.0 中发现模型可以在较低分辨率下学习时间知识。
  • [阶段1] 发现模型无法很好地学习长视频,生成结果存在噪声,推测是 Open-Sora 1.0 训练中发现的半精度问题。因此,采用 QK 标准化以稳定训练。类似于 SD3,发现模型很快适应了 QK 标准化。还将 iddpm-speed 切换到 iddpm,并将掩码率增加到 25%,因为发现图像条件学习效果不好。训练了 17000 步,耗时 14 小时。大多数视频分辨率为 144p,配置文件为 stage1.py。第一阶段训练持续约一周,总步数为 81000 步。
  • [阶段2] 切换到更高分辨率,大多数视频分辨率为 240p 和 480p。在所有预训练数据上训练了 22000 步,耗时一天。
  • [阶段3] 切换到更高分辨率,大多数视频分辨率为480p和720p(stage3.py)。在高质量数据上训练了 4000 步,耗时一天。发现加载上一阶段的优化器状态可以帮助模型更快学习。

总结,Open-Sora 1.1 的训练在 64 台 H800 GPU 上大约需要 9 天。

限制与未来工作

在接近 Sora 的复制时,发现当前模型存在许多限制,这些限制指向未来的工作。

  • 生成失败:发现许多情况下(尤其是 token 数量大或内容复杂时),模型无法生成场景。可能是时间注意力崩溃,已在代码中发现潜在错误,正在努力修复。此外,将在下一版本中增加模型大小和训练数据以提高生成质量。
  • 生成的内容有噪声和不连贯:发现生成的模型有时有噪声且不连贯,尤其是长视频。认为问题在于未使用时间 VAE。Pixart-Sigma 发现适应新 VAE 很简单,计划在下一版本中为模型开发时间 VAE。
  • 缺乏时间一致性:发现模型无法生成时间一致性高的视频。认为问题在于训练 FLOP 不足。计划收集更多数据并继续训练模型以提高时间一致性。
  • 人类视频生成质量差:发现模型无法生成高质量的人类视频。认为问题在于缺乏人类数据。计划收集更多人类数据并继续训练模型以提高人类视频生成质量。
  • 美学评分低:发现模型的美学评分不高。问题在于缺乏美学评分过滤,因 IO 瓶颈未进行。计划通过美学评分过滤数据并微调模型以提高美学评分。
  • 长视频生成质量较差:发现使用相同提示,长视频质量较差。这意味着图像质量未能同样适应不同长度的序列。

总结

  • 在 open-sora 1.1 上的更新还是挺多的,在多分辨率、变长方面的支持已经是能实现 sora 的基本功能了,对于开源领域的视频生成是个很大的进展
  • 目前训练量并不大,64 台 H800 GPU 上大约需要 9 天。下个版本模型感觉有很大可能能修复一些 bug 后 scale 到更大的集群规模,可以期待效果会有显著提升。优质数据后续可能对开源领域来说瓶颈很大。
  • 依然没有上时序压缩的 VAE,时序压缩有限。不过看起来下个版本应该要用了。
  • 分 stage 训练时加载上一阶段的优化器状态有助于训练应该是基操
  • 长视频效果不好看起来可能是和 frame 分布有关,frame 多的视频看数据分布本来就占比较少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/841025.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

swift中json和字典Dict或者数组相互转换,JSONSerialization的强大使用

在Swift中,你可以使用JSONSerialization类将JSON字符串转换为字典。要将 Swift 字典转换为 JSON 字符串,我们可以使用JSONSerialization类的data(withJSONObject:options:)方法。这个方法将字典转换为二进制数据,然后我们可以使用String(data…

Day23:Leetcode:530.二叉搜索树的最小绝对差 + 501.二叉搜索树中的众数 + 236. 二叉树的最近公共祖先

LeetCode:530.二叉搜索树的最小绝对差 问题描述 解决方案: 1.思路 中序遍历 2.代码实现 class Solution {int pre;int ans;public int getMinimumDifference(TreeNode root) {ans Integer.MAX_VALUE;pre -1;dfs(root);return ans;}public void d…

Unity射击游戏开发教程:(26)创建绕圈跑的效果

unity游戏 在本文中,我将介绍如何为敌人创建圆周运动。gif 中显示的确切行为是敌人沿着屏幕向下移动,直到到达某个点,一旦到达该点,它就会绕圈移动。

从浮点数定义到FP8: AI模型中不同的数据类型

背景:AI模型中不同的数据类型对硬件算力和内存的需求是不同的,为了提高模型在硬件平台的吞吐量,减少数据通信带宽需求,往往倾向于将高位宽数据运算转向较低位宽的数据运算。本文通过重新回顾计算机中整数和浮点数的定义&#xff0…

HCIP-Datacom-ARST自选题库__ISIS简答【3道题】

1.IS-1S是链路状态路由协议,便用SPF算法进行路由计算。某园区同时部署了IPv4和IPV6井运行IS-IS实现网络的互联互通,如图所示,该网络IPv4和IPV6开销相同,R1和R4只支持IPV4。缺省情况下,计算形成的IPv6最短路径树中&…

python数据分析——字符串和文本数据2

参考资料:活用pandas库 1、字符串格式化 (1)格式化字符串 要格式化字符串,需要编写一个带有特殊占位符的字符串,并在字符串上调用format方法向占位符插入值。 # 案例1 varflesh wound s"Its just a {}" p…

solidworks画螺母学习笔记

螺母 单位mm 六边形 直径16mm,水平约束,内圆直径10mm 拉伸 选择两侧对称,厚度7mm 拉伸切除 画相切圆 切除深度7mm,反向切除 拔模角度45 镜像切除 倒角 直径1mm 异形孔向导 螺纹线 偏移打勾,距离为2mm…

java:static关键字用法

在静态方法中不能访问类的非静态成员变量和非静态方法, 因为非静态成员变量和非静态方法都必须依赖于具体的对象才能被调用。 从上面代码里看出: 1.静态方法不能调用非静态成员变量。静态方法test2()中调用非静态成员变量address,编译失败…

从容应对亿级QPS访问,Redis还缺少什么?no.29

众所周知,Redis 在线上实际运行时,面对海量数据、高并发访问,会遇到不少问题,需要进行针对性扩展及优化。本课时,我会结合微博在使用 Redis 中遇到的问题,来分析如何在生产环境下对 Redis 进行扩展改造&…

算法金 | Dask,一个超强的 python 库

本文来源公众号“算法金”,仅用于学术分享,侵权删,干货满满。 原文链接:Dask,一个超强的 python 库 1 Dask 概览 在数据科学和大数据处理的领域,高效处理海量数据一直是一项挑战。 为了应对这一挑战&am…

滑动菜单栏

效果如下&#xff1a; NavigationView 新建menu布局,表示菜单栏的选项 <menu xmlns:android"http://schemas.android.com/apk/res/android"> <group android:checkableBehavior"single"> <item android:id"id/navCall" android…

海外CDN加速方式

随着全球化经济的进一步推进和互联网时代的到来&#xff0c;给对外贸易行业带来了巨大的商机&#xff0c;众多传统的贸易公司都纷纷建立起自已的外贸网站或服务站点等各种信息化平台&#xff0c; 相当多的贸易公司也从他们所构建的平台中得到了很高的利益&#xff0c;然而由于当…

医疗科技:UWB模块为智能医疗设备带来的变革

随着医疗科技的不断发展和人们健康意识的提高&#xff0c;智能医疗设备的应用越来越广泛。超宽带&#xff08;UWB&#xff09;技术作为一种新兴的定位技术&#xff0c;正在引领着智能医疗设备的变革。UWB模块作为UWB技术的核心组成部分&#xff0c;在智能医疗设备中发挥着越来越…

抖音运营_打造高流量的抖音账号

目录 一 账号定位 行业定位 用户定位 内容定位 二 账号人设 我是谁? 我的优势 我的差异化 三 创建账号 名字 头像 简介 四 抖音养号 为什么要养号&#xff1f; 抖音快速养号 正确注册抖音账号 一机一卡一号 实名认证 正确填写账号信息 养号期间的操作 五…

韵搜坊 -- Elastic Stack快速入门

文章目录 现有问题Elastic Stack介绍&#xff08;一套技术栈&#xff09;安装ES安装KibanaElasticsearch概念倒排索引Mapping分词器IK分词器&#xff08;ES插件&#xff09;打分机制 ES的几种调用方式restful api调用&#xff08;http 请求&#xff09;kibana devtools客户端调…

程序员做推广?我劝你别干

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 这是卢松松会员专区&#xff0c;一位会员朋友的咨询&#xff0c;如果你也有自研产品&#xff0c;但不知道如何推广&#xff0c;一定要阅读本文!强烈建议收藏关注&#xff0c;因为你关注的人&#xff0c;决定你看到的…

【机器学习300问】98、卷积神经网络中的卷积核到底有什么用?以边缘检测为例说明其意义。

卷积核是用于从输入数据中提取特征的关键工具。卷积核的设计直接关系到网络能够识别和学习的特征类型。本文让我以边缘检测为例&#xff0c;带大家深入理解卷积核的作用。 一、卷积核的作用 卷积核&#xff0c;又称为过滤器&#xff0c;本质上是一个小的矩阵&#xff0c;其元素…

微信小程序毕业设计-智慧旅游平台系统项目开发实战(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;微信小程序毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计…

【算法】二分算法——山脉数组的峰顶索引

该题用二分算法解“山脉数组的峰顶索引”&#xff0c;有需要借鉴即可。 目录 1.题目2.总结 1.题目 题目链接&#xff1a;LINK 暴力求解很简单&#xff0c;这里不再提及。 这个可以根据峰顶值分为两部分&#xff0c;因而具有“二段性”&#xff0c;可以用二分算法&#xff0c…

默认路由实现两个网段互通实验

默认路由实现两个网段互通实验 **默认路由&#xff1a;**是一种特殊的静态路由&#xff0c;当路由表中与数据包目的地址没有匹配的表项时&#xff0c;数据包将根据默认路由条目进行转发。默认路由在某些时候是非常有效的&#xff0c;例如在末梢网络中&#xff0c;默认路由可以…