LLM微调(二)| 微调LLAMA-2和其他开源LLM的两种简单方法

本文将介绍两种开源工具来微调LLAMA-2。

一、使用autotrain-advanced微调LLAMA-2

        AutoTrain是一种无代码工具,用于为自然语言处理(NLP)任务、计算机视觉(CV)任务、语音任务甚至表格任务训练最先进的模型。

1) 安装相关库,使用huggingface_hub下载微调数据

!pip install autotrain-advanced!pip install huggingface_hub

2) 更新autotrain-advanced所需要的包

# update torch!autotrain setup --update-torch

3) 登录Huggingface

# Login to huggingfacefrom huggingface_hub import notebook_loginnotebook_login()

4) 开始微调LLAMA-2

! autotrain llm \--train \--model {MODEL_NAME} \--project-name {PROJECT_NAME} \--data-path data/ \--text-column text \--lr {LEARNING_RATE} \--batch-size {BATCH_SIZE} \--epochs {NUM_EPOCHS} \--block-size {BLOCK_SIZE} \--warmup-ratio {WARMUP_RATIO} \--lora-r {LORA_R} \--lora-alpha {LORA_ALPHA} \--lora-dropout {LORA_DROPOUT} \--weight-decay {WEIGHT_DECAY} \--gradient-accumulation {GRADIENT_ACCUMULATION}

核心参数含义

llm: 微调模型的类型

— project_name: 项目名称

— model: 需要微调的基础模型

— data_path: 指定微调所需要的数据,可以使用huggingface上的数据集

— text_column: 如果数据是表格,需要指定instructions和responses对应的列名

— use_peft: 指定peft某一种方法

— use_int4: 指定int 4量化

— learning_rate: 学习率

— train_batch_size: 训练批次大小

— num_train_epochs: 训练轮数大小

— trainer: 指定训练的方式

— model_max_length: 设置模型最大上下文窗口

— push_to_hub(可选): 微调好的模型是否需要存储到Hugging Face? 

— repo_id: 如果要存储微调好的模型到Hugging Face,需要指定repository ID

— block_size: 设置文本块大小

下面看一个具体的示例:

!autotrain llm--train--project_name "llama2-autotrain-openassitant"--model TinyPixel/Llama-2-7B-bf16-sharded--data_path timdettmers/openassistant-guanaco--text_column text--use_peft--use_int4--learning_rate 0.4--train_batch_size 3--num_train_epochs 2--trainer sft--model_max_length 1048--push_to_hub--repo_id trojrobert/llama2-autotrain-openassistant--block_size 1048 > training.log

二、使用TRL微调LLAMA-2

       TRL是一个全栈库,提供了通过强化学习来训练transformer语言模型一系列工具,包括从监督微调步骤(SFT)、奖励建模步骤(RM)到近端策略优化(PPO)步骤。

1)安装相关的库

!pip install -q -U trl peft transformers  datasets bitsandbytes wandb

2)从Huggingface导入数据集

from datasets import load_datasetdataset_name = "timdettmers/openassistant-guanaco"dataset = load_dataset(dataset_name, split="train")

3)量化配置,从Huggingface下载模型

import torchfrom transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig# quantizition configurationbnb_config = BitsAndBytesConfig(    load_in_4bit=True,    bnb_4bit_quant_type="nf4",    bnb_4bit_compute_dtype=torch.float16,)# download modelmodel_name = "TinyPixel/Llama-2-7B-bf16-sharded"model = AutoModelForCausalLM.from_pretrained(    model_name,    quantization_config=bnb_config,    trust_remote_code=True)model.config.use_cache = False

4)下载Tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)tokenizer.pad_token = tokenizer.eos_token

5)创建PEFT配置

from peft import LoraConfig, get_peft_modellora_alpha = 16lora_dropout = 0.1lora_r = 64peft_config = LoraConfig(    lora_alpha=lora_alpha,    lora_dropout=lora_dropout,    r=lora_r,    bias="none",    task_type="CAUSAL_LM")

6)创建微调和训练配置

from transformers import TrainingArgumentsoutput_dir = "./results"per_device_train_batch_size = 4gradient_accumulation_steps = 4optim = "paged_adamw_32bit"save_steps = 100logging_steps = 10learning_rate = 2e-4max_grad_norm = 0.3max_steps = 100warmup_ratio = 0.03lr_scheduler_type = "constant"training_arguments = TrainingArguments(    output_dir=output_dir,    per_device_train_batch_size=per_device_train_batch_size,    gradient_accumulation_steps=gradient_accumulation_steps,    optim=optim,    save_steps=save_steps,    logging_steps=logging_steps,    learning_rate=learning_rate,    fp16=True,    max_grad_norm=max_grad_norm,    max_steps=max_steps,    warmup_ratio=warmup_ratio,    group_by_length=True,    lr_scheduler_type=lr_scheduler_type,)

7)创建SFTTrainer配置

from trl import SFTTrainermax_seq_length = 512trainer = SFTTrainer(    model=model,    train_dataset=dataset,    peft_config=peft_config,    dataset_text_field="text",    max_seq_length=max_seq_length,    tokenizer=tokenizer,    args=training_arguments,)

8)在微调的时候,对LN层使用float 32训练更稳定

for name, module in trainer.model.named_modules():    if "norm" in name:        module = module.to(torch.float32)

9)开始微调

trainer.train()

10)保存微调好的模型

model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model  # Take care of distributed/parallel trainingmodel_to_save.save_pretrained("outputs")

11)加载微调好的模型

lora_config = LoraConfig.from_pretrained('outputs')tuned_model = get_peft_model(model, lora_config)

12)测试微调好的模型效果

text = "What is a large language model?"device = "cuda:0"inputs = tokenizer(text, return_tensors="pt").to(device)outputs = tuned_model.generate(**inputs, max_new_tokens=50)print(tokenizer.decode(outputs[0], skip_special_tokens=True))

参考文献:

[1] https://trojrobert.medium.com/4-easier-ways-for-fine-tuning-llama-2-and-other-open-source-llms-eb3218657f6e

[2] https://colab.research.google.com/drive/1JMEi2VMNGMOTyfEcQZyp23EISUrWg5cg?usp=sharing

[3] https://colab.research.google.com/drive/1ctevXhrE60s7o9RzsxpIqq37EjyU9tBn?usp=sharing#scrollTo=bsbdrb5p2ONa

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/202229.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

循环队列中的求队列长度公式怎么来的?【数学角度】

循环队列中的队列长度怎么来的? 引入 在一个循环队列中,队列的元素个数可以通过头指针(Front,通常用F表示)和尾指针(Rear,通常用R表示)来计算。假设队列的存储空间大小为n,队列中…

二:C语言-数据类型和变量

二:数据类型和变量 1.数据类型的介绍: ​ 内置数据类型(C语言本身具有的):字符 - char;整型 - int;浮点型 - float;布尔类型 - _Bool ​ 自定义数据类型(自己创建的类…

选择更好的Notes索引附件方式

大家好,才是真的好。 首先介绍最近产品更新消息。在上一周,HCL主要发布了以下几个产品更新:HCL Verse 3.2.0、HCL Volt MX Go 2.0.2、HCL Domino Rest API 1.0.8。 HCL Verse是今后Domino的产品当中主要使用的webmail功能,这一次…

kafka学习笔记--基础知识概述

本文内容来自尚硅谷B站公开教学视频,仅做个人总结、学习、复习使用,任何对此文章的引用,应当说明源出处为尚硅谷,不得用于商业用途。 如有侵权、联系速删 视频教程链接:【尚硅谷】Kafka3.x教程(从入门到调优…

【开题报告】基于SpringBoot的农场管理系统的设计与实现

1.选题背景 随着社会经济的发展和人们对食品安全和质量的要求不断提高,农业管理也面临着新的挑战和需求。传统的农场管理方式往往依靠手工记录和经验积累,存在信息不及时、管理效率低下等问题。而基于SpringBoot的农场管理系统的设计与实现,…

国标GB28181设备注册安防监控平台EasyCVR不上线是什么原因?

安防视频监控EasyCVR平台兼容性强,可支持的接入协议众多,包括国标GB28181、RTSP/Onvif、RTMP,以及厂家的私有协议与SDK,如:海康ehome、海康sdk、大华sdk、宇视sdk、华为sdk、萤石云sdk、乐橙sdk等。平台能将接入的视频…

[头歌Python实验]Python入门之基础语法

目录 第1关:行与缩进 第2关:标识符与保留字 第3关:注释 第4关:输入输出 如果对你有帮助的话,不妨点赞收藏关注评论走一波吧,爱你么么哒吗😘💖💖💖 第1关…

逆向爬虫进阶实战:突破反爬虫机制,实现数据抓取

文章目录 一、引言二、逆向爬虫进阶技巧三、逆向爬虫进阶实战代码片段四、总结与展望好书推荐内容简介作者简介前言节选 一、引言 随着网络技术的发展,网站为了保护自己的数据和资源,纷纷采用了各种反爬虫机制。然而,逆向爬虫技术的出现&…

【Python百宝箱】漫游Python数据可视化宇宙:pyspark、dash、streamlit、matplotlib、seaborn全景式导览

Python数据可视化大比拼:从大数据处理到交互式Web应用 前言 在当今数字时代,数据可视化是解释和传达信息的不可或缺的工具之一。本文将深入探讨Python中流行的数据可视化库,从大数据处理到交互式Web应用,为读者提供全面的了解和…

控乐屋品牌|智汇恒星全宅智能空间万物互联,千亿蓝海蓄势待发

随着5G、大数据、云计算、物联网等技术的发展,智能化正覆盖人们生活的方方面面,全屋智能的出现为“一键式”智能家居生活享受提供无限可能。近年来智能家居行业总体规模增长迅速,数据显示,2022年中国智能家居行业市场规模约为6200…

Redis滚动分页的使用

Feed流 关注推送也叫Feed流。通过无限下拉刷新获取新的信息。 Feed流产品常见有两种模式: Timeline: 不做内容筛选,简单的按照内容发布时间排序,常用于好友或关注。例如朋友圈 优点:信息全面,不会有缺失。并且实现也…

2023五岳杯量子计算挑战赛APMCM亚太地区

问题一要求在特定区域内部署两个边缘服务器,以便根据计算需求分布覆盖最大的计算需求。每个边缘服务器都有一个覆盖半径为1。目标是确定两个边缘服务器的位置,以覆盖最大的计算需求。假设边缘服务器的位置位于网格的中心,每个网格内的计算需求…

我们为什么那么关注 Java 中的 String Template ,Java 21 特性

本心、输入输出、结果 文章目录 我们为什么那么关注 Java 中的 String Template ,Java 21 特性前言String TemplateString Template 有什么好处字符串连接 – 一个常见但无趣且容易出错的任务jetbrains IDEA 2023.2 版本及以上对于 String Template 的支持字符串模板…

试编写算法将带头结点的单链表就地逆置(就地是指辅助空间复杂度为 O(0))。

题目描述:试编写算法将带头结点的单链表就地逆置(就地是指辅助空间复杂度为 O(0)。 分析: 将单链表就地逆置可以考虑使用头插法。 LinkList Reverse(LinkList L){LNode *p L->next;LNode *r;L->next NULL;while(p){r p->next;p…

redis中使用pipeline

在操作数据库时,为了加快程序的执行速度,在新增或更新数据时,可以通过批量提交的方式来减少应用和数据库间的传输次数;在redis中也有这样的技术实现批量处理,也就是管道——Pipeline。它也是通过批量提交数据的方式来实…

FPS和SFTP的速度哪个更快?区别在哪里?

在互联网时代,我们频繁需要传输大文件,如视频、音乐、图片和文档等。这些文件不仅占用大量空间,而且传输时间长。确保传输过程的安全性和稳定性,以防文件被窃取或损坏成为重要考虑因素。在选择传输方式时,FPS和SFTP是两…

力扣-435.无重叠空间

利用快排,对数组右边界进行排序。 用一个变量记录区间的分割点,然后用这个分割点去和下一个区间做比较,如果没有重叠,更新右边界,没有重叠的区间个数加一。 然后更新右边界,继续进行比较。 最后用总区间…

Databend 如何利用 GPT-4 进行质量保证

背景 在数据库行业,质量是核心要素。 Databend 的应用场景广泛,特别是在金融相关领域,其查询结果的准确性对用户至关重要。因此,在快速迭代的过程中,如何确保产品质量,成为我们面临的重大挑战。 随着 Da…

leaflet:经纬度坐标转为地址,点击鼠标显示地址信息(137)

第137个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+leaflet中将经纬度坐标转化为地址,点击鼠标显示某地的地址信息 。主要利用mapbox的api将坐标转化为地址,然后在固定的位置显示出来。 直接复制下面的 vue+leaflet源代码,操作2分钟即可运行实现效果 文章目录 示…

Springboot 项目关于版本升级到 3.x ,JDK升级到17的相关问题

由于spring 停止对2.x 版本的维护,以及 jdk 频繁发布等客观因素,现需要对已有springboot 工程做一次全面升级;已因对市面上第三方等依赖库的兼容问题; 现有工程使用哥技术栈是版本: freemarker :2.3.32 spr…