MoE模型性能还能更上一层楼?一次QLoRA微调实践

Fine-Tuning Mixtral 8x7B with QLoRA:Enhancing Model Performance 🚀

编者按:最近,混合专家(Mixture of Experts,MoE)这种模型设计策略展现出了卓越的语言理解能力,如何在此基础上进一步提升 MoE 模型的性能成为业界热点。

本文作者使用一种名为 QLoRA 的方法,通过量化和 LoRA 技术对 MoE 模型 Mixtral-8x7B 进行微调,以期大幅提高其性能。

作者详细阐明这种方法的诸多优势,包括显著增强 MoE 模型的理解生成能力、计算效率更高等。文中还逐步介绍了使用 QLoRA 微调 Mixtral-8x7B 的全过程。

本文探索了使用 QLoRA 推动 MoE 模型的性能改进这一技术方案。期待未来更多关于 MoE 模型的性能改进方案出现!

一、简介

目前整个业界都希望经过优化的模型能够表现出卓越的性能,这一追求不断推动着自然语言理解(natural language understanding)的发展。Mixtral-8x7B Mixture of Experts(MoE)模型就是其中之一,该模型在各种基准测试(benchmarks)中表现出优于同类产品的性能,尤其是优于 Llama 2 70B。

本教程采用一种名为 QLoRA 的创新方法对 Mixtral-8x7B 模型进行微调,该方法结合了量化(quantization)和 LoRA(Local Representation Adaptation)技术。期望通过这两种技术的结合来进一步增强Mixtral-8x7B模型的能力。

image.png

Source: Mixtral[1]

二、相关定义

● Mixtral 8x7B:一种混合专家模型,因其架构设计在自然语言处理任务中表现出色而闻名。

● QLoRA:Quantization 和 LoRA 技术相结合的缩写。量化涉及降低模型权重的精度,从而优化内存使用并加快计算速度。LoRA 可调整模型中的局部表征,增强模型对特定上下文的理解。

三、优势

● 增强性能:使用 QLoRA 对 Mixtral 8x7B 进行微调,可提高其性能,从而更好地理解和生成各种领域的文本。

● 能效比高:量化的整合降低了内存需求和计算复杂度,使模型更节省资源。

● 针对垂直领域进行微调:通过微调,该模型可针对特定任务进行定制,从而提高其在特定领域的准确性和相关性。

四、代码实现说明

本教程在 Notebook 环境中(译者注:使用Jupyter notebook 或白海IDP自研notebook)使用 Python。整个过程包括使用 "bitsandbytes "库加载 4 位精度的大型 Mixtral 模型。随后,在训练阶段使用 Hugging Face 的 PEFT 库实现 LoRA。

4.1 步骤 1:安装相关库

# You only need to run this once per machine, even if you stop/restart it
!pip install --upgrade pip
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets matplotlib

4.2 步骤 2:设置 Accelerator

from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfigfsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

4.3 步骤 3:使用Weights & Biases追踪性能指标

!pip install -q wandb -Uimport wandb, os
wandb.login()wandb_project = "viggo-finetune"
if len(wandb_project) > 0:os.environ["WANDB_PROJECT"] = wandb_project

4.4 步骤 4:加载数据集

from datasets import load_datasetdataset_name = "databricks/databricks-dolly-15k"train_dataset = load_dataset(dataset_name, split="train[0:800]")
eval_dataset = load_dataset(dataset_name, split="train[800:1000]")

4.5 步骤 5:加载基础模型

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfigbase_model_id = "mistralai/Mixtral-8x7B-v0.1"
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype=torch.bfloat16
)model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")# Tokenization 
tokenizer = AutoTokenizer.from_pretrained(base_model_id,padding_side="left",add_eos_token=True,add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_tokendef tokenize(prompt):result = tokenizer(prompt)result["labels"] = result["input_ids"].copy()return resultdef generate_and_tokenize_prompt(data_point):full_prompt = f"""Given a question and some additional context, provide an answer### Target sentence:Question: {data_point['instruction']}Additional Context: {f"Here is some context: {data_point['context']}" if len(data_point["context"]) > 0 else ""}Response: [/INST] {data_point['response']}</s>"""tokenized_prompt = tokenizer(full_prompt)return tokenized_prompttokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)untokenized_text = tokenizer.decode(tokenized_train_dataset[1]['input_ids']) 
print(untokenized_text)# Output
<s> Given a question and some additional context, provide an answer### Target sentence:Question: Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?Additional Context: Response: [/INST] The name of the third daughter is Alice</s></s>

4.6 步骤 6:获取数据集中各个样本长度的分布情况

import matplotlib.pyplot as pltdef plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):lengths = [len(x['input_ids']) for x in tokenized_train_dataset]lengths += [len(x['input_ids']) for x in tokenized_val_dataset]print(len(lengths))# Plotting the histogramplt.figure(figsize=(10, 6))plt.hist(lengths, bins=20, alpha=0.7, color='blue')plt.xlabel('Length of input_ids')plt.ylabel('Frequency')plt.title('Distribution of Lengths of input_ids')plt.show()plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)

image.png

Source: Image created by Author

4.7 步骤 7:在数据的左侧添加 padding ,以减少内存的使用

max_length = 320 # This was an appropriate max length for my dataset# redefine the tokenize function and tokenizertokenizer = AutoTokenizer.from_pretrained(base_model_id,padding_side="left",add_eos_token=True,  add_bos_token=True,  
)
tokenizer.pad_token = tokenizer.eos_tokendef tokenize(prompt):result = tokenizer(prompt,truncation=True,max_length=max_length,padding="max_length",)result["labels"] = result["input_ids"].copy()return resulttokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)untokenized_text = tokenizer.decode(tokenized_train_dataset[4]['input_ids']) 
print(untokenized_text)# Output
<s> Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute'].The attributes must be one of the following: ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']### Target sentence:When did Virgin Australia start operating?Here is some context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[/INST] Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s></s>
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)

image.png

Source: Image created by Author

4.8 步骤 8:设置 LoRA

from peft import prepare_model_for_kbit_trainingmodel.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)def print_trainable_parameters(model):"""Prints the number of trainable parameters in the model."""trainable_params = 0all_param = 0for _, param in model.named_parameters():all_param += param.numel()if param.requires_grad:trainable_params += param.numel()print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")from peft import LoraConfig, get_peft_modelconfig = LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj","w1","w2","w3","lm_head",],bias="none",lora_dropout=0.05,  # Conventionaltask_type="CAUSAL_LM",
)model = get_peft_model(model, config)
print_trainable_parameters(model)# Apply the accelerator. You can comment this out to remove the accelerator.
model = accelerator.prepare_model(model)# Output
trainable params: 120350720 || all params: 23602952192 || trainable%: 0.5098968934945001

4.9 步骤 9:进行训练

import transformers
from datetime import datetimeif torch.cuda.device_count() > 1: # If more than 1 GPUmodel.is_parallelizable = Truemodel.model_parallel = Trueproject = "databricks-dolly-finetune"
base_model_name = "mixtral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_nametokenizer.pad_token = tokenizer.eos_tokentrainer = transformers.Trainer(model=model,train_dataset=tokenized_train_dataset,eval_dataset=tokenized_val_dataset,args=transformers.TrainingArguments(output_dir=output_dir,warmup_steps=5,per_device_train_batch_size=1,gradient_checkpointing=True,gradient_accumulation_steps=4,max_steps=500,learning_rate=2.5e-5, logging_steps=25,fp16=True, optim="paged_adamw_8bit",logging_dir="./logs",        # Directory for storing logssave_strategy="steps",       # Save the model checkpoint every logging stepsave_steps=50,                # Save checkpoints every 50 stepsevaluation_strategy="steps", # Evaluate the model every logging stepeval_steps=50,               # Evaluate and save checkpoints every 50 stepsdo_eval=True,                # Perform evaluation at the end of trainingreport_to="wandb",           # Comment this out if you don't want to use weights & baisesrun_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          # Name of the W&B run (optional)),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

4.10 步骤 10:使用训练完毕的模型

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfigbase_model_id = "mistralai/Mixtral-8x7B-v0.1"
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype=torch.bfloat16
)base_model = AutoModelForCausalLM.from_pretrained(base_model_id,  # Mixtral, same as beforequantization_config=bnb_config,  # Same quantization config as beforedevice_map="auto",trust_remote_code=True,use_auth_token=True
)eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id,add_bos_token=True,trust_remote_code=True,
)
from peft import PeftModelft_model = PeftModel.from_pretrained(base_model, "mixtral-databricks-dolly-finetune/checkpoint-100")
eval_prompt = """Given a question and some additional context, provide an answer### Target sentence:
Question: When was Tomoaki Komorida born?
Here is some context: Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.### Response:
"""model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")ft_model.eval()with torch.no_grad():print(eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))Given a question and some additional context, provide an answer### Target sentence:
Question: When was Tomoaki Komorida born?
Here is some context: Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.### Response:
Tomoaki Komorida was born on July 10, 1981.

五、结论

利用 QLoRA 对 Mixtral-8x7B 模型进行微调是自然语言处理 (NLP) 领域的一个重要进展,它将模型性能提升到了新的高度。这一缜密的过程融合了量化和 LoRA 等前沿技术,为超越基准(benchmarks)提供了一条稳健的途径,甚至在各种评估指标上超越了强大的 Llama 2 70B 模型。

本教程的核心在于使用QLoRA进行微调,利用bitsandbytes以4位精度实例化模型,并运用Hugging Face 🤗的PEFT库。该指南不仅概述了微调方法,还揭示了实践过程中可能遇到的问题,如OutOfMemory errors,为用户提供了精确的解决途径。

从本质上讲,该教程并非是一个技术指南,更像一个倡导模型微调最佳实践的指引。它倡导协作式微调,请邀请其他研究人员和从业者一同踏上推动语言理解模型发展的旅程。

前沿技术、详细的指导以及合作共赢的态度使得该教程对于NLP社区来说是一个非常重要且不可或缺的资源,期望能够引导 NLP 社区进一步提高模型性能,丰富理解能力。

Resources:

● Mixtral-8x7b[2]

● Thanks to Harper Carroll[2]

文中链接

[1]https://mistral.ai/news/mixtral-of-experts/

[2]https://huggingface.co/blog/mixtral

[3]https://twitter.com/HarperSCarroll

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/613260.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Win10子系统Ubuntu实战(二)

在 Windows 10 中安装 Ubuntu 子系统&#xff08;Windows Subsystem for Linux&#xff0c;简称 WSL&#xff09;有几个主要的用途和好处&#xff1a;Linux 环境的支持、跨平台开发、命令行工具、测试和验证、教育用途。总体而言&#xff0c;WSL 提供了一种将 Windows 和 Linux…

使用python执行系统命令的五种方式

在日常开发中&#xff0c;有时需要在Python脚本中执行系统命令&#xff0c;Python有五种方式来执行系统命令&#xff0c;推荐使用第五种。 python执行系统命令的五种方式 方法1: os.system 这是最简单的方法&#xff0c;适合简单的业务场景&#xff0c;输入为完整命令字符串…

【IP-Adapter】进阶 - 同款人物【2】 ☑

测试模型&#xff1a;###最爱的模型\flat2DAnimerge_v30_2.safetensors [b2c93e7a89] 原图&#xff1a; 加入 control1 [IP-Adapter] 加入 control 2 [OpenPose] 通过openpose骨骼图修改人物动作。 加入 control 3 lineart 加入cotrol3 …

【LV12 DAY13 UART 串口通信】

UART–&#xff08;一种通信协议&#xff09; 通用异步收发器&#xff0c;是一种通用的串行&#xff0c;异步通信总线&#xff0c;该总线有两条数据线&#xff0c;可以实现全双工的发送和接收&#xff0c;在嵌入式系统中常用于主机与辅助设备之间的通信。 波特率 波特率用于描…

1.10 Unity中的数据存储 JSON

一、介绍 Json是最常用也是目前用的比较多的一种&#xff0c;超轻量级&#xff0c;可便捷性使用&#xff0c;平时用到比较多的都是解析Json和往Json中添加数据、修改数据等等JSON(JavaScript Object Notation,JS对象标记)是一种轻量级的数据交换格式&#xff0c;它基于ECMAScr…

Leetcode18-算术三元组的数目(2367)

1、题目 给你一个下标从 0 开始、严格递增 的整数数组 nums 和一个正整数 diff 。如果满足下述全部条件&#xff0c;则三元组 (i, j, k) 就是一个 算术三元组 &#xff1a; i < j < k &#xff0c; nums[j] - nums[i] diff 且 nums[k] - nums[j] diff 返回不同 算术三…

PR-视频去水印

文章目录 前言PR-视频去水印实现示例 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差&#xff0c;实在白嫖的话&#xff0c;那欢迎常来啊!!! PR-视频…

python绘制热力图-数据处理-VOC数据类别标签分布及数量统计(附代码)

前言 当你需要统计训练数据中每个类别标签有多少&#xff0c;并且想知道坐标中心分布在图像的位置信息时&#xff0c;你可以利用一下脚本进行计算&#xff01; 步骤 要绘制热力图来分析VOC数据的分布统计&#xff0c;可以按照以下步骤进行&#xff1a; 数据处理&#xff1…

XCTF:MISCall[WriteUP]

使用file命令&#xff0c;查看该文件类型 file d02f31b893164d56b7a8e5edb47d9be5 文件类型&#xff1a;bzip2 使用bzip2命令可对该文件进行解压 bzip2 -d d02f31b893164d56b7a8e5edb47d9be5 生成了一个后缀为.out的文件 再次使用file命令&#xff0c;查看该文件类型 file…

LC17. 电话号码的字母组合

代码随想录 class Solution {String[] numString {"", "", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"};List<String> res new ArrayList<…

[开发语言][python][c++]:C++中的this指针和Python中的Self -- 26岁生日

C中的this指针和Python中的Self 1. python中的Self2. C中的this指针3. C中的this指针和Python中self的异同点&#xff1a; 以朋友的新岁祝福开篇&#xff0c;祝笔者也祝大家☺️&#xff1a; 一岁一礼 一寸欢喜且喜且乐 且以永日​ From VardoZ癸卯年十一月廿六(兔年)之…

缓存代理服务器

1 缓存代理 1.1 缓存代理的概述 web代理的作用 缓存网页对象&#xff0c;减少重复请求 存储一些之前被访问的或且可能将要备再次访问的静态网页资源对象&#xff0c;使用户可以直接从缓存代理服务器获取资源&#xff0c;从而减少上游原始服务器的负载压力&#xff0c;加快整…

LeetCode刷题--- 地下城游戏

个人主页&#xff1a;元清加油_【C】,【C语言】,【数据结构与算法】-CSDN博客 个人专栏 力扣递归算法题 http://t.csdnimg.cn/yUl2I 【C】 ​​​​​​http://t.csdnimg.cn/6AbpV 数据结构与算法 ​​​http://t.csdnimg.cn/hKh2l 前言&#xff1a;这个专栏主要讲述动…

Java基础 | 类和对象

Java基础 | 类和对象 类成员变量成员方法权限修饰符 局部变量final变量this关键字类的构造方法静态变量和静态方法static修饰符 类的主方法 对象对象的创建对象的引用 数据类型转换隐式类型转换显式类型转换 所有知识点均来源于《Java从入门到精通》&#xff08;第六版&#xf…

在软件测试过程中如何有效的开展接口自动化测试

一.简介 接口自动化测试是指使用自动化测试工具和脚本对软件系统中的接口进行测试的过程。其目的是在软件开发过程中&#xff0c;通过对接口的自动化测试来提高测试效率和测试质量&#xff0c;减少人工测试的工作量和测试成本&#xff0c;并且能够快速发现和修复接口错误&#…

最长回文数字

中心扩散 中心扩散就是从中心往外逐层扩散。以单个字符往两边扩散&#xff0c;如果两边字符相等则是回文串。扩散又分两种情况 分别是以该字符为中心&#xff0c;和以该字符和下一个字符的空隙为中心 let longestPalindrome2 function (s) {const n s.lengthif (n 1) retu…

如何保护linux服务器远程使用的安全

服务器安全是一个非常敏感的问题&#xff0c;因服务器远程入侵导致数据丢失的安全问题频频出现&#xff0c;一旦服务器入侵就会对个人和企业造成巨大的损失。因此&#xff0c;在日常使用服务器的时候&#xff0c;我们需要采取一些安全措施来保障服务器的安全性。 目前服务器系…

线程休眠、线程让步、线程优先级相关内容学习笔记

1、线程休眠 &#xff08;1&#xff09;sleep() 如果需要让当前正在执行的线程暂停一段时间&#xff0c;并进入阻塞状态&#xff08;Timed_Waiting)&#xff0c;则可以通过调用Thread类的静态sleep()方法来实现。 static void sleep(long millis)&#xff1a;让当前正在执行的线…

Shell编程--函数function

函数 1.定义函数2.调用函数2.1.取消函数2.2.其他脚本调用 3.函数传参 1.定义函数 函数声明&#xff1a; function_name () { list of commands } 函数名 function_name&#xff0c;这就是你将使用它从其他地方在你的脚本调用。 function (功能) 功能函数 计算机函数&…

C++ 异常处理

C++ 异常处理 实验介绍 所谓的异常便是程序中数据出现不合理的特殊情况处理,在做项目时我们常常需要想到的是特殊的情况,只有将所有的特殊情况处理好之后程序才能很好的运行。 那么异常处理跟多态有什么关系呢? C++ 标准库中已经存在异常处理类,并且就是使用了多态的方式…