LLaMA-Factory GLM4-9B-CHAT LoRA 微调实战

🤩LLaMA-Factory GLM LoRA 微调

安装llama-factory包

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git

进入下载好的llama-factory,安装依赖包

cd LLaMA-Factory
pip install -e ".[torch,metrics]"
#上面这步操作会完成torch、transformers、datasets等相关依赖包的安装

LLaMAFactory目录结构

大模型下载

modelscope代码方式下载

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat', cache_dir='/root/autodl-tmp', revision='master')

git clone方式下载

git lfs install # 大文件传输
sudo apt-get install git-lfsgit clone https://www.modelscope.cn/ZhipuAI/glm-4-9b-chat.git

指令数据集构建-Alpaca 格式

Alpaca 格式是一种用于训练自然语言处理(NLP)模型的数据集格式,特别是在对话系统和问答系统中。这种格式通常包含指令(instruction)、输入(input)和输出(output)三个部分,它们分别对应模型的提示、模型的输入和模型的预期输出。三者的数据都是字符串形式

Alpaca 格式训练数据集

[{"instruction": "Answer the following question about the movie 'Inception'.","input": "What is the main theme of the movie?","output": "The main theme of the movie 'Inception' is the exploration of dreams and reality."},{"instruction": "Provide a summary of the book '1984' by George Orwell.","input": "What is the book '1984' about?","output": "The book '1984' is a dystopian novel that tells the story of a totalitarian regime and the protagonist's struggle against it."}
]

我的数据集构建如下

crop_train.json

[{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "煤是一种常见的化石燃料,家庭用煤经过了从\"煤球\"到\"蜂窝煤\"的演变。","output": "[{\"head\": \"煤\", \"relation\": \"use\", \"tail\": \"燃料\"}]"},{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "内分泌疾病是指内分泌腺或内分泌组织本身的分泌功能和(或)结构异常时发生的症候群。","output": "[{\"head\": \"腺\", \"relation\": \"use\", \"tail\": \"分泌\"}]"},
]

我的数据格式转换代码:

import json
import re# 选择要格式转换的数据集
file_name = "merged_trainProcess.json"# 读取原始数据
with open(f'./{file_name}', 'r', encoding='utf-8') as file:data = json.load(file)# 转换数据格式
converted_data = [{"instruction": item["instruction"],"input": item["text"],"output": json.dumps(item["triplets"], ensure_ascii=False),} for item in data]# 将转换后的数据写入新文件
output_file_name = f'processed_{file_name}'
with open(output_file_name, 'w', encoding='utf-8') as file:json.dump(converted_data, file, ensure_ascii=False, indent=4)print(f'{output_file_name} Done')

构建好后保存到llama-factory目录中某文件下

修改 LLaMa-Factory 目录中的 data/dataset_info.json 文件,在其中添加:

"crop_merged": {"file_name": "/home/featurize/data/crop_train.json"		#自己的训练数据集.json文件的绝对路径}

微调模型代码

在 LLaMA-Factory 目录中**新建配置文件 crop_glm4_lora_sft.yaml :**

### model:glm-4-9b-chat模型地址的绝对路径
model_name_or_path: /home/featurize/glm-4-9b-chat### method
stage: sft			# supervised fine-tuning(监督式微调)
do_train: true		# 是否执行训练过程
finetuning_type: lora		# 微调技术的类型
lora_target: all			# all 表示对模型的所有参数进行 LoRA 微调### dataset
# dataset 要和 data/dataset_info.json 中添加的信息保持一致
dataset: crop_merged			# 数据集的名称
template: glm4					# 数据集的模板类型
cutoff_len: 2048				# 输入序列的最大长度
max_samples: 100000				# 最大样本数量, 代表在训练过程中最多只会用到训练集的1000条数据;-1代表训练所有训练数据集
overwrite_cache: true	
preprocessing_num_workers: 16		# 数据预处理时使用的进程数			### output
# output_dir是模型训练过程中的checkpoint,训练日志等的保存目录
output_dir: saves/crop-glm4-epoch10/lora/sft
logging_steps: 10		# 日志记录的频率
#save_steps: 500
plot_loss: true			# 绘制损失曲线
overwrite_output_dir: true		# 是否覆盖输出目录中的现有内容
save_strategy: epoch			# 保存策略,epoch 表示每个 epoch 结束时保存。### train
per_device_train_batch_size: 1		# 每个设备上的批次大小,每增加几乎显存翻倍
gradient_accumulation_steps: 8		# 梯度累积的步数,这里设置为 8,意味着每 8 步执行一次梯度更新。
learning_rate: 1.0e-4				# 学习率
num_train_epochs: 3				# 训练的 epoch 数
lr_scheduler_type: cosine			#学习率调度器的类型
warmup_ratio: 0.1
fp16: true							# 更适合模型推理,混合精度训练,在训练中同时使用FP16和FP32
#bf16: true							# 更适合训练阶段
gradient_checkpointing: true		# 启用梯度检查点(gradient checkpointing),减少中间激活值的显存占用### eval
do_eval: false				# 是否进行评估
val_size: 0.1				# 验证集的大小比例,这里设置为 0.1,即 10%
per_device_eval_batch_size: 1		# 每个设备上的评估批次大小
eval_strategy: steps		# 评估策略,steps 表示根据步数进行评估
eval_steps: 1000				# 评估的步数间隔
### model 
model_name_or_path: /home/featurize/glm-4-9b-chat### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all### dataset
dataset: crop_merged
template: glm4
cutoff_len: 512
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: saves/crop-glm4-epoch10/lora/sft
logging_steps: 100
save_steps: 1000
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 1
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
gradient_checkpointing: true### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: epoch

执行以下命令开始微调:

cd LLaMA-Factory
llamafactory-cli train crop_glm4_lora_sft.yaml

合并模型代码

训练完成后,在 LLaMA-Factory 目录中**新建配置文件 crop_glm4_lora_sft_export.yaml:**

导出的模型将被保存在models/CropLLM-glm-4-9b-chat目录下。

### model
model_name_or_path: /home/featurize/glm-4-9b-chat
# 刚才crop_glm4_lora_sft.yaml文件中的 output_dir
adapter_name_or_path: saves/crop-glm4-epoch10/lora/sft
template: glm4
finetuning_type: lora### export
export_dir: models/CropLLM-glm-4-9b-chat
export_size: 2
export_device: cpu
export_legacy_format: false

半精度(FP16、BF16)

执行以下命令开始合并模型:

cd LLaMA-Factory
llamafactory-cli export crop_glm4_lora_sft_export.yaml

models/CropLLM-glm-4-9b-chat 目录中就可以获得经过Lora微调后的完整模型

在这里插入图片描述

推理预测

import torch
from transformers import AutoModelForCausalLM, AutoTokenizermodel_path = "/home/featurize/LLaMA-Factory/models/CropLLM-glm-4-9b-chat"device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)system_prompt = "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。"
input = "玉米黑粉病又名“乌霉”和“瘤黑粉病”,病原菌是真菌,:担孢子菌,是由于玉米黑粉菌所引起的一种局部浸染性病害。病瘤内的黑粉是病菌的"inputs = tokenizer.apply_chat_template([{"role": "system", "content": system_prompt},{"role": "user", "content": input}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True)inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,trust_remote_code=True
).to(device).eval()gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)outputs = outputs[:, inputs['input_ids'].shape[1]:]print(f"model: {model_path}")print(f"task: {system_prompt}")print(f"input: {input}")output = tokenizer.decode(outputs[0], skip_special_tokens=True)print(f"output: {output}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/65076.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

view draw aosp15

基础/背景知识 如何理解Drawable? 在 Android 中,Drawable 是一个抽象的概念,表示可以绘制到屏幕上的内容。 它可以是位图图像、矢量图形、形状、颜色等。 Drawable 本身并不是一个 View,它不能直接添加到布局中,而是…

gridcontrol表格某一列设置成复选框,选择多行(repositoryItemCheckEdit1)

1. 往表格中添加repositoryItemCheckEdit1 2. 事件: repositoryItemCheckEdit1.QueryCheckStateByValue repositoryItemCheckEdit1_QueryCheckStateByValue; private void repositoryItemCheckEdit1_QueryCheckStateByValue(object sender, DevExpress.XtraEditor…

重温设计模式--适配器模式

文章目录 适配器模式(Adapter Pattern)概述适配器模式UML图适配器模式的结构目标接口(Target):适配器(Adapter):被适配者(Adaptee): 作用&#xf…

C语言项目 天天酷跑(上篇)

前言 这里讲述这个天天酷跑是怎么实现的,我会在天天酷跑的下篇添加源代码,这里会讲述天天酷跑这个项目是如何实现的每一个思路,都是作者自己学习于别人的代码而创作的项目和思路,这个代码和网上有些许不一样,因为掺杂了…

公交车信息管理系统:构建智能城市交通的基石

程序设计 本系统主要使用Java语言编码设计功能,MySQL数据库管控数据信息,SSM框架创建系统架构,通过这些关键技术对系统进行详细设计,设计和实现系统相关的功能模块。最后对系统进行测试,这一环节的结果,基本…

MDS-NPV/NPIV

在存储区域网络(SAN)中,域ID(Domain ID)是一个用于区分不同存储区域的关键参数。域ID允许SAN环境中的不同部分独立操作,从而提高效率和安全性。以下是关于域ID的一些关键信息: 域ID的作用&…

【网络安全产品大调研系列】1. 漏洞扫描

1. 为什么会出现漏扫技术? 每次黑客攻击事件进行追溯的时候,根据日志分析后,我们往往发现基本都是系统、Web、 弱口令、配置这四个方面中的其中一个出现的安全问题导致黑客可以轻松入侵的。 操作系统的版本滞后,没有更新补丁&am…

验证 Dijkstra 算法程序输出的奥秘

一、引言 Dijkstra 算法作为解决图中单源最短路径问题的经典算法,在网络路由、交通规划、资源分配等众多领域有着广泛应用。其通过不断选择距离源节点最近的未访问节点,逐步更新邻居节点的最短路径信息,以求得从源节点到其他所有节点的最短路径。在实际应用中,确保 Dijkst…

【论文阅读笔记】Learning to sample

Learning to sample 前沿引言方法问题声明S-NET匹配ProgressiveNet: sampling as ordering 实验分类检索重建 结论附录 前沿 这是一篇比较经典的基于深度学习的点云下采样方法 核心创新点: 首次提出了一种学习驱动的、任务特定的点云采样方法引入了两种采样网络&…

Mysql大数据量表分页查询性能优化

一、模拟场景 1、产品表t_product,数据量500万+ 2、未做任何优化前,cout查询时间大约4秒;LIMIT offset, count 时,offset 值较大时查询时间越久。 count查询 SELECT COUNT(*) AS total FROM t_product WHERE deleted = 0 AND tenant_id = 1 分页查询 SELECT * FROM t_…

pythonWeb~伍~初识Django

初识Django 1.技术栈 Python知识点:函数、面向对象。前端知识点:HTML、CSS、JavaScript、jQuery、BootStrap。MySQL数据库。Python的Web框架: Flask,自身短小精悍 第三方组件。Django,内部已集成了很多组件 第三方…

kimi搜索AI多线程批量生成txt原创文章软件-不需要账号及key

kimi搜索AI多线程批量生成txt原创文章软件介绍: 软件可以设置三种模型写文章:kimi:默认AI模型,kimi-search:联网检索模型 ,kimi-research:探索版搜索聚合模型 1、可以设置写联网搜索文章&#…

DevNow x Notion

前言 Notion 应该是目前用户量比较大的一个在线笔记软件,它的文档系统也非常完善,支持多种文档格式,如 Markdown、富文本、表格、公式等。 早期我也用过一段时间,后来有点不习惯,就换到了 Obsidian ,但是…

CSPM认证最推荐学习哪个级别?

一、什么是CSPM? CSPM的全称是Certified Strategic Project Manager,中文名称为“项目管理专业人员能力评价等级证书”。这是由中国标准化协会依据国家标准《项目管理专业人员能力评价要求》(GB/T 41831-2022)推出的一项认证&…

oracle: create new database

用database configuration Assistant 引导创建数据库。记得给system,sys 设置自己的口令,便于添加新操作用户。 创建操作用户: -- 别加双引号,否则,无法用 create user geovindu identified by 888888; create user geovin identi…

快速建站(网站如何在自己的电脑里跑起来) 详细步骤 一

1.选择开源网站 开源网站的建设平台有多种类型,每种类型都针对不同的需求和用途。我们根据自己的需求来选择 一、内容管理系统(CMS) 内容管理系统是开源网站建设中最常见的类型之一,它们提供了一个易于使用的界面,使用…

ECharts散点图-气泡图,附视频讲解与代码下载

引言: ECharts散点图是一种常见的数据可视化图表类型,它通过在二维坐标系或其它坐标系中绘制散乱的点来展示数据之间的关系。本文将详细介绍如何使用ECharts库实现一个散点图,包括图表效果预览、视频讲解及代码下载,让你轻松掌握…

【蓝桥杯——物联网设计与开发】拓展模块5 - 光敏/热释电模块

目录 一、光敏/热释电模块 (1)资源介绍 🔅原理图 🔅AS312 🌙简介 🌙特性 🔅LDR (2)STM32CubeMX 软件配置 (3)代码编写 (4&#x…

Escalate_Linux靶机

Escalate_Linux靶机 前言:集合了多种liunx提权方法的靶场,通过该靶场可以简单的了解liunx提权方法 1,扫描一下端口 80/tcp open http 111/tcp open rpc 2049/udp nfs要知道对方的共享才能挂载 139/445 Samba SMB是一个协议名&#xff0c…

PPO算法基础(一)

PPO近端策略优化算法 我们今天还是主要来理解PPO算法的数学原理。PPO是一种策略梯度方法,简单的策略梯度对每个样本(或者一组样本)进行一次梯度更新,对单个样本执行多个梯度步骤会导致一些问题,因为梯度偏差太大&…