【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录

  • 【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
    • 目录
      • 摘要
      • 研究背景
      • 问题与挑战
      • 如何解决
      • 创新点
      • 算法模型
      • 实验效果
      • 代码
      • 推荐阅读指数:✭✭✭✭✩
    • 后记


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbFclass Round(Function):@staticmethoddef forward(self, input):sign = torch.sign(input)output = sign * torch.floor(torch.abs(input) + 0.5)return output@staticmethoddef backward(self, grad_output):grad_input = grad_output.clone()return grad_input# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,zero_point=True, q_group_size=-1,inplace=False,get_scale_zp=False):org_w_shape = w.shapeif q_group_size > 0:assert org_w_shape[-1] % q_group_size == 0w = w.reshape(-1, q_group_size)elif q_group_size == -1:w = w.reshape(-1, w.shape[-1])assert w.dim() == 2if zero_point:max_val = w.amax(dim=1, keepdim=True)min_val = w.amin(dim=1, keepdim=True)max_int = 2 ** n_bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)else:  # we actually never used thisassert min_val is Nonemax_val = w.abs().amax(dim=1, keepdim=True)max_val = max_val.clamp(min=1e-5)max_int = 2 ** (n_bit - 1) - 1min_int = - 2 ** (n_bit - 1)scales = max_val / max_intzeros = 0assert torch.isnan(scales).sum() == 0assert torch.isnan(w).sum() == 0if inplace:((w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)).mul_(scales)else:w = (torch.clamp(torch.round(w / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(w).sum() == 0w = w.reshape(org_w_shape)if get_scale_zp:return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)else:return w@torch.no_grad()
def real_quantize_model_weight(model, w_bit, q_config,init_only=False
):from .qmodule import WQLinearfrom .pre_quant import get_blocks, get_named_linears, set_op_by_nameassert q_config["zero_point"], "We only support zero_point quantization now."layers = get_blocks(model)for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):layer = layers[i]named_linears = get_named_linears(layer)# scale_activations(layer)for name, module in named_linears.items():if init_only:q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], True)q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)else:module.cuda()module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)# scales = scales.t().contiguous()# zeros = zeros.t().contiguous()q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], False, scales, zeros)module.cpu()q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)torch.cuda.empty_cache()gc.collect()torch.cuda.empty_cache()gc.collect()def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):quantizer = SteN2F3Quantizer(q_group_size=q_group_size)w = quantizer(w)return wclass SteInt3AsymQuantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizeself.bit = 3def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, x.shape[-1])assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteInt2AsymQuantizer(nn.Module):def __init__(self, q_group_size=64):super().__init__()self.q_group_size = q_group_sizeself.bit = 2def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteN2F3Quantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizedef forward(self, x):org_w_shape = x.shape# reshape to groupsizeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0qx = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:qx = x.reshape(-1, x.shape[-1])assert qx.dim() == 2# Get the Min Maxmax_val = qx.amax(dim=1, keepdim=True)min_val = qx.amin(dim=1, keepdim=True)scale_pos = torch.abs(max_val)scale_neg = torch.abs(min_val)dev = qx.devicex_pos = torch.zeros_like(qx)x_neg = torch.zeros_like(qx)x_pos = torch.where(qx >= 0, qx, x_pos)x_neg = torch.where(qx < 0, qx, x_neg)q_pos = x_pos / scale_posq_neg = x_neg / scale_negq_pos, q_neg = self.round_pass(q_pos, q_neg, dev)qx = q_pos * scale_pos + q_neg * scale_negqx = qx.reshape(org_w_shape)return qxdef round_n2f3(self, q_pos, q_neg, dev):q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)return q_pos, q_negdef round_pass(self, q_pos, q_neg, dev):y_grad_pos, y_grad_neg = q_pos, q_negy_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/59700.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙next打包流程

目录 下载团结引擎 添加开源鸿蒙打包支持 打包报错 路径问题 安装DevEcoStudio 可以在DevEcoStudio进行打包hap和app 包结构 没法直接用previewer运行 真机运行和测试需要配置签名,DevEcoStudio可以自动配置, 模拟器安装hap提示报错 安装成功,但无法打开 团结1.3版本新增工具…

基于Jeecgboot3.6.3vue3的flowable流程online表单的审批使用介绍

更多技术支持与服务请加入我的知识星球或加我微信&#xff0c;名称:亿事达nbcio技术交流社区https://t.zsxq.com/iPi8F 今天介绍一下基于jeecgboot3.6.3的flowable流程使用online表单进行审批的情况 1、首先建立一个online应用类型的流程&#xff0c;如下&#xff1a; 2、进行…

【LeetCode】【算法】238. 除自身以外数组的乘积

LeetCode 238. 除自身以外数组的乘积 题目描述 给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据保证数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位整数范围内。 请不…

如何构建一个可扩展的测试自动化框架?

以下为作者观点&#xff1a; 假设你是测试自动化方面的新手&#xff0c;想参与构建一个框架。在这种情况下&#xff0c;重要的是要了解框架所需的组件&#xff0c;以及它们是如何组合的。思考项目的具体需求和目标&#xff0c;以及可能遇到的困难和挑战。 假如你是一个测试架…

实战:索引的命中机制

在 SQL Server 中,查询是否能命中索引(即是否能使用 Index Seek)取决于多个因素,包括索引的结构、查询条件的排列、和数据库优化器的策略。以下是一些常见的命中索引和不能命中索引的情况,及其详细解释: 一、命中索引的情况 1. 前导列匹配(典型的命中索引场景) 索引结…

使用Docker快速部署FastAPI Web应用

Docker是基于 Linux 内核的cgroup、namespace以及 AUFS 类的Union FS 等技术&#xff0c;对进程进行封装隔离&#xff0c;一种操作系统层面的虚拟化技术。Docker中每个容器都基于镜像Image运行&#xff0c;镜像是容器的只读模板&#xff0c;容器是模板的一个实例。镜像是分层结…

C++【string类,模拟实现string类】

&#x1f31f;个人主页&#xff1a;落叶 &#x1f31f;当前专栏: C专栏 目录 为什么学习string类 C语言中的字符串 标准库中的string类 auto和范围for auto关键字 迭代器 范围for string类的常用接口说明和使用 1. string类对象的常见构造 2.string类对象的容量操作 3…

A019基于SpringBoot的校园闲置物品交易系统

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…

【赵渝强老师】Redis的RDB数据持久化

Redis 是内存数据库&#xff0c;如果不将内存中的数据库状态保存到磁盘&#xff0c;那么一旦服务器进程退出会造成服务器中的数据库状态也会消失。所以 Redis 提供了数据持久化功能。Redis支持两种方式的持久化&#xff0c;一种是RDB方式&#xff1b;另一种是AOF&#xff08;ap…

Excel:vba实现批量插入图片批注

实现的效果&#xff1a;实现的代码如下&#xff1a; Sub InsertImageNamesAndPictures()Dim PicPath As StringDim PicName As StringDim PicFullPath As StringDim RowNum As IntegerDim Name As StringDim Comment As CommentDim folder As FileDialog 定义文件选择对话框 清…

tomcat启动失败和缓存清理办法

tomcat只在学校接触过并且是在window xp和win7的电脑上配置过&#xff08;中途升级过电脑系统&#xff09;&#xff0c;只记得在windows系统上可以将其设置成服务管理。但我已毕业10多年了&#xff0c;学的知识早就不知道丢哪里了。这次为了修改一个07&#xff0c;08年的项目&a…

ReactPress:深入解析技术方案设计与源码

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议&#xff0c;欢迎一起共建&#xff0c;感谢Star。 ReactPress是一个基于React框架开发的开源发布平台&#xff0c;它不仅仅是一个简单的博客系统&#xff0c;更是一个功能全…

A20红色革命文物征集管理系统

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…

先锋精科委身芯片“圈子” 引致交易不公允和信披不透明

不要违背圈子的规则&#xff0c;但也不要盲从圈子的规则。 ——语出马云。 引 言 “圈子”是钥匙&#xff0c;也是一把锁。 走进“圈子”&#xff0c;将获得包括资金、订单、货源、技术等企业发展所需的资源&#xff0c;能够助推一家企业乃至整个行业的跨越式发展&#…

MinerU容器构建教程

一、介绍 MinerU作为一款智能数据提取工具&#xff0c;其核心功能之一是处理PDF文档和网页内容&#xff0c;将其中的文本、图像、表格、公式等信息提取出来&#xff0c;并转换为易于阅读和编辑的格式&#xff08;如Markdown&#xff09;。在这个过程中&#xff0c;MinerU需要利…

【论文复现】基于深度学习的手势识别算法

本文所涉及所有资源均在这里可获取。 &#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff0c;喜爱音乐、摄影的一位博主。 &#x1f4d7;本文收录于论文复现系列&#xff0c;大家有兴趣的可以看一看…

使用QtWebEngine的Mac应用如何发布App Store

前言 因为QtWebEngine时第三方包,苹果并不直接支持进行App Store上签名和发布,所以构建和发布一个基于使用QtWebEngine的应用程序并不容易,这里我们对Qt 5.8稍微做一些修改,以便让我们的基于QtWeb引擎的应用程序并让签名能够得到苹果的许可。 QtWebEngine提供了C++和Qml的…

智能新纪元:人工智能技术的社会影响与伦理挑战-亿发

在数字化时代&#xff0c;人工智能&#xff08;AI&#xff09;正以其不可阻挡之势&#xff0c;深刻改变着我们的生产、生活和学习方式。它不仅是一项技术革命&#xff0c;更是推动社会进步的重要力量。本文将探讨人工智能如何重塑未来&#xff0c;以及它所带来的深远影响。 AI…

云平台虚拟机运维笔记整理,使用libvirt创建和管理虚拟机,以及开启虚拟机嵌套,虚拟磁盘扩容,物理磁盘扩容等等

云平台虚拟机运维笔记整理,使用libvirt创建和管理虚拟机,以及开启虚拟机嵌套,虚拟磁盘扩容,物理磁盘扩容等等。 掌握和使用qemu和libvirt,分别使用它们创建一个cirros虚拟机,并配置好网络。 宿主机node0的系统为ubuntu16,IP为192.168.56.200。 qemu和libvirt简介 QEMU…

mac crontab 不能使用问题简记

需要 crontab 有权限&#xff0c;如下截图设置 在访达上方【前往】-》【前往文件夹】输入/ 然后按 Command Shift . 显示隐藏文件&#xff0c;然后将 usr 放到左边栏 然后如下操作 系统设置中找到 隐私安全->完全访问磁盘 点击小锁头 点击号&#xff0c;将/usr/bin/c…