最实战的GLM4微调入门:从文本分类开始

GLM4是清华智谱团队最近开源的大语言模型。

以GLM4作为基座大模型,通过指令微调的方式做高精度文本分类,是学习LLM微调的入门任务。

在这里插入图片描述

使用的9B模型,显存要求相对较高,需要40GB左右。

在本文中,我们会使用 GLM4-9b-Chat 模型在 复旦中文新闻 数据集上做指令微调训练,同时使用SwanLab监控训练过程、评估模型效果。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

相关文章:Qwen2指令微调

知识点:什么是指令微调?

大模型指令微调(Instruction Tuning)是一种针对大型预训练语言模型的微调技术,其核心目的是增强模型理解和执行特定指令的能力,使模型能够根据用户提供的自然语言指令准确、恰当地生成相应的输出或执行相关任务。

指令微调特别关注于提升模型在遵循指令方面的一致性和准确性,从而拓宽模型在各种应用场景中的泛化能力和实用性。

在实际应用中,我的理解是,指令微调更多把LLM看作一个更智能、更强大的传统NLP模型(比如Bert),来实现更高精度的文本预测任务。所以这类任务的应用场景覆盖了以往NLP模型的场景,甚至很多团队拿它来标注互联网数据

下面是实战正片:

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerate
pandas
tiktoken

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas accelerate tiktoken

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.10、tiktokn==0.7.0,更多环境细节可以查看这里

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载GLM4-9b-Chat模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallbackswanlab_callback = SwanLabCallback(...)trainer = Trainer(...callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlabdef dataset_jsonl_transfer(origin_path, new_path):"""将原始数据集转换为大模型微调所需数据格式的新数据集"""messages = []# 读取旧的JSONL文件with open(origin_path, "r") as file:for line in file:# 解析每一行的json数据data = json.loads(line)context = data["text"]catagory = data["category"]label = data["output"]message = {"instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型","input": f"文本:{context},类型选型:{catagory}","output": label,}messages.append(message)# 保存重构后的JSONL文件with open(new_path, "w", encoding="utf-8") as file:for message in messages:file.write(json.dumps(message, ensure_ascii=False) + "\n")def process_func(example):"""将数据集进行预处理"""MAX_LENGTH = 384 input_ids, attention_mask, labels = [], [], []instruction = tokenizer(f"<|system|>\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",add_special_tokens=False,)response = tokenizer(f"{example['output']}", add_special_tokens=False)input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]if len(input_ids) > MAX_LENGTH:  # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   def predict(messages, model, tokenizer):device = "cuda"text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt").to(device)generated_ids = model.generate(model_inputs.input_ids,max_new_tokens=512)generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]print(response)return response# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"if not os.path.exists(train_jsonl_new_path):dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"],inference_mode=False,  # 训练模式r=8,  # Lora 秩lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1,  # Dropout 比例
)model = get_peft_model(model, config)args = TrainingArguments(output_dir="./output/GLM4-9b",per_device_train_batch_size=4,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=2,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,report_to="none",
)swanlab_callback = SwanLabCallback(project="GLM4-fintune",experiment_name="GLM4-9B-Chat",description="使用智谱GLM4-9B-Chat模型在zh_cls_fudan-news数据集上微调。",config={"model": "ZhipuAI/glm-4-9b-chat","dataset": "huangjintao/zh_cls_fudan-news",},
)trainer = Trainer(model=model,args=args,train_dataset=train_dataset,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),callbacks=[swanlab_callback],
)trainer.train()# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]test_text_list = []
for index, row in test_df.iterrows():instruction = row['instruction']input_value = row['input']messages = [{"role": "system", "content": f"{instruction}"},{"role": "user", "content": f"{input_value}"}]response = predict(messages, model, tokenizer)messages.append({"role": "assistant", "content": f"{response}"})result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"test_text_list.append(swanlab.Text(result_text, caption=response))swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始,这些loss、grad_norm等信息会到一定的step时打印出来:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的glm4的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的glm4能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了glm4指令微调的训练!

7. 模型推理

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModeldef predict(messages, model, tokenizer):device = "cuda"text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt").to(device)generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]return response# 加载原下载路径的tokenizer和model
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16)# 加载训练好的Lora模型,将下面的checkpointXXX替换为实际的checkpoint文件名名称
model = PeftModel.from_pretrained(model, model_id="./output/GLM4-9b/checkpoint-XXX")test_texts = {'instruction': "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",'input': "文本:航空动力学报JOURNAL OF AEROSPACE POWER1998年 第4期 No.4 1998科技期刊管路系统敷设的并行工程模型研究*陈志英* * 马 枚北京航空航天大学【摘要】 提出了一种应用于并行工程模型转换研究的标号法,该法是将现行串行设计过程(As-is)转换为并行设计过程(To-be)。本文应用该法将发动机外部管路系统敷设过程模型进行了串并行转换,应用并行工程过程重构的手段,得到了管路敷设并行过程模型。"
}instruction = test_texts['instruction']
input_value = test_texts['input']messages = [{"role": "system", "content": f"{instruction}"},{"role": "user", "content": f"{input_value}"}
]response = predict(messages, model, tokenizer)
print(response)

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/856171.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全志 Android 11:实现响应全局按键

一、篇头 最近实现热键想功能&#xff0c;简单总结了下全志平台Android 11 的响应全局热键的方法。 二、需求 实现全局热键&#xff0c;响应F-、AF、F三个按键&#xff0c;AF只用于启动调焦界面&#xff0c;F-和F除了可以启动调焦界面外&#xff0c;还用于调整镜头的焦距&…

Go web框架|Beego、Gin、Echo、Iris等干货教程

Go 是一门正在快速增长的编程语言&#xff0c;专为构建简单、快速且可靠的软件而设计。golang提供的net/http库已经很好了&#xff0c;对于http的协议的实现非常好&#xff0c;基于此再造框架&#xff0c;也不会是难事&#xff0c;因此生态中出现了很多框架。 本篇文章主要介绍…

使用node把任意网站封装为可执行文件

直接上步骤&#xff1a; 1. node.js 环境准备 下载地址 那个版本都行&#xff0c;下一步->下一步 安装即可 2. windows 系统下&#xff0c; 快捷键 winr ->输入 cmd -> 回车 3. 执行第一个命令&#xff0c;安装 nativefier 等一段时间 npm install nativefier -g 4…

全面赋能,永久免费!讯飞星火API能力正式免费开放

2023年5月&#xff0c;讯飞星火正式发布&#xff0c;迅速成为千万用户获取知识、学习知识的“超级助手”&#xff0c;成为解放生产力、释放想象力的“超级杠杆”。 2024年5月&#xff0c;讯飞星火API能力正式免费开放&#xff0c;携手生态开发者加快大模型赋能刚需场景。 领…

【考研408计算机组成原理】存储系统之Cache考点

苏泽 “弃工从研”的路上很孤独&#xff0c;于是我记下了些许笔记相伴&#xff0c;希望能够帮助到大家 另外&#xff0c;利用了工作之余的一点点时间&#xff0c;整理了一套考研408的知识图谱&#xff0c; 我根据这一套知识图谱打造了这样一个408知识图谱问答系统 里面的每一…

【C++题解】1324 - 扩建鱼塘问题

问题&#xff1a;1324 - 扩建鱼塘问题 类型&#xff1a;分支问题 题目描述&#xff1a; 有一个尺寸为 mn 的矩形鱼塘&#xff0c;请问如果要把该鱼塘扩建为正方形&#xff0c;那么它的面积至少增加了多少平方米&#xff1f; 输入&#xff1a; 两个整数 m 和 n 。 输出&…

LeetCode 54.螺旋矩阵

1.题目要求如图所示: 各位看官们&#xff0c;大家好呀&#xff0c;今天小编用的方法比较麻烦&#xff0c;就是按顺时针遍历&#xff0c;但也挺好理解的&#xff0c;因为就是迭代法循环&#xff0c;所以就不给大家讲步骤了&#xff0c;直接就发代码了: /*** Note: The returne…

深入浅出Netty:高性能网络应用框架的原理与实践

深入浅出Netty&#xff1a;高性能网络应用框架的原理与实践 1. Netty简介 Netty是一个基于Java的异步事件驱动的网络应用框架&#xff0c;广泛用于构建高性能、高可扩展性的网络服务器和客户端。它提供对多种协议&#xff08;如TCP、UDP、SSL等&#xff09;的支持&#xff0c;…

【计算机网络篇】数据链路层(11)在数据链路层扩展以太网

文章目录 &#x1f354;使用网桥在数据链路层扩展以太网&#x1f95a;网桥的主要结构和基本工作原理&#x1f388;网桥的主要结构&#x1f50e;网桥转发帧的例子&#x1f50e;网桥丢弃帧的例子&#x1f50e;网桥转发广播帧的例子 &#x1f95a;透明网桥&#x1f50e;透明网桥的…

网络基础篇:网络模型

目录 一、初识网络 二、网络的分层 OSI七层模型 TCP/IP四层模型 网络与系统的关系 网络传输基本流程 数据包封装和分用 三、IP地址与MAC地址 认识IP地址 认识MAC地址 IP与MAC的关系 一、初识网络 同一台设备上的进程间通信有很多种方式 &#xff1a; 管道&#xff08…

需求虽小但是问题很多,浅谈JavaScript导出excel文件

最近我在进行一些前端小开发&#xff0c;遇到了一个小需求&#xff1a;我想要将数据导出到 Excel 文件&#xff0c;并希望能够封装成一个函数来实现。这个函数需要接收一个二维数组作为参数&#xff0c;数组的第一行是表头。在导出的过程中&#xff0c;要能够确保避免出现中文乱…

二叉树(数据结构篇)

数据结构之二叉树 二叉树 概念&#xff1a; 二叉树(binary tree)是一颗每个节点都不能多于两个子节点的树&#xff0c;左边的子树称为左子树&#xff0c;右边的子树称为右子树 性质&#xff1a; 二叉树实际上是图&#xff0c;二叉树相对于树更常用。 平衡二叉树的深度要比…

正版 navicat 下载

1. 打开浏览器访问 navicat 官网 Navicat | 下载 Navicat Premium 14 天免费 Windows、macOS 和 Linux 的试用版 windows 用户选择这三项其中一个就可以 2. 下载 点击之后等个几秒钟就会开始下载了 3. 双击打开 下载好的 .exe 程序 进入安装程序 (不影响之前已经安装过的) 可…

客户ITSS案例 — 江苏中友讯华信息科技有限公司

● 2019年12月17日至12月20日&#xff0c;中国电子工业标准化技术协会信息技术服务分会&#xff08;以下称ITSS分会&#xff09;组织召开了运行维护服务能力成熟度符合性评估专家评审会。在江苏新世纪信息科技有限公司的咨询辅导下&#xff0c;江苏中友讯华信息科技有限公司顺利…

猫头虎分享已解决Bug || **Mismatched Types**: `mismatched types`

&#x1f42f; 猫头虎分享已解决Bug || Mismatched Types: mismatched types &#x1f42f; 关于猫头虎 大家好&#xff0c;我是猫头虎&#xff0c;别名猫头虎博主&#xff0c;擅长的技术领域包括云原生、前端、后端、运维和AI。我的博客主要分享技术教程、bug解决思路、开发…

ECharts 雷达图案例001-自定义节点动画

ECharts 雷达图案例001-自定义节点动画 引言 在数据可视化的领域中&#xff0c;ECharts 提供了一种强大的工具来展示多维数据。本文将介绍如何使用 ECharts 创建一个自定义节点样式的雷达图&#xff0c;让数据展示更加生动和个性化。 效果预览 通过自定义节点样式&#xff…

AI早班车2024.6.19

全球AI新闻速递 1.广东 / 山东警方破获两起“AI 换脸伪造不雅照”案。 2.腾讯混元、港科大、清华推出表情包框架&#xff1a;Follow Your Emoji。 3.抖音联合博纳影业推出首部 AIGC 科幻短剧集《三星堆&#xff1a;未来启示录》。 4.亚马逊&#xff1a;宣布向全球创企提供 …

【Java】BigDecimal类型——BigDecimal 为什么可以保证精度不丢失

目录 简介类介绍案例分析总结BigDecimal类型的使用场景MySQL中存储BigDecimal类型数据补充&#xff1a;BigDecimal类型使用时的注意事项BigDecimal类型的其他使用 简介 BigDecimal是Java中的一个类&#xff0c;用于处理大数运算。它提供了精确的数值计算&#xff0c;可以处理任…

真空玻璃可见光透射比检测 玻璃制品检测 玻璃器皿检测

建筑玻璃检测 防火玻璃、钢化玻璃、夹层玻璃、均质钢化玻璃、平板玻璃、中空玻璃、真空玻璃、镀膜玻璃夹丝玻璃、光栅玻璃、压花玻璃、建筑用U形玻璃、镶嵌玻璃、玻璃幕墙等 工业玻璃检测 钢化安全玻璃、电加温玻璃、玻璃、半钢化玻璃、视镜玻璃、汽车安全玻璃、汽车后窗电热…

Walrus:去中心化存储和DA协议,可以基于Sui构建L2和大型存储

Walrus是为区块链应用和自主代理提供的创新去中心化存储网络。Walrus存储系统今天以开发者预览版的形式发布&#xff0c;面向Sui开发者征求反馈意见&#xff0c;并预计很快会向其他Web3社区广泛推广。 通过采用纠删编码创新技术&#xff0c;Walrus能够快速且稳健地将非结构化数…