【LLM训练系列02】如何找到一个大模型Lora的target_modules

方法1:观察attention中的线性层

import numpy as np
import pandas as pd
from peft import PeftModel
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from typing import List
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
import os
os.environ['CUDA_VISIBLE_DEVICES']='1,2'
os.environ["TOKENIZERS_PARALLELISM"] = "false"model_path ="/home/jovyan/codes/llms/Qwen2.5-14B-Instruct"
base_model = AutoModel.from_pretrained(model_path, device_map='cuda:0',trust_remote_code=True)

打印attention模型层的名字

for name, module in base_model.named_modules():if 'attn' in name or 'attention' in name:  # Common attention module namesprint(name)for sub_name, sub_module in module.named_modules():  # Check sub-modules within attentionprint(f"  - {sub_name}")

方法2:通过bitsandbytes量化查找线性层

import bitsandbytes as bnb
def find_all_linear_names(model):lora_module_names = set()for name, module in model.named_modules():if isinstance(module, bnb.nn.Linear4bit):names = name.split(".")# model-specificlora_module_names.add(names[0] if len(names) == 1 else names[-1])if "lm_head" in lora_module_names:  # needed for 16-bitlora_module_names.remove("lm_head")return list(lora_module_names)

加载模型

bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16)
base_model = AutoModel.from_pretrained(model_path,quantization_config=bnb_config,device_map="auto")

查找Lora的目标层

find_all_linear_names(base_model)


还有个函数,一样的原理

def find_target_modules(model):# Initialize a Set to Store Unique Layersunique_layers = set()# Iterate Over All Named Modules in the Modelfor name, module in model.named_modules():# Check if the Module Type Contains 'Linear4bit'if "Linear4bit" in str(type(module)):# Extract the Type of the Layerlayer_type = name.split('.')[-1]# Add the Layer Type to the Set of Unique Layersunique_layers.add(layer_type)# Return the Set of Unique Layers Converted to a Listreturn list(unique_layers)find_target_modules(base_model)

方法3:通过分析开源框架的源码swift

代码地址

from collections import OrderedDict
from dataclasses import dataclass, field
from typing import List, Union@dataclass
class ModelKeys:model_type: str = Nonemodule_list: str = Noneembedding: str = Nonemlp: str = Nonedown_proj: str = Noneattention: str = Noneo_proj: str = Noneq_proj: str = Nonek_proj: str = Nonev_proj: str = Noneqkv_proj: str = Noneqk_proj: str = Noneqa_proj: str = Noneqb_proj: str = Nonekva_proj: str = Nonekvb_proj: str = Noneoutput: str = None@dataclass
class MultiModelKeys(ModelKeys):language_model: Union[List[str], str] = field(default_factory=list)connector: Union[List[str], str] = field(default_factory=list)vision_tower: Union[List[str], str] = field(default_factory=list)generator: Union[List[str], str] = field(default_factory=list)def __post_init__(self):# compatfor key in ['language_model', 'connector', 'vision_tower', 'generator']:v = getattr(self, key)if isinstance(v, str):setattr(self, key, [v])if v is None:setattr(self, key, [])LLAMA_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',o_proj='model.layers.{}.self_attn.o_proj',q_proj='model.layers.{}.self_attn.q_proj',k_proj='model.layers.{}.self_attn.k_proj',v_proj='model.layers.{}.self_attn.v_proj',embedding='model.embed_tokens',output='lm_head',
)INTERNLM2_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.feed_forward',down_proj='model.layers.{}.feed_forward.w2',attention='model.layers.{}.attention',o_proj='model.layers.{}.attention.wo',qkv_proj='model.layers.{}.attention.wqkv',embedding='model.tok_embeddings',output='output',
)CHATGLM_KEYS = ModelKeys(module_list='transformer.encoder.layers',mlp='transformer.encoder.layers.{}.mlp',down_proj='transformer.encoder.layers.{}.mlp.dense_4h_to_h',attention='transformer.encoder.layers.{}.self_attention',o_proj='transformer.encoder.layers.{}.self_attention.dense',qkv_proj='transformer.encoder.layers.{}.self_attention.query_key_value',embedding='transformer.embedding',output='transformer.output_layer',
)BAICHUAN_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',qkv_proj='model.layers.{}.self_attn.W_pack',embedding='model.embed_tokens',output='lm_head',
)YUAN_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',qk_proj='model.layers.{}.self_attn.qk_proj',o_proj='model.layers.{}.self_attn.o_proj',q_proj='model.layers.{}.self_attn.q_proj',k_proj='model.layers.{}.self_attn.k_proj',v_proj='model.layers.{}.self_attn.v_proj',embedding='model.embed_tokens',output='lm_head',
)CODEFUSE_KEYS = ModelKeys(module_list='gpt_neox.layers',mlp='gpt_neox.layers.{}.mlp',down_proj='gpt_neox.layers.{}.mlp.dense_4h_to_h',attention='gpt_neox.layers.{}.attention',o_proj='gpt_neox.layers.{}.attention.dense',qkv_proj='gpt_neox.layers.{}.attention.query_key_value',embedding='gpt_neox.embed_in',output='gpt_neox.embed_out',
)PHI2_KEYS = ModelKeys(module_list='transformer.h',mlp='transformer.h.{}.mlp',down_proj='transformer.h.{}.mlp.c_proj',attention='transformer.h.{}.mixer',o_proj='transformer.h.{}.mixer.out_proj',qkv_proj='transformer.h.{}.mixer.Wqkv',embedding='transformer.embd',output='lm_head',
)QWEN_KEYS = ModelKeys(module_list='transformer.h',mlp='transformer.h.{}.mlp',down_proj='transformer.h.{}.mlp.c_proj',attention='transformer.h.{}.attn',o_proj='transformer.h.{}.attn.c_proj',qkv_proj='transformer.h.{}.attn.c_attn',embedding='transformer.wte',output='lm_head',
)PHI3_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',o_proj='model.layers.{}.self_attn.o_proj',qkv_proj='model.layers.{}.self_attn.qkv_proj',embedding='model.embed_tokens',output='lm_head',
)PHI3_SMALL_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',o_proj='model.layers.{}.self_attn.dense',qkv_proj='model.layers.{}.self_attn.query_key_value',embedding='model.embed_tokens',output='lm_head',
)DEEPSEEK_V2_KEYS = ModelKeys(module_list='model.layers',mlp='model.layers.{}.mlp',down_proj='model.layers.{}.mlp.down_proj',attention='model.layers.{}.self_attn',o_proj='model.layers.{}.self_attn.o_proj',qa_proj='model.layers.{}.self_attn.q_a_proj',qb_proj='model.layers.{}.self_attn.q_b_proj',kva_proj='model.layers.{}.self_attn.kv_a_proj_with_mqa',kvb_proj='model.layers.{}.self_attn.kv_b_proj',embedding='model.embed_tokens',output='lm_head',
)

我的博客即将同步至腾讯云开发者社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=3hiaca88ulogc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

05_Spring JdbcTemplate

在继续了解Spring的核心知识前,我们先看看Spring的一个模板类JdbcTemplate,它是一个JDBC的模板类,用来简化JDBC的操作。 接下来以实际来进行说明 一、实例环境准备 数据库及表准备 我们在本地mysql中新增一个数据库test,并新增一张数据表:user create database if not…

萨瑞MCU R7FA8D1BH环境搭建教程

萨瑞MCU R7FA8D1BH环境搭建教程 如果你是大学生 遇到电子技术 学习 成长 入行难题 佳喔威信,给你提供一定资源和战略方法上的帮助 相信我的专业职业经历一定能帮到你 目录 概述 2. 开发板介绍3. 搭建rtthread环境4. 安装瑞萨的keil环境5. 搭建瑞萨的keil辅助环境…

鸿蒙实战:使用显式Want启动Ability

文章目录 1. 实战概述2. 实现步骤2.1 创建鸿蒙应用项目2.2 修改Index.ets代码2.3 创建SecondAbility2.4 创建Second.ets 3. 测试效果4. 实战总结5. 拓展练习 - 启动文件管理器5.1 创建鸿蒙应用项目5.2 修改Index.ets代码5.3 测试应用运行效果 1. 实战概述 本实战详细阐述了在 …

【Nginx】反向代理Https时相关参数:

在Nginx代理后台HTTPS服务时,有几个关键的参数需要配置,以确保代理服务器能够正确地与后端服务器进行通信。一些重要参数的介绍: proxy_ssl_server_name:这个参数用于指定是否在TLS握手时通过SNI(Server Name Indicati…

PH热榜 | 2024-11-19

DevNow 是一个精简的开源技术博客项目模版,支持 Vercel 一键部署,支持评论、搜索等功能,欢迎大家体验。 在线预览 1. Layer 标语:受大脑启发的规划器 介绍:体验一下这款新一代的任务和项目管理系统吧!它…

React Native 基础

React 的核心概念 定义函数式组件 import组件 要定义一个Cat组件,第一步要使用 import 语句来引入React以及React Native的 Text 组件: import React from react; import { Text } from react-native; 定义函数作为组件 const CatApp = () => {}; 渲染Text组件

一文详细了解websocket应用以及连接断开的解决方案

文章目录 websocketvite 热启动探索websocket -心跳websocket 事件监听应用过程中问题总结 websocket Websocket简介 定义和工作原理 Websocket是一种在单个TCP连接上进行全双工通信的协议。与传统的HTTP请求 - 响应模式不同,它允许服务器主动向客户端推送数据。例…

Vue 3与TypeScript集成指南:构建类型安全的前端应用

在Vue 3中使用TypeScript,可以让你的组件更加健壮和易于维护。以下是使用TypeScript与Vue 3结合的详细步骤和知识点: 1. 环境搭建 首先,确保你安装了Node.js(推荐使用最新的LTS版本)和npm或Yarn。然后,安…

React-useRef与DOM操作

#题引:我认为跟着官方文档学习不会走歪路 ref使用 组件重新渲染时,react组件函数里的代码会重新执行,返回新的JSX,当你希望组件“记住”某些信息,但又不想让这些信息触发新的渲染时,你可以使用ref&#x…

# Spring事务

Spring事务 什么是spring的事务? 在Spring框架中,事务管理是一种控制数据库操作执行边界的技术,确保一系列操作要么全部成功,要么全部失败,从而维护数据的一致性和完整性。Spring的事务管理主要关注以下几点&#xf…

Jenkins更换主题颜色+登录页面LOGO图片

默认主题和logo图片展示 默认主题黑色和白色。 默认LOGO图片 安装插件 Login ThemeMaterial Theme 系统管理–>插件管理–>Available plugins 搜不到Login Theme是因为我提前装好了 没有外网的可以参考这篇离线安装插件 验证插件并修改主题颜色 系统管理–>A…

LLM文档对话 —— pdf解析关键问题

一、为什么需要进行pdf解析? 最近在探索ChatPDF和ChatDoc等方案的思路,也就是用LLM实现文档助手。在此记录一些难题和解决方案,首先讲解主要思想,其次以问题回答的形式展开。 二、为什么需要对pdf进行解析? 当利用L…

【虚幻引擎】UE5数字人开发实战教程

本套课程将会交大家如何去开发属于自己的数字人,包含大模型接入,流式输出,语音识别,语音合成,口型驱动,动画蓝图,语音唤醒等功能。 课程介绍视频如下: 【虚幻引擎】UE5 历时一个多月…

上位机编程命名规范

1.大小写规范 文件名全部小写是一种广泛使用的命名约定,特别是在跨平台开发和开源项目中。主要原因涉及技术约束、可读性和一致性等方面。以下是原因和优劣势的详细分析: 1. 避免跨平台问题 不同操作系统对文件名的大小写处理方式不同: Li…

JAVA:探索 PDF 文字提取的技术指南

1、简述 随着信息化的发展,PDF 文档成为了信息传播的重要媒介。在许多应用场景下,如数据迁移、内容分析和信息检索,我们需要从 PDF 文件中提取文字内容。JAVA提供了多种库来处理 PDF 文件,其中 PDFBox 和 iText 是最常用的两个。…

form表单的使用

模板 <template><el-form :model"formData" ref"form1Ref" :rules"rules"><el-form-item label"手机号" prop"tel"><el-input v-model"formData.tel" /></el-form-item><el-f…

【priority_queue的使用及模拟实现】—— 我与C++的不解之缘(十六)

前言 ​ priority_queue&#xff0c;翻译过来就是优先级队列&#xff0c;但是它其实是我们的堆结构&#xff08;如果堆一些遗忘的可以看一下前面的文章复习一下【数据结构】二叉树——顺序结构——堆及其实现_二叉树顺序结构-CSDN博客&#xff09;&#xff0c;本篇文章就来使用…

php 与 thinkphp 13 张 表 关联 查询,a.pry_key=b.pry_key and c.pry_key= b.pry_key 代码示例

在 PHP 中&#xff0c;假设你有 13 张表并且这些表之间通过 pry_key 关联&#xff0c;你可以使用 SQL 的 JOIN 来将这些表连接在一起&#xff0c;然后通过 PHP 执行该查询。以下是一个简化的示例&#xff0c;展示如何通过 JOIN 语句将 13 张表联接&#xff0c;并使用 PHP 代码执…

MacOS下的Opencv3.4.16的编译

前言 MacOS下编译opencv还是有点麻烦的。 1、Opencv3.4.16的下载 注意&#xff0c;我们使用的是Mac&#xff0c;所以ios pack并不能使用。 如何嫌官网上下载比较慢的话&#xff0c;可以考虑在csdn网站上下载&#xff0c;应该也是可以找到的。 2、cmake的下载 官网的链接&…

Kibana 本地安装使用

一 Kibana简介 1.1 Kibana 是一种数据可视化工具&#xff0c;通常需要结合Elasticsearch使用&#xff1a; Elasticsearch 是一个实时分布式搜索和分析引擎。 Logstash 为用户提供数据采集、转换、优化和输出的能力。 Kibana 是一种数据可视化工具&#xff0c;为 Elasticsear…