1字节=8bit
16float=2字节
模型后面的xxb的单位是字节。
1b 字节≈ 0.93G,这个是以8bit运行,4bit减半,16bit(float)加倍,32bit(double)炒鸡加倍。
剩下的是小头,需要参数计算:
- s:最大序列长度(输入中的令牌数量)
- b:批大小
- h:模型的隐藏维度
- a:注意头的数量
对于整个层
总内存需求总计为11sbh + 5as²b(来自注意力块)+ 19sbh(来自MLP块)+ 4sbh(来自LN)
。
每层激活内存消耗= 34 sbh + 5as²b
小头一般远小于10G。
所以比如llama7b,只需要7*0.93≈9G,再加10,内存19G就可以(实际会更少,因为小头远低于10G),注意这个是以8bit运行,4bit减半,16bit(float)加倍,32bit(double)炒鸡加倍。
感谢博客:https://developer.aliyun.com/article/1496103
感谢github: