使用 LlamaFactory 结合开源大语言模型实现文本分类:从数据集构建到 LoRA 微调与推理评估

文章目录

    • 背景介绍
      • 文本分类数据集
      • Lora 微调
      • 模型部署与推理
        • 期待模型的输出结果
    • 文本分类评估代码

背景介绍

本文将一步一步地,介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验,以 LoRA微调 qwen/Qwen2.5-7B-Instruct 为例。

文本分类数据集

按照 alpaca 样式构建数据集,并在将其添加到 LLaMA-Factory/data/dataset_info.json 文件中。如此方便直接根据自定义数据集的名字,获取到数据集的数据。

[{"instruction": "","input": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:\n\n要求}}\nreason: \nlabel:","output": "reason: 该文本主要讨论的是xxx。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"},...
]

Lora 微调

llamafactory 框架支持网页端训练,但本文选择在终端使用命令行微调模型。

模型微调训练的参数较多,将模型训练的参数都存储在 yaml 文件中。

qwen_train_cls.yaml 的文件内容如下:

### model
model_name_or_path: qwen/Qwen2.5-7B-Instruct### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all### dataset
# dataset_dir: data
dataset_dir: LLaMA-Factory/data/ 填写相应路径
dataset: 数据集名 
template: qwen
cutoff_len: 2048
# max_samples: 1000 若数据集较大,可随机筛选一部分数据微调模型
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: output/qwen2.5-7B/cls_epoch2 训练的LoRA权重输出路径
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

使用下述命令启动模型训练:

nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &

命令分解介绍:
nohup, 全称为 “no hangup”(不要挂起)。它的作用是让命令在退出终端后仍然运行,防止因关闭终端或会话中断导致进程被终止。
默认情况下,nohup 会将输出重定向到 nohup.out 文件,但这里已经显式指定了输出位置。
llamafactory-cli train qwen_train_cls.yaml 运行 llamafactory-cli 工具,用于执行训练任务。
train 是子命令,表示进行训练。
qwen_train_cls.yaml 是一个配置文件,包含训练所需的超参数、数据路径、模型结构等。
qwen_train_cls.log
将标准输出 (stdout) 重定向到 qwen_train_cls.log 文件中。
即运行过程中的正常日志信息会被记录到这个文件。
2>&1: 将标准错误输出 (stderr) 重定向到标准输出 (stdout)。
这样,所有错误信息也会被写入到 qwen_train_cls.log 文件中。
&: 表示将整个命令放到后台运行。终端会立即返回,您可以继续进行其他操作,而不用等待命令完成。

模型部署与推理

模型训练完成后得到 Lora 权重。相关微调模型部署与推理,请浏览下述两篇文章,相比llamafactory原本的模型推理速度更快。

  • 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
  • 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

目前llamafactory已经支持 vllm_infer 推理,这个PR是笔者提交的:

  • llamafactory vllm.https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py
期待模型的输出结果

下述是使用 llamafactory 推理出的数据格式,建议大家在做推理评估时,也做成这个样式,方便统一评估。

{"prompt": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:...", "predict": "\nreason: 该文本主要讨论了改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。这些内容与社会管理和经济管理密切相关,但更侧重于行政管理和社会治理的改革,因此更符合“社会管理”这一类别。\n\nlabel: 社会管理", 		 "label": "reason: 该文本主要讨论的是改革创新、行政区划调整、体制机制障碍的破除以及行政管理体制等与政府治理和社会管理相关的内容,强调了与高质量发展和生活品质的关系。这些内容显示出对社会管理和行政管理的关注,尤其是在推动城乡一体化和适应高质量发展要求方面。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
}

文本分类评估代码

import os
import re
import jsonfrom sklearn.metrics import classification_report, confusion_matrix# 文本类别
CLASS_NAME = ["产业相关",..."法律法规与行政事务","其他",
]def load_jsonl(file_path):"""加载指定路径的 JSON 文件并返回解析后的数据。:param file_path: JSON 文件的路径:return: 解析后的数据(通常是字典或列表):raises FileNotFoundError: 如果文件未找到:raises json.JSONDecodeError: 如果 JSON 格式不正确"""data = []try:with open(file_path, "r", encoding="utf-8") as file:for line in file:tmp = json.loads(line)data.append(tmp)except FileNotFoundError as e:print(f"文件未找到:{file_path}")raise eexcept json.JSONDecodeError as e:print(f"JSON 格式错误:{e}")raise ereturn datadef parser_label(text: str):pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"matches = re.findall(pattern, text, re.DOTALL)if len(matches) == 1:return matches[0]return Nonedef trans2num(item):predict = parser_label(item["predict"])label = parser_label(item["label"])predict_idx = -1label_idx = -1for idx, cls_name in enumerate(CLASS_NAME):if predict == cls_name:predict_idx = idxif label == cls_name:label_idx = idxreturn predict_idx, label_idxdef cls_eval(input_file):data = load_jsonl(file_path=input_file)predicts = []labels = []for item in data:predict, label = trans2num(item)if label == -1:continuepredicts.append(predict)labels.append(label)return classification_report(predicts, labels, output_dict=False)

本文使用了大模型生成式预测文本类别,我没有使用结构化输出的方式,大家可以使用结构化的json格式输出,这样在提取大模型预测结果的时候会方便很多。

大家按照自己模型的输出结果,修改parser_label 函数,这个函数用于从大模型的输出结果提取label。

cls_eval("xxx/generated_predictions.jsonl")

就会得到下述的输出结果:

-1 代表模型预测的类别不在给定的类别中。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/63241.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

发论文参考文献部分怎么注明数据集出处gitee

见的参考文献标注格式(如APA、MLA、Chicago等),电子文献或网络资源的标注通常包括作者(或组织)、标题、发布年份、获取路径(URL)等信息。 二、具体步骤 查找数据集信息: 在Gitee上找…

ARM内核与单片机

1.单片机硬件架构如下所示:各种硬件通过总线进行连接。 2.M4内核架构 3.单片机如何工作: 4.CPU是通过读写寄存器来控制GPIO的 5.GPIO的硬件框架:一共有8种模式 (1)推挽/推挽复用输出。下图先看图1,如果输入…

vue3:mitt

在 Vue 3 中使用 mitt 进行事件总线的实现非常简单。mitt 是一个轻量级的事件库,适用于 Vue 项目中的组件间通信。 实现自定义组件直接相互传值,父到子,子到子,子对子,子对孙,想怎么传就怎么传。和android…

PHP 命令执行漏洞学习记录

PHP 命令执行 命令函数 作用 例子 system() 执行外部程序,并且显示输出 system(whoami) exec() 执行一个外部程序 echo exec(whoami); shell_exec() 通过shell环境执行命令,并且将完整的输出以字符串的形式返回 echo shell_exec(whoami); passthru() 执行外部程序…

VSCode GDB远程嵌入开发板调试

VSCode GDB远程嵌入式开发板调试 一、原理 嵌入式系统中一般在 PC端运行 gdb工具,源码也是在 PC端,源码对应的可执行文件放到开发板中运行。为此我们需要在开发板中运行 gdbserver,通过网络与 PC端的 gdb进行通信。因此要想在 PC上通过 gdb…

【机器学习】机器学习的基本分类-无监督学习(Unsupervised Learning)

无监督学习(Unsupervised Learning) 无监督学习是一种机器学习方法,主要用于没有标签的数据集。其目标是从数据中挖掘出潜在的结构和模式。常见的无监督学习任务包括 聚类、降维、密度估计 和 异常检测。 1. 无监督学习的核心目标 1.1 聚类…

【Python]深入Python日志管理:从logging到分布式日志追踪的完整指南

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 日志是软件开发中的核心部分,尤其在分布式系统中,日志对于调试和问题定位至关重要。本篇文章将从Python标准库的logging模块出发,逐步探讨日志管理的最佳实践,涵盖日志配置、日志分层、日志格式化等基…

专业140+总分420+上海交通大学819考研经验上交电子信息与通信工程,真题,大纲,参考书。博睿泽信息通信考研论坛,信息通信考研Jenny

考研结束,专业819信号系统与信号处理140,总分420,终于梦圆交大,高考时敢都不敢想目标,现在已经成为现实,考研后劲很大,这一年的复习经历,还是历历在目,整理一下&#xff…

【NLP修炼系列之Bert】Bert多分类多标签文本分类实战(附源码下载)

引言 今天我们就要用Bert做项目实战,实现文本多分类任务和我在实际公司业务中的多标签文本分类任务。通过本篇文章,可以让想实际入手Bert的NLP学习者迅速上手Bert实战项目。 1 项目介绍 本文是Bert文本多分类和多标签文本分类实战,其中多分…

[Redis#17] 主从复制 | 拓扑结构 | 复制原理 | 数据同步 | psync

目录 主从模式 主从复制作用 建立主从复制 主节点信息 从节点信息 断开主从复制关系 主从拓扑结构 主从复制原理 1. 复制过程 2. 数据同步(PSYNC) 3. 三种复制方式 一、全量复制 二、部分复制 三、实时复制 四、主从复制模式存在的问题 在…

【青牛科技】拥有两个独立的、高增益、内部相位补偿的双运算放大器,可适用于单电源或双电源工作——D4558

概述: D4558内部包括有两个独立的、高增益、内部相位补偿的双运算放大器,可适用于单电源或双电源工作。该电路具有电压增益高、噪声低等特点。主要应用于音频信号放大,有源滤波器等场合。 D4558采用DIP8、SOP8的封装形式 主要特点&#xff…

泰坦军团品牌焕新:LOGO变更开启电竞细分市场新篇章

深圳世纪创新显示电子有限公司旗下的高端电竞显示器品牌泰坦军团,上月发布通告,自2024年6月起已陆续进行品牌升级和LOGO变更。 泰坦军团自2015年成立以来,凭借先进的技术和顶级的工业设计,已成为众多年轻人首选的游戏显示器品牌&…

HALCON 算子 之 阈值分割算子

文章目录 什么是阈值分割?为什么要阈值分割?如何进行阈值分割?全局threshold —— 全局固定阈值分割auto_threshold —— 全局自动阈值分割fast_threshold —— 快速全局阈值分割watersheds_threshold —— 分水岭盆地阈值分割 局部dyn_thres…

【代码随想录|贪心算法重叠区间问题】

452.用最少数量的箭引爆气球 题目链接452. 用最少数量的箭引爆气球 - 力扣(LeetCode) 这道题是要求从下往上穿箭,把所有气球扎爆要的最少箭的数量 思路就是我们比较这个气球和上一个气球有没有重合的,重合我们一根箭一起就射了…

鸿蒙获取 APP 信息及手机信息

前言:获取 APP 版本信息可以通过 bundleManager.getBundleInfoForSelfSync(bundleFlags) 去获取,获取手机信息可以通过 kit.BasicServicesKit 库去获取,以下是封装好的工具类。 import bundleManager from ohos.bundle.bundleManager; impo…

爬取的数据能实时更新吗?

在当今数字化时代,实时数据更新对于企业和个人都至关重要。无论是市场分析、商品类目监控还是其他需要实时数据的应用场景,爬虫技术都能提供有效的解决方案。本文将探讨如何利用PHP爬虫实现数据的实时更新,并提供相应的代码示例。 1. 实时数…

JS中多方式数组复制知识扩展

JS中多方式数组复制知识扩展 前言浅拷贝JavaScript 展开操作符for() 循环其他: array.forEachforEach方法详解 array.mapmap()方法详解 array.filterfilter()方法详解 array.reducereduce()方法详解 array.sliceslice()方法详解 Array.fromArray.from()方法详解 深拷…

Hive 中 Order By、Sort By、Cluster By 和 Distribute By 的详细解析

Hive 中 Order By、Sort By、Cluster By 和 Distribute By 的详细解析 在 Hive 数据查询与处理操作中,Order By、Sort By、Cluster By 和 Distribute By 这些语句对于数据的排序、分区以及在 Reduce 阶段的处理起着关键作用。本文将详细解析它们各自的语法、区别以…

okHttp的tcp连接池的复用

okhttp的连接池是tcp连接池吧,是两台机器之间的连接,ip:port连接,然后具体的接口再添加具体的url吗? 具体的 HTTP 请求(包括 URL、请求方法、头部等)则是在复用的 TCP 连接上进行传输的。 是的&#xff0c…

Linux 正确关机方式详解

在Linux系统中,正确地关机是一个重要的操作,它不仅影响到系统的数据完整性,还可能影响到其他用户的工作。本文将详细介绍Linux系统中的各种关机方式,包括它们的使用场景和具体命令。 为什么需要正确关机 在DOS和Windows系统中&a…