[大模型]BlueLM-7B-Chat Lora 微调

BlueLM-7B-Chat Lora 微调

概述

本节我们简要介绍如何基于 transformers、peft 等框架,对 BlueLM-7B-Chat 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出Lora。

这个教程会在同目录下给大家提供一个 [notebook](./04-BlueLM-7B-Chat Lora 微调.ipynb) 文件,来让大家更好的学习。

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

pip install transformers==4.35.2
pip install peft==0.4.0
pip install datasets==2.10.1
pip install accelerate==0.20.3
pip install tiktoken
pip install transformers_stream_generator

在本节教程里,我们将微调数据集放置在根目录 /dataset。

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{"instruction": "解释什么是人工智能。\n","input": "","output": "人工智能是一种利用计算机程序和算法创造出类似人类智能的技术,可以让计算机在解决问题、学习、推理和自然语言处理等方面表现出类似人类的能力。"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。而在BlueLM中数据的目标格式是这样的

{"inputs": "[|Human|]:解释什么是人工智能。\n[|AI|]:", "targets": "人工智能是一种利用计算机程序和算法创造出类似人类智能的技术,可以让计算机在解决问题、学习、推理和自然语言处理等方面表现出类似人类的能力。"}

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):MAX_LENGTH = 384input_ids = []labels = []instruction = tokenizer(text=f"[|Human|]:现在你要扮演皇帝身边的女人--甄嬛\n\n {example['instruction']}{example['input']}[|AI|]:", add_special_tokens=False)response = tokenizer(text=f"{example['output']}", add_special_tokens=False)input_ids = [tokenizer.bos_token_id] + instruction["input_ids"] + response["input_ids"] + [tokenizer.eos_token_id]labels = [tokenizer.bos_token_id] + [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.eos_token_id]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"labels": labels}

经过格式化的数据,也就是送入模型的每一条数据,都是一个字典,包含了 input_idslabels 两个键值对,其中 input_ids 是输入文本的编码,labels 是输出文本的编码。decode之后应该是这样的:

<s> [|Human|]: 现在你要扮演皇帝身边的女人--甄嬛\n\n 这个温太医啊,也是古怪,谁不知太医不得皇命不能为皇族以外的人请脉诊病,他倒好,十天半月便往咱们府里跑。 [|AI|]:  你们俩话太多了,我该和温太医要一剂药,好好治治你们。</s>

为什么会是这个形态呢?好问题!不同模型所对应的格式化输入都不一样,BlueLM只有[|Human|]和[|AI|]两个角色,所以自然而然数据格式就是这样的啦。

加载tokenizer和模型

import torchmodel = AutoModelForCausalLM.from_pretrained('vivo-ai/BlueLM-7B-Chat', trust_remote_code=True, torch_dtype=torch.half, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained('vivo-ai/BlueLM-7B-Chat')
model.generation_config.pad_token_id = model.generation_config.eos_token_id

定义LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是4倍。

config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["c_attn", "c_proj", "w1", "w2"],inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(output_dir="./output/Qwen",per_device_train_batch_size=8,gradient_accumulation_steps=2,logging_steps=10,num_train_epochs=3,gradient_checkpointing=True,save_steps=100,learning_rate=1e-4,save_on_each_node=True
)

使用 Trainer 训练

把 model 放进去,把上面设置的参数放进去,数据集放进去,OK!开始训练!

trainer = Trainer(model=model,args=args,train_dataset=tokenized_id,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

模型推理

使用最常用的方式进行推理

text = "小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——"
inputs = tokenizer(f"[|Human|]:{text}[|AI|]:", return_tensors="pt")
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

完整代码如下:

from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig
import torch
from peft import LoraConfig, TaskType, get_peft_modeldef process_func(example):MAX_LENGTH = 384input_ids = []labels = []instruction = tokenizer(text=f"[|Human|]:现在你要扮演皇帝身边的女人--甄嬛\n\n {example['instruction']}{example['input']}[|AI|]:",add_special_tokens=False)response = tokenizer(text=f"{example['output']}", add_special_tokens=False)input_ids = [tokenizer.bos_token_id] + instruction["input_ids"] + response["input_ids"] + [tokenizer.eos_token_id]labels = [tokenizer.bos_token_id] + [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.eos_token_id]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"labels": labels}# lora配置
config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)# 训练参数
args = TrainingArguments(output_dir="./output/BlueLM",per_device_train_batch_size=8,gradient_accumulation_steps=2,logging_steps=10,num_train_epochs=3,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True
)if __name__ == '__main__':# 将JSON文件转换为CSV文件df = pd.read_json('./huanhuan.json')ds = Dataset.from_pandas(df)# 加载tokenizertokenizer = AutoTokenizer.from_pretrained('vivo-ai/BlueLM-7B-Chat', use_fast=False, trust_remote_code=True)# 将数据集变化为token形式tokenized_id = ds.map(process_func, remove_columns=ds.column_names)# 创建模型model = AutoModelForCausalLM.from_pretrained('vivo-ai/BlueLM-7B-Chat', trust_remote_code=True,torch_dtype=torch.half, device_map="auto")model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法# 模型合并model = get_peft_model(model, config)# 使用trainer训练trainer = Trainer(model=model,args=args,train_dataset=tokenized_id,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),)trainer.train()  # 开始训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/814424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VMware 替代专题|金融、制造、医疗等行业用户实践合集(含虚拟化、vSAN、整体替代)

随着 VMware 调整产品组合和订阅模式&#xff0c;不少国内用户都将寻找 VMware 的替代方案提上日程。根据我们在 3 月初 VMware 升级替代研讨会上收集的用户反馈&#xff0c;近 50% 的 VMware 用户已计划使用其他厂商的超融合方案或相关产品替换 VMware 超融合或部分组件。 在…

使用python互相转换AVI、MP4、GIF格式视频文件

一、AVI文件转MP4文件 要将AVI格式的视频转换为 MP4&#xff0c;你可以使用 Python的 moviepy 库。以下是一个示例代码&#xff0c;用于将 AVI 文件转换为 MP4 文件&#xff1a; from moviepy.editor import VideoFileClip# 读取 AVI 文件 clip VideoFileClip("input.a…

【spring】AOP切面注解学习(二)

文接上篇&#xff1a;【spring】AOP切面注解学习&#xff08;一&#xff09; AOP切面注解测试示例代码 示例代码 一 maven的pom文件导入 <dependency><groupId>org.springframework</groupId><artifactId>spring-aop</artifactId></depende…

使用Kotlin进行全栈开发 Ktor+Kotlin/JS

首发于Enaium的个人博客 前言 本文将介绍如何使用 Kotlin 全栈技术栈KtorKotlin/JS来构建一个简单的全栈应用。 准备工作 创建项目 首先我们需要创建一个Kotlin项目&#xff0c;之后继续在其中新建两个子项目&#xff0c;一个是Kotlin/JS项目&#xff0c;另一个是Ktor项目。…

上海计算机学会 2023年10月月赛 丙组T1 三个数的中位数(模拟)

第一题&#xff1a;T1三个数的中位数 标签&#xff1a;模拟题意&#xff1a;给定三个整数&#xff0c;请输出按大小排序后&#xff0c;位于正中间的数字。题解&#xff1a;给三个数从小到大排序&#xff0c;输出中间的即可。代码&#xff1a; #include <bits/stdc.h> u…

itop4412内核编译_编译自定义函数到内核

我的itop4412开发板是半路捡的&#xff0c;所以没办法加他们的售后群&#xff0c;遇到的问题只好一点点记录吧 内核驱动编译 在日常工作过程中&#xff0c;编写内核程序可能机会不多&#xff0c;但是将厂商提供的内核源码编译到固件中&#xff0c;这个技能还是必须掌握的。 i…

每天学习一个Linux命令之w

每天学习一个Linux命令之w 介绍&#xff1a; 在Linux操作系统中&#xff0c;我们经常需要查看当前登录用户信息、系统负载以及其他用户的登录情况。w命令就是一个很常用的命令&#xff0c;它可以提供这些信息。本篇博客将详细介绍w命令及其所有可用的选项&#xff0c;帮助你更…

Redis入门到通关之String命令

文章目录 ⛄1 String 介绍⛄2 命令⛄3 对应 RedisTemplate API❄️❄️ 3.1 添加缓存❄️❄️ 3.2 设置过期时间(单独设置)❄️❄️ 3.3 获取缓存值❄️❄️ 3.4 删除key❄️❄️ 3.5 顺序递增❄️❄️ 3.6 顺序递减 ⛄4 以下是一些常用的API⛄5 应用场景 ⛄1 String 介绍 Stri…

Asterisk 21.2.0编译安装经常遇到的问题和解决办法之卸载pjsip

目录 会安装也要会卸载make uninstallldconfig 会安装也要会卸载 有些人就只会装。 最常见的场景就是需要卸载之前版本的pjproject。 一般来说&#xff0c;其他版本的 pjproject 会被作为静态链接库安装。这些库跟 Asterisk可能不兼容。 因此&#xff0c;在安装正确版本的pjpro…

连锁收银系统哪个好用 国内三大连锁收银系统评比

随着数字化管理趋势下互联网技术的不断发展革新&#xff0c;互联网技术&#xff0c;以及不断升级优化传统行业渠道模式&#xff0c;线上线下结合的电子商务模式正逐渐成为企业发展的趋势。而门店管理系统也在越来越多的企业应用。但市场上连锁店管理系统品牌诸多&#xff0c;很…

生产事故:线程管理不善诱发P0故障

背景 处于业务诉求&#xff0c;需要建立一个统一的调度平台&#xff0c;最终是基于 Dolphinscheduler 的 V1.3.6 版本去做二次开发。在平台调研建立时&#xff0c;这个版本是最新的版本 命运之轮开始转动 事故 表象 上班后业务部门反馈工作流阻塞&#xff0c;登录系统发现大…

设计模式(23):访问者模式

定义 表示一个作用于某对象结构中的各元素的操作&#xff0c;它使我们可以在不改变元素的类的前提下定义作用与这些元素的新操作。 模式动机 对于存储在一个集合中的对象&#xff0c;他们可能具有不同的类型(即使有一个公共的接口)&#xff0c;对于该集合中的对象&#xff0…

Java-博客系统(前后端交互)

目录 前言 博客系统基本情况 1 创建项目&#xff0c;引入依赖 2 数据库设计 2.1 分析 2.2 建库建表 3 封装数据库 3.1 在java目录下创建DBUtil类&#xff0c;通过这个类对数据库进行封装 3.2 在java目录下创建实体类&#xff08;博客类Blog&#xff09; 3.2 在java目录下创建…

docker nginx-lua发送post json 请求

环境准备 dockerfile from fabiocicerchia/nginx-lua:1.25.3-ubuntu22.04 run apt-get -qq update && apt-get -qq install luarocks run luarocks install lua-cjson run luarocks install lua-iconv run luarocks install lua-resty-http后台代理服务准备&#xff…

3D场景编辑方法——CustomNeRF

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 摘要Abstract文献阅读&#xff1a;3D场景编辑方法——CustomNeRF1、研究背景2、提出方法3、CustomNeRF3.1、整体框架步骤3.2、对特定问题的解决 4、实验结果5、总结…

组合模式:构建树形对象结构的设计艺术

在软件开发中&#xff0c;组合模式是一种结构型设计模式&#xff0c;用于表示对象的部分-整体层次结构。通过使单个对象和组合对象具有相同的接口&#xff0c;这种模式允许客户端以统一的方式处理单个对象和组合对象。本文将详细介绍组合模式的定义、实现、应用场景以及优缺点。…

自动化运维(二十六)Ansible 实战变量插件和连接插件

Ansible 支持多种类型的插件&#xff0c;这些插件可以帮助你扩展和定制 Ansible 的功能。每种插件类型都有其特定的用途和应用场景。今天我们一起学习变量插件和连接插件。 一、变量插件 Ansible 变量插件允许动态地添加变量到主机或组中&#xff0c;这些变量可以在 playbook…

.net Web Api Post请求传递数据

.net c#调用Web Api Post请求传输数据&#xff0c;用.net8一直传不了自定义的json格式数据&#xff0c;后面找到用实体传递Api那边用一样字段的实体接收才能正常传输数据。记录一下 var mails new {Name "tt",Hobby "test" }; string json JsonConv…

2024HW--->入侵排查

在蓝队的面试中&#xff0c;我们有可能会被问到对可能被入侵的机器&#xff0c;怎么样去排查&#xff0c;下面就来总结一下 1.Windows入侵排查 1.检查系统账号的安全 检测系统账号&#xff0c;其实最重要的就是一个点 "查看服务器是否存在可疑账号、新增账号。" 最…

数据结构课程设计选做(一)---数字排序(哈希、排序)

2.1.1 题目内容 2.1.1-A [问题描述] 给定n个整数&#xff0c;请统计出每个整数出现的次数&#xff0c;按出现次数从多到少的顺序输出。 2.1.1-B [基本要求] &#xff08;1&#xff09;输入格式&#xff1a; 输入的第一行包含一个整数n&#xff0c;表示给定数字的个数。 第二…