大模型增量预训练新技巧-解决灾难性遗忘

大模型增量预训练新技巧-解决灾难性遗忘

机器学习算法与自然语言处理 2024年03月21日 00:02 吉林

以下文章来源于NLP工作站 ,作者刘聪NLP

NLP工作站.

AIGC前沿知识分享&落地经验总结

转载自 | NLP工作站

作者 | 刘聪NLP

目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧,详见:

  • 如何更好地继续预训练-Continue PreTraining

  • 领域大模型-训练Trick&落地思考

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

 

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的L个Transformer块分成N组,每组中包含M=L/N个Transformer块,对于每组后添加P个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))layer_cnt = 0output = {}
for i in range(original_layers):for k in ckpt:if ('layers.' + str(i) + '.') in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1if (i+1) % split == 0:for k in ckpt:if ('layers.' + str(i) + '.') in k:if 'down_proj' in k or 'o_proj' in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])else:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1assert layer_cnt==layers
for k in ckpt:if not 'layers' in k:output[k] = ckpt[k]torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 P=1,M=4,N=8,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时

ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当

图片

代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部需要分开插入

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42029.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

G1 和 CMS

1、CMS CMS(Concurrent Mark Sweep,并发标记清除,是为了解决早期垃圾收集器在执行垃圾回收时导致应用程序暂停时间过长的问题而设计的。 CMS的工作流程主要包括以下几个阶段: 初始标记(Initial Mark)&…

一体化运维监控平台:赋能各行业用户运维升级

在当今数字化转型的大潮中,企业IT系统的复杂性和规模不断攀升,对运维团队提出了前所未有的挑战。如何高效、精准地监控和管理IT基础设施,确保业务连续性和稳定性,成为所有企业关注的焦点。美信,自2007年成立以来&#…

el-scrollbar实现自动滚动到底部(AI聊天)

目录 项目背景 实现步骤 实现代码 完整示例代码 项目背景 chatGPT聊天消息展示滚动面板,每次用户输入提问内容或者ai进行流式回答时需要不断的滚动到底部确保展示最新的消息。 实现步骤 采用element ui 的el-scrollbar作为聊天消息展示组件。 通过操作dom来实…

端、边、云三级算力网络

目录 端、边、云三级算力网络 NPU Arm架构 OpenStack kubernetes k3s轻量级Kubernetes kubernetes和docker区别 DCI(Data Center Interconnect) SD/WAN TF 端、边、云三级算力网络 算力网络从传统云网融合的角度出发,结合 边缘计算、网络云化以及智能控制的优势,通…

Qt开发 | Qt创建线程 | Qt并发-QtConcurrent

文章目录 一、Qt创建线程的三种方法二、Qt并发:QtConcurrent介绍三、QtConcurrent run参数说明四、获取QtConcurrent的返回值五、C其他线程技术介绍 一、Qt创建线程的三种方法 以下是Qt创建线程的三种方法: 方法一:派生于QThread 派生于QThre…

理解算法复杂度:空间复杂度详解

引言 在计算机科学中,算法复杂度是衡量算法效率的重要指标。时间复杂度和空间复杂度是算法复杂度的两个主要方面。在这篇博客中,我们将深入探讨空间复杂度,了解其定义、常见类型以及如何进行分析。空间复杂度是衡量算法在执行过程中所需内存…

ceph mgr [errno 39] RBD image has snapshots (error deleting image from trash)

ceph mgr 报错 debug 2024-07-08T09:25:56.512+0000 7f9c63bd2700 0 [rbd_support INFO root] execute_task: task={"sequence": 3, "id": "260b9fee-d567-4301-b7eb-b1fe1b037413", "message": "Removing image replicapool/8…

昇思25天学习打卡营第19天|Diffusion扩散模型

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com) Diffusion扩散模型 本文基于Hugging Face:The Annotated Diffusion Model一文翻译迁移而来,同时参考了由浅入深了解Diffusion Model一文。 本教程在Jupyter Notebook上成…

python库 - missingno

missingno 是一个用于可视化和分析数据集中缺失值的 Python 库。它提供了一系列简单而强大的工具,帮助用户直观地理解数据中的缺失模式,从而更好地进行数据清洗和预处理。missingno 库特别适用于数据分析和数据科学项目,尤其是在处理缺失数据…

昇思MindSpore学习笔记5-02生成式--RNN实现情感分类

摘要: 记录MindSpore AI框架使用RNN网络对自然语言进行情感分类的过程、步骤和方法。 包括环境准备、下载数据集、数据集加载和预处理、构建模型、模型训练、模型测试等。 一、概念 情感分类。 RNN网络模型 实现效果: 输入: This film is terrible 正…

放大镜案例

放大镜 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>商品放大镜</title><link rel&qu…

如何使用allure生成测试报告

第一步下载安装JDK1.8&#xff0c;参考链接JDK1.8下载、安装和环境配置教程-CSDN博客 第二步配置allure环境&#xff0c;参考链接allure的安装和使用(windows环境)_allure windows-CSDN博客 第三步&#xff1a; 第四步&#xff1a; pytest 查看目前运行的测试用例有无错误 …

如何使用 pytorch 创建一个神经网络

我已发布在&#xff1a;如何使用 pytorch 创建一个神经网络 SapientialM.Github.io 构建神经网络 1 导入所需包 import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms2 检查GPU是否可用 dev…

ffmpeg滤镜创建过程

1、创建一个滤镜图 AVFilterGraph *filter_graph avfilter_graph_alloc(); 2、创建滤镜的输入和输出 AVFilterInOut *inputs avfilter_inout_alloc(); AVFilterInOut *outputs avfilter_inout_alloc(); 3、每个滤镜创建上下文 AVFilterContext *filter1_ctx avfilter_…

Yolov10训练,转化onnx,推理

yolov10对于大目标的效果好&#xff0c;小目标不好 一、如果你训练过yolov5&#xff0c;yolov8&#xff0c;的话那么你可以直接用之前的环境就行 目录 一、如果你训练过yolov5&#xff0c;yolov8&#xff0c;的话那么你可以直接用之前的环境就行 二、配置好后就可以配置文件…

android webview 远程调试

打开远程调试选项 MainActivity super.onCreate(savedInstanceState);// enable Cordova apps to be started in the backgroundBundle extras getIntent().getExtras();if (extras ! null && extras.getBoolean("cdvStartInBackground", false)) {moveT…

前端JS特效第24集:jquery css3实现瀑布流照片墙特效

jquery css3实现瀑布流照片墙特效&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8" /> <title>jquerycss3实现瀑…

Nginx:负载均衡小专题

运维专题 Nginx&#xff1a;负载均衡小专题 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

在Conda环境中高效使用Kubernetes:跨平台容器化实践指南

摘要 Conda 是一个流行的跨平台包和环境管理器&#xff0c;广泛用于Python社区。而 Kubernetes 是一个开源的容器编排系统&#xff0c;用于自动化部署、扩展和管理容器化应用程序。本文将探讨如何在 Conda 环境中使用 Kubernetes&#xff0c;包括设置 Conda 环境、容器化应用程…

【专项刷题】— 位运算

常见类型介绍&#xff1a; & &#xff1a;有 0 就是 0 | &#xff1a;有 1 就是 1 ^ &#xff1a;相同为 0 &#xff0c;相异为 1 或者 无进位相加给定一个数确定它的二进制位的第x个数是0还是1&#xff1a;将一个数的二进制的第x位改成1&#xff1a;将一个数的二进制的第x…