7. 探究模型参数与显存的关系以及不同精度造成的影响

这篇文章将探讨两个重点:

  • 模型参数与显存(GPU 内存)之间的关系
  • 不同精度的导入方式,以及它们对显存和性能的影响

理解这些概念会让你在模型的选择上更加游刃有余。

文章目录

  • 模型参数与显存的关系
    • 模型参数量与内存占用
    • GPU 显存需求
  • 不同精度的导入方式及其影响
    • 常见的数值精度格式
    • 对显存占用的影响
  • 精度的权衡与选择
    • 准确性 vs. 性能
    • 何时选择何种精度
    • 硬件兼容性
  • 实际应用中的精度技巧
    • 使用 FP16 精度
    • 使用 BF16 精度
    • 使用 INT8 量化
    • 消除警告
  • 实践示例
    • 对比不同精度下的显存占用
    • 常见问题及解决方案
      • 问题一:`RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error`
      • 问题二:`TypeError: dispatch_model() got an unexpected keyword argument 'offload_index'`
  • 总结
  • 参考文献

模型参数与显存的关系

模型参数量与内存占用

神经网络模型由多个层组成,每一层都包含权重(weight)和偏置(bias),这些统称为模型参数。而模型的参数量一般直接影响它的学习和表示能力。

模型大小计算公式
模型大小(字节) = 参数数量 × 每个参数的字节数 \text{模型大小(字节)} = \text{参数数量} \times \text{每个参数的字节数} 模型大小(字节)=参数数量×每个参数的字节数

示例:

对于一个拥有 10 亿(1,000,000,000) 参数的模型,使用 32 位浮点数(float32 表示,每个参数占用 4 字节,即:
模型大小 = 1 , 000 , 000 , 000 × 4 字节 = 4 GB \text{模型大小} = 1,000,000,000 \times 4 \text{字节} = 4 \text{GB} 模型大小=1,000,000,000×4字节=4GB

具体来讲,以meta-llama/Meta-Llama-3.1-70B-Instruct 这个拥有 700 亿(70B) 参数的大模型为例,我们仅考虑模型参数,它的显存需求就已经超过了大多数消费级 GPU(如 RTX 4090 最高 48G):
70 × 1 0 9 × 4 字节 = 280 GB 70 \times 10^9 \times 4 \text{字节} = 280 \text{GB} 70×109×4字节=280GB

GPU 显存需求

而在实际部署模型时,GPU 不仅需要容纳模型参数,还需要处理其他数据,这意味着更大的显存占用量。其中包括:

  • 模型参数:模型的权重和偏置。
  • 优化器状态(仅训练时):如动量(momentum)和梯度平方和等信息。
  • 中间激活值:前向传播和反向传播过程中产生的中间结果。
  • 批量大小(Batch Size):一次处理的数据样本数量。

推理与训练的区别

  • 推理阶段:仅需加载模型参数和少量的中间激活值。
  • 训练阶段:需要额外存储梯度和优化器状态,因此显存需求更大。

不同精度的导入方式及其影响

为了降低显存占用,我们可以使用不同的数值精度格式来存储模型参数,这些精度格式在内存使用和计算性能上各有优劣。

常见的数值精度格式

  • FP32(32 位浮点数):标准精度,每个参数占用 4 字节
  • FP16(16 位浮点数):半精度浮点数,每个参数占用 2 字节
  • BF16(16 位脑浮点数):与 FP16 类似,但具有更大的指数范围,适用于深度学习。
  • INT8(8 位整数):低精度整数,每个参数占用 1 字节
  • 量化格式4 位 或更低,用于特殊的量化算法,进一步减少内存占用。

对显存占用的影响

使用更低的精度可以显著减少模型的内存占用:

  • FP16 相对于 FP32:内存占用减半。
  • INT8 相对于 FP32:内存占用减少到原来的四分之一。

示例

对于一个 70B 参数的模型:

  • FP32 精度:280 GB 显存。
  • FP16/BF16 精度:140 GB 显存。
  • INT8 精度:70 GB 显存。

注意:实际显存占用还受到其他因素影响,如 CUDA 上下文、中间激活值和显存碎片等,因此不会严格按照理论值减半或减少四分之一。对于较小的模型,差距可能不会那么显著。

精度的权衡与选择

准确性 vs. 性能

  • 高精度(FP32)

    • 优点:更高的数值稳定性和模型准确性。
    • 缺点:占用更多显存,计算速度较慢
  • 低精度(FP16/INT8)

    • 优点:占用更少的显存,计算速度更快
    • 缺点:可能引入数值误差,影响模型性能。

何时选择何种精度

  • FP32

    • 适用于训练小型模型或对数值精度要求较高的任务。
  • FP16/BF16

    • 适用于训练大型模型,利用混合精度(Mixed Precision)来节省显存并加速计算。
  • INT8

    • 主要用于推理阶段,尤其是在显存资源有限的情况下部署超大模型

硬件兼容性

  • FP16 支持

    • 大多数现代 NVIDIA GPU(如 RTX 20 系列及以上)支持 FP16。
  • BF16 支持

    • 需要 NVIDIA A100、H100 等数据中心级别的 GPU,或最新的 RTX 40 系列 GPU。
  • INT8 支持

    • 需要特殊的库(如 bitsandbytes)和硬件支持。

实际应用中的精度技巧

使用 FP16 精度

在训练中启用混合精度

PyTorch 提供了 torch.cuda.amp 模块,可以方便地实现混合精度训练,加速计算并降低显存占用。

import torch
from torch import nn, optim
from torch.cuda.amp import GradScaler, autocast# MPS (Metal Performance Shaders) for Apple Silicon GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"model = nn.Sequential(...)  # 定义模型
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()for data, labels in dataloader:data = data.to(device)labels = labels.to(device)optimizer.zero_grad()with autocast():outputs = model(data)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

在推理中使用 FP16

model.half()  # 将模型转换为 FP16
model.to(device)  # 将模型移动到合适的设备
inputs = inputs.half().to('cuda')
outputs = model(inputs)

使用 BF16 精度

BF16(Brain Floating Point)具有与 FP32 相同的指数位数,减小了溢出和下溢的风险。

model = model.to(torch.bfloat16).to(device)
inputs = inputs.to(torch.bfloat16).to(device)
outputs = model(inputs)

注意:并非所有 GPU 都支持 BF16,你需要检查硬件兼容性。

使用 INT8 量化

安装 bitsandbytes

pip install bitsandbytes

使用 bitsandbytes 库实现 INT8 量化

from transformers import AutoModelForCausalLM
import bitsandbytes as bnbmodel_name = 'gpt2-large'model = AutoModelForCausalLM.from_pretrained(model_name,load_in_8bit=True,device_map='auto'
)

消除警告

在加载模型时,可能会遇到以下警告:

The load_in_4bit and load_in_8bit arguments are deprecated and will be removed in the future versions. Please, pass a BitsAndBytesConfig object in quantization_config argument instead.

解决方法:

使用 BitsAndBytesConfig 对象来配置量化参数。

from transformers import AutoModelForCausalLM, BitsAndBytesConfigbnb_config = BitsAndBytesConfig(load_in_8bit=True)model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb_config,device_map='auto'
)

实践示例

对比不同精度下的显存占用

加载模型并查看显存占用

以下代码示例展示了在不同精度下加载 gpt2-large 模型时的显存占用情况,并进行简单的推理测试。gpt2-large 大约有 812M(8.12 亿)= 0.812B 个参数。

import os
import gc
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnbdef load_model_and_measure_memory(precision, model_name, device):if precision == 'fp32':model = AutoModelForCausalLM.from_pretrained(model_name).to(device)elif precision == 'fp16':model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device)elif precision == 'int8':bnb_config = BitsAndBytesConfig(load_in_8bit=True)model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb_config,device_map='auto')else:raise ValueError("Unsupported precision")# 确保所有 CUDA 操作完成torch.cuda.synchronize()mem_allocated = torch.cuda.memory_allocated(device) / 1e9print(f"Precision: {precision}, Memory Allocated after loading model: {mem_allocated:.2f} GB")# 删除模型并清理缓存del modelgc.collect()torch.cuda.empty_cache()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'gpt2-large'for precision in ['fp32', 'fp16', 'int8']:print(f"\n--- Loading model with precision: {precision} ---")load_model_and_measure_memory(precision, model_name, device)

示例输出

--- Loading model with precision: fp32 ---
Precision: fp32, Memory Allocated after loading model: 3.21 GB--- Loading model with precision: fp16 ---
Precision: fp16, Memory Allocated after loading model: 1.60 GB--- Loading model with precision: int8 ---
Precision: int8, Memory Allocated after loading model: 0.89 GB

额外说明

  • torch.cuda.memory_allocated 仅测量由 PyTorch 当前进程分配的显存,不包括其他进程或系统预留的显存。

常见问题及解决方案

问题一:RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error

RuntimeError: Failed to import transformers.models.gpt2.modeling_gpt2 because of the following error (look up to see its traceback):
module ‘wandb.proto.wandb_internal_pb2’ has no attribute ‘Result’

解决方法

  • 卸载并重新安装 wandb

    pip uninstall wandb
    pip install wandb
    
  • 如果问题仍然存在,禁用 wandb

    import os
    os.environ["WANDB_DISABLED"] = "true"

问题二:TypeError: dispatch_model() got an unexpected keyword argument 'offload_index'

解决方法:

  • 检查 transformersaccelerate 库的版本:

    import transformers
    import accelerateprint(f"Transformers version: {transformers.__version__}")
    print(f"Accelerate version: {accelerate.__version__}")
    
  • 更新库:

    pip install --upgrade transformers accelerate
    

总结

现在你应该理解了模型参数与显存的关系,以及不同数值精度对显存和性能的影响,这不仅在实际应用中具有重要意义,也是面试中的常见考点,而且对于后续的学习同样很重要。毕竟看得懂代码在说什么,比当作黑箱要好得多。

最后的思考:

精度的降低意味着性能的妥协,在我过去的一些小型试验中,低精度下训练的性能还是一般都不如高精度。但,跑不跑的好是一回事,能不能跑又是另一回事,如果低显存能跑大模型,性能上的妥协也是完全可以接受的。

参考文献

  • PyTorch Mixed Precision Training
  • Transformers Documentation
  • bitsandbytes - Github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/53978.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JMeter脚本开发

环境部署 Ubuntu系统 切换到root用户 sudo su 安装上传下载的命令 apt install lrzsz 切换文件目录 cd / 创建文件目录 mkdir java 切换到Java文件夹下 cd java 输入rz回车 选择jdk Linux文件上传 解压安装包 tar -zxvf jdktab键 新建数据库 运行sql文件 选择sql文件即…

基于51单片机的电饭锅控制系统proteus仿真

地址: https://pan.baidu.com/s/1CGyg6uPhFI0MeaBWwe_HAg 提取码:1234 仿真图: 芯片/模块的特点: AT89C52/AT89C51简介: AT89C52/AT89C51是一款经典的8位单片机,是意法半导体(STMicroelectro…

RedisTemplate操作ZSet的API

文章目录 ⛄概述⛄常见命令有⛄RedisTemplate API❄️❄️ 向集合中插入元素,并设置分数❄️❄️向集合中插入多个元素,并设置分数❄️❄️按照排名先后(从小到大)打印指定区间内的元素, -1为打印全部❄️❄️获得指定元素的分数❄️❄️返回集合内的成员个数❄️❄…

前端网络层性能优化

前言 在数字时代,速度已成为互联网体验的关键。用户对网页加载时间的容忍度越来越低,每一毫秒的延迟都可能导致用户的流失。根据谷歌的研究,页面加载时间超过3秒的网站,其跳出率会增加120%。在这个以用户为中心的网络世界里&…

Git换行符自动转换参数core.autocrlf的用法

core.autocrlf 是 Git 中用于控制换行符自动转换的配置选项。它有以下几个可能的值: 1. true 作用:在 checkin 时将 CRLF 转换为 LF,在 checkout 时将 LF 转换为 CRLF。适用场景:适用于 Windows 用户,希望在本地文件…

LineageOS刷机教程

版权归作者所有,如有转发,请注明文章出处:https://cyrus-studio.github.io/blog/ LineageOS 是一个基于 Android 开源项目(AOSP)的开源操作系统,主要由社区开发者维护。它起源于 CyanogenMod 项目&#xff…

10年Python程序员教你多平台采集10万+电商数据【附实例】

10万级电商数据采集需要注意什么? 在进行10万级电商数据采集时,有许多关键因素需要注意: 1. 采集平台覆盖:确保可以覆盖主流的电商平台,如淘宝、天猫、京东、拼多多等。 2. 数据字段覆盖:检查是否可以对平…

go 笔记

数据结构与 方法(增删改查) 安装goland,注意版本是2024.1.1,不是2024.2.1,软件下载地址也在链接中提供了 ‘go’ 不是内部或外部命令,也不是可运行的程序 或批处理文件。 在 Windows 搜索栏中输入“环境变量”&#…

架构理论碰撞:对比TOGAF、Zachman、DODAF和FEAF等主流架构框架

信息架构框架对比分析:选择适合企业的最佳方案 在企业数字化转型过程中,信息架构的设计与实施至关重要。成功的信息架构能够有效地支持业务流程优化,提升数据管理效率,推动技术创新。然而,不同的信息架构框架各有其独…

linux gcc 静态库的简单介绍

在 Linux 上,使用 GCC 编译器来创建和调用静态库时,涉及的实现原理和调用机制可以分为以下几个步骤: 1. 静态库的创建 静态库(通常以 .a 结尾)是由多个目标文件(.o 文件)打包在一起的归档文件…

判断线是否相交、判断点是否在线上、求线相交交点

先定义个点、线结构 typedef struct tagStruVertex {double x;double y;double distanceTo(const tagStruVertex& point) const{return sqrt((x - point.x) * (x - point.x) (y - point.y) * (y - point.y));}bool equal(const tagStruVertex& point) const{if (poin…

COMTRADE binary数据文件解析

一、COMTRADE 二进制文件的解析需要用到cfg文件中的配置信息,以及dat文件中的数据。 二、cfg文件 1、cfg文件整体配置 2、cfg文件实例 厂站名,记录装置,COMTRADE标准版本年号 SMARTSTATION,IED123,2013 总通道数,模拟通道编号&…

记录word转xml文件踩坑

word文件另存为xml文件后,xml文件乱码 解决方法: 1.用word打开.docx文件 2.另存为xml文件 3.点击工具 -> Web选项 -> 编码,选择UTF-8 4.点击确定 5.使用notpad打开xml文件 6.使用xml tool进行xml格式化即可。

uniapp小程序,使用腾讯地图获取定位

本篇文章分享一下在实际开发小程序时遇到的需要获取用户当前位置的问题,在小程序开发过程中经常使用到获取定位功能。uniapp官方也提供了相应的API供我们使用。 官网地址:uni.getLocation(OBJECT)) 官网获取位置的详细介绍这里就不再讲述了,大…

安宝特方案 | 医疗AR眼镜,重新定义远程会诊体验

【AR眼镜:重新定义远程会诊体验】 在快速发展的医疗领域,安宝特医疗AR眼镜以其尖端技术和创新功能,引领远程会诊的未来,致力于为为医生和患者带来更高效、精准和无缝的医疗体验。 探索安宝特医疗AR眼镜如何在医疗行业中引领新风潮…

视频推拉流/直播点播EasyDSS平台安装失败并报错“install mediaserver error”是什么原因?

TSINGSEE青犀视频推拉流/直播点播EasyDSS平台支持音视频采集、视频推拉流、播放H.265编码视频、存储、分发等视频能力服务,在应用场景中可实现视频直播、点播、转码、管理、录像、检索、时移回看等。此外,平台还支持用户自行上传视频文件,也可…

Gitbook 本地安装教程

Gitbook 本地安装教程 安装 node [nodejs的v10.21.0版本,下载地址:https://nodejs.org/dist/v10.21.0/node-v10.21.0-x64.msi] 其他版本有问题 npmnpm install -g gitbook-cligitbook init [初始化目录结构]gitbook build [编译]gitbook serve [运行] …

MongoDB日志级别

日志 查看当前的日志级别 根据你提供的 MongoDB 命令结果,命令 db.adminCommand({ getParameter: "logComponentVerbosity" }) 返回了 "ok" : 0,这意味着命令执行失败,没有成功获取到日志级别的配置信息。错误信息 &quo…

【项目一】基于pytest的自动化测试框架———解读requests模块

解读python的requests模块 什么是requests模块基础用法GET与POST的区别数据传递格式会话管理与持久性连接处理相应结果应对HTTPS证书验证错误处理与异常捕获 这篇blog主要聚焦如何使用 Python 中的 requests 模块来实现接口自动化测试。下面我介绍一下 requests 的常用方法、数…

【JAVA入门】Day45 - 压缩流 / 解压缩流

【JAVA入门】Day45 - 压缩流 / 解压缩流 文章目录 【JAVA入门】Day45 - 压缩流 / 解压缩流一、解压缩流二、压缩流 在文件传输过程中,文件体积比较大,传输较慢,因此我们发明了一种方法,把文件里的数据压缩到一种压缩文件中&#x…