Qwen量化脚本run_gptq.py解析

Qwen量化脚本run_gptq.py解析

代码路径 https://github.com/QwenLM/Qwen/
run_gptq.py路径 https://github.com/QwenLM/Qwen/blob/main/run_gptq.py

代码解析:

import argparse
import json
from typing import Dict
import loggingimport torch
import transformers
from transformers import AutoTokenizer
from transformers.trainer_pt_utils import LabelSmoother
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
IGNORE_TOKEN_ID = LabelSmoother.ignore_index#其中json文件格式如下
# [
#   {
#     "id": "identity_0",
#     "conversations": [
#       {
#         "from": "user",
#         "value": "xxxx"
#       },
#       {
#         "from": "assistant",
#         "value": "xxx"
#       }
#     ]
#   },
#   {
#     "id": "identity_1",
#     "conversations": [
#       {
#         "from": "user",
#         "value": "xxx"
#       },
#       {
#         "from": "assistant",
#         "value": "xxx"
#       }
#     ]
#   },
# ]def preprocess(sources,tokenizer: transformers.PreTrainedTokenizer,max_len: int,system_message: str = "You are a helpful assistant."
) -> Dict:"""preprocess函数接收一个包含对话数据的json列表作为输入,\n通过调用transformers库中的tokenizer对数据进行编码,\n并按照特定格式构建输入ID序列和目标ID序列.\n返回一个包含预处理数据的列表,这些数据已转换为PyTorch张量,适合于后续模型训练或推断"""#roles字典:为对话中的角色("user"和"assistant")分配特殊的前缀标签,用于区分对话双方roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}#im_start和im_end:指定tokenizer中im_start_id和im_end_id对应的整数ID。im_start = tokenizer.im_start_idim_end = tokenizer.im_end_id#nl_tokens:存储tokenizer处理换行符\n得到的输入ID序列。nl_tokens = tokenizer('\n').input_ids#_system、_user和_assistant:分别存储经过tokenizer处理后的"system"、"user"和"assistant"标签及其后的换行符对应的输入ID序列。_system = tokenizer('system').input_ids + nl_tokens_user = tokenizer('user').input_ids + nl_tokens_assistant = tokenizer('assistant').input_ids + nl_tokens# Apply prompt templates 定义空列表data,用于存放预处理后的数据样本data = []# input_ids, targets = [], []#遍历输入数据sources中的每个样本(source)for i, source in enumerate(sources):source = source["conversations"]#检查首个对话是否由用户发起(即source[0]["from"]是否为"user"),如果不是,则从源数据中移除首个对话。#过滤无效的identityif roles[source[0]["from"]] != roles["user"]:source = source[1:]#初始化空列表input_id和target,分别用于存储当前样本的输入ID序列和目标ID序列input_id, target = [], []#添加系统消息:将系统消息(包含system_message内容)转换为ID序列,添加到input_id和target中。system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokensinput_id += system#target中的非关键部分(如系统标签和消息内容)用IGNORE_TOKEN_ID填充。target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokensassert len(input_id) == len(target)#遍历源数据中的每个对话(sentence)for j, sentence in enumerate(source):# 提取角色和消息内容,并转换为ID序列role = roles[sentence["from"]]_input_id = tokenizer(role).input_ids + nl_tokens + \tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens#添加到input_id中input_id += _input_id#根据角色类型,生成对应_target的目标ID序列,_target只提取assistant的对话内容,忽略user的对话内容。if role == '<|im_start|>user':#若角色为"user",则目标ID序列仅包含开始标签和结束标签,用忽略ID填充对话内容。_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens#若角色为"assistant",则目标ID序列包含开始标签、忽略ID填充(仅对角色标签)、对话内容(不包括角色标签和结束标签)、结束标签elif role == '<|im_start|>assistant':_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokenselse:raise NotImplementedErrortarget += _targetassert len(input_id) == len(target)#截取并转换为张量:#截取input_id和target至最大长度max_leninput_id = torch.tensor(input_id[:max_len], dtype=torch.int)target = torch.tensor(target[:max_len], dtype=torch.int)#创建一个字典,包含键input_ids(存储输入张量)和attention_mask(等于输入张量,用于指示非填充位置)。将该字典添加到data列表中data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))return dataif __name__ == "__main__":parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")parser.add_argument("--model_name_or_path", type=str, help="model path")parser.add_argument("--data_path", type=str, help="calibration data path")parser.add_argument("--out_path", type=str, help="output path of the quantized model")parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")args = parser.parse_args()quantize_config = BaseQuantizeConfig(bits=args.bits,group_size=args.group_size,damp_percent=0.01,desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly badstatic_groups=False,sym=True,true_sequential=True,model_name_or_path=None,model_file_base_name="model")#使用AutoTokenizer类从给定路径args.model_name_or_path加载预训练的tokenizertokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)tokenizer.pad_token_id = tokenizer.eod_id#加载json数据文件,调用process函数预处理数据,返回处理后的数据data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)#加载预训练的模型model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")#对模型进行量化,不在GPU上缓存示例数据model.quantize(data, cache_examples_on_gpu=False)#保存量化后的模型model.save_quantized(args.out_path, use_safetensors=True)#将tokenizer保存到输出路径tokenizer.save_pretrained(args.out_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1274.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

音视频封装格式解析(1)——H264格式简析,I/P/B帧是什么?H264压缩原理

文章目录 1. H264编码参数2. H264编码原理2.1 压缩原理2.2 编码结构解析 3. NALU结构4. H264 annexb模式5. 补充说明5.1 I帧5.2 P帧5.3 B帧 1. H264编码参数 视频质量和⽹络带宽占⽤是相⽭盾的。通常情况下&#xff0c;视频流占⽤的带宽越⾼则视频质量也越⾼&#xff0c;需要的…

用友NC Cloud saveImageServlet/doPost接口存在任意文件上传漏洞

声明&#xff1a; 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 简介 用友NC Cloud 是基于云计算技术的企业管理软件。它提…

git push 错误: RPC failed; HTTP 413 curl 22 The requested URL returned error: 413

问题 在git push时&#xff0c;发生了如下错误&#xff1a; git 提交代码报错 &#xff1a;error: RPC failed; HTTP 413 curl 22 The requested URL returned error: 413 原因分析 一般两种&#xff1a; 本地git缓存设置太小。 这个的解决方法是&#xff1a;设置缓存大小 g…

lvm动态扩容

1、场景 因项目需要&#xff0c;前期磁盘评估不准确&#xff0c;系统在用了一段时间后发现磁盘空间不够用了&#xff0c;通过查看磁盘做的有lvm&#xff0c;因此联系了云平台的进行了扩容&#xff0c;扩容后如第二部分所看到的vdb是1000G&#xff0c;但是通过查看磁盘/home/da…

网络安全产品---数据库防火墙/审计

数据库防火墙 防火墙的类型繁多&#xff0c;即使下一代防火墙或者说AI防火墙集成功能再多&#xff0c;我觉得waf与数据库防火墙也有其无法被替代的理由&#xff0c;以此记录我对数据库防火墙的理解 what 数据库防火墙是基于数据库协议分析与访问行为控制的数据库安全防护产品…

IOS 32位调试环境搭建

一、背景 调试IOS程序经常使用gdb&#xff0c;目前gdb只支持32位程序调试&#xff0c;暂不支持IOS 64位程序调试。IOS 32位程序使用GDB调试之前&#xff0c;必须确保手机已越狱&#xff0c;否则无法安装和使用GDB调试软件。下面详细介绍GDB调试IOS 32位程序的环境搭建。 二、I…

MapReduce排序机制(Hadoop)

在MapReduce中&#xff0c;排序的目的是为了方便Reduce阶段的处理&#xff0c;通常是为了将相同键的键值对聚合在一起&#xff0c;以便进行聚合操作或其他处理。 1. Map阶段的局部排序&#xff08;Local Sorting&#xff09;&#xff1a; 对于MapTask&#xff0c;它会将处理的…

python爬虫-----深入了解 requests 库下篇(第二十五天)

&#x1f388;&#x1f388;作者主页&#xff1a; 喔的嘛呀&#x1f388;&#x1f388; &#x1f388;&#x1f388;所属专栏&#xff1a;python爬虫学习&#x1f388;&#x1f388; ✨✨谢谢大家捧场&#xff0c;祝屏幕前的小伙伴们每天都有好运相伴左右&#xff0c;一定要天天…

一文学会 ts 构建工具 —— tsup

文章目录 能打包什么&#xff1f;安装用法自定义配置文件条件配置在 package.json 中配置多入口打包生成类型声明文件sourcemap生成格式自定义输出文件代码分割产物目标环境支持 es5编译的环境变量对开发命令行工具友好监听模式 watch提供成功构建的钩子 onSuccess压缩产物 min…

RocketMQ学习笔记

kafka适合于日志收集的场景&#xff08;不需要太多topic&#xff1b;topic下面的partition多了会造成写文件的速度变慢&#xff0c;因为要造很多索引&#xff09; RocketMQ更适合于电商场景&#xff08;适用于topic特别多的情况&#xff09; 快速安装RocketMQ RocketMQ的官网…

C++进阶:搜索树

目录 1. 二叉搜索树1.1 二叉搜索树的结构1.2 二叉搜索树的接口及其优点与不足1.3 二叉搜索树自实现1.3.1 二叉树结点结构1.3.2 查找1.3.3 插入1.3.4 删除1.3.5 中序遍历 2. 二叉树进阶相关练习2.1 根据二叉树创建字符串2.2 二叉树的层序遍历I2.3 二叉树层序遍历II2.4 二叉树最近…

一、MinIO基本知识

MinIO基本知识 一、简介1.许可 二、部署1.Docker部署1.1 部署容器 1.2 MinIO页面访问1.3 创建Bucket![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/6c8aa92975f146b691f1f36ce1033e7c.png) 三、Python-API1.安装包2.Bucket、Object概念3.Bucket-API4.MinIOClient-…

【算法刷题day29】Leetcode:491. 非递减子序列、46. 全排列、47. 全排列 II

文章目录 Leetcode 491. 非递减子序列解题思路代码总结 Leetcode 46. 全排列解题思路代码总结 Leetcode 47. 全排列 II解题思路代码总结 草稿图网站 java的Deque Leetcode 491. 非递减子序列 题目&#xff1a;491. 非递减子序列 解析&#xff1a;代码随想录解析 解题思路 大题…

C++修炼之路之多态--多态的条件与例外,重载+重写+重定义

目录 前言 一&#xff1a;构成多态的条件及一些特殊情况&#xff08;前提是构成父子类&#xff09; 1.多态是在不同的继承关系的类对象&#xff0c;去调用同一函数&#xff0c;产生了不同的结果 2.两个条件 3.三同的两个例外 1.协变---返回值类型可以不同&#xff0c;但必…

c++ qt6.5 打包sqlite组件无法使用,尽然 也需要dll支持!这和开发php 有什么区别!

运行 程序会默认使用当前所在文件夹中的 dll 文件&#xff0c;若文件不存在&#xff0c;会使用系统环境变量路径中的文件&#xff1b;又或者是需要在程序源代码中明确指定使用的 dll 的路径。由于我安装 Qt 时将相关 dll 文件路径都添加到了系统环境变量中&#xff0c;所以即使…

【R语言】混合图:小提琴图+箱线图

{ggstatsplot} 是 {ggplot2} 包的扩展&#xff0c;用于创建图形&#xff0c;其中包含信息丰富的绘图本身中包含的统计测试的详细信息。在典型的探索性数据分析工作流程中&#xff0c;数据可视化和统计建模是两个不同的阶段&#xff1a;可视化通知建模&#xff0c;而建模又可以建…

Vue 3 Hooks:优雅管理组件状态的完整指南

一、介绍 Hooks是Vue 3中的特性&#xff0c;允许在函数组件中使用状态和其他React的逻辑。本教程将演示如何使用TypeScript和Hooks管理Vue 3组件的状态和生命周期。 二、创建Hooks 首先&#xff0c;创建一个hooks.ts文件&#xff0c;包含自定义hooks。 import { ref, onMou…

MapReduce工作流程(Hadoop3.x)

MapReduce 是一种用于并行处理大规模数据集的——编程模型和处理框架。它通常用于分布式计算环境中&#xff0c;如Apache Hadoop。 工作流程 1. 切分阶段&#xff08;Splitting&#xff09;&#xff1a; 数据集被分成多个数据块&#xff0c;每个数据块的大小通常在64MB到12…

Spring+Thymeleaf自定义Formatter

Thymeleaf 关于自定义转换的文档 Configuring a Conversion Service 中&#xff0c;是通过WebMvcConfigurerAdapter 来配置的&#xff0c;但是目前最新版的Spring Boot(V3.2.5)&#xff0c;已经没有这个类了&#xff0c;得用 WebMvcConfigurationSupport 配置&#xff0c;比如实…

kaggle 泰坦尼克使用xgboost 得分0.73684

流程 导入所要使用的包引入kaggle的数据集csv文件查看数据集有无空值填充这些空值提取特征分离训练集和测试集调用模型 导入需要的包 import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import warnings warnings.filterwarni…