GPT实战系列-大模型训练和预测,如何加速、降低显存

GPT实战系列-大模型训练和预测,如何加速、降低显存

不做特别处理,深度学习默认参数精度为浮点32位精度(FP32)。大模型参数庞大,10-1000B级别,如果不注意优化,既耗费大量的显卡资源,也耗费大量的训练时间,AI算法中心的训练的投入都给英伟达送钱去了。有的地方32位精度没有太大必要,这就是浮点精度和量化的动力来源。

大模型的训练和预测过程中,如何加快训练速度?如何降低显存占用?
有哪些简单,快速上手的方法?

文章目录

  • GPT实战系列-大模型训练和预测,如何加速、降低显存
    • 混合精度
      • 精度数位表示
      • 转换流程
    • 量化
      • 量化训练
      • 量化推理

混合精度

混合精度训练(mixed precision training)是一种加速深度学习训练的技术。其主要思想是在精度降低可忍受的范围内,使用较低精度的浮点数(如FP16)来表示神经网络中的权重和激活值,从而减少内存使用和计算开销,进而加速训练过程。

混合精度训练的实现可以分为以下几个步骤:

  1. 将FP32的权重转换为FP16格式,然后进行前向计算,得到FP32的损失(loss)。
  2. 使用FP16计算梯度。
  3. 将梯度转换为FP32格式,并将其更新到权重上。

由于FP16精度较低,表示的数值范围小,可能会导致精度损失,因此在混合精度训练中,需要使用一些技巧来保持模型的精确性。例如,可以使用梯度缩放(GradScaler)来控制梯度的大小,以避免梯度下降过快而影响模型的准确性。

精度数位表示

  • FP32:单精度浮点数,使用32位二进制数表示,其中1位表示符号位,8位表示指数位,23位表示尾数位,能够表示的数值范围为 ± 3.4 × 1 0 38 ±3.4×10^{38} ±3.4×1038
  • FP16:半精度浮点数,使用16位二进制数表示,其中1位表示符号位,5位表示指数位,10位表示尾数位,能够表示的数值范围为 ± 2 15 ±2^{15} ±215
  • FP64:双精度浮点数,使用64位二进制数表示,其中1位表示符号位,11位表示指数位,52位表示尾数位,能够表示的数值范围为 ± 1.8 × 1 0 308 ±1.8×10^{308} ±1.8×10308
  • INT8:8位整数,其中1位表示符号位,能够表示的数值范围为 $ -128到127$。
  • INT4:4位整数,其中1位表示符号位,能够表示的数值范围为 − 8 到 7 -8到7 87

在这里插入图片描述

  • 转换流程

混合精度训练的流程如下:

  1. 将FP32的权重转换为FP16格式,然后进行前向计算,得到FP32的损失(loss)。
  2. 使用FP16计算梯度。
  3. 将梯度转换为FP32格式,并将其更新到权重上。

在训练过程中,使用autocast将输入和输出转换为FP16格式,使用GradScaler对损失值进行缩放,以避免梯度下降过快而影响模型的准确性。

量化

量化是一种通过整型数值表示浮点的计算方式,减少数字表示的位数来减小模型存储量和计算量的方法。在深度学习中,通常使用32位浮点数来表示权重和激活值。但是,这种精度可能会导致计算和存储的开销非常高。因此,量化使用更短的整数表示权重和激活值,从而减少内存和计算开销。

量化使用整型数值,避免使用浮点处理,加速计算过程,同时也减少用于表示数字或值的比特数,降低存储的技术。将通过将权重存储在低精度数据类型中,来降低模型参数的训练、预测计算过程和模型和中间缓存的存储空间。由于量化减少了模型大小,因此它有利于在CPU或嵌入式系统等资源受限的设备上部署模型。

一种常用的方法是将模型权重从原始的16位浮点值量化为精度较低的8位整数值

8bit 参数量化

GPT,Baichuan2,ChatGLM3等大模型LLM已经展示出色的能力,但是它需要大量的CPU和内存,其中使用一种方法可以使用量化来压缩这些模型,以减少内存占用并加速计算推理,并且尽量保持模型精度性能。


在量化过程中,可以使用两种方法:动态量化和静态量化

  • 动态量化在运行时收集数据,并根据数据动态地量化模型。
  • 静态量化在训练过程中对模型进行量化,并在推理时应用量化。

量化会导致模型精确度下降,因为更低的精度可能会导致舍入误差。因此,在量化期间,需要进行一些技巧来保持模型的准确程度,例如:对权重进行缩放或使用动态范围量化。

同时,在量化模型之前,需要对模型进行测试,确保精确度可以接受。另外,不是所有的模型都可以被量化,只有支持动态量化的模型才可以使用该方法进行量化

例如:load_in_8bit=True

 from transformers import AutoTokenizer, AutoModel model = AutoModel.from_pretrained("THUDM/chatglm3-6b",revision='v0.1.0',load_in_8bit=True,trust_remote_code=True,device_map="auto")

总的来说,量化是一种非常有用的方法,可以减少模型的存储和计算开销,提高模型在设备上的执行效率。

量化训练

在深度学习中,量化是一种通过减少数字表示的位数来减小模型存储量和计算量的方法。在使用混合精度训练时,可以将模型权重和梯度从FP32转换为FP16,以节省内存和加速训练。同样的思路,量化训练可以将激活值转换为更短的整数,从而减少内存和计算开销

PyTorch中提供一些量化训练的工具和API,例如QAT(量化感知训练),使用动态范围量化等。其中,使用Adam8bit进行量化训练是一种方法。

量化推理

使用load_in_8bit方法可以实现模型的量化。该方法可以将模型权重和激活值量化为8位整数,从而减少内存和计算开销。具体实现方法如下:

import torch
from transformers import AutoModel# 加载模型
model = AutoModel.from_pretrained('bert-base-uncased',load_in_8bit=True)

需要注意的是,使用load_in_8bit方法量化模型可能会导致模型精确度下降。另外,不是所有的模型都可以被量化,只有支持动态量化的模型才可以使用该方法进行量化。

点个赞 点个赞 点个赞

觉得有用 收藏 收藏 收藏

End


GPT专栏文章:
GPT实战系列-ChatGLM2部署Ubuntu+Cuda11+显存24G实战方案

GPT实战系列-Baichuan2本地化部署实战方案

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

GPT实战系列-如何用自己数据微调ChatGLM2模型训练

GPT实战系列-GPT训练的Pretraining,SFT,Reward Modeling,RLHF

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(二)

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(一)

GPT实战系列-ChatGLM2模型的微调训练参数解读

GPT实战系列-GPT训练的Pretraining,SFT,Reward Modeling,RLHF


决策引擎专栏:
Falcon构建轻量级的REST API服务

决策引擎-利用Drools实现简单防火墙策略

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/188176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python应用:利用matplotlib画学生成绩分布饼图

1. 题目 给定一组学生成绩:[85, 92, 78, 65, 95, 88, 72, 60, 98, 45, 100, 46, 23, 88, 67, 89, 67, 88, 99],现在评分等级为优(90-100)、良(70-89)、及格(60-69)、不及格&#xff…

玩转大数据4:大数据的崛起与应用领域探索

图片来源网络 引言 在当今数字化时代,大数据正以前所未有的速度和规模崛起。大数据的出现不仅改变了企业和组织的经营模式,也对我们的社会生活带来了深刻的影响。Java作为一种广泛使用的编程语言,在大数据领域发挥着重要的作用。本文将重点…

工程师每日刷题 -4

文章目录 1、深度学习2、算法与数据结构2.1、暴力解法2.2、滑动窗口法 3、编程基础 1、深度学习 问题:CNN的本质和优势? CNN 本质上是一个多层感知机 (MLP),其成功的原因关键在于它所采用的【稀疏连接】(局部感受)和…

【带头学C++】----- 九、类和对象 ---- 9.3 析构函数

9.3 析构函数 9.3.1 如何定义析构函数 函数名和类名称相同,在函数名前加 ~ ,没有返回值类型,没有函数形参。 (不能被重载) 当对象生命周期结束的时候,系统自动调用析构函数(析构函数会先清理对象占用内存空间存放的…

【openssl】Window系统如何编译openssl

本文主要记录如何编译出windows版本的openss的lib库 1.下载openssl,获得openssl-master.zip。 a.可以通过github(网址在下方)上下载最新的代码、今天是2023.12.1我用的master版本,下载之后恭喜大侠获得《openssl-master.zip》 …

快递物流模拟系统

快递物流模拟系统 文章目录 快递物流模拟系统一、目的二、技术实现:三、网页功能具体介绍 一、目的 调用百度地图 JavaScript API 创建的简单的基站物流GPS定位与监控系统的示例网页 二、技术实现: 使用百度地图 JavaScript API 版本 2.0。利用 BMap …

Webpack——Webpack简介

1、什么是Webpack? Webpack是一个开源的JavaScript模块打包工具,其最核心的功能是解决模块之间的依赖,把各个模块按照特定的规则和顺序组织在一起,最终合并为一个JS文件(有时会有多个,这里讨论的只是最基本…

SQL Sever 基础知识 - 数据排序

SQL Sever 基础知识 - 二 、数据排序 二 、对数据进行排序第1节 ORDER BY 子句简介第2节 ORDER BY 子句示例2.1 按一列升序对结果集进行排序2.2 按一列降序对结果集进行排序2.3 按多列对结果集排序2.4 按多列对结果集不同排序2.5 按不在选择列表中的列对结果集进行排序2.6 按表…

人才缺口达150万!云计算凭什么这么火?

《中国互联网发展报告2022》指出,2021年,我国云计算市场规模达到3229亿元,增速为54.4%。未来5年内,我国云计算产业将面临高达近150万的人才缺口,预计未来市场仍将保持30%的增速。与此同时,随着大数据、人工…

【每日OJ —— KY11 二叉树遍历】

每日OJ —— KY11 二叉树遍历 1.题目:KY11 二叉树遍历2.解法2.1.算法讲解2.2.代码实现2.3.提交通过展示 1.题目:KY11 二叉树遍历 2.解法 2.1.算法讲解 1.首先需要创建二叉树结构。 2.其次,根据题目根据题目遍历的顺序要求来实现构建二叉树的…

代码demo-内部订单批量投料

为了简化用户操作,开发内部订单批量投料功能 用户可以批量上传,或者选择对应的物料,输入库位和内部订单号后进行过账操作 对用户选择的内部订单做校验,内部订单是否正确 内部订单的公司是否和工厂对应的公司一致等等 下面展示…

Sui与阿联酋科技孵化器Hub71合作支持生态项目建设,扩大全球影响力

近日,总部位于阿联酋( United Arab Emirates ,UAE)的科技孵化器Hub71宣布与Mysten Labs合作,将支持Sui上的新项目。通过本次合作,孵化项目的开发者们不仅可以获得Mysten Labs的技术专业知识和支持&#xff…

Flutter基础开发

参考:http://bbs.itying.com/topic/5cdb83b7fac8b00944a7a0c3 参考:https://www.bilibili.com/video/BV1S4411E7LY?p34&spm_id_frompageDriver 1.使用镜像 由于在国内访问Flutter有时可能会受到限制,Flutter官方为中国开发者搭建了临时镜像,大家可以…

SpringBoot整合MyBatis-Plus

🙈作者简介:练习时长两年半的Java up主 🙉个人主页:程序员老茶 🙊 ps:点赞👍是免费的,却可以让写博客的作者开心好久好久😎 📚系列专栏:Java全栈,…

【HTTP协议】简述HTTP协议的概念和特点

🎊专栏【网络编程】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【如愿】 🥰欢迎并且感谢大家指出小吉的问题 文章目录 🌺概念🌺特点🎄请求协议🎄响应协议…

java第二十六课

数据库多表 多表做到每个表的字段名称不一样 Mysql 关系数据库 结合到商城:用户表 订单表 商品表 商品详情表 用户表:字段: 用户 id:唯一标志用户 用户名称:name 用户性别:sex 用户年龄:age 用户地址:position 用户密码…

C++相关闲碎记录(2)

1、误用shared_ptr int* p new int; shared_ptr<int> sp1(p); shared_ptr<int> sp2(p); //error // 通过原始指针两次创建shared_ptr是错误的shared_ptr<int> sp1(new int); shared_ptr<int> sp2(sp1); //ok 如果对C相关闲碎记录(1)中记录的shar…

智慧机场视频监控系统方案:AI智能助力机场智慧运营

一、方案背景 随着人们生活物质水平的上升&#xff0c;人们对机场的需求也日益增多&#xff0c;在民航新建、迁建、扩建机场项目猛增的同时&#xff0c;也需同步配备相应的安防监控系统&#xff0c;以满足民航机场安全管理要求和机场运营业务的高速发展。 二、方案概述 智慧机…

MySQL 教程 1.4

MySQL 连接 使用mysql二进制方式连接 您可以使用MySQL二进制方式进入到mysql命令提示符下来连接MySQL数据库。 实例 以下是从命令行中连接mysql服务器的简单实例&#xff1a; [roothost]# mysql -u root -p Enter password:****** 在登录成功后会出现 mysql> 命令提示窗…

Redis7--基础篇6(复制replica)

1. 复制(replica)介绍 Redis数据库支持主从复制&#xff0c;master以写为主&#xff0c;slave以读为主&#xff0c;当master数据变化的时候&#xff0c;自动将新的数据异步同步到slave数据库。 实现读写分离、容灾恢复、数据备份、水平扩容支撑高并发。 2. 案例演示 2.1 架构…