使用 LlamaFactory 结合开源大语言模型实现文本分类:从数据集构建到 LoRA 微调与推理评估

文章目录

    • 背景介绍
      • 文本分类数据集
      • Lora 微调
      • 模型部署与推理
        • 期待模型的输出结果
    • 文本分类评估代码

背景介绍

本文将一步一步地,介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验,以 LoRA微调 qwen/Qwen2.5-7B-Instruct 为例。

文本分类数据集

按照 alpaca 样式构建数据集,并在将其添加到 LLaMA-Factory/data/dataset_info.json 文件中。如此方便直接根据自定义数据集的名字,获取到数据集的数据。

[{"instruction": "","input": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:\n\n要求}}\nreason: \nlabel:","output": "reason: 该文本主要讨论的是xxx。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"},...
]

Lora 微调

llamafactory 框架支持网页端训练,但本文选择在终端使用命令行微调模型。

模型微调训练的参数较多,将模型训练的参数都存储在 yaml 文件中。

qwen_train_cls.yaml 的文件内容如下:

### model
model_name_or_path: qwen/Qwen2.5-7B-Instruct### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all### dataset
# dataset_dir: data
dataset_dir: LLaMA-Factory/data/ 填写相应路径
dataset: 数据集名 
template: qwen
cutoff_len: 2048
# max_samples: 1000 若数据集较大,可随机筛选一部分数据微调模型
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: output/qwen2.5-7B/cls_epoch2 训练的LoRA权重输出路径
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

使用下述命令启动模型训练:

nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &

命令分解介绍:
nohup, 全称为 “no hangup”(不要挂起)。它的作用是让命令在退出终端后仍然运行,防止因关闭终端或会话中断导致进程被终止。
默认情况下,nohup 会将输出重定向到 nohup.out 文件,但这里已经显式指定了输出位置。
llamafactory-cli train qwen_train_cls.yaml 运行 llamafactory-cli 工具,用于执行训练任务。
train 是子命令,表示进行训练。
qwen_train_cls.yaml 是一个配置文件,包含训练所需的超参数、数据路径、模型结构等。
qwen_train_cls.log
将标准输出 (stdout) 重定向到 qwen_train_cls.log 文件中。
即运行过程中的正常日志信息会被记录到这个文件。
2>&1: 将标准错误输出 (stderr) 重定向到标准输出 (stdout)。
这样,所有错误信息也会被写入到 qwen_train_cls.log 文件中。
&: 表示将整个命令放到后台运行。终端会立即返回,您可以继续进行其他操作,而不用等待命令完成。

模型部署与推理

模型训练完成后得到 Lora 权重。相关微调模型部署与推理,请浏览下述两篇文章,相比llamafactory原本的模型推理速度更快。

  • 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
  • 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

目前llamafactory已经支持 vllm_infer 推理,这个PR是笔者提交的:

  • llamafactory vllm.https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py
期待模型的输出结果

下述是使用 llamafactory 推理出的数据格式,建议大家在做推理评估时,也做成这个样式,方便统一评估。

{"prompt": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:...", "predict": "\nreason: 该文本主要讨论了改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。这些内容与社会管理和经济管理密切相关,但更侧重于行政管理和社会治理的改革,因此更符合“社会管理”这一类别。\n\nlabel: 社会管理", 		 "label": "reason: 该文本主要讨论的是改革创新、行政区划调整、体制机制障碍的破除以及行政管理体制等与政府治理和社会管理相关的内容,强调了与高质量发展和生活品质的关系。这些内容显示出对社会管理和行政管理的关注,尤其是在推动城乡一体化和适应高质量发展要求方面。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
}

文本分类评估代码

import os
import re
import jsonfrom sklearn.metrics import classification_report, confusion_matrix# 文本类别
CLASS_NAME = ["产业相关",..."法律法规与行政事务","其他",
]def load_jsonl(file_path):"""加载指定路径的 JSON 文件并返回解析后的数据。:param file_path: JSON 文件的路径:return: 解析后的数据(通常是字典或列表):raises FileNotFoundError: 如果文件未找到:raises json.JSONDecodeError: 如果 JSON 格式不正确"""data = []try:with open(file_path, "r", encoding="utf-8") as file:for line in file:tmp = json.loads(line)data.append(tmp)except FileNotFoundError as e:print(f"文件未找到:{file_path}")raise eexcept json.JSONDecodeError as e:print(f"JSON 格式错误:{e}")raise ereturn datadef parser_label(text: str):pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"matches = re.findall(pattern, text, re.DOTALL)if len(matches) == 1:return matches[0]return Nonedef trans2num(item):predict = parser_label(item["predict"])label = parser_label(item["label"])predict_idx = -1label_idx = -1for idx, cls_name in enumerate(CLASS_NAME):if predict == cls_name:predict_idx = idxif label == cls_name:label_idx = idxreturn predict_idx, label_idxdef cls_eval(input_file):data = load_jsonl(file_path=input_file)predicts = []labels = []for item in data:predict, label = trans2num(item)if label == -1:continuepredicts.append(predict)labels.append(label)return classification_report(predicts, labels, output_dict=False)

本文使用了大模型生成式预测文本类别,我没有使用结构化输出的方式,大家可以使用结构化的json格式输出,这样在提取大模型预测结果的时候会方便很多。

大家按照自己模型的输出结果,修改parser_label 函数,这个函数用于从大模型的输出结果提取label。

cls_eval("xxx/generated_predictions.jsonl")

就会得到下述的输出结果:

-1 代表模型预测的类别不在给定的类别中。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/63241.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ARM内核与单片机

1.单片机硬件架构如下所示:各种硬件通过总线进行连接。 2.M4内核架构 3.单片机如何工作: 4.CPU是通过读写寄存器来控制GPIO的 5.GPIO的硬件框架:一共有8种模式 (1)推挽/推挽复用输出。下图先看图1,如果输入…

PHP 命令执行漏洞学习记录

PHP 命令执行 命令函数 作用 例子 system() 执行外部程序,并且显示输出 system(whoami) exec() 执行一个外部程序 echo exec(whoami); shell_exec() 通过shell环境执行命令,并且将完整的输出以字符串的形式返回 echo shell_exec(whoami); passthru() 执行外部程序…

VSCode GDB远程嵌入开发板调试

VSCode GDB远程嵌入式开发板调试 一、原理 嵌入式系统中一般在 PC端运行 gdb工具,源码也是在 PC端,源码对应的可执行文件放到开发板中运行。为此我们需要在开发板中运行 gdbserver,通过网络与 PC端的 gdb进行通信。因此要想在 PC上通过 gdb…

专业140+总分420+上海交通大学819考研经验上交电子信息与通信工程,真题,大纲,参考书。博睿泽信息通信考研论坛,信息通信考研Jenny

考研结束,专业819信号系统与信号处理140,总分420,终于梦圆交大,高考时敢都不敢想目标,现在已经成为现实,考研后劲很大,这一年的复习经历,还是历历在目,整理一下&#xff…

【NLP修炼系列之Bert】Bert多分类多标签文本分类实战(附源码下载)

引言 今天我们就要用Bert做项目实战,实现文本多分类任务和我在实际公司业务中的多标签文本分类任务。通过本篇文章,可以让想实际入手Bert的NLP学习者迅速上手Bert实战项目。 1 项目介绍 本文是Bert文本多分类和多标签文本分类实战,其中多分…

[Redis#17] 主从复制 | 拓扑结构 | 复制原理 | 数据同步 | psync

目录 主从模式 主从复制作用 建立主从复制 主节点信息 从节点信息 断开主从复制关系 主从拓扑结构 主从复制原理 1. 复制过程 2. 数据同步(PSYNC) 3. 三种复制方式 一、全量复制 二、部分复制 三、实时复制 四、主从复制模式存在的问题 在…

【青牛科技】拥有两个独立的、高增益、内部相位补偿的双运算放大器,可适用于单电源或双电源工作——D4558

概述: D4558内部包括有两个独立的、高增益、内部相位补偿的双运算放大器,可适用于单电源或双电源工作。该电路具有电压增益高、噪声低等特点。主要应用于音频信号放大,有源滤波器等场合。 D4558采用DIP8、SOP8的封装形式 主要特点&#xff…

泰坦军团品牌焕新:LOGO变更开启电竞细分市场新篇章

深圳世纪创新显示电子有限公司旗下的高端电竞显示器品牌泰坦军团,上月发布通告,自2024年6月起已陆续进行品牌升级和LOGO变更。 泰坦军团自2015年成立以来,凭借先进的技术和顶级的工业设计,已成为众多年轻人首选的游戏显示器品牌&…

HALCON 算子 之 阈值分割算子

文章目录 什么是阈值分割?为什么要阈值分割?如何进行阈值分割?全局threshold —— 全局固定阈值分割auto_threshold —— 全局自动阈值分割fast_threshold —— 快速全局阈值分割watersheds_threshold —— 分水岭盆地阈值分割 局部dyn_thres…

【代码随想录|贪心算法重叠区间问题】

452.用最少数量的箭引爆气球 题目链接452. 用最少数量的箭引爆气球 - 力扣(LeetCode) 这道题是要求从下往上穿箭,把所有气球扎爆要的最少箭的数量 思路就是我们比较这个气球和上一个气球有没有重合的,重合我们一根箭一起就射了…

在商业智能BI系统中,如何配置高级感的数据可视化折线图?

在数据可视化的世界里,折线图作为一种直观且有效的数据展示方式,被广泛应用于各类数据分析与报告中。折线图不仅能够清晰地展示数据随时间或其他连续变量的变化趋势,还能通过不同的样式配置,增强图表的可读性和美观度。在JVS-智能…

容器镜像仓库

文章目录 1、docker hub1_注册2_登录3_创建容器镜像仓库4_在本地登录Docker Hub5_上传容器镜像6_下载容器镜像 2、harbor1_获取 docker compose二进制文件2_获取harbor安装文件3_获取TLS文件4_修改配置文件5_执行预备脚本6_执行安装脚本7_验证运行情况8_访问harborUI界面9_harb…

网站打开速度测试工具:互联网优化的得力助手

在信息飞速流转的互联网时代,网站如同企业与用户对话的窗口,其打开速度直接关乎用户体验,乃至业务的成败。所幸,一系列专业的网站打开速度测试工具应运而生,它们宛如幕后的技术侦探,精准剖析网站性能&#…

liunx docker 部署 nacos seata sentinel

部署nacos 1.按要求创建好数据库 2.创建docker 容器 docker run -d --name nacos-server -p 8848:8848 -e MODEstandalone -e SPRING_DATASOURCE_PLATFORMmysql -e MYSQL_SERVICE_HOST172.17.251.166 -e MYSQL_SERVICE_DB_NAMEry-config -e MYSQL_SERVICE_PORT3306 -e MYSQL…

爬虫项目基础知识详解

文章目录 Python爬虫项目基础知识一、爬虫与数据分析1.1 Python中的requests库Requests 库的安装Requests 库的 get() 方法爬取网页的通用代码框架HTTP 协议及 Requests 库方法Requests 库主要方法解析 1.2 python中的json库1.3 xpath学习之python中lxml库html了解html结构html…

结构型-组合模式(Composite Pattern)

什么是组合模式 又名部分整体模式,是用于把一组相似的对象当作一个单一的对象。组合模式依据树形结构来组合对象,用来表示部分以及整体层次。这种类型的设计模式属于结构型模式,它创建了对象组的树形结构。 结构 抽象根节点(Co…

Wordpress设置固定链接形式后出现404错误

比如固定连接设置为 /archives/%post_id%.html 这种形式,看起来比较舒服。对搜索引擎也友好。 出现404需要设置伪静态

小程序项目的基本组成结构

分类介绍 项目根目录下的文件及文件夹 pages文件夹 用来存放所有小程序的页面,其中每个页面都由4个基本文件组成,它们分别是: .js文件:页面的脚本文件,用于存放页面的数据、事件处理函数等 .json文件:…

Agent AI: Surveying the Horizons of Multimodal Interaction---摘要、引言、代理 AI 集成

题目 智能体AI:多模态交互视野的考察 论文地址:https://arxiv.org/abs/2401.03568 图1:可以在不同领域和应用程序中感知和行动的Agent AI系统概述。Agent AI是正在成为通用人工智能(AGI)的一个有前途的途径。Agent AI培训已经证…

LRU Cache替换算法

目录 1.什么是LRU Cache? 2.LRU Cache 的底层结构 3.LRU Cache的实现 LRUCache类中的接口总览 构造函数 get操作 put操作 打印 4.LRU Cache的测试 5.LRU Cache相关OJ题 6.LRU Cache类代码附录 1.什么是LRU Cache? 首先我想解释一下什么是cach…