Qwen 微调脚本分析 Qwen/finetune.py

Qwen 微调脚本分析 Qwen/finetune.py

Qwen/finetune.py :

# 基于fastchat和tatsu-lab/stanford_alpaca的修订代码,用于训练语言模型
# 提供使用LoRA(低秩适应)和量化(QLoRA)压缩的选项,以及使用DeepSpeed的分布式训练支持# 导入各种必要的库和模块
from dataclasses import dataclass, field
import json
import math
import logging
import os
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, GPTQConfig, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType# 忽略标记ID,来自LabelSmoother的忽略索引
IGNORE_TOKEN_ID = LabelSmoother.ignore_index# 模型参数类,定义模型相关的参数
@dataclass
class ModelArguments:model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")# 数据参数类,定义数据集相关的参数
@dataclass
class DataArguments:data_path: str = field(default=None, metadata={"help": "Path to the training data."})eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})lazy_preprocess: bool = False# 训练参数类,继承自transformers的TrainingArguments,定义训练相关的参数
@dataclass
class TrainingArguments(transformers.TrainingArguments):cache_dir: Optional[str] = field(default=None)optim: str = field(default="adamw_torch")model_max_length: int = field(default=8192,metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},)use_lora: bool = False# Lora参数类,定义Lora模型压缩相关的参数
@dataclass
class LoraArguments:lora_r: int = 64lora_alpha: int = 16lora_dropout: float = 0.05lora_target_modules: List[str] = field(default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"])lora_weight_path: str = ""lora_bias: str = "none"q_lora: bool = False# 将参数从DeepSpeed的Zero优化器转换为可用于其他操作的形式
def maybe_zero_3(param):if hasattr(param, "ds_id"):assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE# 如果使用Zero-3,将参数转移到CPU并克隆with zero.GatheredParameters([param]):param = param.data.detach().cpu().clone()else:# 否则,直接将参数转移到CPU并克隆param = param.detach().cpu().clone()return param# 从模型中获取PEFT状态字典,可能涉及Zero优化器的转换
def get_peft_state_maybe_zero_3(named_params, bias):# 根据bias参数选择要返回的参数部分if bias == "none":to_return = {k: t for k, t in named_params if "lora_" in k}elif bias == "all":to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}elif bias == "lora_only":to_return = {}maybe_lora_bias = {}lora_bias_names = set()for k, t in named_params:if "lora_" in k:to_return[k] = tbias_name = k.split("lora_")[0] + "bias"lora_bias_names.add(bias_name)elif "bias" in k:maybe_lora_bias[k] = tfor k, t in maybe_lora_bias:if bias_name in lora_bias_names:to_return[bias_name] = telse:raise NotImplementedError# 对选取的参数进行必要的转换to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}return to_return# 获取本地排名,如果未设置则为None
local_rank = None# 在本地排名为0时打印信息
def rank0_print(*args):if local_rank == 0:print(*args)def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):"""根据transformers.Trainer实例的安全方式保存模型到指定目录。参数:- trainer: transformers.Trainer 实例,代表一个训练器,用于获取模型状态等。- output_dir: str,指定保存模型的输出目录。- bias: str,默认为"none",表示处理模型偏置的方式,本函数暂不支持更改偏置。说明:- 该函数会根据是否启用zero3模式、是否使用LORA,来决定如何获取模型的状态字典,并保存至磁盘。- 仅当trainer的args.should_save标志为True且args.local_rank为0时,才会真正保存模型。"""# 检查是否启用了zero3模式if deepspeed.is_deepspeed_zero3_enabled():state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()else:# 根据是否使用LORA,选择不同的方式获取状态字典if trainer.args.use_lora:state_dict = get_peft_state_maybe_zero_3(trainer.model.named_parameters(), bias)else:state_dict = trainer.model.state_dict()# 条件满足时,执行模型保存操作if trainer.args.should_save and trainer.args.local_rank == 0:trainer._save(output_dir, state_dict=state_dict)def preprocess(sources,tokenizer: transformers.PreTrainedTokenizer,max_len: int,system_message: str = "You are a helpful assistant."
) -> Dict:"""对输入源数据进行预处理,包括编码、填充等,以便于输入到模型中进行训练或推理。参数:- sources: 输入源数据列表,每个源数据是一个字典列表,表示对话中的不同角色的句子。- tokenizer: transformers.PreTrainedTokenizer,用于对文本进行编码的分词器。- max_len: int,指定编码后序列的最大长度。- system_message: str,默认为"You are a helpful assistant.",表示系统消息。返回值:- 预处理后数据的字典,包括input_ids(编码后的输入序列)、labels(对应的标签序列)和attention_mask(注意力掩码)。说明:- 该函数首先定义了不同角色的起始标记,然后根据输入源数据,应用模板并进行编码,同时生成对应的标签序列。- 最后,对编码后的序列进行填充至最大长度,并返回预处理后的数据。"""# 定义不同角色的标记roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}# 获取tokenizer的开始和结束标记的IDim_start = tokenizer.im_start_idim_end = tokenizer.im_end_id# 将换行符转换为input_idsnl_tokens = tokenizer('\n').input_ids# 为不同的角色创建对应的标记序列_system = tokenizer('system').input_ids + nl_tokens_user = tokenizer('user').input_ids + nl_tokens_assistant = tokenizer('assistant').input_ids + nl_tokens# 初始化存储输入ID和目标ID的列表input_ids, targets = [], []# 遍历每个对话数据for i, source in enumerate(sources):# 如果对话的第一个消息不是用户的消息,则跳过它if roles[source[0]["from"]] != roles["user"]:source = source[1:]# 初始化当前对话的输入ID和目标ID列表input_id, target = [], []# 创建系统消息的输入ID序列,并添加到input_idsystem = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokensinput_id += system# 创建系统消息的目标ID序列,并添加到target,忽略序列中的某些部分target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens# 确保input_id和target的长度相同assert len(input_id) == len(target)# 遍历对话中的每个句子for j, sentence in enumerate(source):role = roles[sentence["from"]]# 创建当前句子的输入ID序列_input_id = tokenizer(role).input_ids + nl_tokens + \tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokensinput_id += _input_id# 根据角色创建当前句子的目标ID序列,并添加到targetif role == '<|im_start|>user':_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokenselif role == '<|im_start|>assistant':_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokenselse:raise NotImplementedErrortarget += _target# 确保input_id和target的长度相同assert len(input_id) == len(target)# 如果input_id序列长度小于max_len,则用padding填充input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))# 对应地,target序列也进行填充target += [IGNORE_TOKEN_ID] * (max_len - len(target))# 将当前对话的input_id和target添加到列表中input_ids.append(input_id[:max_len])targets.append(target[:max_len])# 将input_ids和targets列表转换为张量input_ids = torch.tensor(input_ids, dtype=torch.int)targets = torch.tensor(targets, dtype=torch.int)# 创建包含input_ids、labels和attention_mask的字典,并返回return dict(input_ids=input_ids,labels=targets,attention_mask=input_ids.ne(tokenizer.pad_token_id),  # 创建attention_mask,排除padding部分)class SupervisedDataset(Dataset):"""Dataset for supervised fine-tuning."""# 构造函数,初始化SupervisedDataset类的实例def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):super(SupervisedDataset, self).__init__()  # 调用基类的构造函数# 在rank 0进程中打印格式化输入信息rank0_print("Formatting inputs...")# 从原始数据中提取对话数据sources = [example["conversations"] for example in raw_data]# 使用preprocess函数预处理对话数据data_dict = preprocess(sources, tokenizer, max_len)# 将预处理后的数据赋值给实例变量self.input_ids = data_dict["input_ids"]self.labels = data_dict["labels"]self.attention_mask = data_dict["attention_mask"]# 返回数据集中样本的数量def __len__(self):return len(self.input_ids)# 根据索引i获取数据集中的样本def __getitem__(self, i) -> Dict[str, torch.Tensor]:return dict(input_ids=self.input_ids[i],labels=self.labels[i],attention_mask=self.attention_mask[i],)# 定义一个函数,用于创建监督式微调的数据模块
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,  # 用于分词的tokenizer对象data_args,  # 数据参数,包含数据路径和是否使用惰性加载等配置max_len,  # 输入序列的最大长度
) -> Dict:# 根据是否采用惰性加载选择数据集类dataset_cls = (LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset)# 在rank 0进程中打印加载数据信息rank0_print("Loading data...")# 打开并加载训练数据的JSON文件train_json = json.load(open(data_args.data_path, "r"))# 创建训练数据集实例train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)# 如果提供了评估数据路径,则加载评估数据集if data_args.eval_data_path:eval_json = json.load(open(data_args.eval_data_path, "r"))eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)else:# 如果没有提供评估数据路径,则设置评估数据集为Noneeval_dataset = None# 返回包含训练数据集和评估数据集的字典return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)这段代码定义了一个名为 `train` 的函数,它是启动模型训练流程的主函数。下面是对函数中每一行的逐行注释:```python
# 定义训练函数
def train():global local_rank  # 声明local_rank为全局变量,以便在函数内部修改其值# 创建一个参数解析器,用于解析命令行参数到数据类parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))# 解析命令行参数到对应的数据类实例(model_args,data_args,training_args,lora_args,) = parser.parse_args_into_dataclasses()# 如果使用DeepSpeed并且是单GPU环境,设置分布式类型为DeepSpeedif getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1:training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED# 设置local_rank为训练参数中的local_rank值local_rank = training_args.local_rank# 初始化设备映射和世界大小device_map = Noneworld_size = int(os.environ.get("WORLD_SIZE", 1))ddp = world_size != 1# 如果使用QLoRA并且是分布式数据并行(DDP)环境,设置设备映射if lora_args.q_lora:device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto"# 如果同时使用FSDP或ZeRO-3和QLoRA,输出警告信息if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():logging.warning("FSDP or ZeRO3 are incompatible with QLoRA.")# 检查模型是否为聊天模型is_chat_model = 'chat' in model_args.model_name_or_path.lower()# 如果使用LoRA并且启用了ZeRO-3,但不是聊天模型,抛出错误if (training_args.use_loraand not lora_args.q_loraand deepspeed.is_deepspeed_zero3_enabled()and not is_chat_model):raise RuntimeError("ZeRO3 is incompatible with LoRA when finetuning on base model.")# 设置模型加载参数model_load_kwargs = {'low_cpu_mem_usage': not deepspeed.is_deepspeed_zero3_enabled(),}# 加载模型配置config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path,cache_dir=training_args.cache_dir,trust_remote_code=True,)# 禁用缓存config.use_cache = False# 加载模型和分词器model = transformers.AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,config=config,cache_dir=training_args.cache_dir,device_map=device_map,trust_remote_code=True,# 如果使用LoRA和QLoRA,设置量化配置quantization_config=GPTQConfig(bits=4, disable_exllama=True)if training_args.use_lora and lora_args.q_loraelse None,**model_load_kwargs,)tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path,cache_dir=training_args.cache_dir,model_max_length=training_args.model_max_length,padding_side="right",use_fast=False,trust_remote_code=True,)# 设置分词器的填充标记IDtokenizer.pad_token_id = tokenizer.eod_id# 如果使用LoRAif training_args.use_lora:if lora_args.q_lora or is_chat_model:modules_to_save = Noneelse:modules_to_save = ["wte", "lm_head"]# 创建LoRA配置lora_config = LoraConfig(r=lora_args.lora_r,lora_alpha=lora_args.lora_alpha,target_modules=lora_args.lora_target_modules,lora_dropout=lora_args.lora_dropout,bias=lora_args.lora_bias,task_type="CAUSAL_LM",modules_to_save=modules_to_save  # 用于添加新标记的参数)# 如果使用QLoRA,准备模型进行k-bit训练if lora_args.q_lora:model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)# 应用LoRA配置到模型model = get_peft_model(model, lora_config)# 打印LoRA可训练参数model.print_trainable_parameters()# 如果启用梯度检查点,启用模型的输入梯度if training_args.gradient_checkpointing:model.enable_input_require_grads()# 加载数据模块data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length)# 创建训练器trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)# 开始训练trainer.train()# 保存训练状态trainer.save_state()# 安全地保存模型到硬盘safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)if __name__ == "__main__":train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/20119.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

个人关于ChatGPT的用法及建议

概述 这里只是个人常用的几个软件&#xff0c;做一下汇总&#xff0c;希望对各位有用。 如果有更高认知的朋友&#xff0c;请留下你的工具名称&#xff0c;提醒我一下&#xff0c;谢谢&#xff5e; 常用的chatgpt模型工具&#xff1a; 以下是一些知名的例子&#xff1a; 文…

代码随想录算法训练营Day55 | 647. 回文子串 516.最长回文子序列 动态规划总结篇

代码随想录算法训练营Day55 | 647. 回文子串 516.最长回文子序列 动态规划总结篇 LeetCode 647. 回文子串 题目链接&#xff1a;LeetCode 647. 回文子串 思路&#xff1a; class Solution { public:int countSubstrings(string s) {vector<vector<bool>> dp(s.…

AI学习指南机器学习篇-多元线性回归

AI学习指南机器学习篇-多元线性回归 在机器学习领域&#xff0c;多元线性回归是一种用于建立自变量和因变量之间关系的模型。在这篇博客中&#xff0c;我们将讨论多元线性回归模型的引入以及它对多个自变量对因变量的影响。我们还将讨论多元线性回归与简单线性回归的区别和应用…

江协科技STM32学习-1 购买24Mhz采样逻辑分析仪

前言&#xff1a; 本文是根据哔哩哔哩网站上“江协科技STM32”视频的学习笔记&#xff0c;在这里会记录下江协科技STM32开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了江协科技STM32教学视频和链接中的内容。 引用&#xff1a; STM32入门教程-2023版 细致讲…

CVPR2024 合成异常数据 工业异常检测 RealNet

前言 本文分享一个基于扩散模型的异常检测框架&#xff0c;用于检测工业场景的缺陷检测或异常检测。 强度可控扩散异常合成&#xff1a;基于扩散过程的合成策略&#xff0c;能够生成不同强度的异常样本&#xff0c;模仿真实异常样本的分布。异常感知特征选择&#xff1a;选择…

学习Java,stringbuilder用法

有sb.append添加元素&#xff0c;sb.reverse反转内容&#xff0c;sb.tostring转换成字符串&#xff0c;sb.length计算长度。

东莞酷得智能 组装机械狗电子玩具方案

这款机械狗玩具电子方案结合了现代电子技术和人工智能元素&#xff0c;旨在为用户提供一个高科技、互动性强的娱乐体验。通过不断的软件更新和硬件迭代&#xff0c;机械狗的功能将持续扩展。 一、功能特点&#xff1a; 1、自动巡游&#xff1a;机械狗能够自主在房间内巡游&am…

分库分表、读写分离--ShardingJDBC

1. 项目准备 1.1 建立数据库表 建立user_manage数据库&#xff0c;在该库中建立1张表app_user用来做分库前的测试&#xff0c;另外建12张按月份命名的表app_user_2024XX用来做分库。 CREATE DATABASE IF NOT EXISTS user_manage CHARACTER SET utf8 COLLATE utf8_general_ci…

Python中的__str__和__repr__:揭示字符串表示的奥秘

标题&#xff1a;Python中的__str__和__repr__&#xff1a;揭示字符串表示的奥秘 摘要 在Python中&#xff0c;对象的字符串表示对于调试和日志记录至关重要。__str__和__repr__是两个特殊的方法&#xff0c;用于定义对象的字符串表示形式。尽管它们在功能上相似&#xff0c;…

vm-bhyve网卡设定桥接故障解决@FreeBSD

问题 在使用vm-bhyve虚拟机管理软件的时候&#xff0c;使用vm无法绑定网卡igb0 vm switch add public igb0 报错&#xff1a;/usr/local/sbin/vm: ERROR: failed to add member igb0 to the virtual switch public 解决 于是准备用原生ifconfig命令来绑定&#xff0c;结果…

【Go基础】快速入门

Go基础入门 用20%的时间学习常用80%的语法 官方网址&#xff08;下载安装/官方文档/官方类库&#xff09; Download Go binaries from https://go.dev/dl/Reference the official Go documentation https://go.dev/doc/See all the the Go packages https://pkg.go.dev/Access…

Linux基础指令及其作用之网络操作

网络操作pingifconfigeth0 接口 ip常用选项和命令 netstat示例输出解释 curl示例输出及解释 wget示例输出解释 网络操作 ping ping 命令用于测试网络连接的连通性和响应时间。它通过向目标主机发送 ICMP 回显请求&#xff08;echo request&#xff09;数据包&#xff0c;并等…

wpf 依赖属性的含义理解

依赖属性允许没有自己的字段&#xff0c;可以通过Binding绑定到其它对象的属性或者说数据源上&#xff0c;从而获得值。 缘由 由于控件有很多的属性&#xff0c;有属性就有字段的内存开销&#xff0c;但实际上对于一个控件&#xff0c;我们大多数只会使用其部分常用属性&#…

ConvNeXt(CVPR 2022)论文解读

paper&#xff1a;A ConvNet for the 2020s official implementation&#xff1a;https://github.com/facebookresearch/ConvNeXt third-party implementation&#xff1a;https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py 背景 在…

代码随想录-算法训练营day55【动态规划16:两个字符串的删除操作、编辑距离、编辑距离总结篇】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part16● 583. 两个字符串的删除操作 ● 72. 编辑距离 ● 编辑距离总结篇 详细布置 583. 两个字符串的删除操作 本题和动态规划&#xff1a;115.不同的子序列 相比&#xff0c;其实就是两个字符串都…

流量回放平台与传统测试工具的对比分析

文章目录 一、流量回放平台的优势与挑战二、传统测试工具的优势与挑战三、实际案例演示四、解决方案五、答疑解惑5.1、传统工具不是也可以做到流量会放平台的无侵入性测试和性能瓶颈分析吗&#xff1f;5.2、开发流量回放平台的成本和使用传统测试工具的成本哪个更大&#xff1f…

基于SSM框架的垃圾分类系统的设计与实现(含源码+sql+开题报告+论文+论文答辩模板)

图1 前台首页截图 首页展示&#xff1a;首页展示法律法规、公示公告、用户交流论坛、分类指南、垃圾站点、以及个人中心&#xff1b; 法律法规&#xff1a;展示我国《城市生活垃圾分类及其评价标准》以及《生活垃圾分类标志》等最新法律法规&#xff1b; 公示公告&#xff1…

另一棵树的子树(oj题)

一、题目链接 https://leetcode.cn/problems/subtree-of-another-tree/submissions/536304222 二、题目思路 1.首先遍历大树&#xff0c;判断大树的根结点的值是否等于小树的根结点的值&#xff0c;如果不相等&#xff0c;就找大树的左孩子或者右孩子&#xff0c;以左孩子为根…

【线性表 - 数组和矩阵】

数组是一种连续存储线性结构&#xff0c;元素类型相同&#xff0c;大小相等&#xff0c;数组是多维的&#xff0c;通过使用整型索引值来访问他们的元素&#xff0c;数组尺寸不能改变。 知识点数组与矩阵相关题目 # 知识点 数组的优点: 存取速度快 数组的缺点: 事先必须知道…

php 实现:给图片加文字水印,图片水印,压缩图片

演示环境&#xff1a; 1、windows10 2、phpstudy 3、php7.4 一、案例演示&#xff1a; 二、素材准备 1、准备一张原始图片 2、准备一张水印图片&#xff08;透明底图的最好&#xff09; 3、字体库&#xff08;windows系统自带的字体库&#xff0c;路径在&#xff1a;C:\Window…