LLM微调(一)| 单GPU使用QLoRA微调Llama 2.0实战

      最近LLaMA 2在LLaMA1 的基础上做了很多优化,比如上下文从2048扩展到4096,使用了Grouped-Query Attention(GQA)共享多头注意力的key 和value矩阵,具体可以参考:

关于LLaMA 2 的细节,可以参考如下文章:

Meta发布升级大模型LLaMA 2:开源可商用

揭秘最领先的Llama2中文大模型!

使用QLoRA微调LLaMA 2

安装环境

pip install transformers datasets peft accelerate bitsandbytes safetensors

导入库

import os, sysimport torchimport datasetsfrom transformers import (    AutoTokenizer,    AutoModelForCausalLM,    BitsAndBytesConfig,    DataCollatorForLanguageModeling,    DataCollatorForSeq2Seq,    Trainer,    TrainingArguments,    GenerationConfig)from peft import PeftModel, LoraConfig, prepare_model_for_kbit_training, get_peft_model

导入LLaMA 2模型

### config ###model_id = "NousResearch/Llama-2-7b-hf" # optional meta-llama/Llama-2–7b-chat-hfmax_length = 512device_map = "auto"batch_size = 128micro_batch_size = 32gradient_accumulation_steps = batch_size // micro_batch_size# nf4" use a symmetric quantization scheme with 4 bits precisionbnb_config = BitsAndBytesConfig(    load_in_4bit=True,    bnb_4bit_use_double_quant=True,    bnb_4bit_quant_type="nf4",    bnb_4bit_compute_dtype=torch.bfloat16)# load model from huggingfacemodel = AutoModelForCausalLM.from_pretrained(    model_id,    quantization_config=bnb_config,    use_cache=False,    device_map=device_map)# load tokenizer from huggingfacetokenizer = AutoTokenizer.from_pretrained(model_id)tokenizer.pad_token = tokenizer.eos_tokentokenizer.padding_side = "right"

输出模型的可训练参数量

def print_number_of_trainable_model_parameters(model):    trainable_model_params = 0    all_model_params = 0    for _, param in model.named_parameters():        all_model_params += param.numel()        if param.requires_grad:            trainable_model_params += param.numel()    print(f"trainable model parameters: {trainable_model_params}. All model parameters: {all_model_params} ")    return trainable_model_paramsori_p = print_number_of_trainable_model_parameters(model)# 输出# trainable model parameter: 262,410,240

配置LoRA参数

# LoRA configmodel = prepare_model_for_kbit_training(model)peft_config = LoraConfig(    r=8,    lora_alpha=32,    lora_dropout=0.1,    target_modules=["q_proj", "v_proj"],    bias="none",    task_type="CAUSAL_LM",)model = get_peft_model(model, peft_config)### compare trainable parameters #peft_p = print_number_of_trainable_model_parameters(model)print(f"# Trainable Parameter \nBefore: {ori_p} \nAfter: {peft_p} \nPercentage: {round(peft_p / ori_p * 100, 2)}")# 输出# trainable model parameter: 4,194,304

r:更新矩阵的秩,也称为Lora注意力维度。较低的秩导致具有较少可训练参数的较小更新矩阵。增加r(不超过32)将导致更健壮的模型,但同时会导致更高的内存消耗。

lora_lpha:控制lora比例因子

target_modules:是一个模块名称列表,如“q_proj”和“v_proj“,用作LoRA模型的目标。具体的模块名称可能因基础模型而异。

bias:指定是否应训练bias参数。可选参数为:“none”、“all”或“lora_only”。

输出LoRA Adapter的参数,发现只占原模型的不到2%。

在微调LLaMA 2之前,我们看一下LLaMA 2的生成效果

### generate ###prompt = "Write me a poem about Singapore."inputs = tokenizer(prompt, return_tensors="pt")generate_ids = model.generate(inputs.input_ids, max_length=64)print('\nAnswer: ', tokenizer.decode(generate_ids[0]))res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]print(res)

       当要求模型写一首关于新加坡的诗时,产生的输出似乎相当模糊和重复,这表明模型很难提供连贯和有意义的回应。

微调数据加载

为了方便演示,我们使用开源的databricks/databricks-dolly-15k,数据格式如下:

{    'instruction': 'Why can camels survive for long without water?',    'context': '',    'response': 'Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.',    'category': 'open_qa',}

       要揭秘LLM能力,构建Prompt是至关重要,通常的Prompt形式有三个字段:Instruction、Input(optional)、Response。由于Input是可选的,因为这里设置了两种prompt_template,分别是有Input 的prompt_input和无Input 的prompt_no_input,代码如下:

max_length = 256dataset = datasets.load_dataset(    "databricks/databricks-dolly-15k", split='train')### generate prompt based on template ###prompt_template = {    "prompt_input": \    "Below is an instruction that describes a task, paired with an input that provides further context.\    Write a response that appropriately completes the request.\    \n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",    "prompt_no_input": \    "Below is an instruction that describes a task.\    Write a response that appropriately completes the request.\    \n\n### Instruction:\n{instruction}\n\n### Response:\n",    "response_split": "### Response:"}def generate_prompt(instruction, input=None, label=None, prompt_template=prompt_template):    if input:        res = prompt_template["prompt_input"].format(            instruction=instruction, input=input)    else:        res = prompt_template["prompt_no_input"].format(            instruction=instruction)    if label:        res = f"{res}{label}"    return res

      使用generate_prompt函数把instruction, context和response拼接起来;然后进行tokenize分词处理,转换为input_ids和attention_mask,为了让模型可以预测下一个token,设计了类似input_ids的labels便于右移操作;

def tokenize(tokenizer, prompt, max_length=max_length, add_eos_token=False):    result = tokenizer(        prompt,        truncation=True,        max_length=max_length,        padding=False,        return_tensors=None)    result["labels"] = result["input_ids"].copy()    return resultdef generate_and_tokenize_prompt(data_point):    full_prompt = generate_prompt(        data_point["instruction"],        data_point["context"],        data_point["response"],    )    tokenized_full_prompt = tokenize(tokenizer, full_prompt)    user_prompt = generate_prompt(data_point["instruction"], data_point["context"])    tokenized_user_prompt = tokenize(tokenizer, user_prompt)    user_prompt_len = len(tokenized_user_prompt["input_ids"])    mask_token = [-100] * user_prompt_len    tokenized_full_prompt["labels"] = mask_token + tokenized_full_prompt["labels"][user_prompt_len:]    return tokenized_full_promptdataset = dataset.train_test_split(test_size=1000, shuffle=True, seed=42)cols = ["instruction", "context", "response", "category"]train_data = dataset["train"].shuffle().map(generate_and_tokenize_prompt, remove_columns=cols)val_data = dataset["test"].shuffle().map(generate_and_tokenize_prompt, remove_columns=cols,)

模型训练

args = TrainingArguments(    output_dir="./llama-7b-int4-dolly",    num_train_epochs=20,    max_steps=200,    fp16=True,    optim="paged_adamw_32bit",    learning_rate=2e-4,    lr_scheduler_type="constant",    per_device_train_batch_size=micro_batch_size,    gradient_accumulation_steps=gradient_accumulation_steps,    gradient_checkpointing=True,    group_by_length=False,    logging_steps=10,    save_strategy="epoch",    save_total_limit=3,    disable_tqdm=False,)trainer = Trainer(    model=model,    train_dataset=train_data,    eval_dataset=val_data,    args=args,    data_collator=DataCollatorForSeq2Seq(      tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True),)# silence the warnings. re-enable for inference!model.config.use_cache = Falsetrainer.train()model.save_pretrained("llama-7b-int4-dolly")

模型测试

       模型训练几个小时结束后,我们合并预训练模型Llama-2–7b-hf和LoRA参数,我们还是以“Write me a poem about Singapore”测试效果,代码如下:

# model path and weightmodel_id = "NousResearch/Llama-2-7b-hf"peft_path = "./llama-7b-int4-dolly"# loading modelmodel = AutoModelForCausalLM.from_pretrained(    model_id,    quantization_config=bnb_config,    use_cache=False,    device_map="auto")# loading peft weightmodel = PeftModel.from_pretrained(    model,    peft_path,    torch_dtype=torch.float16,)model.eval()# generation configgeneration_config = GenerationConfig(    temperature=0.1,    top_p=0.75,    top_k=40,    num_beams=4, # beam search)# generating replywith torch.no_grad():    prompt = "Write me a poem about Singapore."    inputs = tokenizer(prompt, return_tensors="pt")    generation_output = model.generate(        input_ids=inputs.input_ids,        generation_config=generation_config,        return_dict_in_generate=True,        output_scores=True,        max_new_tokens=64,    )    print('\nAnswer: ', tokenizer.decode(generation_output.sequences[0]))

生成模型中参数temperaturetop-ktop-pnum_beam含义可以参考:https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%E4%B9%8BGenerate%E4%B8%AD%E5%8F%82%E6%95%B0%E8%A7%A3%E8%AF%BB.ipynb

参考文献:

[1] https://ai.plainenglish.io/fine-tuning-llama2-0-with-qloras-single-gpu-magic-1b6a6679d436

[2] https://github.com/ChanCheeKean/DataScience/blob/main/13%20-%20NLP/E04%20-%20Parameter%20Efficient%20Fine%20Tuning%20(PEFT).ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83470.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

zotero通过DOI快速导入文献

之前我经常采用两种方式导入文献: (1)下载PDF,然后拖入zotero 这种方法比较费时间,有些文献无法下载pdf (2)通过google scholar检索文献,然后点击引用——EndNote,chorme…

Kotlin中函数的基本用法以及函数类型

函数的基本用法 1、函数的基本格式 2、函数的缺省值 可以为函数设置指定的初始值&#xff0c;而不必要传入值 private fun fix(name: String,age: Int 2){println(name age) }fun main(args: Array<String>) {fix("张三") }输出结果为&#xff1a;张三2 …

WebGL层次模型——多节点模型

目录 多节点模型 MultiJointModel中的层次结构 控制各部件旋转角度的变量 示例程序——共用顶点数据&#xff0c;通过模型矩阵缩放实现&#xff08;MultiJointModel.js&#xff09; MultiJointModel.js&#xff08;按键响应部分&#xff09; MultiJointModel.js&#x…

刷题日记——将x减到0的最小操作数

将x减到0的最小操作数 题目链接&#xff1a;https://leetcode.cn/problems/minimum-operations-to-reduce-x-to-zero/ 题目解读 题目要求移除元素总和等于参数x&#xff0c;这道题给我的第一感觉就是从数组的两边入手&#xff0c;对数据进行加和删除&#xff0c;但是这里有一…

滚雪球学Java(24):Java反射

&#x1f3c6;本文收录于「滚雪球学Java」专栏&#xff0c;专业攻坚指数级提升&#xff0c;助你一臂之力&#xff0c;带你早日登顶&#x1f680;&#xff0c;欢迎大家关注&&收藏&#xff01;持续更新中&#xff0c;up&#xff01;up&#xff01;up&#xff01;&#xf…

EasySwipeMenuLayout - 独立的侧滑删除

官网 GitHub - anzaizai/EasySwipeMenuLayout: A sliding menu library not just for recyclerview, but all views. 项目介绍 A sliding menu library not just for recyclerview, but all views. Recommended in conjunction with BaseRecyclerViewAdapterHelper Feature…

TS泛型的使用

函数中使用泛型&#xff1a; function identity<T>(arg: T): T {return arg; }let result identity<number>(10); // 传入number类型&#xff0c;返回number类型 console.log(result); // 输出: 10let value identity<string>(Hello); // 传入string类型&a…

ad18学习笔记十二:如何把同属性的元器件全部高亮?

1、先选择需要修改的器件的其中一个。 2、右键find similar objects&#xff0c;然后在弹出的对话框中&#xff0c;将要修改的属性后的any改为same 3、像这样勾选的话&#xff0c;能把同属性的元器件选中&#xff0c;其他器件颜色不变 注意了&#xff0c;如果这个时候&#xff…

初学phar反序列化

以下内容参考大佬博客&#xff1a;PHP Phar反序列化浅学习 - 跳跳糖 首先了解phar是什么东东 Phar是PHP的压缩文档&#xff0c;是PHP中类似于JAR的一种打包文件。它可以把多个文件存放至同一个文件中&#xff0c;无需解压&#xff0c;PHP就可以进行访问并执行内部语句。 默认开…

VuePress网站如何使用axios请求第三方接口

前言 VuePress是一个纯静态网站生成器,也就是它是无后端,纯前端的,那想要在VuePress中,发送ajax请求,请求一些第三方接口,有时想要达到自己一些目的 在VuePress中&#xff0c;使用axios请求第三方接口&#xff0c;需要先安装axios&#xff0c;然后引入&#xff0c;最后使用 本文…

爬虫框架Scrapy学习笔记-2

前言 Scrapy是一个功能强大的Python爬虫框架&#xff0c;它被广泛用于抓取和处理互联网上的数据。本文将介绍Scrapy框架的架构概览、工作流程、安装步骤以及一个示例爬虫的详细说明&#xff0c;旨在帮助初学者了解如何使用Scrapy来构建和运行自己的网络爬虫。 爬虫框架Scrapy学…

【Linux学习笔记】权限

1. 普通用户和root用户权限之间的切换2. 权限的三个w2.1. 什么是权限&#xff08;what&#xff09;2.1.1. 用户角色2.1.2. 文件属性 2.2. 怎么操作权限呢&#xff1f;&#xff08;how&#xff09;2.2.1. ugo-rwx方案2.2.2. 八进制方案2.2.3. 文件权限的初始模样2.2.4. 进入一个…

并发编程——synchronized

文章目录 原子性、有序性、可见性原子性有序性可见性 synchronized使用synchronized锁升级synchronized-ObjectMonitor 原子性、有序性、可见性 原子性 数据库事务的原子性&#xff1a;是一个最小的执行的单位&#xff0c;一次事务的多次操作要么都成功&#xff0c;要么都失败…

蓝桥杯 题库 简单 每日十题 day6

01 删除字符 题目描述 给定一个单词&#xff0c;请问在单词中删除t个字母后&#xff0c;能得到的字典序最小的单词是什么&#xff1f; 输入描述 输入的第一行包含一个单词&#xff0c;由大写英文字母组成。 第二行包含一个正整数t。 其中&#xff0c;单词长度不超过100&#x…

记录selenium和chrome使用socks代理打开网页以及查看selenium的版本

使用前&#xff0c;首先打开socks5全局代理。 之前我还写过一篇关于编程中使用到代理的情况&#xff1a; 记录一下python编程中需要使用代理的解决方法_python 使用全局代理_小小爬虾的博客-CSDN博客 在本文中&#xff0c;首先安装selenium和安装chrome浏览器。 参考我的文章…

用VS Code运行C语言(安装VS Code,mingw的下载和安装)

下载并安装VS code。 安装扩展包&#xff1a; 此时&#xff0c;写完代码右键之后并没有运行代码的选项&#xff0c;如图&#xff1a; 接下来安装编译器mingw。 下载链接&#xff1a; https://sourceforge.net/projects/mingw-w64/ 得到压缩包&#xff1a; 解压&#xff1a; …

滚雪球学Java(26):Java进制转换

&#x1f3c6;本文收录于「滚雪球学Java」专栏&#xff0c;专业攻坚指数级提升&#xff0c;助你一臂之力&#xff0c;带你早日登顶&#x1f680;&#xff0c;欢迎大家关注&&收藏&#xff01;持续更新中&#xff0c;up&#xff01;up&#xff01;up&#xff01;&#xf…

由于数字化转型对集成和扩展性的要求,定制化需求难以满足,百数低代码服务商该如何破局?

当政策、技术环境的日益成熟&#xff0c;数字化转型逐步成为企业发展的必选项&#xff0c;企业数字化转型不再是一道选择题&#xff0c;而是决定其生存发展的必由之路。通过数字化转型升级生产方式、管理模式和组织形式&#xff0c;激发内生动力&#xff0c;成为企业顺应时代变…

最新适合小白前端 Javascript 高级常见知识点详细教程(每周更新中)

1. window.onload 窗口或者页面的加载事件&#xff0c;当文档内容完全加载完成会触发的事件&#xff08;包括图形&#xff0c;JS脚本&#xff0c;CSS文件&#xff09;&#xff0c;就会调用处理的函数。 <button>点击</button> <script> btn document.q…

python项目2to3方案预研

目录 官方工具2to3工具安装参数解释基本使用工具缺陷 future工具安装参数解释基本使用工具缺陷 python-modernize工具安装参数解释基本使用工具缺陷 pyupgrade工具安装参数解释基本使用工具缺陷 对比 官方工具2to3 2to3 是Python官方提供的用于将Python 2代码转换为Python 3代…