Qwen2大模型微调入门实战(完整代码)

Qwen2是通义千问团队的开源大语言模型,由阿里云通义实验室研发。以Qwen2作为基座大模型,通过指令微调的方式实现高准确率的文本分类,是学习大语言模型微调的入门任务。

在这里插入图片描述

指令微调是一种通过在由(指令,输出)对组成的数据集上进一步训练LLMs的过程。 其中,指令代表模型的人类指令,输出代表遵循指令的期望输出。 这个过程有助于弥合LLMs的下一个词预测目标与用户让LLMs遵循人类指令的目标之间的差距。

在这个任务中我们会使用Qwen2-1.5b-Instruct模型在zh_cls_fudan_news数据集上进行指令微调任务,同时使用SwanLab进行监控和可视化。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本教程参考了这篇文章。

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerat
pandas

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.9

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载Qwen2-1.5B-Instruct模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallbackswanlab_callback = SwanLabCallback(...)trainer = Trainer(...callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlabdef dataset_jsonl_transfer(origin_path, new_path):"""将原始数据集转换为大模型微调所需数据格式的新数据集"""messages = []# 读取旧的JSONL文件with open(origin_path, "r") as file:for line in file:# 解析每一行的json数据data = json.loads(line)context = data["text"]catagory = data["category"]label = data["output"]message = {"instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型","input": f"文本:{context},类型选型:{catagory}","output": label,}messages.append(message)# 保存重构后的JSONL文件with open(new_path, "w", encoding="utf-8") as file:for message in messages:file.write(json.dumps(message, ensure_ascii=False) + "\n")def process_func(example):"""将数据集进行预处理"""MAX_LENGTH = 384 input_ids, attention_mask, labels = [], [], []instruction = tokenizer(f"<|im_start|>system\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|im_end|>\n<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n",add_special_tokens=False,)response = tokenizer(f"{example['output']}", add_special_tokens=False)input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]if len(input_ids) > MAX_LENGTH:  # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   def predict(messages, model, tokenizer):device = "cuda"text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt").to(device)generated_ids = model.generate(model_inputs.input_ids,max_new_tokens=512)generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]print(response)return response# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"if not os.path.exists(train_jsonl_new_path):dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False,  # 训练模式r=8,  # Lora 秩lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1,  # Dropout 比例
)model = get_peft_model(model, config)args = TrainingArguments(output_dir="./output/Qwen1.5",per_device_train_batch_size=4,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=2,save_steps=100,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,report_to="none",
)swanlab_callback = SwanLabCallback(project="Qwen2-fintune",experiment_name="Qwen2-1.5B-Instruct",description="使用通义千问Qwen2-1.5B-Instruct模型在zh_cls_fudan-news数据集上微调。",config={"model": "qwen/Qwen2-1.5B-Instruct","dataset": "huangjintao/zh_cls_fudan-news",}
)trainer = Trainer(model=model,args=args,train_dataset=train_dataset,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),callbacks=[swanlab_callback],
)trainer.train()# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]test_text_list = []
for index, row in test_df.iterrows():instruction = row['instruction']input_value = row['input']messages = [{"role": "system", "content": f"{instruction}"},{"role": "user", "content": f"{input_value}"}]response = predict(messages, model, tokenizer)messages.append({"role": "assistant", "content": f"{response}"})result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"test_text_list.append(swanlab.Text(result_text, caption=response))swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的qwen2的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的qwen2能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了qwen2指令微调的训练!

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25268.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

倩女幽魂手游攻略:云手机自动搬砖辅助教程!

《倩女幽魂》手游自问世以来一直备受玩家喜爱&#xff0c;其精美画面和丰富的游戏内容让人沉迷其中。而如今&#xff0c;借助VMOS云手机&#xff0c;玩家可以更轻松地进行搬砖&#xff0c;提升游戏体验。 一、准备工作 下载VMOS云手机&#xff1a; 在PC端或移动端下载并安装VM…

流程的控制

条件选择语句 我们一般将条件选择语句分为三类&#xff1a; 单条件双条件多条件 本篇文章将分开诉说着三类。 单条件 单条件的语法很简单&#xff1a; if (条件) {// 代码}条件这里我们需要注意下&#xff0c;可以向里写入两种&#xff1a; 布尔值布尔表达式 当然&…

Docker高级篇之Docker网络

文章目录 1. Docker Network简介2. Docker 网络模式3. Docker 网络模式之bridge4. Docker 网络模式之host5. Docker 网络模式之none6. Docker 网络模式之container7. Docker 网络模式之自定义网络模式 1. Docker Network简介 从Docker的架构和运作流程来看&#xff0c;Docker是…

计算机组成原理之指令寻址

一、顺序寻址 1、定长指令字结构 2、变长指令字结构 二、跳跃寻址 三、数据寻址 1、直接寻址 2、间接寻址 3、寄存器寻址 寄存器间接寻址 4、隐含寻址 5、立即寻址 6、偏移寻址 1、基址寻址 2、变址寻址 3、相对寻址

力扣199. 二叉树的右视图

给定一个二叉树的 根节点 root&#xff0c;想象自己站在它的右侧&#xff0c;按照从顶部到底部的顺序&#xff0c;返回从右侧所能看到的节点值。 示例 1: 输入: [1,2,3,null,5,null,4] 输出: [1,3,4]示例 2: 输入: [1,null,3] 输出: [1,3]示例 3: 输入: [] 输出: [] /*** Def…

语法分析!!!

一、实验题目 根据给定文法编写调试预测分析程序&#xff0c;对任意输入串用预测分析法进行语法分析。 二、实验目的 加深对预测分析法的理解。 三、实验内容 四、实验代码 #include <iostream> #include <stdio.h> #include <string> #include <…

隐式链接DLL

本文仅供学习交流&#xff0c;严禁用于商业用途&#xff0c;如本文涉及侵权请及时联系本人将于及时删除 【例9.5】创建的基于MFC对话框的应用程序MFCImLink2&#xff0c;隐式链接例9.2创建的MFCLibrary2.dll&#xff0c;使用其中的导出函数求正方形的面积。 (1) 使用MFC应用程…

【零基础一看就会】Python爬虫从入门到应用(下)

目录 一、urllib的学习 1.1 urllib介绍 1.2 urllib的基本方法介绍 urllib.Request &#xff08;1&#xff09;构造简单请求 &#xff08;2&#xff09;传入headers参数 &#xff08;3&#xff09;传入data参数 实现发送post请求&#xff08;示例&#xff09; response.…

野火FPGA跟练(四)——串口RS232、亚稳态

目录 简介接口与引脚通信协议亚稳态RS232接收模块模块框图时序波形RTL 代码易错点Testbench 代码仿真 RS232发送模块模块框图时序波形RTL 代码Testbench 代码仿真 简介 UART&#xff1a;Universal Asynchronous Receiver/Transmitter&#xff0c;异步串行通信接口。发送数据时…

微服务开发与实战Day04

一、网关路由 网关&#xff1a;就是网络的关口&#xff0c;负责请求的路由、转发、身份校验。 在SpringCloud中网关的实现包括两种&#xff1a; 1. 快速入门 Spring Cloud Gateway 步骤&#xff1a; ①新建hm-gateway模块 ②引入依赖pom.xml(hm-gateway) <?xml version…

解锁俄罗斯市场:如何选择优质的俄罗斯云服务器

在当前云计算市场上&#xff0c;很多大型的云厂商并没有俄罗斯服务器的云节点&#xff0c;这给许多企业在拓展海外业务时带来了一定的困扰。然而&#xff0c;俄罗斯作为一个经济发展迅速的国家&#xff0c;其市场潜力不可忽视。因此&#xff0c;选择一台优质的俄罗斯云服务器成…

【MySQL】(基础篇三) —— 创建数据库和表

管理数据库和表 管理数据库 创建数据库 在MySQL中&#xff0c;创建数据库的SQL命令相对简单&#xff0c;基本语法如下&#xff1a; CREATE DATABASE 数据库名;如果你想避免在尝试创建已经存在的数据库时出现错误&#xff0c;可以添加 IF NOT EXISTS 子句&#xff0c;这样如…

数据结构(C):二叉树前中后序和层序详解及代码实现及深度刨析

目录 &#x1f31e;0.前言 &#x1f688;1.二叉树链式结构的代码是实现 &#x1f688;2.二叉树的遍历及代码实现和深度刨析代码 &#x1f69d;2.1前序遍历 ✈️2.1.1前序遍历的理解 ✈️2.1.2前序代码的实现 ✈️2.1.3前序代码的深度解剖 &#x1f69d;2.2中序遍历 ✈…

计算机网络:数据链路层 - 扩展的以太网

计算机网络&#xff1a;数据链路层 - 扩展的以太网 集线器交换机自学习算法单点故障 集线器 这是以前常见的总线型以太网&#xff0c;他最初使用粗铜轴电缆作为传输媒体&#xff0c;后来演进到使用价格相对便宜的细铜轴电缆。 后来&#xff0c;以太网发展出来了一种使用大规模…

AI菜鸟向前飞 — LangChain系列之十七 - 剖析AgentExecutor

AgentExecutor 顾名思义&#xff0c;Agent执行器&#xff0c;本篇先简单看看LangChain是如何实现的。 先回顾 AI菜鸟向前飞 — LangChain系列之十四 - Agent系列&#xff1a;从现象看机制&#xff08;上篇&#xff09; AI菜鸟向前飞 — LangChain系列之十五 - Agent系列&#…

Springboot使用webupload大文件分片上传(包含前后端源码)

Springboot使用webupload大文件分片上传&#xff08;包含源码&#xff09; 1. 实现效果1.1 分片上传效果图1.2 分片上传技术介绍 2. 分片上传前端实现2.1 什么是WebUploader&#xff1f;功能特点接口说明事件APIHook 机制 2.2 前端代码实现2.2.1&#xff08;不推荐&#xff09;…

计算机组成原理之计算机系统层次结构

目录 计算机系统层次结构 复习提示 1.计算机系统的组成 2.计算机硬件 2.1冯诺依曼机基本思想 2.1.1冯诺依曼计算机的特点 2.2计算机的功能部件 2.2.1MAR 和 MDR 位数的概念和计算 3.计算机软件 3.1系统软件和应用软件 3.2三个级别的语言 3.2.1三种机器语言的特点 3…

★pwn 24.04环境搭建保姆级教程★

★pwn 24.04环境搭建保姆级教程★ &#x1f338;前言&#x1f33a;Ubuntu 24.04虚拟机&#x1f337;VM&#x1f337;Ubuntu 24.04镜像 &#x1f33a;工具&#x1f337;可能出现的git clone错误&#x1f337;复制粘贴问题&#x1f337;攻击&#x1f337;编题 &#x1f33a;美化&…

【AI大模型】Transformers大模型库(五):AutoModel、Model Head及查看模型结构

目录​​​​​​​ 一、引言 二、自动模型类&#xff08;AutoModel&#xff09; 2.1 概述 2.2 Model Head&#xff08;模型头&#xff09; 2.3 代码示例 三、总结 一、引言 这里的Transformers指的是huggingface开发的大模型库&#xff0c;为huggingface上数以万计的预…

使用 Keras 的 Stable Diffusion 实现高性能文生图

前言 在本文中&#xff0c;我们将使用基于 KerasCV 实现的 [Stable Diffusion] 模型进行图像生成&#xff0c;这是由 stable.ai 开发的文本生成图像的多模态模型。 Stable Diffusion 是一种功能强大的开源的文本到图像生成模型。虽然市场上存在多种开源实现可以让用户根据文本…