使用 QLoRA 在 Google Colab 中微调 Mistral 7b(完整指南)
在本文中,我们将在一个名为 Enlighten 的游戏的整个代码库上微调 Mistral 7b,所有这些都在 Google Colab(或 Kaggle)中免费提供合成数据。在我们的基准测试中,由此产生的模型将优于 Openai 的 GPT-4。
步骤如下:
- 将代码库转换为基于合成对话的训练测试数据集
- 使用 QLoRA 进行微调
- 评估新模型
- 评估基础模型 + GPT-4
- (可选)将适配器与基本型号合并
- (可选)量化模型并获取 GGUF 格式
本文中使用的所有代码和数据集都在资源部分提供,您只需要一个免费的拥抱脸和一个免费的 Google Colab 或 Kaggle 帐户。
介绍
考虑到你已经深入研究LLMs了,你已经知道什么是大型语言模型或它是如何训练的,所以请随时跳过这一部分,如果你只想学习编码部分,请跳到这一部分。
安德烈·卡帕西(Andre Karpathy)也制作了这个惊人的视频LLMs。本文的介绍主要基于它。
大型语言模型(LLM)是进行下一个标记预测的深度学习模型,它们将文本作为输入并预测句子中下一个最可能的单词。
下一个代币预测
然后,他们获取这个新标记,将其添加到上一个输入的末尾,并预测下一个输入。这种情况一直持续到预测的代币是 EOS(序列结束)代币,此时它会停止。
训练这些LLMs需要大量的数据和计算,例如,Llama 2 70b 在 10TB 文本和 6000 个 GPU 上训练了 12 天,花费了大约 200 万美元,最终,我们得到的是一个完成文档的模型。此步骤称为预训练。
骆驼 2 70b
为了让它像助手一样说话,模型稍后会使用基于对话的数据进行训练,在此步骤中,模型将学习像助手一样行事,并使用在相关步骤中学到的所有知识。这为我们提供了指令模型。
此外,为了增加更多知识或增强模型在某些领域或更多领域的能力,我们根据新数据对模型进行微调,与之前的步骤相比,这种微调需要的数据和计算要少得多,但在消费级硬件上仍然不可能。
PEFT(参数高效微调)解决了这个问题,并应用了巧妙的方法,即使在 Google Colab 的免费套餐上也可以进行微调。
我们将使用量化和 LoRA(Low-Rank Adaptation)来微调 Mistral 7b,指导并向其介绍新知识。
LoRA(低等级适配)
LoRA 冻结模型的参数 (W0), 将小的可训练适配器层 (ΔW = BA) 附加到模型上, 并且只训练适配器.这大大减少了可训练参数,并消耗了更少的 RAM。LoRA 中一个重要的超参数是 r,在本例中为 r=2。
因此,我们将使用这种方法并对我们的数据进行微调 Mistral。
准备数据
数据集可在此处找到,因此您不需要在本节中运行任何代码,但我建议您阅读它以了解数据集是如何制作的。
如前所述,我们将对 Enlighten 的代码库进行微调,首先我在每个类中编写了一些关于该类和所有方法的文档。下面是其中一个类作为示例。
using DG.Tweening;
using UnityEngine;
using UnityEngine.Events;/** Player.Interactables.AnimatedInteractable* InteractableObject.cs is the base(abstract) class for all interactable objects in the game. they all must inherit from it or one of its children* all interactable objects have a child of InteractableObject.cs class attached to them* each script that inherits from InteractableObject.cs has its own custom logic for when the player is focusing on it and when the player interacts with it* this class (AnimatedInteractable.cs) inherits from InteractableObject.cs and adds the functionality of playing an animation when the player interacts with the object* other scripts can subscribe to the onInteractAction event to add custom logic when the player interacts with the object* gameObjects with this script attached to them must have an animator component with a trigger parameter called "OnInteract" and an animation that plays when the trigger is called*/[RequireComponent(typeof(Animator))]
public class AnimatedInteractable : InteractableObject {private Animator _animator;[Tooltip("If true, the object will only be animated once then disabled.")] [SerializeField] private bool isOneTimeAnimated;//cooldown between each interaction. If 0, there is no cooldown[SerializeField] private float cooldown;//the action to invoke when the player interacts with the object. Set in the inspector[SerializeField] private UnityEvent onInteractAction;[SerializeField] private AudioSource audioSource;private void Start() {_animator = GetComponent<Animator>();}//player is no longer focusing on the current interactable object. child classes can override this method to add custom logicprotected override void OnObjectWentOutOfFocus() { }//Called by PlayerInteractableObjectsManager.cs when the player presses the interact button while focusing on the object. Plays the animation and invokes the onInteractActionpublic override void Interact() {//play the animationaudioSource.Play();_animator.SetTrigger("OnInteract");//if the object is one time animated, disable the collider so the player can't interact with it againif (isOneTimeAnimated) GetComponent<Collider>().enabled = false;//if it has a cooldown, disable the collider for the duration of the cooldownelse if (cooldown != 0) {GetComponent<Collider>().enabled = false;DOTween.Sequence().AppendInterval(cooldown).OnComplete(() => { GetComponent<Collider>().enabled = true; });}//invoke the onInteractActiononInteractAction?.Invoke();}
}
正如你所看到的,我们的数据目前只是一堆 C# 类,但我们需要基于指令的数据,以问答对的形式出现(这些原始 C# 类可用于微调基本的非指令模型,但生成的模型也将是非指令的,只会完成一个文档,主要用于代码完成)
为了解决这个问题,我们将使用一个更大、更强大的模型来基于代码库合成生成我们的数据,我选择了新发布的 Google Gemini Pro 来完成这项任务,因为它对于这个用例既免费又强大。(Gpt-4 将是最好的模型,但 API 不是免费的)。
我们需要两个数据集,一个用于训练,一个用于测试,两者都将由 Gemini 合成生成。训练数据将是一个关于代码及其答案的问题。测试数据集将采用多项选择题的形式,这是一个问题,后跟 4 个选项和正确的一个。
对于每个 C# 类,我生成了 20 个用于训练的 Q/A,生成了 3 个用于测试的 Q/A。20 个培训问题包括 10 个 just-code 问题和 10 个一般问题。总共有大约 90 个 C# 类。
现在我们将把每个类都交给 Gemini 并要求它生成我们的数据,为此,我们需要我们自己的自定义系统消息。这些是我设计的系统消息。
10 个纯代码问题的系统消息(训练数据)
# you take a C# class from a Unity project with it's documentaion and create 10 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# all questions should ask for code and all answers should be C# code.
# questions should have context.
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 10 objects, each object has 3 text fields: class name, Q(question), A(answer)
# any '\n' in the text fields MUST be '\\n' so that when reading it later on, we won't run into any issues# example output with 2 question/answer
[{"class": "className.cs","Q": "in 'className.cs' How does the beast check if the player is in sight?", "A": " ```csharp\\nVector3 direction = (player.position - new Vector3(0, 0.5f, 0)) - beastTransform.position;\\nif (Physics.Raycast(beastTransform.position, direction, out hit, eyeSightDistance, layerMask)) {\\n if (hit.collider.CompareTag("Player")) {\\n return true;\\n }\\n}\\n```"},{"class": "className.cs","Q": "What is the code used to calculate the distance between the player and the beast in 'className.cs' ?","A": " ```csharp\\nif (navMeshAgent.remainingDistance > distanceToPlayerThatBeastRuns)\\n navMeshAgent.speed = Constants.BeastFastSpeed;\\nelse navMeshAgent.speed = normalSpeed;\\n```"}
]
# end of examples.# this is the C# class:
10 个一般问题的系统消息(训练数据)
# you take a C# class from a Unity project with it's documentaion and create 10 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# questions should have context.
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 10 objects, each object has 3 text fields: class name, Q(question), A(answer)
# any '\n' in the text fields MUST be '\\n' so that when reading it later on, we won't run into any issues# example output with 2 question/answer
[{"class": "className.cs","Q": "What is the purpose of the className.cs class?", "A": "The className.cs class is the main controller for the beast. It manages the state of the beast and the transitions between them.\\n it is implemented in singleton pattern"},{"class": "className.cs", "Q": "in 'className.cs' What is the purpose of the _roamingState variable?","A": "The _roamingState variable is an instance of the BeastStateRoaming class, which represents the beast's roaming state. It manages the behavior and transitions related to the roaming state, including moving between predefined roaming positions."}
]
# end of examples.# this is the C# class:
3 道多项选择题的系统信息(测试数据)
# you take a C# class from a Unity project with it's documentaion and create 3 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 3 objects, each object has 6 text fields: class name, Question, a,b,c,d,Answer# example output with 2 question/answer
[{"class": "className.cs","Question": "In className.cs what is the purpose of the PlayerManager class?", "a": "To control player movement", "b": "To manage some player behavior functionality", "c": "To handle player combat actions", "d": "To store references to key player components", "Answer": "b"},{"class": "className.cs","Question": "What does the FarthestPlaceFromPlayer() method do in className.cs?", "a": "Finds the farthest destination from the player", "b": "Teleports the player", "c": "Returns a random destination", "d": "Sets the player position", "Answer": "a"}
]
# end of examples.# this is the C# class:
在这些系统消息中LLM,我们首先告诉它我们希望它做什么,然后告诉它规则。响应是有效的 JSON,这使我们以后更容易。我还使用了小样本提示技术,给出LLM了一个示例响应,以便其输出更符合我们的需要。
DataGenerator.ipynb 完成所有这些操作,从读取所有 C# 类到生成合成数据,并将其另存为 CSV 文件。我们不会全部介绍,因为它不是本文的主要重点,但这两个代码块基本上是它的核心。
如何调用 Gemini API
genai.configure(api_key=geminiApiKey)
model = genai.GenerativeModel('gemini-pro')def get_raw_text_gemini(file_content,systemMessage):response = model.generate_content(systemMessage+"\n\n"+file_content,generation_config=genai.types.GenerationConfig(max_output_tokens=4000))return(response.text)
将响应转换为 LLM Pandas 数据帧
def make_df(text):data = json.loads(text)df = pd.DataFrame(data)df=df.map(lambda x: x.replace('\\n', '\n'))return dfraw_response=get_raw_text_gemini(file_content,test_systemMessage)df=make_df(raw_response)
我还在训练数据集中添加了一些非合成数据。首先,我写了一些关于 Enlighten 项目的一般信息,然后为每个类添加了以下内容:
问题:编写“ClassName”类
答:整个ClassName.cs代码
最后,训练数据大约有一百万个令牌,我们得到了两个 CSV 文件,TestData.csv 和 TrainData.csv
使用 LoRA 进行微调
编码开始,您可以在 Google Colab 中运行此代码(整个微调笔记本),但首先更改运行时类型并激活 T4 GPU(如果您使用的是 Kaggle,请激活 P100 GPU)。让我们从声明一些变量开始
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "Enlighten_Instruct"test_path='/content/Enlighten-Instruct/Dataset/TestData.csv'
train_path='/content/Enlighten-Instruct/Dataset/TrainData.csv'
然后我们安装一些包,克隆 git 存储库(仅用于数据集),并导入库
%%capture
!git clone 'https://github.com/ali7919/Enlighten-Instruct.git'
!pip install -U bitsandbytes
!pip install transformers==4.36.2
!pip install -U peft
!pip install -U accelerate
!pip install -U trl
!pip install datasets==2.16.0
!pip install sentencepiece
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
from datasets import Dataset
import re
要将最终结果上传到 Hugging Face,我们必须首先登录它,我们将使用“密钥”,首先从左侧工具栏中选择“密钥”选项卡🗝️,然后添加一个名为“HUGGINGFACE_TOKEN”的新密钥和您的拥抱脸令牌的值(您可以通过本指南获取令牌)。最后,检查笔记本访问权限,以便机密在笔记本中可用。
现在我们登录到拥抱的脸。
from google.colab import userdata
secret_hf = userdata.get('HUGGINGFACE_TOKEN')
!huggingface-cli login --token $secret_hf
是时候对模型进行微调了,但首先,我们还有最后一步来准备好我们的数据,每个指令调整都LLM遵循特定的指令/响应格式。其中一种格式是 ChatML,这种格式包含三个部分,系统、用户和助手。
这是 ChatML 格式生成的文本的示例。
<|im_start|>system
Assistant is an intelligent chatbot designed to help users answer their tax related questions.
<|im_end|>
<|im_start|>user
When do I need to file my taxes by?
<|im_end|>
<|im_start|>assistant
In 2023, you will need to file your taxes by April 18th. The date falls after the usual April 15th deadline because April 15th falls on a Saturday in 2023. For more details, see https://www.irs.gov/filing/individuals/when-to-file
<|im_end|>
Mistral 7b instruct 使用更简单的格式。首先,我们有 BOS(序列开始)令牌,即,没有系统消息,然后用户的提示在 [INST] 和 [/INST] 之间,然后是助手的响应,最后,我们有 EOS(序列结束)令牌。
这是以 Mistral 格式生成的文本的示例。
<s>[INST] What is your favourite condiment? [/INST]
Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!</s>
因此,在我们开始微调之前,我们必须更改数据以遵循此格式,我还在说明之前添加了“@Enlighten.”,以便它像系统消息一样运行。稍后,我们还将在提示的开头包含它。
构建数据集:
df = pd.read_csv(train_path)# build training dataset with the right format
df['text'] = '<s>[INST]@Enlighten. ' + df['Q'] +'[/INST]'+ df['A'] + '</s>'# remove columns
df=df.drop(['Q','A','class'],axis=1)# convert to dataset object
dataset = ds.dataset(pa.Table.from_pandas(df).to_batches())
dataset = Dataset(pa.Table.from_pandas(df))p
现在数据已准备就绪,它只有列“text”,并且是一个 Dataset 对象。下一步是加载基本模型,在本例中为 Mistral 7b 指令。
# Load base model
bnb_config = BitsAndBytesConfig(load_in_4bit= True,bnb_4bit_quant_type= "nf4",bnb_4bit_compute_dtype= torch.bfloat16,bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(base_model,load_in_4bit=True,quantization_config=bnb_config,torch_dtype=torch.bfloat16,device_map="auto",trust_remote_code=True,
)model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.bos_token, tokenizer.eos_token
即使使用 LoRA,也无法使用我们拥有的 VRAM 以全精度加载模型并对其执行微调,因此要解决这个问题,我们必须以 4 位精度加载基本模型。
在代码的最后一部分,我们用EOS令牌填充数据查询的其余部分,以便它们都具有相同的长度。
然后,我们准备模型,以便使用LoRA进行参数高效微调。我选择了 r=64 和 alpha=16,更常见的是将 alpha 设置为 r。
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)
接下来,我们设置训练参数,使用这些参数大约需要一个小时来微调模型。
#Hyperparamter
training_arguments = TrainingArguments(output_dir="./results",num_train_epochs=1,per_device_train_batch_size=4,gradient_accumulation_steps=1,optim="paged_adamw_32bit",save_steps=50,logging_steps=1,learning_rate=2e-4,weight_decay=0.001,fp16=False,bf16=False,max_grad_norm=0.3,max_steps=-1,warmup_ratio=0.03,group_by_length=True,lr_scheduler_type="constant",
)
然后,我们用刚刚准备的变量初始化训练器。
# Setting sft parameters
trainer = SFTTrainer(model=model,train_dataset=dataset,peft_config=peft_config,max_seq_length= None,dataset_text_field="text",tokenizer=tokenizer,args=training_arguments,packing= False,
)
最后,是时候开始培训了。
trainer.train()
训练完成后,适配器就是我们最终得到的,它是一个大小约为 350MB 的单个 .safetensor 文件,将“r”设置为低于 64 的值将导致适配器更小。
首先,我们保存适配器。
trainer.model.save_pretrained(new_model)
model.config.use_cache = True
model.eval()
然后我们把适配器推到拥抱脸上,这是我的适配器。
trainer.model.push_to_hub(new_model)
微调完成。
推断模型
微调已经完成,我们的适配器已经应用并加载,以提示模型首先创建一个管道。
logging.set_verbosity(logging.CRITICAL)
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
然后编写一个辅助函数来创建提示,并在开头包含“@Enlighten.”。
def build_prompt(question):prompt=f"<s>[INST]@Enlighten. {question} [/INST]"return prompt
现在提示模型就这么简单。
question = "what is Light Gun?"
prompt = build_prompt(question)
result = pipe(prompt)print(result[0]['generated_text'])
提示:@Enlighten。什么是停电事件?
答:停电事件是断电时触发的全局事件。 当电源关闭时,它由级别管理器触发。 它用于更新 UI 和更改电源开关的状态。它还用于关闭电源开关和打开应急灯。它是在关卡管理器中定义的自定义事件。它是一个不接受任何参数且不返回任何内容的委托。当电源关闭时,它由级别管理器调用。 当电源打开时,它由级别管理器调用。当电源关闭时,它由级别管理器调用。 当电源打开时,它由级别管理器调用。当电源关闭时,它由级别管理器调用。 当电源打开时,它由级别管理器调用。
正如你所看到的,该模型已经在项目代码库上成功微调,现在具有特定于领域的知识。
不要运行下一个代码,它仅供参考,我们为推理编写的代码是在微调所有内容已加载后发生的,但您可以使用此代码在任何适配器上进行推理。
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "codersan/Enlighten_Instruct"
base_model_reload = AutoModelForCausalLM.from_pretrained(base_model,torch_dtype=torch.bfloat16,return_dict=True,low_cpu_mem_usage=True,device_map="auto",trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model_reload, new_model)
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
测试模型
是时候测试微调后的模型了,并将其与基础模型和 GPT-4 进行比较,我们的推理已经设置好了,要测试模型,我们将遍历测试数据集中的每个问题。
df_test=pd.read_csv(test_path)
for index, row in df_test.iterrows():#rest of the code goes here
并执行以下操作:
1.构建提示:系统消息+问题+选项+custonPhrase
testGuide='Answer the following question, at the end of your response write the answer like this: Answer:a or Answer:b or Answer:c or Answer:d \n'
chainOfThoughtActivator='\nfirst think step by step\n'question=testGuide + row['Question'] + '\na)' + row['a'] + '\nb)' + row['b'] + '\nc)' + row['c'] + '\nd)' + row['d'] + chainOfThoughtActivator
prompt = build_prompt(question)
2. 生成响应
result = pipe(prompt)
llmAnswer = result[0]['generated_text']
3. 从响应中提取所选选项
#remove our prompt from the result
promptEnding = "[/INST]"
index = llmAnswer.find(promptEnding)
llmAnswer = llmAnswer[len(promptEnding)+index:]#remove spaces
llmAnswer=llmAnswer.replace(' ','')#find the option in response
index = llmAnswer.find('Answer:')
4. 检查所选选项是否为正确答案
#true answer
truth=row['Answer'] #find and match the option
next_char = llmAnswer[index+len('Answer:'):][0]
if next_char in truth:print('correct')
else:print('wrong')
完整的代码是这样的(我在LLM拒绝回答问题时添加了重试)
df_test=pd.read_csv(test_path)questionCounter=0
correct=0
promptEnding = "[/INST]"# this must be >= 2
fail_limit=10# chain of thought activator, model might run out of output tokens
USE_COT=True#this comes before the question
testGuide='Answer the following question, at the end of your response write the answer like this: Answer:a or Answer:b or Answer:c or Answer:d \n'for index, row in df_test.iterrows():print("#############################")questionCounter = questionCounter + 1#chain of thought activatorif USE_COT:chainOfThoughtActivator='\nfirst think step by step\n'else:chainOfThoughtActivator='\n'#build the promptquestion=testGuide + row['Question'] + '\na)' + row['a'] + '\nb)' + row['b'] + '\nc)' + row['c'] + '\nd)' + row['d'] + chainOfThoughtActivatorprint(question)#true answertruth=row['Answer']#use a loop, if llm stopped before reaching the answer. ask againindex=-1failCounter=0while(index==-1):#build the promptprompt = build_prompt(question)#generate answerresult = pipe(prompt)llmAnswer = result[0]['generated_text']#remove our prompt from itindex = llmAnswer.find(promptEnding)llmAnswer = llmAnswer[len(promptEnding)+index:]print("LLM Answer:")print(llmAnswer)#remove spacesllmAnswer=llmAnswer.replace(' ','')#find the option in responseindex = llmAnswer.find('Answer:')#edge case - llm stoped at the worst timeif(index+len('Answer:')==len(llmAnswer)):index=-1#update question for the next try. remove chain of thoughtquestion=testGuide + row['Question'] + '\na)' + row['a'] + '\nb)' + row['b'] + '\nc)' + row['c'] + '\nd)' + row['d']#Don't get stock on a questionfailCounter=failCounter+1if failCounter==fail_limit:breakif failCounter==fail_limit:continue#find and match the optionnext_char = llmAnswer[index+len('Answer:'):][0]if next_char in truth:correct=correct+1print('correct')else:print('wrong')#update accuracyaccuracy=correct/questionCounterprint(f"Progress: {questionCounter/len(df_test)}")print(f"Accuracy: {accuracy}")
我将在下一节中揭示最终的准确性,但在此之前,我们还将在我们的测试数据集上测试基础 Mistral 7b 指令和 GPT-4,并计算它们的准确性。测试基础 Mitral 很简单,您只需加载基础模型并运行前面的代码,但测试 GPT-4 有点不同。
以下是使用 Openai API 的方式:
from openai import OpenAI
client = OpenAI(api_key=API_KEY)
def generate_response(system_message,prompt):completion = client.chat.completions.create(model="gpt-4",messages=[{"role": "system", "content": system_message},{"role": "user", "content": prompt}],)return completion.choices[0].message.contentgenerate_response("you are an assistant","Hello, how are you?")
其余的基本和以前一样,你可以在这里看到测试GPT-4的笔记本。
结果
我用这个简单的公式计算了 Gpt-4、基础 Mistral 和我们微调的 Mistral 在测试数据集上的精度:
准确性 = 正确答案 / 全部
这些是准确性:
- GPT-4 达到 59% 的准确率
- Base Mistral 7b 指令达到 48% 的准确率
- 我们微调的 Mistral 7b 指令达到了 70% 的准确率
这些问题有 4 个选项,因此只需随机回答它们即可获得 25% 的准确率。基础模型获得超过 25% 的事实意味着有些问题可以用逻辑来回答,但尽管如此,我们新的微调模型在关于 Enlighten 代码库的问题上超过了其基础模型和 GPT-4。
将基本模型与适配器合并
工作已经完成,我们对 Mistral 进行了微调并获得了适配器,但我们可以更进一步,将适配器与基本模型合并并获得我们自己的模型。
我们将在 Kaggle 上的另一个笔记本中执行此操作,因为我无法让它在 Colab 上工作。
首先,我们将基本模型加载到 16 位浮点数中(new_model更改为拥抱脸适配器存储库 ID)。
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "codersan/Enlighten_Instruct"base_model_reload = AutoModelForCausalLM.from_pretrained(base_model,torch_dtype=torch.bfloat16,return_dict=True,low_cpu_mem_usage=True,device_map="auto",trust_remote_code=True,
)
然后我们加载适配器并将其与基本模型合并。
# merge adopter with base model
model = PeftModel.from_pretrained(base_model_reload, new_model)
model = model.merge_and_unload()# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
我们在上一个笔记本中没有执行 merge_and_unload() 的原因是,目前,当以 4 位加载基本模型时,您无法将适配器与基本模型合并,因此这里我们以 16 位加载基本模型。
如果您想知道为什么我们以前没有这样做,那是因为以 16 位加载基本模型很好,但由于 RAM 要求,无法进行微调。
最后,我们将新模型推到了 Hugging Face。
#push the model to hub
hf_name=new_model+'_merged'
model.push_to_hub(hf_name)
tokenizer.push_to_hub(hf_name)
现在我们有了自己的基本模型,它就像基本的Mistral模型一样巨大。这是我的结果。
量化和 GGUF 格式
量化使模型LLM更轻、更快,但缺点是准确性稍差,本文的最后一步是将我们的模型量化为 5 位和 4 位精度。
这是使用 TheBloke 的脚本完成的,用于 ggml 转换和量化。我不会回顾代码,但您可以在此处使用我的笔记本版本。
只需将第一行更改为您自己的 Hugging Face 存储库 ID。
repo_id = "codersan/Enlighten_Instruct_merged"
最后,您将获得两个 .gguf 文件,一个用于 4 位,一个用于 5 位。借助 LM Studio 等开源软件,在此 GGUF 模型上运行推理要容易得多,您可以下载它并使用加载了我们模型的聊天界面。
GitHub - ali7919/Enlighten-Instruct: Fine-tune Mistral-7b on the Enlighten codebase