大模型微调(PEFT)

大模型微调(PEFT)

  • PEFT(Parameter-Efficient Fine-Tuning)
    • 一、PEFT 核心方法
      • 1. LoRA(Low-Rank Adaptation)
      • 2. Adapter
      • 3. Prefix Tuning
      • 4. Prompt Tuning
      • 5. QLoRA(Quantized LoRA)
    • 二、PEFT vs 全参数微调
    • 三、微调大模型示例代码
    • 四、加载微调后的大模型
      • 1. Lora
      • 2. prefix tuning

大模型微调方法描述

PEFT(Parameter-Efficient Fine-Tuning)

PEFT(参数高效微调)是一类用于大幅降低大模型微调成本的技术,核心思想是仅微调少量参数,而非整个模型。以下是系统化的解析:

一、PEFT 核心方法

1. LoRA(Low-Rank Adaptation)

  • 原理
    • 在原始权重旁添加低秩矩阵(W = W₀ + BA),仅训练BA
  • 适用场景:文本生成、对话系统
  • 代码示例
    • r(秩)通常为4~64,参数量减少90%+
    from peft import LoraConfig, get_peft_modelconfig = LoraConfig(r=8,                      # 秩lora_alpha=32,            # 缩放系数target_modules=["q_proj", "v_proj"],  # 作用模块lora_dropout=0.05,bias="none",
    )
    model = get_peft_model(model, config)  # 原始模型+LoRA
    

2. Adapter

  • 原理
    • 在Transformer层间插入小型全连接网络,仅训练Adapter层。
    • 参数量占比约0.5%~5%
  • 适用场景:多任务学习
  • 结构示例
    Transformer Layer → Adapter(Down→ReLU→Up) → Residual→ LayerNorm
    

3. Prefix Tuning

  • 原理

    • 在输入前添加可学习的“虚拟token”(prefix),引导模型生成。
    • 完全不修改原始参数
  • 适用场景:生成任务(如GPT)

  • 结构示例

    import torch
    import torch.nn as nn
    from transformers import AutoModelForCausalLM, AutoTokenizer# 加载预训练模型和分词器
    model_name = "gpt2"  # 可替换为你想要使用的模型名称
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)# 定义Prefix Tuning模块
    class PrefixTuning(nn.Module):def __init__(self, num_virtual_tokens, hidden_size):super(PrefixTuning, self).__init__()self.prefix_embeddings = nn.Embedding(num_virtual_tokens, hidden_size)nn.init.normal_(self.prefix_embeddings.weight, mean=0, std=0.02)def forward(self, input_ids, attention_mask):batch_size = input_ids.shape[0]prefix = self.prefix_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)new_input_ids = torch.cat([torch.full((batch_size, prefix.shape[1]), tokenizer.pad_token_id).to(input_ids.device), input_ids], dim=1)new_attention_mask = torch.cat([torch.ones((batch_size, prefix.shape[1])).to(attention_mask.device), attention_mask], dim=1)return new_input_ids, new_attention_mask
    

4. Prompt Tuning

  • 原理

    • 在输入层加入prompt tokens。
    • 完全不修改原始参数,简化版的Prefix Tuning,无需MLP调整,随着模型规模增大,效果接近full fine-tuning。
  • 结构示例

    prompt = "请回答以下问题:"
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(input_ids.device)
    new_input_ids = torch.cat([prompt_ids.repeat(batch_size, 1), input_ids], dim=1)
    

5. QLoRA(Quantized LoRA)

  • 原理
    4-bit量化基础模型 + LoRA微调,显存需求降低70%
  • 代码示例
    from transformers import BitsAndBytesConfigbnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = AutoModel.from_pretrained("Llama-3-8B", quantization_config=bnb_config)
    

二、PEFT vs 全参数微调

指标PEFT全参数微调
显存占用极低(可单卡微调70B)极高(需多卡)
训练速度快(仅更新少量参数)
效果接近全参数微调最优但差异<5%
部署便利性需合并适配器直接部署

三、微调大模型示例代码

注意:使用model.save_pretrained("fine_tuned_internvl_3") 保存经过 PEFT(如 LoRA 或其他 Adapter 微调)后的模型时,保存的权重通常不包含基础模型(base_model)的原始权重,仅保存微调过程中可训练的部分。

import math
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer, AutoConfig, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import os# 模型加载
path = 'InternVL3'
device_map = split_model(path)
model = AutoModel.from_pretrained(path,torch_dtype=torch.bfloat16,load_in_8bit=True,low_cpu_mem_usage=True,use_flash_attn=True,trust_remote_code=True,device_map=device_map).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)# 配置LoRA
lora_config = LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj", "v_proj"],lora_dropout=0.1,bias="none",task_type="CAUSAL_LM"
)model = get_peft_model(model, lora_config)
model.print_trainable_parameters()# 读取数据集
data_path = 'data'
df = pd.read_parquet(data_path)
dataset = CustomDataset(df, tokenizer)# 训练参数设置
training_args = TrainingArguments(output_dir='./results',num_train_epochs=3,per_device_train_batch_size=4,gradient_accumulation_steps=4,save_steps=10_000,save_total_limit=2,evaluation_strategy="no",logging_steps=10,fp16=True
)# 创建Trainer
trainer = Trainer(model=model,args=training_args,train_dataset=dataset
)# 开始训练
trainer.train()# 保存微调后的模型
model.save_pretrained("fine_tuned_internvl3")    

四、加载微调后的大模型

1. Lora

  • 示例代码:
    from transformers import AutoModel
    from peft import PeftModel# 加载基础的预训练模型
    base_model_path = "base_model_path"  # 替换为基础预训练模型的路径
    base_model = AutoModel.from_pretrained(base_model_path)# 加载微调后的适配器
    adapter_path = "fine_tuned_adapter_path"  # 替换为微调后适配器的保存路径
    model = PeftModel.from_pretrained(base_model, adapter_path)
    

2. prefix tuning

  • 示例代码:
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch.nn as nn# 定义 Prefix Tuning 模块
    class PrefixTuning(nn.Module):def __init__(self, num_virtual_tokens, hidden_size):super(PrefixTuning, self).__init__()self.prefix_embeddings = nn.Embedding(num_virtual_tokens, hidden_size)def forward(self, input_ids, attention_mask):batch_size = input_ids.shape[0]prefix = self.prefix_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)new_input_ids = torch.cat([torch.full((batch_size, prefix.shape[1]), tokenizer.pad_token_id).to(input_ids.device),input_ids], dim=1)new_attention_mask = torch.cat([torch.ones((batch_size, prefix.shape[1])).to(attention_mask.device),attention_mask], dim=1)return new_input_ids, new_attention_mask# 加载基础的预训练模型和分词器
    model_name = "gpt2"  # 可替换为实际的模型名称
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)# 初始化 Prefix Tuning 模块
    num_virtual_tokens = 10  # 替换为实际的虚拟 token 数量
    hidden_size = model.config.hidden_size
    prefix_tuning = PrefixTuning(num_virtual_tokens, hidden_size)# 加载 Prefix Tuning 的参数
    try:prefix_tuning.load_state_dict(torch.load("path/to/prefix_tuning_weights.pth"))
    except FileNotFoundError:print("错误:未找到 Prefix Tuning 参数文件,请检查路径。")exit(1)# 将模型和 Prefix Tuning 模块移动到 GPU(如果可用)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    prefix_tuning.to(device)# 输入文本
    input_text = "Once upon a time"
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    attention_mask = torch.ones_like(input_ids).to(device)# 使用 Prefix Tuning 处理输入
    new_input_ids, new_attention_mask = prefix_tuning(input_ids, attention_mask)# 进行推理
    with torch.no_grad():outputs = model(new_input_ids, attention_mask=new_attention_mask)logits = outputs.logits# 生成文本
    generated_ids = torch.argmax(logits, dim=-1)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/78970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flutter 打包mac程序 dmg教程

✅ 前提条件 ✅ 你已经在 macOS 上安装了 Android Studio Flutter SDK。 ✅ Flutter 支持 macOS 构建。 运行下面命令确认是否支持&#xff1a; Plain Text bash 复制编辑 flutter doctor ---## &#x1f9f1; 第一步&#xff1a;启用 macOS 支持如果是新项目&#xff0c;…

鸿蒙开发-动画

1. 动画-动画特效 // 定义接口 (每个列表项的数据结构) interface ImageCount {url: stringcount: number }// 需求1: 遮罩层显隐 透明度opacity 0-1 层级zIndex -1~99 // 需求2: 图片缩放 缩放scale 0-1Entry Component struct Index {// 基于接口, 准备数据State images…

js:循环查询数组对象中的某一项的值是否为空

循环检查 selinfo 数组中的每一个对象&#xff0c;判断其中的 po_qty 和 price 是否为空&#xff08;null、undefined 或空字符串 ""&#xff09;&#xff0c;可以使用以下几种方法&#xff1a; 方法1&#xff1a;使用 forEach 循环检查每一项 const selinfo this.…

x-cmd install | jellex - 用 Python 语法在终端里玩转 JSON 数据!

目录 核心功能与特点安装优势亮点适用场景 还在为命令行下处理 JSON 数据烦恼吗&#xff1f;jellex 来了&#xff01;它是一款基于终端的交互式 JSON 和 JSON Lines 数据处理工具&#xff0c;让你用熟悉的 Python 语法&#xff0c;轻松过滤、转换和探索 JSON 数据。 核心功能与…

4月份到9月份看6本书第二天【ERP与企业管理】

ERP与企业管理 1-11章全面介绍了ERP的基本原理、物料管理功能、计划功能、生产和采购管理功能、效益以及实施和应用ERP为企业带来的深层次的变化。 第12章讨论了软件系统的选型。 第13章介绍了ERP实施和运行管理的方法 第14章介绍了国际上广泛使用的ERP实施应用的评估方法。…

Opencv计算机视觉编程攻略-第十三节 跟踪视频中的物品

这是opencv系列的最后一节&#xff0c;主要学习视频序列&#xff0c;上一节介绍了读取、处理和存储视频的工具&#xff0c;本文将介绍几种跟踪图像序列中运动物体的算法。可见运动或表观运动&#xff0c;是物体以不同的速度在不同的方向上移动&#xff0c;或者是因为相机在移动…

001 蓝桥杯嵌入式赛道备赛——基础

个人笔记&#xff0c;不扭扭捏捏&#xff0c;一口气到位。方便自己也方便大家 00 时钟线 cubeMX已经完成了大多数工作 01 LED&#xff08;GPIO输出&#xff09; 在使用LED的时候先把SN74HC573锁存器PD2置高电平&#xff0c;然后写入LED所要的高低电平&#xff0c;然后置PD2低…

案例-索引对于并发Insert性能优化测试

前言 最近因业务并发量上升,开发反馈对订单表Insert性能降低。应开发要求对涉及Insert的表进行分析并提供优化方案。   一般对Insert 影响基本都在索引,涉及表已按创建日期做了分区表,索引全部为普通索引未做分区索引。 优化建议: 1、将UNIQUE改为HASH(64) GLOBAL IND…

【技术文章的标准结构与内容指南】

技术文章的标准结构与内容指南 技术文章是传递专业知识、分享实践经验的重要媒介。一篇高质量的技术文章不仅能够帮助读者解决问题&#xff0c;还能促进技术交流与创新。以下是技术文章通常包含的核心内容与结构指南。 1. 标题 一个好的技术文章标题应当&#xff1a; 简洁明…

豪越消防一体化安全管控平台:构建消防“一张图”新生态

在城市化进程加速、建筑规模与功能日益复杂的当下&#xff0c;消防救援工作面临着诸多严峻挑战。火灾隐患如同隐藏在暗处的“定时炸弹”&#xff0c;广泛分布于城市的各个角落&#xff0c;想要快速、精准定位绝非易事。信息传递的不顺畅更是雪上加霜&#xff0c;导致救援效率大…

重学Redis:Redis常用数据类型+存储结构(源码篇)

一、SDS 1&#xff0c;SDS源码解读 sds (Simple Dynamic String)&#xff0c;Simple的意思是简单&#xff0c;Dynamic即动态&#xff0c;意味着其具有动态增加空间的能力&#xff0c;扩容不需要使用者关心。String是字符串的意思。说白了就是用C语言自己封装了一个字符串类型&a…

抖音IP属地可以随便选择地址吗?深度解析

在当今社交媒体盛行的时代&#xff0c;抖音作为受欢迎的短视频平台之一&#xff0c;其IP属地显示功能引发了广泛关注。许多用户好奇&#xff1a;抖音的IP属地是否可以随意更改&#xff1f;是否存在方法可以“伪装”自己的位置&#xff1f;‌本文将深入探讨这一话题。 一、抖音I…

SOLID原则详解:提升软件设计质量的关键

前言 关于设计原则SOLID具体指的是什么&#xff0c;怎么理解这些设计原则&#xff0c;我觉得有必要记录一笔&#xff0c;毕竟这个设计原则确实经常在关键技术文档中提及&#xff0c;在编程思想中提及&#xff0c;在日常的开发中使用&#xff0c;但是对我来说&#xff0c;似乎知…

如何使用 ONLYOFFICE 恢复之前的文件版本?

如何使用 ONLYOFFICE 恢复之前的文件版本&#xff1f; https://www.onlyoffice.com/blog/zh-hans/2023/04/how-to-use-version-history

简简单单实现一个Python+Selenium的自动化测试框架

什么是Selenium&#xff1f; Selenium是一个基于浏览器的自动化测试工具&#xff0c;它提供了一种跨平台、跨浏览器的端到端的web自动化解决方案。Selenium主要包括三部分&#xff1a;Selenium IDE、Selenium WebDriver 和Selenium Grid。 Selenium IDE&#xff1a;Firefox的…

Java设计模式之中介者模式:从入门到架构级实践

一、什么是中介者模式&#xff1f; 中介者模式&#xff08;Mediator Pattern&#xff09;是一种行为型设计模式&#xff0c;其核心思想是通过引入一个中介对象来封装多个对象之间的交互关系。这种模式将原本复杂的网状通信结构转换为星型结构&#xff0c;类似于现实生活中的机…

Trinity三位一体开源程序是可解释的 AI 分析工具和 3D 可视化

一、软件介绍 文末提供源码和程序下载学习 Trinity三位一体开源程序是可解释的 AI 分析工具和 3D 可视化。Trinity 提供性能分析和 XAI 工具&#xff0c;非常适合深度学习系统或其他执行复杂分类或解码的模型。 二、软件作用和特征 Trinity 通过结合具有超维感知能力的不同交…

LeetCode 热题 100_单词拆分(86_139_中等_C++)(动态规划)

LeetCode 热题 100_单词拆分&#xff08;86_139&#xff09; 题目描述&#xff1a;输入输出样例&#xff1a;题解&#xff1a;解题思路&#xff1a;思路一&#xff08;动态规划&#xff09;&#xff1a; 代码实现代码实现&#xff08;思路一&#xff08;动态规划&#xff09;&a…

VM虚拟机安装及Ubuntu安装配置

VM虚拟机安装及Ubuntu安装配置 1、VM虚拟机安装2、创建虚拟机3、Ubuntu系统安装4、编译环境配置4.1 、Ubuntu和 Windows文件互传 文件互传4.1.1、 开启Ubunt下的FTP服务 4.2、 Ubuntu下NFS和SSH服务开启4.2.1、 NFS服务开启4.2.2、 SSH服务开启 4.3、 交叉编译器安装4.3.1 安装…

【KWDB 创作者计划】_产品技术解读_1

【KWDB 创作者计划】_产品技术解读_1 一、存储引擎:高性能混合存储架构1. 存储模型设计2. 存储压缩与编码3. 持久化策略二、KWDB 组件源码解析1. 核心模块分层架构2. 关键组件源码剖析三、KWDB 特性代码通读1. 实时分析能力(Real-Time OLAP)2. 混合负载隔离(HTAP)3. 智能索…