LLM优化:开源星火13B显卡及内存占用优化

1. 背景

本qiang~这两天接了一个任务,部署几个开源的模型,并且将本地经过全量微调的模型与开源模型做一个效果对比。

部署的开源模型包括:星火13B,Baichuan2-13B, ChatGLM6B等

其他两个模型基于transformers架构封装,因此推理服务启动还是十分丝滑,但星火13B是基于Megatron-DeepSpeed框架实现,地址是:https://gitee.com/iflytekopensource/iFlytekSpark-13B,启动推理服务的过程中发现启动13B的显卡占用71G-78G,有些反直觉。

此文就是整理开源星火13B的显存及内存排查并优化的整理过程,至于哪家开源模型效果好,不在此文的讨论范围内。

2. 原因分析

直观上来说,13B的模型,数据类型为bf16,显卡占用大概在26G左右,但星火13B直接占用70G+,不可思议,怪不得网上关于星火开源模型的讨论少之又少,原因显而易见,这么大的显存占用只能用多卡或者A800等80G显卡才能适配。穷人家的孩子,哪有这么多余粮。

排查原因的过程中,少不了源码的调试与分析。在排查的过程中,启动推理服务的文件run_iFlytekSpark_text_generation.py中,model_provider方法是初始化模型并加载模型文件的方法。

def model_provider(pre_process=True, post_process=True):"""Build the model."""print_rank_0('building iFlytekSpark model ...')args = get_args()config = core_transformer_config_from_args(args)### 初始化星火模型model = iFlytekSparkModel(config,num_tokentypes=0,parallel_output=False,pre_process=pre_process,post_process=post_process,return_moe_loss=False)if args.from_pretrained is not None:assert os.path.exists(args.from_pretrained)ckpt_path = get_checkpoint_name(args.from_pretrained)print_rank_0('Loading from {} '.format(args.from_pretrained))# 模型加载权重文件state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")if 'module' in state_dict:state_dict = state_dict['module']model.load_state_dict(state_dict)return model

其中,加载权重文件可以看到,加载state_dict时,直接将权重文件加载到显卡中,而非加载至CPU,然后再执行to方法,转移到GPU。因此该处是一个潜在的优化点

再打入iFlytekSparkModel内部,词表Embedding层,线性转换层,等初始化weight时,也是直接将weight分配在GPU上运行。例如下例:

class RowParallelLinear(torch.nn.Module):def __init__(self, input_size: int, output_size: int, *,config: ModelParallelConfig,init_method: Callable,bias: bool = True,input_is_parallel: bool = False,stride: int = 1,keep_master_weight_for_test: bool = False,skip_bias_add: bool = False,moe=False, enable_expert_tensor_parallelism=False):super(RowParallelLinear, self).__init__()# .........if config.use_cpu_initialization:self.weight = Parameter(torch.empty(self.output_size,self.input_size_per_partition,dtype=config.params_dtype))if config.perform_initialization:self.master_weight = _initialize_affine_weight_cpu(self.weight, self.output_size, self.input_size,self.input_size_per_partition, 1, init_method,stride=stride, return_master_weight=keep_master_weight_for_test,params_dtype=config.params_dtype)else:# 默认按照启动sh命令,会走该分支self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition,device=get_accelerator().current_device_name(), dtype=config.params_dtype))if config.perform_initialization:_initialize_affine_weight_gpu(self.weight, init_method,partition_dim=1, stride=stride)if bias:if config.use_cpu_initialization:self.bias = Parameter(torch.empty(self.output_size,dtype=config.params_dtype))else:# 默认按照启动sh命令,会走该分支self.bias = Parameter(torch.empty(self.output_size, device=get_accelerator().current_device_name(),dtype=config.params_dtype))setattr(self.bias, 'sequence_parallel', self.sequence_parallel)if config.perform_initialization:# Always initialize bias to zero.with torch.no_grad():self.bias.zero_()else:self.register_parameter('bias', None)

3. 优化方案

1. 模型初始化时,模型的Embedding,线性层的权重weight均直接加载至GPU,因此可以优化为先将这些weight加载至CPU

改进的方式也很简单,从上面的源码层面,可以看到,当增加参数” use_cpu_initialization”,将使用CPU进行初始化权重,因此只需要在启动推理服务的脚本中增加” --use-cpu-initialization”参数即可。

2. 加载模型文件时,直接加载至GPU,然后run_iFlytekSpark_text_generation.py中的get_model方法中,当模型加载完成后,会进行分配至GPU以及FP16的转换的操作。如下代码所示。

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):"""Build the model."""args = get_args()args.model_type = model_type# ..........# GPU allocation.for model_module in model:model_module.to(get_accelerator().current_device_name())# Fp16 conversion.if args.fp16 or args.bf16:model = [Float16Module(model_module, args) for model_module in model]# .......return model

因此,优化的方式也很简单,可以优化为先加载至CPU,再运行get_model中的默认分配至GPU,加载完后,再使用垃圾回收机制清除CPU占用的内存即可

话不多说,优化后的代码如下:

def model_provider(pre_process=True, post_process=True):"""Build the model."""print_rank_0('building iFlytekSpark model ...')args = get_args()config = core_transformer_config_from_args(args)model = iFlytekSparkModel(config,num_tokentypes=0,parallel_output=False,pre_process=pre_process,post_process=post_process,return_moe_loss=False)if args.from_pretrained is not None:print(args.from_pretrained)assert os.path.exists(args.from_pretrained)ckpt_path = get_checkpoint_name(args.from_pretrained)print_rank_0('Loading from {} '.format(args.from_pretrained))# state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")# CPU进行加载state_dict = torch.load(ckpt_path, map_location=f"cpu")if 'module' in state_dict:state_dict = state_dict['module']model.load_state_dict(state_dict)# 加载完成,删除state_dict,并垃圾回收del state_dictgc.collect()torch.cuda.empty_cache()return model

4. 效果对比

(1) 优化前的显卡占用: 71.5G

(2) 优化前的内存占用: 虚拟内存占用94.5G

(3) 优化后的显卡占用: 26G

(4) 优化后的内存占用: 43.1G

5. 总结

一句话足矣~

本文主要是针对开源星火13B的显存及内存占用过大的一个代码优化。核心思想是使用CPU预加载模型,再转换至GPU。

后期如有遇到此类问题,可以借鉴之~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/4646.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

表单提交出现问题却没有报错

最近搞毕设提交表单传给后台总是出现错误,有时候可以运行成功,有时候运行不了但是没有报错,以为是jQuery导入的问题尝试换了jQuery的其他导入方式没有解决,后来发现前端页面的表单要防止默认操作!!&#xf…

qt学习篇---C++基础学习

本学习笔记学习下面视频总结,感兴趣可以去学习。讲的很详细 【北京迅为】嵌入式学习之QT学习篇_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1tp4y1i7EJ/?spm_id_from333.337.search-card.all.click&vd_source8827cc0da16223b9f2ad8ae7111de9e2 目录 C…

PDCA循环:持续精进的工具

文章目录 一、什么是PDCA二、PDCA的应用场景三、PDCA在信息系统项目管理中的应用 一、什么是PDCA PDCA循环是由美国质量管理专家沃特阿曼德休哈特(Walter A. Shewhart)在20世纪30年代提出的,最初用于制造业的质量管理。休哈特博士在构想PDCA…

【C++题解】1418. 求一个5位数的各个位之和

问题:1418. 求一个5位数的各个位之和 类型:基本运算、拆位求解 题目描述: 从键盘读入一个 5 位的正整数,请求出这个 5 位数的各个位之和。 输入: 一个 5 位的正整数 n 。 输出: 这个 5 位数的各个位之…

Aiseesoft Blu-ray Player for Mac:蓝光播放器

Aiseesoft Blu-ray Player for Mac是一款功能强大且易于使用的蓝光播放器,专为Mac用户打造。它以其卓越的性能和简洁的操作界面,为用户带来了全新的高清蓝光播放体验。 Aiseesoft Blu-ray Player for Mac v6.6.50激活版下载 这款软件支持播放任何高质量的…

ArcGIS Pro3.0软件破解版安装教程

软件名称:ArcGIS Pro 3.0 安装环境:Windows 软件大小:7.3GB 硬件要求:CPU2GHz,内存4G(或更高) 百度云下载链接 : https://pan.baidu.com/s/1CXy1MSwdQXdVnJoV2X422A 提 取 码 :r0w1 教学内…

AI图书推荐:ChatGPT写论文的流程与策略

论文一直是任何学术学位的顶峰。它展示了学生在研究领域的兴趣和专业知识。撰写论文也是一个学习经验,为学术工作以及专业研究角色做好准备。但是,论文工作总是艰苦的,通常是充满乐趣和创造性的,但有时也是乏味和无聊的。生成式人…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-6.4

前言: 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM(MX6U)裸机篇”视频的学习笔记,在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

采用前后端分离Vue,Ant-Design技术开发的(手麻系统成品源码)适用于三甲医院

开发环境 技术架构:前后端分离 开发语言:C#.net6.0 开发工具:vs2022,vscode 前端框架:Vue,Ant-Design 后端框架:百小僧开源框架 数 据 库:sqlserver2019 系统特性 麻zui、护理、PACU等围术期业务全覆…

FreeRTOS学习——FreeRTOS队列(上)

本篇文章记录我学习FreeRTOS队列的相关知识,主要包括队列简介、队列的结构体、队列创建等知识。 队列是为了任务与任务、任务与中断之间的通信而准备的,可以在任务与任务、任务与中断之间传递消息,队列中可以存储有限的、大小固定的数据项目。…

Android 在attrs.xml添加属性时出现 Found item Attr/****** more than one time

Android 在attrs.xml添加属性时出现 Found item Attr/****** more than one time 问题描述解决办法方式一方式二 小结 问题描述 在Android应用开发过程中,经常需要自定义控件,并且定义控件的属性,方便灵活的修改控件的显示样式,提…

IT廉连看——UniApp——样式绑定

IT廉连看——UniApp——样式绑定 一、样式绑定 两种添加样式的方法: 1、第一种写法 写一个class属性,然后将css样式写在style中。 2、第二种写法 直接把style写在class后面 添加一些效果:字体大小 查看效果 证明这样添加样式是没有问题的…

【提示学习论文】PMF:Efficient Multimodal Fusion via Interactive Prompting论文原理

Efficient Multimodal Fusion via Interactive Prompting(CVPR2023) 基于交互式提示的高效多模态融合方法减少针对下游任务微调模型的计算成本提出模块化多模态融合架构,促进不同模态之间的相互交互将普通提示分为三种类型,仅在单…

websocket 单点通信,广播通信

Websocket协议是对http的改进,可以实现client 与 server之间的双向通信; websocket连接一旦建立就始终保持,直到client或server 中断连接,弥补了http无法保持长连接的不足,方便了客户端应用与服务器之间实时通信。 参…

大数据005-hadoop003-了解MR及Java的简单实现

了解MapReduce MapReduce过程分为两个阶段:map阶段、reduce阶段。每个阶段搜键-值对作为输入和输出。 要执行一个MR任务,需要完成map、reduce函数的代码开发。 Hellow World 【Hadoop权威指南】中的以分析气象数据为例,找到每年的最高气温。…

Jenkins持续化集成

优质博文:IT-BLOG-CN 工作过程如下环境准备 开发人员提交代码>jenkins获取代码>调用单元测试>打包>发布 环境准备Jenkins的安装 Tomcat、Maven、Git或Svn、Jdk Jenkins的安装 1、官网下载war :http://Jenkins-ci.org/ 2、tomcat-users.…

NTFS文件权限管理

实验环境 windows server 2016 实验要求 实验步骤 1、 新建文件 2、打开文件夹的属性->安全->高级 3、禁用继承 4、添加组或用户 技术资料: 常用软件: 手机端项目: 电脑端项目: 公司制度: 销售资源&#xff…

【Scala---01】Scala『 Scala简介 | 函数式编程简介 | Scala VS Java | 安装与部署』

文章目录 1. Scala简介2. 函数式编程简介3. Scala VS Java4. 安装与部署 1. Scala简介 Scala是由于Spark的流行而兴起的。Scala是高级语言,Scala底层使用的是Java,可以看做是对Java的进一步封装,更加简洁,代码量是Java的一半。 因…

JAVA读取从WPS在Excel中嵌入的图片资源

读取从WPS在Excel中嵌入的图片资源 引言 许多数据文件中可能包含嵌入式图片,这些图片对于数据分析和可视化非常重要。然而,从 WPS 在 Excel 中读取这些图片可能会有一些技术挑战。在本文中,我将展示如何从 WPS Excel 文件中读取嵌入的图片&am…

海外三大AI图片生成器对比(Stable Diffusion、Midjourney、DALL·E 3)

Stable Diffusion DreamStudio 是Stable Diffusion 的官方网页,价格便宜,对图片的操作性强,但同时编辑页面不太直观,对使用者的要求较高。 与 DALLE 和 Midjourney 不同,Stable Diffusion 是开源的。这也意味着&…