BatchSize
显存占用:与batch_size呈线性关系,可理解为 M t o t a l = M f i x e d + B a t c h S i z e ∗ M p e r − s a m p l e M_{total}=M_{fixed}+BatchSize*M_{per-sample} Mtotal=Mfixed+BatchSize∗Mper−sample,其中 M f i x e d M_{fixed} Mfixed指的是模型本身固定占用的显存(由参数数量决定)和优化器状态(也由参数数量决定)
总训练时间:理论上与BatchSize无关(总数不变,单步训练时间增加,总步数减少),但实际中随BatchSize越大,总时间可能减少(硬件并行效率提升),直到显存或硬件并行能力达到瓶颈。
截断长度(输入序列分词后的最大长度,即每条样本被大模型读取的最大长度)
1. 显存占用
在大型语言模型(如 Transformer)中,显存占用主要与模型的激活值(Activations)有关,而激活值的大小受到输入序列长度(即截断长度)的直接影响。以下是逐步分析:
激活值的定义
激活值是指模型在正向传播过程中每一层计算出的中间结果,通常存储在显存中,以便反向传播时计算梯度。对于 Transformer 模型,激活值主要与注意力机制(Self-Attention)和前馈网络(Feed-Forward Network, FFN)的计算相关。
显存占用的组成
显存占用主要包括:
- 模型参数(权重和偏置):与模型规模(层数、隐藏维度)相关,与截断长度无关。
- 激活值:与输入序列长度(截断长度 L L L)、批次大小(batch size B B B)、隐藏维度(hidden size H H H)和层数( N N N)成正比。
- 梯度(训练时):与参数量和激活值大小相关。
对于激活值部分,显存占用主要来源于:
- 注意力机制:计算 Q ⋅ K T Q \cdot K^T Q⋅KT的注意力分数矩阵,尺寸为 ( B , L , L ) (B, L, L) (B,L,L),每层需要存储。
- 中间张量:如 V V V的加权和、前馈层的输出等。
数学表达式
假设: L L L:截断长度(序列长度), B B B:批次大小, H H H:隐藏维度, N N N:模型层数, P P P:浮点数精度(如 FP32 为 4 字节,FP16 为 2 字节)
激活值的显存占用近似为:
显存 激活值 ≈ N ⋅ B ⋅ L ⋅ H ⋅ P + N ⋅ B ⋅ L 2 ⋅ P \text{显存}_{\text{激活值}} \approx N \cdot B \cdot L \cdot H \cdot P + N \cdot B \cdot L^2 \cdot P 显存激活值≈N⋅B⋅L⋅H⋅P+N⋅B⋅L2⋅P
- 第一项 N ⋅ B ⋅ L ⋅ H ⋅ P N \cdot B \cdot L \cdot H \cdot P N⋅B⋅L⋅H⋅P:表示每层的线性张量(如 Q , K , V Q, K, V Q,K,V或 FFN 输出)的显存占用。
- 第二项 N ⋅ B ⋅ L 2 ⋅ P N \cdot B \cdot L^2 \cdot P N⋅B⋅L2⋅P:表示注意力分数矩阵的显存占用(仅在标准注意力机制中显著,若使用优化如 FlashAttention,则可能减少)。
结论:显存占用与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的实现方式。
2. 训练时间
训练时间主要与计算量(FLOPs,浮点运算次数)和硬件并行能力有关,而截断长度会影响计算量。
计算量的组成
- 注意力机制:每层的计算量与 L 2 L^2 L2相关,因为需要计算 L × L L \times L L×L的注意力矩阵。
- 前馈网络:每层的计算量与 L L L线性相关,因为对每个 token 独立计算。
总计算量(FLOPs)近似为:
FLOPs ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) \text{FLOPs} \approx N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2) FLOPs≈N⋅B⋅(2⋅L2⋅H+4⋅L⋅H2)
- 2 ⋅ L 2 ⋅ H 2 \cdot L^2 \cdot H 2⋅L2⋅H:注意力机制的矩阵乘法(如 Q ⋅ K T Q \cdot K^T Q⋅KT和 softmax ⋅ V \text{softmax} \cdot V softmax⋅V),
- 4 ⋅ L ⋅ H 2 4 \cdot L \cdot H^2 4⋅L⋅H2:前馈网络的计算(假设 FFN 隐藏层维度为 4 H 4H 4H)。
训练时间
训练时间与 FLOPs 成正比,同时受硬件并行能力(如 GPU 的计算核心数)影响。假设每秒浮点运算能力为 F GPU F_{\text{GPU}} FGPU(单位:FLOPs/s),则单次前向+反向传播的训练时间为:
时间 ≈ FLOPs F GPU ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) F GPU \text{时间} \approx \frac{\text{FLOPs}}{F_{\text{GPU}}} \approx \frac{N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2)}{F_{\text{GPU}}} 时间≈FGPUFLOPs≈FGPUN⋅B⋅(2⋅L2⋅H+4⋅L⋅H2)
结论:训练时间与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的计算占比。
3. 总结
- 显存占用:与 L L L呈 O ( L ) O(L) O(L)或 O ( L 2 ) O(L^2) O(L2)关系,取决于是否存储完整的注意力矩阵。
- 训练时间:与 L L L呈 O ( L ) O(L) O(L)到 O ( L 2 ) O(L^2) O(L2)关系,注意力机制的二次项通常更显著。
1
假设某模型大小为5GB,推理所需显存也为5GB,普通Lora微调(FP16)所需显存为5GB*2=10GB,8bit的QLora量化为5GB/2=2.5GB,4bit的QLora量化为5GB/4=1.25GB