扩展说明: 指令微调 Llama 2

这篇博客是一篇来自 Meta AI,关于指令微调 Llama 2 的扩展说明。旨在聚焦构建指令数据集,有了它,我们则可以使用自己的指令来微调 Llama 2 基础模型。

目标是构建一个能够基于输入内容来生成指令的模型。这么做背后的逻辑是,模型如此就可以由其他人生成自己的指令数据集。这在当想开发私人个性化定制模型,如发送推特、写邮件等,时很方便。这也意味着你可以通过你的邮件来生成一个指令数据集,然后用它来训练一个模型来为你写邮件。

好,那我们来开始吧?我们将进行:

  1. 定义应用场景细节并创建指令的提示词模板

  2. 构建指令数据集

  3. 使用 trlSFTTrainer 指令微调 Llama 2

  4. 测试模型、进行推理

1. 定义应用场景细节并创建指令的提示词模板

在描述应用场景前,我们要更好的理解一下究竟什么是指令。

指令是一段文本或提供给大语言模型,类似 Llama,GPT-4 或 Claude,使用的提示词,用来指导它去生成回复。指令可以让人们做到把控对话,约束模型输出更自然、实用的输出,并使这些结果能够对齐用户的目的。制作清晰的、整洁的指令则是生成高质量对话的关键。

指令的例子如下表所示。

能力示例指令
头脑风暴提供一系列新口味的冰淇淋的创意。
分类根据剧情概要,将这些电影归类为喜剧、戏剧或恐怖片。
确定性问答用一个单词回答“法国的首都是哪里?”
生成用罗伯特·弗罗斯特的风格写一首关于大自然和季节变化的诗。
信息提取从这篇短文中提取主要人物的名字。
开放性问答为什么树叶在秋天会变色?用科学的理由解释一下。
摘要用 2-3 句话概括一下这篇关于可再生能源最新进展的文章。

如开头所述,我们想要微调模型,以便根据输入 (或输出) 生成指令。我们希望将其用作创建合成数据集的方法,以赋予 LLM 和代理个性化能力。

把这个想法转换成一个基础的提示模板,按照 Alpaca 格式.

### Instruction:
Use the Input below to create an instruction, which could have been used to generate the input using an LLM. ### Input:
Dear [boss name],I'm writing to request next week, August 1st through August 4th,
off as paid time off.I have some personal matters to attend to that week that require 
me to be out of the office. I wanted to give you as much advance 
notice as possible so you can plan accordingly while I am away.Please let me know if you need any additional information from me 
or have any concerns with me taking next week off. I appreciate you 
considering this request.Thank you, [Your name]### Response:
Write an email to my boss that I need next week 08/01 - 08/04 off.

2. 创建指令数据集

在定义了我们的应用场景和提示模板后,我们需要创建自己的指令数据集。创建高质量的指令数据集是获得良好模型性能的关键。研究表明,“对齐,越少越好” 表明,创建高质量、低数量 (大约 1000 个样本) 的数据集可以达到与低质量、高数量的数据集相同的性能。

创建指令数据集有几种方法,包括:

  1. 使用现有数据集并将其转换为指令数据集,例如 FLAN

  2. 使用现有的 LLM 创建合成指令数据集,例如 Alpaca

  3. 人力创建指令数据集,例如 Dolly。

每种方法都有其优缺点,这取决于预算、时间和质量要求。例如,使用现有数据集是最简单的,但可能不适合您的特定用例,而使用人力可能是最准确的,但必然耗时、昂贵。也可以结合几种不同方法来创建指令数据集,如 Orca: Progressive Learning from Complex Explanation Traces of GPT-4.。

为了简单起见,我们将使用 **Dolly**,这是一个开源的指令跟踪记录数据集,由数千名 Databricks 员工在 InstructGPT paper 中描述的几个行为类别中生成,包括头脑风暴、分类、确定性回答、生成、信息提取、开放性回答和摘要。

开始编程吧,首先,我们来安装依赖项。

!pip install "transformers==4.31.0" "datasets==2.13.0" "peft==0.4.0" "accelerate==0.21.0" "bitsandbytes==0.40.2" "trl==0.4.7" "safetensors>=0.3.1" --upgrade

我们使用 🤗 Datasets library 的 load_dataset() 方法加载 databricks/databricks-dolly-15k 数据集。

from datasets import load_dataset
from random import randrange# 从hub加载数据集
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# dataset size: 15011

为了指导我们的模型,我们需要将我们的结构化示例转换为通过指令描述的任务集合。我们定义一个 formatting_function ,它接受一个样本并返回一个符合格式指令的字符串。

def format_instruction(sample):return f"""### Instruction:
Use the Input below to create an instruction, which could have been used to generate the input using an LLM. ### Input:
{sample['response']}### Response:
{sample['instruction']}
"""

我们来在一个随机的例子上测试一下我们的结构化函数。

from random import randrangeprint(format_instruction(dataset[randrange(len(dataset))]))

3. 使用 trlSFTTrainer 指令微调 Llama 2

我们将使用最近在由 Tim Dettmers 等人的发表的论文“QLoRA: Quantization-aware Low-Rank Adapter Tuning for Language Generation”中介绍的方法。QLoRA 是一种新的技术,用于在微调期间减少大型语言模型的内存占用,且并不会降低性能。QLoRA 的 TL;DR; 是这样工作的:

  • 将预训练模型量化为 4bit 位并冻结它。

  • 附加轻量化的、可训练的适配器层。(LoRA)

  • 在使用冻结的量化模型基于文本内容进行微调时,仅微调适配器层参数。

如果您想了解有关 QLoRA 及其工作原理的更多信息,我建议您阅读 Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA 博客文章。

Flash Attention (快速注意力)

Flash Attention 是一种经过重新排序的注意力计算方法,它利用经典技术 (排列、重计算) 来显著加快速度,将序列长度的内存使用量从二次降低到线性。它基于论文“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”。

TL;DR; 将训练加速了 3 倍。在这儿获得更多信息 FlashAttention。Flash Attention 目前仅支持 Ampere (A10, A40, A100, …) & Hopper (H100, …) GPU。你可以检查一下你的 GPU 是否支持,并用下面的命令来安装它:

注意: 如果您的机器的内存小于 96GB,而 CPU 核心数足够多,请减少 MAX_JOBS 的数量。在我们使用的 g5.2xlarge 上,我们使用了 4

python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
pip install ninja packaging
MAX_JOBS=4 pip install flash-attn --no-build-isolation

_安装 flash attention 是会需要一些时间 (10-45 分钟)_。

该示例支持对所有 Llama 检查点使用 Flash Attention,但默认是未启用的。要开启 Flash Attention,请取消代码块中这段的注释, # COMMENT IN TO USE FLASH ATTENTION

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfiguse_flash_attention = False# COMMENT IN TO USE FLASH ATTENTION
# replace attention with flash attention 
# if torch.cuda.get_device_capability()[0] >= 8:
#     from utils.llama_patch import replace_attn_with_flash_attn
#     print("Using flash attention")
#     replace_attn_with_flash_attn()
#     use_flash_attention = True# Hugging Face 模型id
model_id = "NousResearch/Llama-2-7b-hf" # non-gated
# model_id = "meta-llama/Llama-2-7b-hf" # gated# BitsAndBytesConfig int-4 config 
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16
)# 加载模型与分词器
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto")
model.config.pretraining_tp = 1 # 通过对比doc中的字符串,验证模型是在使用flash attention
if use_flash_attention:from utils.llama_patch import forward    assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

SFTTrainer 支持与 peft 的本地集成,这使得高效地指令微调LLM变得非常容易。我们只需要创建 LoRAConfig 并将其提供给训练器。

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model# 基于 QLoRA 论文来配置 LoRA
peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM", 
)# 为训练准备好模型
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

在开始训练之前,我们需要定义自己想要的超参数 (TrainingArguments)。

from transformers import TrainingArgumentsargs = TrainingArguments(output_dir="llama-7-int4-dolly",num_train_epochs=3,per_device_train_batch_size=6 if use_flash_attention else 4,gradient_accumulation_steps=2,gradient_checkpointing=True,optim="paged_adamw_32bit",logging_steps=10,save_strategy="epoch",learning_rate=2e-4,bf16=True,tf32=True,max_grad_norm=0.3,warmup_ratio=0.03,lr_scheduler_type="constant",disable_tqdm=True # 当配置的参数都正确后可以关闭tqdm
)

我们现在有了用来训练模型 SFTTrainer 所需要准备的每一个模块。

from trl import SFTTrainermax_seq_length = 2048 # 数据集的最大长度序列trainer = SFTTrainer(model=model,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,formatting_func=format_instruction, args=args,
)

通过调用 Trainer 实例上的 train() 方法来训练我们的模型。

# 训练
trainer.train() # tqdm关闭后将不显示进度条信息# 保存模型
trainer.save_model()

不使用 Flash Attention 的训练过程在 g5.2xlarge 上花费了 03:08:00。实例的成本为 1,212$/h ,总成本为 3.7$

使用 Flash Attention 的训练过程在 g5.2xlarge 上花费了 02:08:00。实例的成本为 1,212$/h ,总成本为 2.6$

使用 Flash Attention 的结果令人满意,速度提高了 1.5 倍,成本降低了 30%。

4. 测试模型、进行推理

在训练完成后,我们想要运行和测试模型。我们会使用 pefttransformers 将 LoRA 适配器加载到模型中。

if use_flash_attention:# 停止 flash attentionfrom utils.llama_patch import unplace_flash_attn_with_attnunplace_flash_attn_with_attn()import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizerargs.output_dir = "llama-7-int4-dolly"# 加载基础LLM模型与分词器
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,low_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,
) 
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)

我们来再次用随机样本加载一次数据集,试着来生成一条指令。

from datasets import load_dataset 
from random import randrange# 从hub加载数据集并得到一个样本
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
sample = dataset[randrange(len(dataset))]prompt = f"""### Instruction:
Use the Input below to create an instruction, which could have been used to generate the input using an LLM. ### Input:
{sample['response']}### Response:
"""input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
# with torch.inference_mode():
outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.9)print(f"Prompt:\n{sample['response']}\n")
print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
print(f"Ground truth:\n{sample['instruction']}")

太好了!我们的模型可以工作了!如果想要加速我们的模型,我们可以使用 Text Generation Inference 部署它。因此我们需要将我们适配器的参数合并到基础模型中去。

from peft import AutoPeftModelForCausalLMmodel = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,low_cpu_mem_usage=True,
) # 合并 LoRA 与 base model
merged_model = model.merge_and_unload()# 保存合并后的模型
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")# push合并的模型到hub上
# merged_model.push_to_hub("user/repo")
# tokenizer.push_to_hub("user/repo")

原文作者: Philschmid

原文链接: https://www.philschmid.de/instruction-tune-llama-2

译者: Xu Haoran

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/675314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于OpenCV灰度图像转GCode的斜向扫描实现

基于OpenCV灰度图像转GCode的斜向扫描实现基于OpenCV灰度图像转GCode的斜向扫描实现 引言激光雕刻简介OpenCV简介实现步骤 1.导入必要的库2. 读取灰度图像3. 图像预处理4. 生成GCode5. 保存生成的GCode6. 灰度图像斜向扫描代码示例 总结 系列文章 ⭐深入理解G0和G1指令&…

Python 深入理解 os 和 sys 模块

Python 深入理解 os 和 sys 模块 OS 介绍代码智能连接(拼接)路径创建目录展示(列出目录)删除文件重命名文件或目录 sys 介绍代码命令行参数处理 (sys.argv)标准输入输出重定向 (sys.stdin, sys.stdout, sys.stderr):解…

数据结构 - 线索树

一、 为什么要用到线索二叉树? 我们先来看看普通的二叉树有什么缺点。下面是一个普通二叉树(链式存储方式): 乍一看,会不会有一种违和感?整个结构一共有 7 个结点,总共 14 个指针域&#xff0c…

WordPress函数wptexturize的介绍及用法示例,字符串替换为HTML实体

在查看WordPress你好多莉插件时发现代码中使用了wptexturize()函数用来随机输出一句歌词,下面boke112百科就跟大家一起来学习一下WordPress函数wptexturize的介绍及用法示例。 WordPress函数wptexturize介绍 wptexturize( string $text, bool $reset false ): st…

Ubuntu搭建计算集群

计算机硬件和技术的发展使得高性能模拟和计算在生活和工作中的作用逐渐显现出来,无论是计算化学,计算物理和当下的人工智能都离不开高性能计算。笔者工作主要围绕计算化学和物理开展,亦受限于自身知识和技术所限,文中只是浅显地尝…

HarmonyOS class类对象基础使用

按我们之前的写法 就是 Entry Component struct Dom {p:Object {name: "小猫猫",age: 21,gf: {name: "小小猫猫",age: 18,}}build() {Row() {Column() {// ts-ignoreText(this.p.gf.name)}.width(100%)}.height(100%)} }直接用 Object 一层一层往里套 这…

C++进阶(十三)异常

📘北尘_:个人主页 🌎个人专栏:《Linux操作系统》《经典算法试题 》《C》 《数据结构与算法》 ☀️走在路上,不忘来时的初心 文章目录 一、C语言传统的处理错误的方式二、C异常概念三、异常的使用1、异常的抛出和捕获2、异常的重新…

java之spring AOP

AOP 面向切面编程&#xff0c; 切入点&#xff0c;就是你写的函数&#xff0c;装饰器&#xff0c;装饰到那些函数上 在哪里生效 引入依赖<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifact…

网络学习:数据链路层VLAN原理和配置

一、简介&#xff1a; VLAN又称为虚拟局域网&#xff0c;它是用来将使用路由器的网络分割成多个虚拟局域网&#xff0c;起到隔离广播域的作用&#xff0c;一个VLAN通常对应一个IP网段&#xff0c;不同VLAN通常规划到不同IP网段。划分VLAN可以提高网络的通讯质量和安全性。 二、…

MySQL进阶查询篇(5)-事务的隔离级别与应用

数据库事务(Transaction)是指作为一个单元执行的一系列操作&#xff0c;要么全部成功完成&#xff0c;要么全部失败回滚。数据库事务具有四个特性&#xff0c;即原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)和持久性(Durability)。本文将重点介绍MySQL数据库中的…

2024年华为OD机试真题-螺旋数字矩阵-Java-OD统一考试(C卷)

题目描述: 疫情期间,小明隔离在家,百无聊赖,在纸上写数字玩。他发明了一种写法: 给出数字个数n和行数m(0 < n ≤ 999,0 < m ≤ 999),从左上角的1开始,按照顺时针螺旋向内写方式,依次写出2,3...n,最终形成一个m行矩阵。 小明对这个矩阵有些要求: 1.每行数字的…

跟着小德学C++之TCP基础

嗨&#xff0c;大家好&#xff0c;我是出生在达纳苏斯的一名德鲁伊&#xff0c;我是要立志成为海贼王&#xff0c;啊不&#xff0c;是立志成为科学家的德鲁伊。最近&#xff0c;我发现我们所处的世界是一个虚拟的世界&#xff0c;并由此开始&#xff0c;我展开了对我们这个世界…

红队打靶练习:GLASGOW SMILE: 1.1

目录 信息收集 1、arp 2、nmap 3、nikto 4、whatweb 目录探测 1、gobuster 2、dirsearch WEB web信息收集 /how_to.txt /joomla CMS利用 1、爆破后台 2、登录 3、反弹shell 提权 系统信息收集 rob用户登录 abner用户 penguin用户 get root flag 信息收集…

flutter 国内源

Flutter 在中国由于网络原因&#xff0c;从官方默认的国外源下载Dart包和Flutter SDK可能会比较慢或者不稳定。为了加速依赖包的获取与Flutter SDK的安装&#xff0c;可以使用国内镜像源。以下是一些国内常用的Flutter和Dart包镜像源&#xff1a; 清华大学开源软件镜像站 Flu…

计算机网络(第六版)复习提纲29

第六章&#xff1a;应用层 SS6.1 域名系统DNS 1 DNS被设计为一个联机分布式数据库系统&#xff0c;并采用客户服务器方式&#xff08;C/S&#xff09; 2 域名的体系结构 3 域名服务器及其体系结构 A 域名服务器的分类 1 根域名服务器 2 顶级域名服务器&#xff08;TLD服务器&a…

Gitlab和Jenkins集成 实现CI (一)

版本声明 部署时通过docker拉取的最新版本 gitlab: 16.8 jenkins: 2.426.3 安装环境 可参考这篇文章 停止防火墙 由于在内网&#xff0c;这里防火墙彻底关掉&#xff0c;如果再外网或者云上的悠着点 systemctl stop firewalled systemctl disable firewalledsystemctl sto…

K8S之运用亲和性设置Pod的调度约束

亲和性 Node节点亲和性硬亲和实践软亲和性实践 Pod节点亲和性和反亲和性pod亲和性硬亲和实践 pod反亲和性 Pod 的yaml文件里 spec 字段中包含一个 affinity 字段&#xff0c;使用一组亲和性调度规则&#xff0c;指定pod的调度约束。 kubectl explain pods.spec.affinity 配置…

【代码】Processing笔触手写板笔刷代码合集

代码来源于openprocessing&#xff0c;考虑到国内不是很好访问&#xff0c;我把我找到的比较好的搬运过来&#xff01; 合集 参考&#xff1a;https://openprocessing.org/sketch/793375 https://github.com/SourceOf0-HTML/processing-p5.js/tree/master 这个可以体验6种笔触…

ubuntu22.04安装部署03: 设置root密码

一、前言 ubuntu22.04 安装完成以后&#xff0c;默认root用户是没有设置密码的&#xff0c;需要手动设置。具体的设置过程如下文内容所示&#xff1a; 相关文件&#xff1a; 《ubuntu22.04装部署01&#xff1a;禁用内核更新》 《ubuntu22.04装部署02&#xff1a;禁用显卡更…

简单聊聊go语言中引用模块的版本控制以及invalid: should be v0 or v1, not v2问题的解决

文章目录 前言具体示例手动升级依赖库should be v0 or v1, not v2总结 前言 如果你接触go语言比较早&#xff0c;一定有过当年所有go源码全部放入 GOPATH 下的混乱经历&#xff0c;不过发展到今天&#xff0c;go的包管理使用 go.mod 和 go.work 已经能得心应手&#xff0c;满足…