BaiChuan13B-GPTQ量化详解

知识要点:
1、按照网上搜索的一些代码,如使用auto_gptq原生库进行训练后量化,可能会正常量化,但是在线推理时会出现如找不到bin文件或者tf文件,即模型权重文件,所以和网上大部分代码不同的地方在于,需要提前保存对应模型的权重文件,如果是BaiChuan13B,那么在进行模型量化前,对其进行保存
代码如下:

def save_bin(pretrained_model_dir, quantized_model_dir):from transformers import AutoModelForCausalLMimport torchimport osoriginal_model = AutoModelForCausalLM.from_pretrained(pretrained_model_dir, trust_remote_code=True,torch_dtype=torch.float16,      # 不执行这个保存的bin文件会非常的大,大概50多Gsafetensors=True)print("保存bin文件...")model_path = os.path.join(quantized_model_dir, "pytorch_model"+".bin")torch.save(original_model.state_dict(), model_path)print("保存bin文件完成...")

量化代码,使用原生库auto_gptq进行量化:

def from_authority_autogptq(pretrained_model_dir, quantized_model_dir):from transformers import AutoTokenizer, AutoModelForCausalLMfrom auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfigimport loggingimport torchimport oslogging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")# 量化分词器加载tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=False, trust_remote_code=True)examples = [tokenizer("auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.")]# 量化参数配置quantize_config = BaseQuantizeConfig(bits=4,             # quantize model to 4-bitgroup_size=128,     # it is recommended to set the value to 128desc_act=False,     # set to False can significantly speed up inference but the perplexity may slightly bad)# load un-quantized model, by default, the model will always be loaded into CPU memoryquantize_model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config=quantize_config, trust_remote_code=True,device_map="auto",)print("开始量化模型.......")quantize_model.quantize(examples)# save model weightsprint("保存量化文件...")quantize_model.save_quantized(quantized_model_dir)print("保存量化文件完成...")print("保存tokenizer...")tokenizer.save_pretrained(quantized_model_dir)print("保存tokenizer完成...")

按照上述步骤,此时模型量化文件保存成功,接下来就是模型在线推理

def get_baichuan2_autogptq(quantized_model_dir):from transformers import AutoModelForCausalLM, AutoTokenizerfrom transformers.generation.utils import GenerationConfigimport torch# 模型地址model_id = quantized_model_dirprint("加载分词器tokenizer...")tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=False)'''warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default).This will lead to slow inference or training speed'''print("加载量化model...")quantized_model_4bit = AutoModelForCausalLM.from_pretrained(# 要载入的模型名称model_id, load_in_4bit=True,# 仅使用本地模型,不通过网络下载模型local_files_only=True,# 指定模型精度torch_dtype=torch.float16,trust_remote_code=True,safetensors=True)print("加载config...")quantized_model_4bit.generation_config = GenerationConfig.from_pretrained(model_id)# 实例测试print("生成...")messages = []messages.append({"role": "user", "content":"亚历山大为何如此厉害"})response = quantized_model_4bit.chat(tokenizer, messages)print(response)return response 

最后整合代码:

'''bin 文件是保存的是原始的加载模型文件,不涉及量化操作的模型过程,不然会报错或者加载不出来!!!'''
def save_bin(pretrained_model_dir, quantized_model_dir):from transformers import AutoModelForCausalLMimport torchimport osoriginal_model = AutoModelForCausalLM.from_pretrained(pretrained_model_dir, trust_remote_code=True,torch_dtype=torch.float16,      # 不执行这个保存的bin文件会非常的大,大概50多Gsafetensors=True)print("保存bin文件...")model_path = os.path.join(quantized_model_dir, "pytorch_model"+".bin")torch.save(original_model.state_dict(), model_path)print("保存bin文件完成...")# auto_gptq原生库, 量化占用显存7-10G不等,用时23分钟,推理18G
def from_authority_autogptq(pretrained_model_dir, quantized_model_dir):from transformers import AutoTokenizer, AutoModelForCausalLMfrom auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfigimport loggingimport torchimport oslogging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")# 量化分词器加载tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=False, trust_remote_code=True)examples = [tokenizer("auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.")]# 量化参数配置quantize_config = BaseQuantizeConfig(bits=4,             # quantize model to 4-bitgroup_size=128,     # it is recommended to set the value to 128desc_act=False,     # set to False can significantly speed up inference but the perplexity may slightly bad)# load un-quantized model, by default, the model will always be loaded into CPU memoryquantize_model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config=quantize_config, trust_remote_code=True,device_map="auto",)print("开始量化模型.......")quantize_model.quantize(examples)# save model weightsprint("保存量化文件...")quantize_model.save_quantized(quantized_model_dir)print("保存量化文件完成...")print("保存tokenizer...")tokenizer.save_pretrained(quantized_model_dir)print("保存tokenizer完成...")# 加载量化后的模型方法
def get_baichuan2_autogptq(quantized_model_dir):from transformers import AutoModelForCausalLM, AutoTokenizerfrom transformers.generation.utils import GenerationConfigimport torch# 模型地址model_id = quantized_model_dirprint("加载分词器tokenizer...")tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,use_fast=False)'''warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default).This will lead to slow inference or training speed'''print("加载量化model...")quantized_model_4bit = AutoModelForCausalLM.from_pretrained(# 要载入的模型名称model_id, load_in_4bit=True,# 仅使用本地模型,不通过网络下载模型local_files_only=True,# 指定模型精度torch_dtype=torch.float16,trust_remote_code=True,safetensors=True)print("加载config...")quantized_model_4bit.generation_config = GenerationConfig.from_pretrained(model_id)# 实例测试print("生成...")messages = []messages.append({"role": "user", "content":"```桥架\n1、名称:机房走线架(铝合金) 2、规格:300mm*100mm 3、含支吊架制作安装 4、其它:具体详见图纸、技术规范书、图集、招标文件、招标答疑、政府相关文件、规范等其它资料,满足验收要求```\n请仔细阅读上文,并从中分析出实体列表中的各实体。请使用json字典格式回答,其中,键为各实体名称,值为从文本中提取出的内容(若没有相应实体则值为'无')。\n实体列表如下(目标实体之间通过“;”隔开): ```名称;型号;材质;类型;规格;接地方式```"})response = quantized_model_4bit.chat(tokenizer, messages)print(response)return response if __name__ == "__main__":# from_transformers_autogptq 方法量化模型# pretrained_model_dir = "/root/lk/big_model/Baichuan2-13B-Chat"# quantized_model_dir = "/root/lk/big_model/baichuan2_autogptq"# from_transformers_autogptq(pretrained_model_dir, quantized_model_dir)import datetimeprint("程序开始时间------->>>>>>", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))# 地址pretrained_model_dir = "/root/lk/big_model/Baichuan2-13B-Chat"quantized_model_dir = "/root/lk/big_model/baichuan2_autogptq"# 第一步:保存原始模型的Bin文件,然后再量化(很关键)# save_bin(pretrained_model_dir, quantized_model_dir)# 第二部:执行来自autogptq原始包量化模型# from_authority_autogptq(pretrained_model_dir, quantized_model_dir)# 第三部:使用量化模型进行推理(需要添加对应文件)get_baichuan2_autogptq(quantized_model_dir)print("程序结束时间------->>>>>>", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

对应包版本:

auto-gptq==0.6.0
transformers==4.39.2
torch==2.0.1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/2064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javascript1

[TOC](javascript初始)一.编程语言 编程 计算机语言 编程语言:汇编语言 高级语言:java c python javascript 标记语言:css html 二.计算机基础 计算机组成 数据存储 单位 运行 三. javascript 初识 历史? 布兰登. 艾奇创立 是什么?…

HarmonyOS ArkUI实战开发-手势密码(PatternLock)

ArkUI开发框架提供了图案密码锁 PatternLock 组件,它以宫格图案的方式输入密码,用于密码验证,本节读者简单介绍一下该控件的使用。 PatternLock定义介绍 interface PatternLockInterface {(controller?: PatternLockController): PatternL…

react 父组件调用子组件的属性或方法

前言 在vue3中, 子组件会使用 defineExpose 暴露出父组件需要访问的 变量 或 方法父组件通过 ref 函数定义子组件的 refName,并通过 refName.value.xxx 继续访问 react 中呢? 可使用 useImperativeHandle、forwardRef、useRef 第一步&am…

Unity2D 学习笔记 1.如何高效切换场景

Unity2D 学习笔记 1.如何高效切换场景 前言采用Scene的方式切换创建场景设置场景模板保存模板使用方法 前言 关于Unity版本,VS Studio版本以及其它相关设置,请移步上一篇Unity2D 学习笔记 0.Unity需要记住的常用知识 本节的方法源于Unity中文课程网《U…

Ubuntu下使用VisualStudioCode进行Java开发

0-1开始Java语言编程之路 一、Ubuntu下Java语言环境搭建 二、Ubuntu下Docker环境安装 三、使用Docker搭建本地Nexus Maven私有仓库 四、Ubuntu下使用VisualStudioCode进行Java开发 Visual Studio Code 下载 点击这个链接Visual Studio Code,进入VisualStudioCode的…

一文看懂电位器的接线方式

电位器是一种用于精确控制电路中电压、电流或信号幅度的电子元件,通过调整其内部电刷相对于电阻体的位置,可以连续改变其电阻值,进而实现对电路特性的微调或控制。根据电阻体材料、结构特点以及输出电压与输入电压(或电刷位移&…

Java | Leetcode Java题解之第43题字符串相乘

题目: 题解: class Solution {public String multiply(String num1, String num2) {if (num1.equals("0") || num2.equals("0")) {return "0";}int m num1.length(), n num2.length();int[] ansArr new int[m n];for…

3D Gaussian Splatting介绍

目录 一、概述二、基础介绍1. 多维高斯分布2. 将3D 高斯投影到2D像素平面3. 球谐函数4. Splatting and α \alpha α blending 三、整体流程四、 伪代码五、评价指标六、实验结果七、reference 一、概述 3D Gaussian Splatting和NeRF一样,主要用于新视图合成。 特…

三相电子式电表ADL400储能防逆流含CE/MID认证

安科瑞薛瑶瑶18701709087/17343930412 ADL400 导轨式多功能电能表,是主要针对电力系统,工矿企业,公用设施的电能统计、 管理需求而设计的一款智能仪表,产品具有精度高、体积小、安装方便等优点。集成 常见 电 力参数测量及电能…

Ozone V3.32a使用总结

目录 前言 Ozone介绍 Ozone下载使用 总结 前言 由于项目需要,现在正在使用Ozone作为软件debug的工具,不同于Keil集成了代码编辑器,编译器,调试器,Ozone则主要作为一个代码调试工具使用。最近发现Ozone还有些功能挺…

算法-合并素数

给一个数组,每次操作可以把相邻的两个素数元素进行合并, 合并后的新数为原来的两个数之和,并删除原来两个数。现在希望最终数组的元素数量尽可能少。 输入 第一行 n 代表数组元素个数 第二行 数组的各个元素 4 7 2 2 3 输出 最终的个数 1 pac…

vue项目打包时因为图片问题报错

执行 npm run build命令打包项目时报错,看起来是图片的问题: package.json里面image-webpack-loader的版本是^7.0.1 解决方案: 1、先卸载 npm uninstall image-webpack-loader 2、用cnpm重新安装 cnpm install image-webpack-loader --save…

PLSQL程序块中的无名块

文章目录 PLsql ---过程化语言程序块:无名块变量利用 select into 语句给变量赋值 打印输出手动输入变量类型引用型变量类型%TYPE%ROWTYPE 记录型变量类型 在程序块下的增删改RETURNING INTO增加数据修改数据删除数据 PLsql —过程化语言 程序块 plsql是Oracle默认…

单细胞+RIP-seq项目文章| Cell ReportshnRNPU蛋白在小鼠精原干细胞池建立的关键作用

精原干细胞(SSCs)是负责精子发生的干细胞,具有自我更新和分化产生功能性精子的能力。SSCs的持续再生对于维持雄性生育力至关重要。然而,SSC池的发育起源尚不清楚。在哺乳动物中,SSCs源自名为 prospermatogonia&#xf…

Android Studio开发工具学习之Git远程仓库拉取与推送

Git远程仓库操作 1.1 推送项目到远端服务器1.1.1 进入Gitee或Github、创建一个新的仓库1.1.2 将Android Studio中项目推送至Gitee 1.2 从远端服务器拉取项目1.2.1 AS工程页拉取新项目1.2.2 AS启动页拉取项目 1.1 推送项目到远端服务器 1.1.1 进入Gitee或Github、创建一个新的仓…

微服务两种方式登录

目录 1.restTemplate方式 1.1页面 1.2消费者 1.3生产者 1.4效果 2.Feign方式 2.1Service 2.2生产者 三个生产者 一个消费者,三个生产者需要用mysqlmybatis 三个不同的数据库。 页面输入用户名和密码,提交到后端消费者,消费者传到生产…

RabbitMQ入门实战

文章目录 RabbitMQ入门实战基本概念安装快速入门单向发送多消费者 RabbitMQ入门实战 官方:https://www.rabbitmq.com 基本概念 AMQP协议:https://www.rabbitmq.com/tutorials/amqp-concepts.html 定义:高级信息队列协议(Advanc…

让流程图动起来

我们平时画流程,然后贴到文档,就完事了。但是过程演示的时候,如果只是一张静态图,很难吸引到听众的注意力,表达效果并不太好。常用的方法是可以用PPT进行动态演示,做PPT也是需要花一些时间,同时…

7、OpenCompass 大模型评测实战

0、为什么要研究大模型的评测? 首先,研究评测对于我们全面了解大型语言模型的优势和限制至关重要。尽管许多研究表明大型语言模型在多个通用任务上已经达到或超越了人类水平,但仍然存在质疑,即这些模型的能力是否只是对训练数据的…

【数据结构(邓俊辉)学习笔记】向量01——接口与实现

文章目录 0.意图1、概述2 从数组到向量3 向量ADT接口4 Vector 模板类5 构造与析构5.1默认构造方法5.2基于复制的构造方法5.3 析构方法 0.意图 一方面是将工作学习中零星的知识点串起来,另一方面向量是其他数据类型的基础,比如栈队列等,所以基…