一、目录
1 大模型训练需要多少算力?
2. 大模型训练需要多少显存?
3. 大模型需要多少数据量训练?
4. 训练时间估计
5. epoch 选择经验
6. 浮点计算性能测试
二、实现
1 大模型训练需要多少算力?
训练总算力(Flops)= 6 * 模型的参数量 * 训练数据的token 数
参考:https://blog.csdn.net/qq_29788741/article/details/135411259?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-0-135411259-blog-134679887.235^v43^control&spm=1001.2101.3001.4242.1&utm_relevant_index=1
模型的参数量和训练数据的 token 数之间也有个比例关系,这也很容易理解,只要把模型想象成数据的压缩版本就行了,压缩比总是有极限的。模型的参数量太小,就吃不下训练数据里面所有的知识;模型的参数量如果大于训练数据的 token 数,那又浪费,还容易导致 over-fitting。
- 大模型训练需要多少显存?
内存分配: 1.模型参数 2. 梯度 3.优化器参数。
chatglm3 6B为例:全精度模型参数是float32类型:1b(10亿)个模型参数,约占用4G显存(实际大小:10^9 * 4 / 1024^3 ~= 3.725 GB),那么LLaMA的参数量为6b,那么加载模型参数需要的显存为:3.725 * 6 ~= 22 GB
5. 训练显存计算大小=模型参数占用+梯度占用+优化器占用+CUDA kernel占用LLaMA-6B为例:模型参数:等于参数量*每个参数所需内存。对于 fp32,需要 6B*4 bytes = 24GB内存梯度:同上,等于参数量*每个梯度参数所需内存。对于 fp32,需要 6B*4 bytes = 24GB内存优化器参数:不同的优化器所储存的参数量不同。对于常用的 AdamW 来说,需要储存两倍的模型参数(用来储存一阶和二阶momentum)。fp32 的 AdamW 需要 6B*8 bytes = 48 GB除此之外,CUDA kernel也会占据一些 RAM,大概 1.3GB 左右,查看方式如下。综上,模型部分大致需要 24+24GB+48GB+1.3GB = 97GB 左右
- 大模型需要多少数据量训练?
2022 年 9 月,DeepMind(Chinchilla 论文)中提出Hoffman scaling laws:表明每个参数需要大约 20 个文本token进行训练。比如一个7B的模型需要140B token,若每个token使用int32(四字节)进行编码的话,就是560GB的数据。
训练模型参数量与训练数据量的统计
参考:https://zhuanlan.zhihu.com/p/667363516
https://zhuanlan.zhihu.com/p/636812912?utm_id=0
4. 训练时间估计
理想清空下,训练总算力(Flops)= 6 * 模型的参数量 * 训练数据的token 数
一般GPU 利用率在0.3 到 0.55 之间。 GPU峰值:每张卡每秒实际做的浮点运算数,一般在理论上限的50%以上,现在衡量计算速度的标准是TFLOPS。
浮点运算峰值计算能力 = 每个SM的CUDA核心数 * 每个CUDA核心的时钟频率 * 每个CUDA核心的浮点运算能力。
- epoch 选择经验
1 如果你有百万数据量的话,一个epoch足够了。如果只有几千上万的数据量,可以尝试1~3个epoch
2 深度神经网络以及最新的视觉Transformer模型训练数百个epoch是很常见的操作,不过大型语言模型通常指训练1个epoch。研究人员对维基百科的数据进行了一项相关实验,相比C4来说他们认为维基百科是高质量的,不过事实证明,当维基百科数据在训练期间重复多个epoch后发生了退化现象。
- 浮点计算性能测试
import torch
from torch.utils import benchmark
typ = torch.float16 #数据精度 #FP16 精度
#typ = torch.float32 #数据精度 #tf32
#typ = torch.float64 #数据精度 #FP64
n = 1024 * 16
a = torch.randn(n, n).type(typ).cuda()
b = torch.randn(n, n).type(typ).cuda()
t = benchmark.Timer( stmt='a @ b',globals={'a': a, 'b': b})
x = t.timeit(50)
print(2*n**3 / x.median /1e12)
V100 测试结果:
理论值网址:https://www.nvidia.cn/data-center/v100/