Unsloth 微调 Llama 3

本文参考:
https://colab.research.google.com/drive/135ced7oHytdxu3N2DNe1Z0kqjyYIkDXp
改编自:https://blog.csdn.net/qq_38628046/article/details/138906504


文章目录

    • 一、项目说明
      • 安装相关依赖
      • 下载模型和数据
    • 二、训练
      • 1、加载 model、tokenizer
      • 2、设置LoRA训练参数
      • 3、准备数据集
        • 数据格式处理
        • 加载数据集并进行映射处理操作
      • 4、训练超参数配置
        • SFTTrainer
        • 显示当前内存状态
      • 5、执行训练
      • 6、模型推理
      • 7、保存LoRA模型
      • 8、加载模型
      • 9、执行推理
      • 10、保存完整模型
      • 11、保存为GGUF格式


一、项目说明

Llama-3-Chinese-Instruct 是基于Meta Llama-3的中文开源大模型,其在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。

GitHub:https://github.com/ymcui/Chinese-LLaMA-Alpaca-3


安装相关依赖

unsloth 根据不同改的 cuda 版本有不同的安装方式,详见:https://blog.csdn.net/lovechris00/article/details/140404957

pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes

下载模型和数据

Unsloth 支持很多模型: https://huggingface.co/unsloth,包括 mistral,llama,gemma

这里我们使用 FlagAlpha/Llama3-Chinese-8B-Instruct 模型 和 kigner/ruozhiba-llama3 数据集

提前下载:

export HF_ENDPOINT=https://hf-mirror.comhuggingface-cli download FlagAlpha/Llama3-Chinese-8B-Instruct
uggingface-cli download --repo-type dataset kigner/ruozhiba-llama3

数据将保存到 ~/.cache/huggingface/hub


你也可以使用 modelscope下载,如:

from modelscope import snapshot_downloadmodel_dir = snapshot_download('FlagAlpha/Llama3-Chinese-8B-Instruct',cache_dir="/root/models")

安装 modelscope

pip install modelscope 

二、训练

1、加载 model、tokenizer

from unsloth import FastLanguageModel
import torchmodel, tokenizer = FastLanguageModel.from_pretrained(model_name = "/root/models/Llama3-Chinese-8B-Instruct", # 模型路径max_seq_length = 2048, # 可以设置为任何值内部做了自适应处理# dtype = torch.float16, # 数据类型使用float16dtype = None,  # 会自动推断类型load_in_4bit = True, # 使用4bit量化来减少内存使用

2、设置LoRA训练参数

model = FastLanguageModel.get_peft_model(model,r = 16, # 选择任何大于0的数字!建议使用8、16、32、64、128target_modules = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],lora_alpha = 16,lora_dropout = 0,  # 支持任何值,但等于0时经过优化bias = "none",    # 支持任何值,但等于"none"时经过优化# [NEW] "unsloth" 使用的VRAM减少30%,适用于2倍更大的批处理大小!use_gradient_checkpointing = "unsloth", # True或"unsloth"适用于非常长的上下文random_state = 3407,use_rslora = False,  # 支持排名稳定的LoRAloftq_config = None, # 和LoftQ

3、准备数据集

准备数据集其实就是指令集构建,LLM的微调一般指指令微调过程。所谓指令微调,就是使用指定的微调数据格式、形式。
训练目标是让模型具有理解并遵循用户指令的能力。因此在指令集构建时,应该针对目标任务,针对性的构建任务指令集。
这里使用 alpaca 格式的数据集,格式形式如下:

[{"instruction": "用户指令(必填)","input": "用户输入(选填)","output": "模型回答(必填)",},"system": "系统提示词(选填)","history": [["第一轮指令(选填)", "第一轮回答(选填)"],["第二轮指令(选填)", "第二轮回答(选填)"]]
]

  • instruction:用户指令,要求AI执行的任务或问题
  • input:用户输入,是完成用户指令所必须的输入内容,就是执行指令所需的具体信息或上下文
  • output:模型回答,根据给定的指令和输入生成答案

这里根据企业私有文档数据,生成相关格式的训练数据集,大概格式如下:

[{"instruction": "内退条件是什么?","input": "","output": "内退条件包括与公司签订正式劳动合同并连续工作满20年及以上,以及距离法定退休年龄不足5年。特殊工种符合国家相关规定可提前退休的也可在退休前5年内提出内退申请。"},
]

数据格式处理

定义对数据处理的函数方法

alpaca_prompt = """下面是一项描述任务的说明,配有提供进一步背景信息的输入。写出一个适当完成请求的回应。### Instruction:
{}### Input:
{}### Response:
{}"""EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):instructions = examples["instruction"]inputs       = examples["input"]outputs      = examples["output"]texts = []for instruction, input, output in zip(instructions, inputs, outputs):# Must add EOS_TOKEN, otherwise your generation will go on forever!text = alpaca_prompt.format(instruction, input, output) + EOS_TOKENtexts.append(text)return { "text" : texts, }

加载数据集并进行映射处理操作
from datasets import load_dataset
dataset = load_dataset("kigner/ruozhiba-llama3", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)print(dataset[0])

经处理后的一条数据格式如下:

{'output': '输出内容','input': '','instruction': '指令内容','text': '下面是一项描述任务的说明,配有提供进一步背景信息的输入。写出一个适当完成请求的回应。\n\n### Instruction:\n指令内容?\n\n### Input:\n\n\n### Response:\n输出内容。<|end_of_text|>'
}

4、训练超参数配置

from transformers import TrainingArguments
from trl import SFTTrainertraining_args  = TrainingArguments(output_dir = "models/lora/llama", # 输出目录per_device_train_batch_size = 2, # 每个设备的训练批量大小gradient_accumulation_steps = 4, # 梯度累积步数warmup_steps = 5,max_steps = 60, # 最大训练步数,测试时设置# num_train_epochs= 5, # 训练轮数   logging_steps = 10,  # 日志记录频率save_strategy = "steps", # 模型保存策略save_steps = 100, # 模型保存步数learning_rate = 2e-4, # 学习率fp16 = not torch.cuda.is_bf16_supported(), # 是否使用float16训练bf16 = torch.cuda.is_bf16_supported(), # 是否使用bfloat16训练optim = "adamw_8bit",  # 优化器weight_decay = 0.01,  # 正则化技术,通过在损失函数中添加一个正则化项来减小权重的大小lr_scheduler_type = "linear",  # 学习率衰减策略seed = 3407, # 随机种子)

SFTTrainer
trainer = SFTTrainer(model=model, # 模型tokenizer=tokenizer, # 分词器args=training_args, # 训练参数train_dataset=dataset, # 训练数据集dataset_text_field="text", # 数据集文本字段名称max_seq_length=2048, # 最大序列长度dataset_num_proc=2, # 数据集处理进程数packing=False, # 可以让短序列的训练速度提高5倍
)

显示当前内存状态
# 当前GPU信息
gpu_stats = torch.cuda.get_device_properties(0)
# 当前模型内存占用
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# GPU最大内存
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

可以看出当前模型占用5.633G显存


5、执行训练

trainer_stats = trainer.train()

显示最终内存和时间统计数据

# 计算总的GPU使用内存(单位:GB)
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# 计算LoRA模型使用的GPU内存(单位:GB)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
# 计算总的GPU内存使用百分比
used_percentage = round(used_memory / max_memory * 100, 3)
# 计算LoRA模型的GPU内存使用百分比
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

可以看出模型训练时显存增加了0.732G


6、模型推理

FastLanguageModel.for_inference(model) # 启用原生推理速度快2倍
inputs = tokenizer(
[alpaca_prompt.format("内退条件是什么?", # instruction"", # input"", # output)
], return_tensors = "pt").to("cuda")outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

可以看出模型回答跟训练数据集中的数据意思基本一致。


7、保存LoRA模型

注意:这仅保存 LoRA 适配器,而不是完整模型

lora_model = '/home/username/models/lora/llama0715/llama_lora'
model.save_pretrained(lora_model)
# adapter_config.json  adapter_model.safetensors  README.mdtokenizer.save_pretrained(lora_model)
# tokenizer_config.json  special_tokens_map.json  tokenizer.json# 保存到huggingface
# model.push_to_hub("your_name/lora_model", token = "...")
# tokenizer.push_to_hub("your_name/lora_model", token = "...")

adapter_config.json 内容如下:

{"alpha_pattern": {},"auto_mapping": null,"base_model_name_or_path": "FlagAlpha/Llama3-Chinese-8B-Instruct","bias": "none","fan_in_fan_out": false,"inference_mode": true,"init_lora_weights": true,"layer_replication": null,"layers_pattern": null,"layers_to_transform": null,"loftq_config": {},"lora_alpha": 16,"lora_dropout": 0,"megatron_config": null,"megatron_core": "megatron.core","modules_to_save": null,"peft_type": "LORA","r": 16,"rank_pattern": {},"revision": "unsloth","target_modules": ["gate_proj","k_proj","up_proj","q_proj","o_proj","v_proj","down_proj"],"task_type": "CAUSAL_LM","use_dora": false,"use_rslora": false
}

8、加载模型

注意:从新加载模型将额外占用显存,若GPU显存不足,需关闭、清除先前加载、训练模型的内存占用
加载刚保存的LoRA适配器用于推断,他将自动加载整个模型及LoRA适配器。adapter_config.json定义了完整模型的路径。

import torch
from unsloth import FastLanguageModelmodel, tokenizer = FastLanguageModel.from_pretrained(model_name = "models/llama_lora",max_seq_length = 2048,dtype = torch.float16,load_in_4bit = True,
)FastLanguageModel.for_inference(model)

9、执行推理

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

10、保存完整模型

# 合并到16bit 保存到本地 OR huggingface
model.save_pretrained_merged("models/Llama3", tokenizer, save_method = "merged_16bit",)
# model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")# 合并到4bit 保存到本地 OR huggingface
model.save_pretrained_merged("models/Llama3", tokenizer, save_method = "merged_4bit",)
# model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

11、保存为GGUF格式

将模型保存为GGUF格式

# 保存到 16bit GGUF 体积大
model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")# 保存到 8bit Q8_0 体积适中
model.save_pretrained_gguf("model", tokenizer,)
model.push_to_hub_gguf("hf/model", tokenizer, token = "")# 保存到 q4_k_m GGUF 体积小
model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

2024-07-15(一)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872557.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零开始接触人工智能大模型,该如何学习?

人工智能是计算机科学领域中最具前瞻性和影响力的技术之一。它是一种智慧型算法&#xff0c;能够模拟人类的思维过程&#xff0c;处理大量的数据和信息&#xff0c;从而发现隐藏在其中的规律和趋势。人工智能的应用范围非常广泛&#xff0c;包括语音识别、图像识别、自然语言处…

高精度减法(C++)

【题目描述】 求两个大的正整数相减的差。 【输入】 共2行&#xff0c;第1行是被减数a&#xff0c;第2行是减数b。每个大整数不超过200位&#xff0c;不会有多余的前导零。注意&#xff0c;a 可能小于 b。 【输出】 一行&#xff0c;即所求的差。 【输入样例】 99999999…

《简历宝典》14 - 简历中“项目经历”,实战讲解,前端篇

上一节我们针对项目经历做了内功式的讲解&#xff0c;为了加深读者的印象&#xff0c;可以更轻松的套用到自己的简历上&#xff0c;本章继续从前端开发、Java开发以及软件测试的三个角度&#xff0c;再以校招和初级、中级以及高级三个维度分别入手&#xff0c;以实战讲解的形式…

gihub导入gitee仓库实现仓库同步

昨天在GitHub里导入了gitee仓库&#xff0c;但是在仓库同步这里卡了很久&#xff0c;因为网上大多数都是从github导入gitee&#xff0c;然后github生成token放入实现同步&#xff0c;但是我找到一种更为方便的&#xff01; 1.首先找到项目文件下的.git文件里的config文件 2.在…

Python实战MySQL之数据库操作全流程详解

概要 MySQL是一种广泛使用的关系型数据库管理系统,Python可以通过多种方式与MySQL进行交互。本文将详细介绍如何使用Python操作MySQL数据库,包括安装必要的库、连接数据库、执行基本的CRUD(创建、读取、更新、删除)操作,并包含具体的示例代码,帮助全面掌握这一过程。 准…

Vue 和 React 框架实现滚动缓冲区

Vue 实现 <template><div id"app" scroll"handleScroll"><!-- 页面内容 --><div v-for"item in items" :key"item">{{ item }}</div></div> </template><script> export default {d…

dom4j 操作 xml 之按照顺序插入标签

最近学了一下 dom4j 操作 xml 文件&#xff0c;特此记录一下。 public class Dom4jNullTagFiller {public static void main(String[] args) throws DocumentException {SAXReader reader new SAXReader();//加载 xml 文件Document document reader.read("C:\\Users\\24…

基于jeecgboot-vue3的Flowable流程支持bpmn流程设计器与仿钉钉流程设计器-编辑多版本处理

因为这个项目license问题无法开源&#xff0c;更多技术支持与服务请加入我的知识星球。 1、前端编辑带有仿钉钉流程的处理 /** 编辑流程设计弹窗页面 */const handleLoadXml (row) > {console.log("handleLoadXml row",row)const params {flowKey: row.key,ver…

搜集日志。

logstash 负责&#xff1a; 接收数据 input — 解析过滤并转换数据 filter(此插件可选) — 输出数据 output input — decode — filter — encode — output elasticsearch 查询和保存数据 Elasticsearch 去中心化集群 Data node 消耗大量 CPU、内存和 I/O 资源 分担一部分…

四、GD32 MCU 常见外设介绍

系统架构 1.RCU 时钟介绍 众所周知&#xff0c;时钟是MCU能正常运行的基本条件&#xff0c;就好比心跳或脉搏&#xff0c;为所有的工作单元提供时间 基数。时钟控制单元提供了一系列频率的时钟功能&#xff0c;包括多个内部RC振荡器时钟(IRC)、一个外部 高速晶体振荡器时钟(H…

Docker修改Postgresql密码

在Docker环境中&#xff0c;对已运行的PostgreSQL数据库实例进行密码更改是一项常见的维护操作。下面将详述如何通过一系列命令行操作来实现这一目标。 修改方式 查看容器状态及信息 我们需要定位到正在运行的PostgreSQL容器以获取其相关信息。执行以下命令列出所有正在运行…

Mongodb多键索引中索引边界的混合

学习mongodb&#xff0c;体会mongodb的每一个使用细节&#xff0c;欢迎阅读威赞的文章。这是威赞发布的第93篇mongodb技术文章&#xff0c;欢迎浏览本专栏威赞发布的其他文章。如果您认为我的文章对您有帮助或者解决您的问题&#xff0c;欢迎在文章下面点个赞&#xff0c;或者关…

安全防御---防火墙双击热备与带宽管理

目录 一、实验拓扑 二、实验需求 三、实验的大致思路 四、实验过程 4、基础配置 4.1 FW4的接口信息 4.2 新建办公&#xff0c;生产&#xff0c;游客&#xff0c;电信&#xff0c;移动安全区域 4.3 接口的网络配置 生产区:10.0.1.2/24 办公区:10.0.2.2/24 4.4 FW4的…

极地生产力自主采样系统的观测:融池比例统计 MEDEA 融池比例数据集

Observations from the Autonomous Polar Productivity Sampling System. 极地生产力自主采样系统的观测结果 简介 该项目是美国国家航空航天局 ICESCAPE 大型项目的一部分&#xff0c;旨在研究浮游植物丰度的长期季节性变化与整个生长季节在波弗特海和楚科奇海测量到的海冰…

Spring与设计模式实战之策略模式

Spring与设计模式实战之策略模式 引言 在现代软件开发中&#xff0c;设计模式是解决常见设计问题的有效工具。它们提供了经过验证的解决方案&#xff0c;帮助开发人员构建灵活、可扩展和可维护的系统。本文将探讨策略模式在Spring框架中的应用&#xff0c;并通过实际例子展示…

Linux 驱动开发 举例

Linux驱动开发涉及编写内核模块或设备驱动程序&#xff0c;以便让Linux内核能够识别和控制硬件设备。以下是一个简单的Linux驱动开发示例&#xff0c;这个示例将展示如何创建一个简单的字符设备驱动。 示例&#xff1a;简单的字符设备驱动 1. 定义设备驱动结构 首先&#xf…

深度学习损失计算

文章目录 深度学习损失计算1.如何计算当前epoch的损失&#xff1f;2.为什么要计算样本平均损失&#xff0c;而不是计算批次平均损失&#xff1f; 深度学习损失计算 1.如何计算当前epoch的损失&#xff1f; 深度学习中的损失计算&#xff0c;通常为数据集的平均损失&#xff0…

CREC晶振产品分类

CREC晶振大类有石英晶体谐振器、石英晶体振荡器、石英晶体滤波器 其中石英晶体谐振器&#xff1a; KHZ石英谐振器 车规级32.768KHz石英谐振器 专为汽车RTC应用而设计&#xff0c;通过AECQ-200可靠性测试&#xff0c;满足汽车电子的高标准时频需求&#xff0c;为客户提供可靠…

完整的优化流程需要做什么工作

&#x1f47d;System.out.println(“&#x1f44b;&#x1f3fc;嗨&#xff0c;大家好&#xff0c;我是代码不会敲的小符&#xff0c;目前工作于上海某电商服务公司…”); &#x1f4da;System.out.println(“&#x1f388;如果文章中有错误的地方&#xff0c;恳请大家指正&…

三生随记——空调的诅咒

在一个炎热的夏日&#xff0c;小镇上的居民们都在忍受着高温的煎熬。阳光无情地炙烤着大地&#xff0c;空气仿佛凝固了一般&#xff0c;让人喘不过气来。 杰克和艾米是一对年轻的夫妻&#xff0c;他们刚刚搬进了这座小镇边缘的一座古老房子。这座房子虽然宽敞&#xff0c;但却透…