【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录

  • 【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
    • 目录
      • 摘要
      • 研究背景
      • 问题与挑战
      • 如何解决
      • 创新点
      • 算法模型
      • 实验效果
      • 代码
      • 推荐阅读指数:✭✭✭✭✩
    • 后记


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbFclass Round(Function):@staticmethoddef forward(self, input):sign = torch.sign(input)output = sign * torch.floor(torch.abs(input) + 0.5)return output@staticmethoddef backward(self, grad_output):grad_input = grad_output.clone()return grad_input# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,zero_point=True, q_group_size=-1,inplace=False,get_scale_zp=False):org_w_shape = w.shapeif q_group_size > 0:assert org_w_shape[-1] % q_group_size == 0w = w.reshape(-1, q_group_size)elif q_group_size == -1:w = w.reshape(-1, w.shape[-1])assert w.dim() == 2if zero_point:max_val = w.amax(dim=1, keepdim=True)min_val = w.amin(dim=1, keepdim=True)max_int = 2 ** n_bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)else:  # we actually never used thisassert min_val is Nonemax_val = w.abs().amax(dim=1, keepdim=True)max_val = max_val.clamp(min=1e-5)max_int = 2 ** (n_bit - 1) - 1min_int = - 2 ** (n_bit - 1)scales = max_val / max_intzeros = 0assert torch.isnan(scales).sum() == 0assert torch.isnan(w).sum() == 0if inplace:((w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)).mul_(scales)else:w = (torch.clamp(torch.round(w / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(w).sum() == 0w = w.reshape(org_w_shape)if get_scale_zp:return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)else:return w@torch.no_grad()
def real_quantize_model_weight(model, w_bit, q_config,init_only=False
):from .qmodule import WQLinearfrom .pre_quant import get_blocks, get_named_linears, set_op_by_nameassert q_config["zero_point"], "We only support zero_point quantization now."layers = get_blocks(model)for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):layer = layers[i]named_linears = get_named_linears(layer)# scale_activations(layer)for name, module in named_linears.items():if init_only:q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], True)q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)else:module.cuda()module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)# scales = scales.t().contiguous()# zeros = zeros.t().contiguous()q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], False, scales, zeros)module.cpu()q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)torch.cuda.empty_cache()gc.collect()torch.cuda.empty_cache()gc.collect()def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):quantizer = SteN2F3Quantizer(q_group_size=q_group_size)w = quantizer(w)return wclass SteInt3AsymQuantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizeself.bit = 3def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, x.shape[-1])assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteInt2AsymQuantizer(nn.Module):def __init__(self, q_group_size=64):super().__init__()self.q_group_size = q_group_sizeself.bit = 2def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteN2F3Quantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizedef forward(self, x):org_w_shape = x.shape# reshape to groupsizeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0qx = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:qx = x.reshape(-1, x.shape[-1])assert qx.dim() == 2# Get the Min Maxmax_val = qx.amax(dim=1, keepdim=True)min_val = qx.amin(dim=1, keepdim=True)scale_pos = torch.abs(max_val)scale_neg = torch.abs(min_val)dev = qx.devicex_pos = torch.zeros_like(qx)x_neg = torch.zeros_like(qx)x_pos = torch.where(qx >= 0, qx, x_pos)x_neg = torch.where(qx < 0, qx, x_neg)q_pos = x_pos / scale_posq_neg = x_neg / scale_negq_pos, q_neg = self.round_pass(q_pos, q_neg, dev)qx = q_pos * scale_pos + q_neg * scale_negqx = qx.reshape(org_w_shape)return qxdef round_n2f3(self, q_pos, q_neg, dev):q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)return q_pos, q_negdef round_pass(self, q_pos, q_neg, dev):y_grad_pos, y_grad_neg = q_pos, q_negy_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/59700.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RocketMQ部署教程

拉取 RocketMQ 镜像&#xff1a; 首先&#xff0c;从 Docker Hub 获取最新的 RocketMQ 镜像&#xff1a; docker pull apache/rocketmq:latest创建 Docker 网络&#xff1a; 为了使各容器之间能够通信&#xff0c;创建一个名为 rocketmq 的网络&#xff1a; docker network cre…

ORACLE批量插入更新如何拆分大事务?

拆分大事务 一、批量插入更新二、拆分事务之前文章MYSQL批量插入更新如何拆分大事务?说明了Mysql如何拆分,本篇文章探讨Oracle或OceanBase批量插入更新拆分大事务的问题 一、批量插入更新 oracle批量插入更新可使用merge语法eg: merge test ausing test_tmp bon (a.id = b.id…

鸿蒙next打包流程

目录 下载团结引擎 添加开源鸿蒙打包支持 打包报错 路径问题 安装DevEcoStudio 可以在DevEcoStudio进行打包hap和app 包结构 没法直接用previewer运行 真机运行和测试需要配置签名,DevEcoStudio可以自动配置, 模拟器安装hap提示报错 安装成功,但无法打开 团结1.3版本新增工具…

chatgpt3.5权重参数有多少MB;llama7B权重参数有多少MB

目录 chatgpt3.5权重参数有多少MB llama7B权重参数有多少MB chatgpt3.5权重参数有多少MB 关于ChatGPT 3.5的权重参数占用的存储空间大小,虽然直接给出具体的MB数值可能较为困难(因为这取决于多种因素,如参数表示的精度、是否进行了压缩等),但可以根据其参数量来估算一个…

基于Jeecgboot3.6.3vue3的flowable流程online表单的审批使用介绍

更多技术支持与服务请加入我的知识星球或加我微信&#xff0c;名称:亿事达nbcio技术交流社区https://t.zsxq.com/iPi8F 今天介绍一下基于jeecgboot3.6.3的flowable流程使用online表单进行审批的情况 1、首先建立一个online应用类型的流程&#xff0c;如下&#xff1a; 2、进行…

【LeetCode】【算法】238. 除自身以外数组的乘积

LeetCode 238. 除自身以外数组的乘积 题目描述 给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据保证数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位整数范围内。 请不…

如何构建一个可扩展的测试自动化框架?

以下为作者观点&#xff1a; 假设你是测试自动化方面的新手&#xff0c;想参与构建一个框架。在这种情况下&#xff0c;重要的是要了解框架所需的组件&#xff0c;以及它们是如何组合的。思考项目的具体需求和目标&#xff0c;以及可能遇到的困难和挑战。 假如你是一个测试架…

实战:索引的命中机制

在 SQL Server 中,查询是否能命中索引(即是否能使用 Index Seek)取决于多个因素,包括索引的结构、查询条件的排列、和数据库优化器的策略。以下是一些常见的命中索引和不能命中索引的情况,及其详细解释: 一、命中索引的情况 1. 前导列匹配(典型的命中索引场景) 索引结…

使用Docker快速部署FastAPI Web应用

Docker是基于 Linux 内核的cgroup、namespace以及 AUFS 类的Union FS 等技术&#xff0c;对进程进行封装隔离&#xff0c;一种操作系统层面的虚拟化技术。Docker中每个容器都基于镜像Image运行&#xff0c;镜像是容器的只读模板&#xff0c;容器是模板的一个实例。镜像是分层结…

C++【string类,模拟实现string类】

&#x1f31f;个人主页&#xff1a;落叶 &#x1f31f;当前专栏: C专栏 目录 为什么学习string类 C语言中的字符串 标准库中的string类 auto和范围for auto关键字 迭代器 范围for string类的常用接口说明和使用 1. string类对象的常见构造 2.string类对象的容量操作 3…

android——jetpack startup初始化框架

一、jetpack startup Android Jetpack Startup是一个库&#xff0c;它简化了Android应用启动过程&#xff0c;尤其是对于那些需要处理复杂数据绑定和初始化逻辑的应用。它的核心在于提供了一个StartupComponent&#xff0c;用于声明应用的初始化逻辑&#xff0c;这个逻辑会在首…

A019基于SpringBoot的校园闲置物品交易系统

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…

【赵渝强老师】Redis的RDB数据持久化

Redis 是内存数据库&#xff0c;如果不将内存中的数据库状态保存到磁盘&#xff0c;那么一旦服务器进程退出会造成服务器中的数据库状态也会消失。所以 Redis 提供了数据持久化功能。Redis支持两种方式的持久化&#xff0c;一种是RDB方式&#xff1b;另一种是AOF&#xff08;ap…

Excel:vba实现批量插入图片批注

实现的效果&#xff1a;实现的代码如下&#xff1a; Sub InsertImageNamesAndPictures()Dim PicPath As StringDim PicName As StringDim PicFullPath As StringDim RowNum As IntegerDim Name As StringDim Comment As CommentDim folder As FileDialog 定义文件选择对话框 清…

tomcat启动失败和缓存清理办法

tomcat只在学校接触过并且是在window xp和win7的电脑上配置过&#xff08;中途升级过电脑系统&#xff09;&#xff0c;只记得在windows系统上可以将其设置成服务管理。但我已毕业10多年了&#xff0c;学的知识早就不知道丢哪里了。这次为了修改一个07&#xff0c;08年的项目&a…

SpringBoot使用ApplicationContext.getBean启动报空指针处理记录

问题&#xff1a;项目启动报空指针 定位&#xff1a;新增filter中init方法使用getbean控制 解决&#xff1a;在新增filter上加注解 DependsOn({"applicationContextUtils"}) Component DependsOn({"applicationContextUtils"})//此处解决空指针问题 pu…

ReactPress:深入解析技术方案设计与源码

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议&#xff0c;欢迎一起共建&#xff0c;感谢Star。 ReactPress是一个基于React框架开发的开源发布平台&#xff0c;它不仅仅是一个简单的博客系统&#xff0c;更是一个功能全…

华为OD机试真题-用户调度问题-2024年OD统一考试(E卷)

最新华为OD机试考点合集:华为OD机试2024年真题题库(E卷+D卷+C卷)_华为od机试题库-CSDN博客 每一题都含有详细的解题思路和代码注释,精编c++、JAVA、Python三种语言解法。帮助每一位考生轻松、高效刷题。订阅后永久可看,发现新题及时跟新。 题目描述 在通信系统中,一…

Docker 基础命令简介

目录 Docker 基础命令 1. Docker 版本信息 2. 获取 Docker 帮助 3. 列出所有运行中的容器 4. 运行一个新的容器 5. 查看容器日志 6. 停止容器 7. 启动已停止的容器 8. 删除容器 9. 列出所有镜像 10. 拉取镜像 11. 构建镜像 12. 删除镜像 13. 执行命令 14. 查看容…

A20红色革命文物征集管理系统

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…