【LLM教程-llama】如何Fine Tuning大语言模型?

今天给大家带来了一篇超级详细的教程,手把手教你如何对大语言模型进行微调(Fine Tuning)!(代码和详细解释放在后文)

目录

大语言模型进行微调(Fine Tuning)需要哪些步骤?

大语言模型进行微调(Fine Tuning)训练过程及代码


大语言模型进行微调(Fine Tuning)需要哪些步骤?

大语言模型进行微调(Fine Tuning)的主要步骤🤩

  1. 📚 准备训练数据集
    首先你需要准备一个高质量的训练数据集,最好是与你的应用场景相关的数据。可以是文本数据、对话数据等,格式一般为JSON/TXT等。

  2. 📦 选择合适的基础模型
    接下来需要选择一个合适的基础预训练模型,作为微调的起点。常见的有GPT、BERT、T5等大模型,可根据任务场景进行选择。

  3. ⚙️ 设置训练超参数
    然后是设置训练的各种超参数,比如学习率、批量大小、训练步数等等。选择合理的超参数对模型效果影响很大哦。

  4. 🧑‍💻 加载模型和数据集
    使用HuggingFace等库,把选定的基础模型和训练数据集加载进来。记得对数据集进行必要的前处理和划分。

  5. ⚡ 开始模型微调训练
    有了模型、数据集和超参数后,就可以开始模型微调训练了!可以使用PyTorch/TensorFlow等框架进行训练。

  6. 💾 保存微调后的模型
    训练结束后,别忘了把微调好的模型保存下来,方便后续加载使用哦。

  7. 🧪 在测试集上评估模型
    最后在准备好的测试集上评估一下微调后模型的效果。看看与之前的基础模型相比,是否有明显提升?

大语言模型进行微调(Fine Tuning)训练过程及代码

那如何使用 Lamini 库加载数据、设置模型和训练超参数、定义推理函数、微调基础模型、评估模型效果呢?

  • 首先,导入必要的库
import os
import lamini
import datasets
import tempfile
import logging
import random
import config
import os
import yaml
import time
import torch
import transformers
import pandas as pd
import jsonlinesfrom utilities import *
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments
from transformers import AutoModelForCausalLM
from llama import BasicModelRunner

这部分导入了一些必需的Python库,包括Lamini、Hugging Face的Datasets、Transformers等。

  • 加载Lamini文档数据集
dataset_name = "lamini_docs.jsonl"
dataset_path = f"/content/{dataset_name}"
use_hf = False
dataset_path = "lamini/lamini_docs"
use_hf = True

这里指定了数据集的路径,同时设置了use_hf标志,表示是否使用Hugging Face的Datasets库加载数据。

  • 设置模型、训练配置和分词器
model_name = "EleutherAI/pythia-70m"
training_config = { ... }
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
train_dataset, test_dataset = tokenize_and_split_data(training_config, tokenizer)

这部分指定了基础预训练模型的名称,并设置了训练配置(如最大长度等)。然后,它使用AutoTokenizer从预训练模型中加载分词器,并对分词器进行了一些调整。最后,它调用tokenize_and_split_data函数对数据进行分词和划分训练/测试集。

  • 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(model_name)
device_count = torch.cuda.device_count()
if device_count > 0:device = torch.device("cuda")
else:device = torch.device("cpu")
base_model.to(device)

这里使用AutoModelForCausalLM从预训练模型中加载基础模型,并根据设备(GPU或CPU)将模型移动到相应的设备上。

  • 定义推理函数
def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100):...

这个函数用于在给定输入文本的情况下,使用模型和分词器进行推理并生成输出。它包括对输入文本进行分词、使用模型生成输出以及解码输出等步骤。

  • 尝试使用基础模型进行推理
test_text = test_dataset[0]['question']
print("Question input (test):", test_text)
print(f"Correct answer from Lamini docs: {test_dataset[0]['answer']}")
print("Model's answer: ")
print(inference(test_text, base_model, tokenizer))

这部分使用上一步定义的inference函数,在测试数据集的第一个示例上尝试使用基础模型进行推理。它打印了输入问题、正确答案和模型的输出。

  • 设置训练参数
max_steps = 3
trained_model_name = f"lamini_docs_{max_steps}_steps"
output_dir = trained_model_name
training_args = TrainingArguments(# Learning ratelearning_rate=1.0e-5,# Number of training epochsnum_train_epochs=1,# Max steps to train for (each step is a batch of data)# Overrides num_train_epochs, if not -1max_steps=max_steps,# Batch size for trainingper_device_train_batch_size=1,# Directory to save model checkpointsoutput_dir=output_dir,# Other argumentsoverwrite_output_dir=False, # Overwrite the content of the output directorydisable_tqdm=False, # Disable progress barseval_steps=120, # Number of update steps between two evaluationssave_steps=120, # After # steps model is savedwarmup_steps=1, # Number of warmup steps for learning rate schedulerper_device_eval_batch_size=1, # Batch size for evaluationevaluation_strategy="steps",logging_strategy="steps",logging_steps=1,optim="adafactor",gradient_accumulation_steps = 4,gradient_checkpointing=False,# Parameters for early stoppingload_best_model_at_end=True,save_total_limit=1,metric_for_best_model="eval_loss",greater_is_better=False
)

这一部分设置了训练的一些参数,包括最大训练步数、输出模型目录、学习率等超参数。

为什么要这样设置这些训练超参数:

  1. learning_rate=1.0e-5
    学习率控制了模型在每个训练步骤中从训练数据中学习的速度。1e-5是一个相对较小的学习率,可以有助于稳定训练过程,防止出现divergence(发散)的情况。

  2. num_train_epochs=1
    训练的轮数,即让数据在模型上循环多少次。这里设置为1,是因为我们只想进行轻微的微调,避免过度训练(overfitting)。

  3. max_steps=max_steps
    最大训练步数,会覆盖num_train_epochs。这样可以更好地控制训练的总步数。

  4. per_device_train_batch_size=1
    每个设备(GPU/CPU)上的训练批量大小。批量大小越大,内存占用越高,但训练过程可能更加稳定。

  5. output_dir=output_dir
    用于保存训练过程中的检查点(checkpoints)和最终模型的目录。

  6. overwrite_output_dir=False
    如果目录已存在,是否覆盖它。设为False可以避免意外覆盖之前的结果。

  7. eval_steps=120, save_steps=120
    每120步评估一次模型性能,并保存模型。频繁保存可以在训练中断时恢复。

  8. warmup_steps=1
    学习率warmup步数,一开始使用较小的学习率有助于稳定训练早期阶段。

  9. per_device_eval_batch_size=1
    评估时每个设备上的批量大小。通常与训练时相同。

  10. evaluation_strategy="steps", logging_strategy="steps"
    以步数为间隔进行评估和记录日志,而不是以epoch为间隔。

  11. optim="adafactor"
    使用Adafactor优化器,适用于大规模语言模型训练。

  12. gradient_accumulation_steps=4
    梯度积累步数,可以模拟使用更大批量大小的效果,节省内存。

  13. load_best_model_at_end=True
    保存验证集上性能最好的那个检查点,作为最终模型。

  14. metric_for_best_model="eval_loss", greater_is_better=False
    根据验证损失评估模型,损失越小越好。

model_flops = (base_model.floating_point_ops({"input_ids": torch.zeros((1, training_config["model"]["max_length"]))})* training_args.gradient_accumulation_steps
)print(base_model)
print("Memory footprint", base_model.get_memory_footprint() / 1e9, "GB")
print("Flops", model_flops / 1e9, "GFLOPs")print(base_model)
print("Memory footprint", base_model.get_memory_footprint() / 1e9, "GB")
print("Flops", model_flops / 1e9, "GFLOPs")

这里还计算并打印了模型的内存占用和计算复杂度(FLOPs)。

最后,使用这些参数创建了一个Trainer对象,用于实际进行模型训练。

trainer = Trainer(model=base_model,model_flops=model_flops,total_steps=max_steps,args=training_args,train_dataset=train_dataset,eval_dataset=test_dataset,
)
  • 训练模型几个步骤
training_output = trainer.train()

这一行代码启动了模型的微调训练过程,并将训练输出存储在training_output中。

  • 保存微调后的模型
save_dir = f'{output_dir}/final'
trainer.save_model(save_dir)
print("Saved model to:", save_dir)
finetuned_slightly_model = AutoModelForCausalLM.from_pretrained(save_dir, local_files_only=True)
finetuned_slightly_model.to(device)

这部分将微调后的模型保存到指定的目录中。

然后,它使用AutoModelForCausalLM.from_pretrained从保存的模型中重新加载该模型,并将其移动到相应的设备上。

  • 使用微调后的模型进行推理
test_question = test_dataset[0]['question']
print("Question input (test):", test_question)
print("Finetuned slightly model's answer: ")
print(inference(test_question, finetuned_slightly_model, tokenizer))
test_answer = test_dataset[0]['answer']
print("Target answer output (test):", test_answer)

这里使用之前定义的inference函数,在测试数据集的第一个示例上尝试使用微调后的模型进行推理。

打印了输入问题、模型输出以及正确答案。

  • 加载并运行其他预训练模型
finetuned_longer_model = AutoModelForCausalLM.from_pretrained("lamini/lamini_docs_finetuned")
tokenizer = AutoTokenizer.from_pretrained("lamini/lamini_docs_finetuned")
finetuned_longer_model.to(device)
print("Finetuned longer model's answer: ")
print(inference(test_question, finetuned_longer_model, tokenizer))bigger_finetuned_model = BasicModelRunner(model_name_to_id["bigger_model_name"])
bigger_finetuned_output = bigger_finetuned_model(test_question)
print("Bigger (2.8B) finetuned model (test): ", bigger_finetuned_output)

这部分加载了另一个经过更长时间微调的模型,以及一个更大的2.8B参数的微调模型。它使用这些模型在测试数据集的第一个示例上进行推理,并打印出结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38847.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VuePress介绍

从本文开始,动手搭建自己的博客!希望读者能跟着一起动手,这样才能真正掌握。 ‍ VuePress 是什么 VuePress 是由 Vue 作者带领团队开发的,非常火,使用的人很多;Vue 框架官网也是用了 VuePress 搭建的。即…

000.二分查找算法题解目录

000.二分查找算法题解目录 69. x 的平方根(简单)

4PCS点云配准算法实现

4PCS点云配准算法的C实现如下&#xff1a; #include <iostream> #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include <pcl/common/common.h> #include <pcl/common/distances.h> #include <pcl/common/transforms.h> #in…

唯一ID:UUID 介绍与 google/uuid 库生成 UUID

UUID 即通用唯一识别码&#xff0c;是一种用于计算机系统中以确保全局唯一性的标识符。其标准定义于 RFC 4122 文档中。标准形式包含 32 个 16 进制数字&#xff0c;以连字符切割为五组&#xff0c;格式为 8-4-4-4-12&#xff0c;总共 36 个字符。&#xff08;形如, d169aa7f-4…

php 通过vendor文件 生成还原最新的composer.json

起因&#xff1a;因为历史原因&#xff0c;在本项目中composer.json基本算废了&#xff0c;没法直接使用composer管理扩展&#xff0c;今天尝试修复一下composer.json。 历史文件&#xff0c;可以看出来已经很久没有维护了&#xff0c;我们主要是恢复require的信息 {"na…

K8s节点维护流程

用途 用于下线异常节点、集群缩容等 操作步骤 1. 查看节点名称 先确认节点的名称 kubectl get node -o wide2. 设置节点不可调度 设置节点不可调度状态&#xff0c;禁止新的pod调度到该节点上 kubectl cordon ${node_name}3. 剔除节点上运行的pod&#xff08;生产环境慎…

Spring Boot中集成Redis实现缓存功能

Spring Boot中集成Redis实现缓存功能 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天我们将深入探讨如何在Spring Boot应用程序中集成Redis&#xff0c;实现…

AP无法上线原因分析及排障

一、AP未分配到IP地址 如果遇到AP无法上线问题&#xff0c;可以检查下AP是否分配到IP地址。AP获取IP地址有两种方式&#xff1a;静态方式&#xff1a;登录到AP设备&#xff0c;手工配置IP地址&#xff0c;该方式操作起来比较麻烦&#xff0c;不推荐使用&#xff1b;DHCP方式&am…

基于CNN的股票预测方法【卷积神经网络】

基于机器学习方法的股票预测系列文章目录 一、基于强化学习DQN的股票预测【股票交易】 二、基于CNN的股票预测方法【卷积神经网络】 文章目录 基于机器学习方法的股票预测系列文章目录一、CNN建模原理二、模型搭建三、模型参数的选择&#xff08;1&#xff09;探究window_size…

下代iPhone或回归可拆卸电池,苹果这操作把我看傻了

刚度过一个愉快的周末&#xff0c;苹果又双叒叕摊上事儿了。 iPhone13 系列被曝扎堆电池鼓包了。 早在去年&#xff0c;就有 iPhone13 和 iPhone14 用户反馈过类似的问题&#xff0c;表示在手机仅仅使用了一年多的时间就出现了电池鼓包的情况&#xff0c;而且还把屏幕给撑起来了…

舞会无领导:一种树形动态规划的视角

没有上司的舞会 Ural 大学有 &#x1d441; 名职员&#xff0c;编号为1∼&#x1d441;。 他们的关系就像一棵以校长为根的树&#xff0c;父节点就是子节点的直接上司。 每个职员有一个快乐指数&#xff0c;用整数 &#x1d43b;&#x1d456; 给出&#xff0c;其中1≤&…

校园卡手机卡怎么注销?

校园手机卡的注销流程可以根据不同的运营商和具体情况有所不同&#xff0c;但一般来说&#xff0c;以下是注销校园手机卡的几种常见方式&#xff0c;我将以分点的方式详细解释&#xff1a; 一、线上注销&#xff08;通过手机APP或官方网站&#xff09; 下载并打开对应运营商的…

C++ 指针介绍

指针是C编程语言中的一个强大且重要的特性。它允许程序员直接操作内存地址&#xff0c;从而提供了对低级别内存的访问和控制。虽然指针在使用时可能比较复杂且容易出错&#xff0c;但它们在提高程序效率和灵活性方面有着不可替代的作用。本文将介绍C指针的基本概念、用法及其应…

Docker 中 MySQL 迁移策略(单节点)

目录 一、 简介二、操作流程2.1 进入mysql容器2.2 导出 MySQL 数据2.3. 将导出的文件复制到宿主机2.4 创建 Docker Compose 配置2.5 启动新的 Docker 容器2.6 导入数据到新的容器2.7 验证数据2.8 删除旧的容器&#xff08;删除操作需慎重&#xff09; 三、推荐配置四、写在后面…

当年很多跑到美加澳写代码的人现在又移回香港?什么原因?

当年很多跑到美加澳写代码的人现在又移回香港&#xff1f;什么原因&#xff1f; 近年来&#xff0c;确实有部分曾经移民到美国、加拿大、澳大利亚等地的香港居民选择移回香港。这一现象与多种因素相关&#xff0c;主要可以归结为以下几点&#xff1a; 疫情后的环境变化&#…

【STM32】温湿度采集与OLED显示

一、任务要求 1. 学习I2C总线通信协议&#xff0c;使用STM32F103完成基于I2C协议的AHT20温湿度传感器的数据采集&#xff0c;并将采集的温度-湿度值通过串口输出。 任务要求&#xff1a; 1&#xff09;解释什么是“软件I2C”和“硬件I2C”&#xff1f;&#xff08;阅读野火配…

2025第13届常州国际工业装备博览会招商全面启动

常州智造 装备中国|2025第13届常州国际工业装备博览会招商全面启动 2025第13届常州国际工业装备博览会将于2025年4月11-13日在常州西太湖国际博览中心盛大举行&#xff01;目前&#xff0c;各项筹备工作正稳步推进。 60000平米的超大规模、800多家国内外工业装备制造名企将云集…

C++中的RAII(资源获取即初始化)原则

C中的RAII&#xff08;Resource Acquisition Is Initialization&#xff0c;资源获取即初始化&#xff09;原则是一种管理资源、避免资源泄漏的惯用法。RAII是C之父Bjarne Stroustrup提出的设计理念&#xff0c;其核心思想是将资源的获取&#xff08;如动态内存分配、文件句柄、…

最细最有条理解析:事件循环(消息循环)是什么?进程与线程的定义、关系与差异

目录 事件循环&#xff1a;引入 一、浏览器的进程模型 1.1、什么是进程&#xff08;Process&#xff09; 1.2、什么是线程&#xff08;Thread&#xff09; 1.3、进程与线程之间的关系联系与区别 二、浏览器有哪些进程和线程 2.1、浏览器的主要进程 ①浏览器进程 ②网络…

ctfshow sqli-libs web561--web568

web561 ?id-1 or 1--?id-1 union select 1,2,3--?id-1 union select 1,(select group_concat(column_name) from information_schema.columns where table_nameflags),3-- Your Username is : id,flag4s?id-1 union select 1,(select group_concat(flag4s) from ctfshow.f…