解锁ChatGLM-6B的潜力:优化大语言模型训练,突破任务困难与答案解析难题

解锁ChatGLM-6B的潜力:优化大语言模型训练,突破任务困难与答案解析难题

LLM(Large Language Model)通常拥有大量的先验知识,使得其在许多自然语言处理任务上都有着不错的性能。

但,想要直接利用 LLM 完成一些任务会存在一些答案解析上的困难,如规范化输出格式,严格服从输入信息等。

因此,在这个项目下我们参考 ChatGLM-Tuning 的代码,尝试对大模型 ChatGLM-6B 进行 Finetune,使其能够更好的对齐我们所需要的输出格式。

1. 环境安装

由于 ChatGLM 需要的环境和该项目中其他实验中的环境有所不同,因此我们强烈建议您创建一个新的虚拟环境来执行该目录下的全部代码。

下面,我们将以 Anaconda 为例,展示如何快速搭建一个环境:

  1. 创建一个虚拟环境,您可以把 llm_env 修改为任意你想要新建的环境名称:
conda create -n llm_env python=3.8
  1. 激活新建虚拟环境并安装响应的依赖包:
conda activate llm_env
pip install -r requirements.txt
  1. 安装对应版本的 peft
cd peft-chatglm
python setup.py install

2. 数据集准备

在该实验中,我们将尝试使用 信息抽取 + 文本分类 任务的混合数据集喂给模型做 finetune,数据集在 data/mixed_train_dataset.jsonl

每一条数据都分为 contexttarget 两部分:

  1. context 部分是接受用户的输入。

  2. target 部分用于指定模型的输出。

context 中又包括 2 个部分:

  1. Instruction:用于告知模型的具体指令,当需要一个模型同时解决多个任务时可以设定不同的 Instruction 来帮助模型判别当前应当做什么任务。

  2. Input:当前用户的输入。

  • 信息抽取数据示例

Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。

{"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 找到句子中的三元组信息并输出成json给我:\n\n九玄珠是在纵横中文网连载的一部小说,作者是龙马。\nAnswer: ", "target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}
  • 文本分类数据示例

Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。

{"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子可能是一条关于什么的评论,用列表形式回答:\n\n很不错,很新鲜,快递小哥服务很好,水果也挺甜挺脆的\nAnswer: ", "target": "[\"水果\"]"
}

3. 模型训练

3.1 单卡训练

实验中支持使用 LoRA Finetune 和 P-Tuning 两种微调方式。

运行 train.sh 文件,根据自己 GPU 的显存调节 batch_size, max_source_seq_len, max_target_seq_len 参数:

# LoRA Finetune
python train.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_lora True \--lora_rank 8 \--batch_size 1 \--num_train_epochs 2 \--save_freq 1000 \--learning_rate 3e-5 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints/finetune \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM Fine-Tune" \--device cuda:0# P-Tuning
python train.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_ptuning True \--pre_seq_len 128 \--batch_size 1 \--num_train_epochs 2 \--save_freq 200 \--learning_rate 2e-4 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints/ptuning \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM P-Tuning" \--device cuda:0

成功运行程序后,会看到如下界面:

...
global step 900 ( 49.89% ) , epoch: 1, loss: 0.78065, speed: 1.25 step/s, ETA: 00:12:05
global step 1000 ( 55.43% ) , epoch: 2, loss: 0.71768, speed: 1.25 step/s, ETA: 00:10:44
Model has saved at checkpoints/model_1000.
Evaluation Loss: 0.17297
Min eval loss has been updated: 0.26805 --> 0.17297
Best model has saved at checkpoints/model_best.
global step 1100 ( 60.98% ) , epoch: 2, loss: 0.66633, speed: 1.24 step/s, ETA: 00:09:26
global step 1200 ( 66.52% ) , epoch: 2, loss: 0.62207, speed: 1.24 step/s, ETA: 00:08:06
...

log/finetune_log 下会看到训练 loss 的曲线图:

3.2 多卡训练

运行 train_multi_gpu.sh 文件,通过 CUDA_VISIBLE_DEVICES 指定可用显卡,num_processes 指定使用显卡数:

# LoRA Finetune
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 train_multi_gpu.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_lora True \--lora_rank 8 \--batch_size 1 \--num_train_epochs 2 \--save_freq 500 \--learning_rate 3e-5 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints_parrallel/finetune \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM Fine-Tune(parallel)"# P-Tuning
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 train_multi_gpu.py \--train_path data/mixed_train_dataset.jsonl \--dev_path data/mixed_dev_dataset.jsonl \--use_ptuning True \--pre_seq_len 128 \--batch_size 1 \--num_train_epochs 2 \--save_freq 500 \--learning_rate 2e-4 \--logging_steps 100 \--max_source_seq_len 400 \--max_target_seq_len 300 \--save_dir checkpoints_parrallel/ptuning \--img_log_dir "log/fintune_log" \--img_log_name "ChatGLM P-Tuning(parallel)"

相同数据集下,单卡使用时间:

Used 00:27:18.

多卡(2并行)使用时间:

Used 00:13:05.

4. 模型预测

修改训练模型的存放路径,运行 python inference.py 以测试训练好模型的效果:

device = 'cuda:0'
max_new_tokens = 300
model_path = "checkpoints/model_1000"           # 模型存放路径tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True
)model = AutoModel.from_pretrained(model_path,trust_remote_code=True
).half().to(device)
...

您也可以使用我们提供的 Playground 来进行模型效果测试:

streamlit run playground_local.py --server.port 8001

在浏览器中打开对应的 机器ip:8001 即可访问。

5. 标注平台

如果您需要标注自己的数据,也可以在 Playground 中完成。

streamlit run playground_local.py --server.port 8001

在浏览器中打开对应的 机器ip:8001 即可访问。

项目链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/LLM/chatglm_finetune/readme.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/49940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【rust/egui】(四)看看template的app.rs:update以及组件TopBottomPanelButton

说在前面 rust新手,egui没啥找到啥教程,这里自己记录下学习过程环境:windows11 22H2rust版本:rustc 1.71.1egui版本:0.22.0eframe版本:0.22.0上一篇:这里 update update实际上还是eframe::App的…

自研分布式IM-HubuIM RFC草案

HubuIM RFC草案 消息协议设计 基本协议 评估标准 【性能】协议传输效率,尽可能降低端到端的延迟,延迟高于200ms用户侧就会有所感知 【兼容】既要向前兼容也要向后兼容 【存储】减少消息包的大小,降低空间占用率,一个字节在亿…

最新AI系统ChatGPT程序源码/微信公众号/H5端+搭建部署教程+完整知识库

一、前言 SparkAi系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。 那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧&#xff01…

二级MySQL(五)——完整性约束练习

在MYSQL中,通常用来指定一个已有数据库作为当前数据库的语句是【USE】 下列选项中不是MYSQL中常用数据类型的是【VAR】 在MYSQL中,常用【NULL】表示一个字段没有值或缺值 在CREATE TABLE语句中,通常使用【PRIMARY KEY】关键字来指定主键 …

IoT DC3 是一个基于 Spring Cloud 的开源的、分布式的物联网(IoT)平台本地部署步骤

dc3 windows 本地搭建步骤: ​​ 必要软件环境 进入原网页# 务必保证至少需要给 docker 分配:1 核 CPU 以及 4G 以上的运行内存! JDK : 推荐使用 Oracle JDK 1.8 或者 OpenJDK8,理论来说其他版本也行; Maven : 推荐…

【视觉SLAM入门】5.2. 2D-3D PNP 3D-3D ICP BA非线性优化方法 数学方法SVD DLT

"养气之学,戒之躁急" 1. 3D-2D PNP1.1 代数法1.1.1 DLT(直接线性变换法)1.1.2. P3P 1.2 优化法BA (Bundle Adjustment)法 2. 3D-3D ICP2.1 代数法2.1.1 SVD方法 2.2 优化(BA)法2.2.2 非线性优化方法 前置事项: 1. 3D-2D PNP 该问题描述为&am…

记录《现有docker中安装spark3.4.1》

基础docker环境中存储hadoop3--方便后续查看 参考: 实践: export JAVA_HOME/opt/apache/jdk1.8.0_333 export SPARK_MASTER_IP192.168.0.220 export SPARK_WORKER_MEMORY4g export SPARK_WORKER_CORES2 export SPARK_EXECUTOR_MEMORY4g export HADOOP_H…

Mongodb两种启动方法

一、命令行启动 1.修改存放数据库的位置 说明:E:\data\mongodb;我在E盘创建的文件夹mongodb mongod --dbpathE:\data\mongodb 2.成功启动 说明:默认端口27017,代表已经启动成功 ,并在mongodb自动创建文件 二、配置项…

华为AR路由器 典型配置案例——以太网交换

目录 Eth-Trunk 例:配置三层链路聚合 组网需求 操作步骤 检查配置结果 配置脚本 VLAN 举例:配置基于接口划分VLAN,实现同一VLAN内的互通(同设备) 组网需求 操作步骤 检查配置结果 配置脚本 举例&#xff…

TCP半连接队列和全连接队列

目录 什么是 TCP 半连接队列和全连接队列? TCP 全连接队列溢出 如何知道应用程序的 TCP 全连接队列大小? 如何模拟 TCP 全连接队列溢出的场景? 全连接队列溢出会发生什么 ? 如何增大全连接队列呢 ? TCP 半连接队列溢出 如何查看 TC…

22款美规奔驰S500升级原厂香氛负离子系统,清香宜人,久闻不腻

奔驰原厂香氛合理性可通过车内空气调节组件营造芳香四溢的怡人氛围。通过更换手套箱内香氛喷雾发生器所用的香水瓶,可轻松选择其他香氛。香氛的浓度和持续时间可调。淡雅的香氛缓缓喷出,并且在关闭后能够立刻散去。车内气味不会永久改变,香氛…

Linux 虚拟机安装 hadoop

目录 1 hadoop下载 2 解压hadoop 3 为 hadoop 文件夹改名 4 给 hadoop 文件夹赋权 5 修改环境变量 6 刷新环境变量 7 在hadoop313目录下创建文件夹data 8 检查文件 9 编辑 ./core-site.xml文件 10 编辑./hadoop-env.sh文件 11 编辑./hdfs-site.xml文件 12 编辑./mapr…

【C++入门到精通】C++入门 —— 模版(template)

阅读导航 前言一、模版的概念二、函数模版1. 函数模板概念2. 函数模板定义格式3. 函数模板的原理4. 函数模版的实例化🚩隐式实例化🚩显式实例化 5. 函数模板的匹配原则 三、类模板1. 类模板的定义格式2. 类模板的实例化 四、非类型模板参数1. 概念2. 定义…

动态规划之路径问题

路径问题 1. 不同路径(medium)2. 不同路径II(medium)3. 礼物最大值(medium)4. 下降路径最小和(medium)5. 最⼩路径和(medium)6. 地下城游戏(hard&…

图床项目进度(一)——UI首页

1. 前言 前面我不是说了要做一个图床吗,现在在做ui。 我vue水平不够高,大部分参考b站项目照猫画虎。 vue实战后台 我使用ts,vite,vue3进行了重构。 当然,我对这些理解并不深刻,许多代码都是游离于表面&am…

基于决策树(Decision Tree)的乳腺癌诊断

决策树(DecisionTree)学习是以实例为基础的归纳学习算法。算法从--组无序、无规则的事例中推理出决策树表示形式的分类规则,决策树也能表示为多个If-Then规则。一般在决策树中采用“自顶向下、分而治之”的递归方式,将搜索空间分为若千个互不相交的子集,在决策树的内部节点(非叶…

Spring事务和事务传播机制

目录 一. Spring 中事务的实现 1. 编程式事务 2. 声明式事务 3. Transaction 参数的设置 4. Transaction 的隔离级别 5. Transaction 的工作原理 二. Spring 事务传播机制 七种事务传播机制 支持当前事务 不支持当前事务 嵌套事务 一. Spring 中事务的实现 1. 编程式事…

mmdetection基于 PyTorch 的目标检测开源工具箱 入门教程

安装环境 MMDetection 支持在 Linux,Windows 和 macOS 上运行。它需要 Python 3.7 以上,CUDA 9.2 以上和 PyTorch 1.8 及其以上。 1、安装依赖 步骤 0. 从官方网站下载并安装 Miniconda。 步骤 1. 创建并激活一个 conda 环境。 conda create --name…

Docker是什么?详谈它的框架、使用场景、优势

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 目录 一、什么是 Docker? 二、Docker 的架构 1、Docker客户端 2、Docker守护进程 3、Docker镜像 4、Docker容器 5、Docker…

【业务功能篇76】微服务网关路由predicates断言条件-filters路由转换地址-跨域问题-多级目录树化层级设计-mybatisPlus逻辑删除

业务开发-基础业务-分类管理 启动renren-fast如果出现如下错误 -Djps.track.ap.dependenciesfalse 添加相关配置即可 分类管理 1.后端分类接口 JDK8特性:https://blog.csdn.net/qq_38526573/category_11113126.html 在后端服务中我们需要查询出所有的三级分类信…