使用PEFT库进行ChatGLM3-6B模型的LORA高效微调

PEFT库进行ChatGLM3-6B模型LORA高效微调

  • LORA微调ChatGLM3-6B模型
    • 安装相关库
    • 使用ChatGLM3-6B
    • 模型GPU显存占用
    • 准备数据集
    • 加载模型
    • 加载数据集
    • 数据处理
    • 数据集处理
    • 配置LoRA
    • 配置训练超参数
    • 开始训练
    • 保存LoRA模型
    • 模型推理
    • 从新加载
    • 合并模型
    • 使用微调后的模型

LORA微调ChatGLM3-6B模型

本文基于transformers、peft等框架,对ChatGLM3-6B模型进行Lora微调。

LORA(Low-Rank Adaptation)是一种高效的模型微调技术,它可以通过在预训练模型上添加额外的低秩权重矩阵来微调模型,从而仅需更新很少的参数即可获得良好的微调性能。这相比于全量微调大幅减少了训练时间和计算资源的消耗。

安装相关库

pip install ransformers==4.37.2 peft==0.8.0 accelerate==0.27.0 bitsandbytes

使用ChatGLM3-6B

直接调用ChatGLM3-6B模型来生成对话

from transformers import AutoTokenizer, AutoModelmodel_id = "/root/work/chatglm3-6b"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
#model = AutoModel.from_pretrained(model_id, trust_remote_code=True).half().cuda()
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, device='cuda')model = model.eval()
response, history = model.chat(tokenizer, "你好", history=history)
print(response)

在这里插入图片描述

模型GPU显存占用

默认情况下,模型以半精度(float16)加载,模型权重需要大概 13GB显存。

获取当前模型占用的GPU显存

memory_bytes = model.get_memory_footprint()
# 转换为GB
memory_gb = memory_footprint_bytes / (1024 ** 3)  
print(f"{memory_gb :.2f}GB")

注意:与实际进程占用有差异,差值为预留给PyTorch的显存

准备数据集

准备数据集其实就是指令集构建,LLM的微调一般指指令微调过程。所谓指令微调,就是使用的微调数据格式、形式。

训练目标是让模型具有理解并遵循用户指令的能力。因此在指令集构建时,应该针对目标任务,针对性的构建任务指令集。

这里使用alpaca格式的数据集,格式形式如下:

[{"instruction": "用户指令(必填)","input": "用户输入(选填)","output": "模型回答(必填)",},"system": "系统提示词(选填)","history": [["第一轮指令(选填)", "第一轮回答(选填)"],["第二轮指令(选填)", "第二轮回答(选填)"]]
]
instruction:用户指令,要求AI执行的任务或问题input:用户输入,是完成用户指令所必须的输入内容,就是执行指令所需的具体信息或上下文output:模型回答,根据给定的指令和输入生成答案

这里根据企业私有文档数据,生成相关格式的训练数据集,大概格式如下:

[{"instruction": "内退条件是什么?","input": "","output": "内退条件包括与公司签订正式劳动合同并连续工作满20年及以上,以及距离法定退休年龄不足5年。特殊工种符合国家相关规定可提前退休的也可在退休前5年内提出内退申请。"},
]

加载模型

from transformers import AutoModel, AutoTokenizermodel_id = "/root/work/chatglm3-6b"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

加载数据集

from datasets import load_datasetdata_id="/root/work/jupyterlab/zd.json"
dataset = load_dataset("json", data_files=data_id)
print(dataset["train"])

在这里插入图片描述

数据处理

Lora训练数据是需要经过tokenize编码处理,然后后再输入模型进行训练。一般需要将输入文本编码为input_ids,将输出文本编码为labels,编码之后的结果都是多维的向量。

需要定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典。

# tokenize_func 函数
def tokenize_func(example, tokenizer, ignore_label_id=-100):"""对单个数据样本进行tokenize处理。参数:example (dict): 包含'content'和'summary'键的字典,代表训练数据的一个样本。tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。ignore_label_id (int, optional): 在label中用于填充的忽略ID,默认为-100。返回:dict: 包含'tokenized_input_ids'和'labels'的字典,用于模型训练。"""prompt_text = ''                          # 所有数据前的指令文本max_input_length = 512                    # 输入的最大长度max_output_length = 1536                  # 输出的最大长度# 构建问题文本question = prompt_text + example['instruction']if example.get('input', None) and example['input'].strip():question += f'\n{example["input"]}'# 构建答案文本answer = example['output']# 对问题和答案文本进行tokenize处理q_ids = tokenizer.encode(text=question, add_special_tokens=False)a_ids = tokenizer.encode(text=answer, add_special_tokens=False)# 如果tokenize后的长度超过最大长度限制,则进行截断if len(q_ids) > max_input_length - 2:  # 保留空间给gmask和bos标记q_ids = q_ids[:max_input_length - 2]if len(a_ids) > max_output_length - 1:  # 保留空间给eos标记a_ids = a_ids[:max_output_length - 1]# 构建模型的输入格式input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)question_length = len(q_ids) + 2  # 加上gmask和bos标记# 构建标签,对于问题部分的输入使用ignore_label_id进行填充labels = [ignore_label_id] * question_length + input_ids[question_length:]return {'input_ids': input_ids, 'labels': labels}

进行数据映射处理,同时删除特定列

# 获取 'train' 部分的列名
column_names = dataset['train'].column_names  # 使用lambda函数调用tokenize_func函数,并传入example和tokenizer作为参数
tokenized_dataset = dataset['train'].map(lambda example: tokenize_func(example, tokenizer),batched=False,  # 不按批次处理remove_columns=column_names  # 移除特定列(column_names中指定的列)
)

执行print(tokenized_dataset[0]),打印tokenize处理结果
在这里插入图片描述

数据集处理

还需要使用一个数据收集器,可以使用transformers 中的DataCollatorForSeq2Seq数据收集器

from transformers import DataCollatorForSeq2Seqdata_collator = DataCollatorForSeq2Seq(tokenizer,model=model,label_pad_token_id=-100,pad_to_multiple_of=None,padding=True
)

或者自定义实现一个数据收集器

import torch
from typing import List, Dict, Optional# DataCollatorForChatGLM 类
class DataCollatorForChatGLM:"""用于处理批量数据的DataCollator,尤其是在使用 ChatGLM 模型时。该类负责将多个数据样本(tokenized input)合并为一个批量,并在必要时进行填充(padding)。属性:pad_token_id (int): 用于填充(padding)的token ID。max_length (int): 单个批量数据的最大长度限制。ignore_label_id (int): 在标签中用于填充的ID。"""def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):"""初始化DataCollator。参数:pad_token_id (int): 用于填充(padding)的token ID。max_length (int): 单个批量数据的最大长度限制。ignore_label_id (int): 在标签中用于填充的ID,默认为-100。"""self.pad_token_id = pad_token_idself.ignore_label_id = ignore_label_idself.max_length = max_lengthdef __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:"""处理批量数据。参数:batch_data (List[Dict[str, List]]): 包含多个样本的字典列表。返回:Dict[str, torch.Tensor]: 包含处理后的批量数据的字典。"""# 计算批量中每个样本的长度len_list = [len(d['input_ids']) for d in batch_data]batch_max_len = max(len_list)  # 找到最长的样本长度input_ids, labels = [], []for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):pad_len = batch_max_len - len_of_d  # 计算需要填充的长度# 添加填充,并确保数据长度不超过最大长度限制ids = d['input_ids'] + [self.pad_token_id] * pad_lenlabel = d['labels'] + [self.ignore_label_id] * pad_lenif batch_max_len > self.max_length:ids = ids[:self.max_length]label = label[:self.max_length]input_ids.append(torch.LongTensor(ids))labels.append(torch.LongTensor(label))# 将处理后的数据堆叠成一个tensorinput_ids = torch.stack(input_ids)labels = torch.stack(labels)return {'input_ids': input_ids, 'labels': labels}
data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)

配置LoRA

在peft中使用LoRA非常简单。借助PeftModel抽象,可以快速将低秩适配器(LoRA)应用到任意模型中。

在初始化相应的微调配置类(LoraConfig)时,需要显式指定在哪些层新增适配器(Adapter),并将其设置正确。

ChatGLM3-6B模型通过以下方式获取需要训练的模型层的名字

from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPINGtarget_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']

在PEFT库的 constants.py 文件中定义了不同的 PEFT 方法,在各类大模型上的微调适配模块。

在这里插入图片描述
主要是配置LoraConfig类,其中可以设置很多参数,但主要参数只有几个

# 从peft库导入LoraConfig和get_peft_model函数
from peft import LoraConfig, get_peft_model, TaskType# 创建一个LoraConfig对象,用于设置LoRA(Low-Rank Adaptation)的配置参数
config = LoraConfig(r=8,  # LoRA的秩,影响LoRA矩阵的大小lora_alpha=32,  # LoRA适应的比例因子# 指定需要训练的模型层的名字,不同模型对应层的名字不同# target_modules=["query_key_value"],target_modules=target_modules,lora_dropout=0.05,  # 在LoRA模块中使用的dropout率bias="none",  # 设置bias的使用方式,这里没有使用bias# task_type="CAUSAL_LM"  # 任务类型,这里设置为因果(自回归)语言模型task_type=TaskType.CAUSAL_LM
)# 使用get_peft_model函数和给定的配置来获取一个PEFT模型
model = get_peft_model(model, config)# 打印出模型中可训练的参数
model.print_trainable_parameters()

在这里插入图片描述

配置训练超参数

配置训练超参数使用TrainingArguments类,可配置参数同样有很多,但主要参数也是只有几个

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir=f"{model_id}-lora",  # 指定模型输出和保存的目录per_device_train_batch_size=4,  # 每个设备上的训练批量大小learning_rate=2e-4,  # 学习率fp16=True,  # 启用混合精度训练,可以提高训练速度,同时减少内存使用logging_steps=20,  # 指定日志记录的步长,用于跟踪训练进度save_strategy="steps",   # 模型保存策略save_steps=50,   # 模型保存步数# max_steps=50, # 最大训练步长num_train_epochs=1  # 训练的总轮数)

查看添加LoRA模块后的模型

print(model)

开始训练

配置model、参数、数据集后就可以进行训练了

trainer = Trainer(model=model,  # 指定训练时使用的模型train_dataset=tokenized_dataset,  # 指定训练数据集args=training_args,data_collator=data_collator,
)model.use_cache = False
# trainer.train() 
with torch.autocast("cuda"): trainer.train()

在这里插入图片描述

注意:

执行trainer.train() 时出现异常,参考:bitsandbytes的issues

保存LoRA模型

lora_model_path = "lora/chatglm3-6b-int8"
trainer.model.save_pretrained(lora_model_path )
#model.save_pretrained(lora_model_path )

在这里插入图片描述

模型推理

使用LoRA模型,进行模型推理

lora_model = trainer.model

1.文本补全

text = "人力资源部根据各部门人员"inputs = tokenizer(text, return_tensors="pt").to(0)out = lora_model.generate(**inputs, max_new_tokens=500)
print(tokenizer.decode(out[0], skip_special_tokens=True))

在这里插入图片描述
2.问答对话

from peft import PeftModelinput_text = '公司的招聘需求是如何提出的?'
model.eval()
response, history = lora_model.chat(tokenizer=tokenizer, query=input_text)
print(f'ChatGLM3-6B 微调后回答: \n{response}')

在这里插入图片描述

从新加载

加载源model与tokenizer,使用PeftModel合并源model与PEFT微调后的参数,然后进行推理测试。

from peft import PeftModel
from transformers import AutoModel, AutoTokenizermodel_path="/root/work/chatglm3-6b"
peft_model_checkpoint_path="./chatglm3-6b-lora/checkpoint-50"model = AutoModel.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)# 将训练所得的LoRa权重加载起来
p_model = PeftModel.from_pretrained(model, model_id=peft_model_checkpoint_path) p_model = p_model.cuda()
response, history = p_model.chat(tokenizer, "内退条件是什么?", history=[])
print(response)

合并模型

将lora权重合并到大模型中,将模型参数加载为16位浮点数

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch model_path="/root/work/chatglm3-6b"
peft_model_path="./lora/chatglm3-6b-int8"
save_path = "chatglm3-6b-lora"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto")
model = PeftModel.from_pretrained(model, peft_model_path)
model = model.merge_and_unload()tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)

查看合并文件
在这里插入图片描述

使用微调后的模型

from transformers import AutoTokenizer, AutoModeltokenizer = AutoTokenizer.from_pretrained("chatglm3-6b-lora", trust_remote_code=True)
model = AutoModel.from_pretrained("chatglm3-6b-lora", trust_remote_code=True, device='cuda')model = model.eval()
response, history = model.chat(tokenizer, "内退条件是什么?", history=[])
print(response)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/35910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

6 序列数据和文本的深度学习

6.1 使用文本数据 文本是常用的序列化数据类型之一。文本数据可以看作是一个字符序列或词的序列。对大多数问题,我们都将文本看作词序列。深度学习序列模型(如RNN及其变体)能够从文本数据中学习重要的模式。这些模式可以解决类似以下领域中的问题: 自然…

JVM专题十一:JVM 中的收集器一

上一篇JVM专题十:JVM中的垃圾回收机制专题中,我们主要介绍了Java的垃圾机制,包括垃圾回收基本概念,重点介绍了垃圾回收机制中自动内存管理与垃圾收集算法。如果说收集算法是内存回收的方法论,那么垃圾收集器就是内存回…

【开发者推荐】告别繁琐:一键解锁国产ETL新贵,Kettle的终结者

在数字化转型的今天,数据集成的重要性不言而喻。ETL工具作为数据管理的核心,对企业决策和运营至关重要。尽管Kettle广受欢迎,但国产ETL工具 TASKCTL 以其创新特性和卓越性能,为市场提供了新的选择。 TASKCTL概述 TASKCTL 是一款免…

wget之Win11中安装及使用

wget之Win11中安装及使用 文章目录 wget之Win11中安装及使用1. 下载2. 安装3. 配置环境变量4. 查看及使用1. 查看版本2. 帮助命令3. 基本使用 1. 下载 下载地址:https://eternallybored.org/misc/wget 选择对应的版本进行下载即可 2. 安装 将下载后的wget-1.21.4-w…

算法导论 总结索引 | 第四部分 第十六章:贪心算法

1、求解最优化问题的算法 通常需要经过一系列的步骤,在每个步骤都面临多种选择。对于许多最优化问题,使用动态规划算法求最优解有些杀鸡用牛刀了,可以使用更简单、更高效的算法 贪心算法(greedy algorithm)就是这样的算…

Git 学习笔记(超详细注释,从0到1)

Git学习笔记 1.1 关键词 Fork、pull requests、pull、fetch、push、diff、merge、commit、add、checkout 1.2 原理(看图学习) 1.3 Fork别人仓库到自己仓库中 记住2个地址 1)上游地址(upstream地址):http…

【Qt】Qt多线程编程指南:提升应用性能与用户体验

文章目录 前言1. Qt 多线程概述2. QThread 常用 API3. 使用线程4. 多线的使用场景5. 线程安全问题5.1. 加锁5.2. QReadWriteLocker、QReadLocker、QWriteLocker 6. 条件变量 与 信号量6.1. 条件变量6.2 信号量 总结 前言 在现代软件开发中,多线程编程已成为一个不可…

C语言类型转换理解不同的基本类型为什么能够进行运算

类型转换 1.类型转换1.1隐式转换1.2常用算术转换1.2强制类型转换 1.类型转换 在执行算数运算时,计算机比C语言的限制更多。为了让计算机执行算术运算,通常要求操作数用相同的大小(即为的数量相同),但是C语言却允许混合…

Java基础:常用类(四)

Java基础:常用类(四) 文章目录 Java基础:常用类(四)1. String字符串类1.1 简介1.2 创建方式1.3 构造方法1.4 连接操作符1.5 常用方法 2. StringBuffer和StringBuilder类2.1 StringBuffer类2.1.1 简介2.1.2 …

智能电能表如何助力智慧农业

智能电能表作为智能电网数据采集的基本设备之一,不仅具备传统电能表基本用电量的计量功能,还具备双向多种费率计量功能、用户端控制功能、多种数据传输模式的双向数据通信功能以及防窃电功能等智能化的功能。这些功能使得智能电能表在农业领域的应用具有…

【渗透测试】小程序反编译

前言 在渗透测试时,除了常规的Web渗透,小程序也是我们需要重点关注的地方,微信小程序反编译后,可以借助微信小程序开发者工具进行调试,搜索敏感关键字,或许能够发现泄露的AccessKey等敏感信息及数据 工具…

【SkiaSharp绘图11】SKCanvas属性详解

文章目录 SKCanvas构造SKCanvas构造光栅 Surface构造GPU Surface构造PDF文档构造XPS文档构造SVG文档SKNoDrawCanvas 变换剪裁和状态构造函数相关属性DeviceClipBounds获取裁切边界(设备坐标系)ClipRect修改裁切区域IsClipEmpty当前裁切区域是否为空IsClipRect裁切区域是否为矩形…

JFreeChart 生成Word图表

文章目录 1 思路1.1 概述1.2 支持的图表类型1.3 特性 2 准备模板3 导入依赖4 图表生成工具类 ChartWithChineseExample步骤 1: 准备字体文件步骤 2: 注册字体到FontFactory步骤 3: 设置图表具体位置的字体柱状图:饼图:折线图:完整代码&#x…

国产车规MCU OTA方案总结

目录 1. 旗芯微FC4150 OTA 2. 云途YTM32B1MD OTA 3.小结 今天没有废话,啪一下很快,把目前接触到的国内带eFlash的车规MCU硬件OTA方案做一个梳理。 1. 旗芯微FC4150 OTA 旗芯微FC4150是基于ARM Cortex(快去审核下官网介绍,少了个T)-M4F内…

openGauss Developer Day 2024丨MogDB实现数据库技术跨越,Ustore引擎革新存储新境界

openGauss Developer Day 2024 6月21日,openGauss Developer Day 2024在北京昆泰嘉瑞文化中心成功召开。大会聚集学术专家、行业用户、合作伙伴和开发者,共同探讨数据库面向多场景的技术创新,分享基于 openGauss 的行业联合创新成果及实践案例…

探索PHP中的魔术常量

PHP中的魔术常量(Magic Constants)是一些特殊的预定义常量,它们在不同的上下文中具有不同的值。这些常量可以帮助开发者获取文件路径、行号、函数名等信息,从而方便调试和日志记录。本文将详细介绍PHP中的魔术常量,帮助…

web前端——javaScript

目录 一、javaScript概述 1.javaScript历史 2.JavaScript与html,css关系 二、基本语法 ①放在head中 ②放在 body中 ③写在外部的.js文件中 1.变量 2.数据类型 3.算术运算符 4.逻辑运算符 5.赋值运算 6.逻辑运算符 7.条件运算符 8.控制语句 三、函数 1…

Arduino - 按钮 - 长按短按

Arduino - Button - Long Press Short Press Arduino - 按钮 - 长按短按 Arduino - Button - Long Press Short Press We will learn: 我们将学习: How to detect the button’s short press 如何检测按钮的短按How to detect the button’s long press 如何检测…

重大进展!微信支付收款码全场景接入银联网络

据中国银联6月19日消息,近日,银联网络迎来微信支付收款码场景的全面接入,推动条码支付互联互通取得新进展,为境内外广大消费者提供更多支付选择、更好支付体验。 2024年6月,伴随微信支付经营收款码的开放,微…

Rust: duckdb和polars读csv文件比较

一、文件准备 样本内容,N行9列的csv标准格式,有字符串,有浮点数,有整型。 有两个csv文件,一个大约是2.1万行;一个是64万行。 二、toml文件 [package] name "my_duckdb" version "0.1.0&…