大模型增量预训练新技巧:解决灾难性遗忘

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))layer_cnt = 0output = {}
for i in range(original_layers):for k in ckpt:if ('layers.' + str(i) + '.') in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1if (i+1) % split == 0:for k in ckpt:if ('layers.' + str(i) + '.') in k:if 'down_proj' in k or 'o_proj' in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])else:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1assert layer_cnt==layers
for k in ckpt:if not 'layers' in k:output[k] = ckpt[k]torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/666106.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git 如何修改仓库地址

问题背景:组内更换大部门之后,代码仓的地址也迁移了,所以原来的git仓库地址失效了。 虽然重新建一个新的文件夹,再把每个项目都git clone一遍也可以。但是有点繁琐,而且有的项目本地还有已经开发一半的代码&#xff0c…

遗失的源代码之回归之路的探索与实践

背景 最近比较突然被安排接手一个项目,该项目的情况如下 原生和RN结合的混合开发模式组件化开发,有很多基础组件以及业务组件但是在梳理项目依赖时发现了个别组件源码不全的情况,于是写了个cli用于对比两个版本产物文件,生成差异结果以便于快速进行源码找回恢复。 结果如下…

【lesson9】高并发内存池Page Cache层释放内存的实现

文章目录 Page Cache层释放内存的流程Page Cache层释放内存的实现 Page Cache层释放内存的流程 如果central cache释放回一个span,则依次寻找span的前后page id的没有在使用的空闲span,看是否可以合并,如果合并继续向前寻找。这样就可以将切…

05、全文检索 -- Solr -- Solr 全文检索之图形界面的文档管理(文档的添加、删除,如何通过关键字等参数查询文档)

目录 Solr 全文检索之文档管理添加文档使用 JSON 添加文档:使用 XML 添加文档: 删除文档使用 JSON 删除文档:使用 XML 删除文档: 查询文档查询文档的详细参数fq(Filter Query):过滤sort:排序sta…

[Linux 进程(六)] 写时拷贝 - 进程终止

文章目录 1、写时拷贝2、进程终止2.1 进程退出场景2.1.1 退出码2.1.2 错误码错误码 vs 退出码2.1.3 代码异常终止引入 2.2 进程常见退出方法2.2.1 exit函数2.2.2 _exit函数 本片我们主要来讲进程控制,讲之前我们先把写时拷贝理清,然后再开始讲进程控制。…

[office] 在Excel2010中设定某些单元格数据不参与排序的方法介绍 #其他#知识分享#笔记

在Excel2010中设定某些单元格数据不参与排序的方法介绍 在Excel中排序,相信大家都会了,直接将一组数据按照从小到大或者从大到小进行排序,但是,现在要求我们规定其中几组数据不进行排序,只排序其余的部分。又该如何操作…

ruoyi(若依)(el-menu也可参考)菜单栏过长显示省略号才显示气泡

一、背景 若依前后端分离的版本,新版本中优化了菜单名称过长悬停显示标题,但是是悬浮所有长度大于5的标题。可以查看提交记录:https://gitee.com/y_project/RuoYi-Cloud/commit/99932d91c0144da9f34f5bb05683cc0b86303217 但是我希望是只悬浮…

VC++中使用OpenCV绘制直线、矩形、圆和文字

VC中使用OpenCV绘制直线、矩形、圆和文字 在VC中使用OpenCV绘制直线、矩形、圆和文字非常简单,分别使用OpenCV中的line、rectangle、circle、putText这四个函数即可。具体可以参考OpenCV官方文档:https://docs.opencv.org/4.x/index.html 下面的代码展…

nodejs express中使用连接池或者MySQL链接数据库出现Cannot read property ‘query‘ of undefined报错

1.如果你已经排除了数据库的启动状态原因和本地服务是否启动的原因 2.不妨看看你是否没有排查其他的数据库,我就是一直在排查第一个主数据库,却忘了我还连接了第二个数据库,就是第二个数据库的原因,出现这个错误。 3.我们可以通…

「 CISSP学习笔记 」08. 安全运营

该知识领域涉及如下考点,具体内容分布于如下各个子章节: 理解并遵守调查执行记录和监控活动执行配置管理 (CM)(例如,预配、基线、自动化)应用基本的安全操作概念应用资源保护执行事故管理执行和维护检测和预防措施实施…

我们使用的IPv4耗尽(We‘re running out of IPv4)

IPv4(Internet Protocol version 4)是互联网上使用最广泛的网络层协议之一,于1981年在 RFC 791 中发布,它定义了 32 位的IP地址结构和基本的协议操作。 由于 IPv4 使用 32 位的地址,因此只有四十亿(4,294,967,296,2^32)个地址。 这就导致随着地址不断被分配,IPv4 地…

未来电话呼叫技术的社会影响与发展趋势----云微呼

未来电话呼叫技术将以更为智能化、便捷化和个性化为主要发展趋势,其所带来的社会影响也将是多层面的。以下将探讨未来电话呼叫技术可能的发展趋势以及对社会的影响: 智能化助力生活便捷: 未来电话呼叫技术将更加智能化,通过人工智…

uniapp 组件封装

1. uniapp 组件封装时间戳格式化为星期 1.1. components/m-week.vue <template><text>{{week}}</text> </template> <script>export default {props: {time: String},mounted(e) {this.week this.getWeek(Number(this.time))},data() {return …

车载测试Vector工具——基于DoIP的ECU/车辆的连接故障排除

车载测试Vector工具——基于DoIP的ECU/车辆的连接故障排除 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师(Wechat:gongkenan2013)。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和…

【考研408】计算机网络笔记

文章目录 计算机网络体系结构计算机网络概述计算机网络的组成计算机网络的功能计算机网络的分类计算机网络的性能指标课后习题 计算机网络体系结构与参考模型计算机网络协议、接口、服务的概念ISO/OSI参考模型和TCP/IP模型课后习题 物理层通信基础基本概念奈奎斯特定理与香农定…

PyCharm / DataSpell 导入WSL2 解析器,实现GPU加速

PyCharm / DataSpell 导入WSL2 解析器的实现 Windows的解析器不好么&#xff1f;设置WSL2和实现GPU加速为PyCharm / DataSpell 设置WSL解析器设置Interpreter Windows的解析器不好么&#xff1f; Windows上的解析器的确很方便&#xff0c;也省去了我们很多的麻烦。但是WSL2的解…

cesium-水平测距

cesium测量两点间的距离 <template><div id"cesiumContainer" style"height: 100vh;"></div><div id"toolbar" style"position: fixed;top:20px;left:220px;"><el-breadcrumb><el-breadcrumb-item&…

React16源码: React中处理hydrate的核心流程源码实现

hydrate 1 &#xff09;概述 hydrate 在react当中不算特别重要, 但是很多时候会用到的一个API这个 API 它主要作用就是在进入第一次渲染的时候&#xff0c;如果本身 dom 树上面已经有一个dom结构存在是否可以去利用这一部分已经存在的dom&#xff0c;然后去避免掉在第一次渲染…

[晓理紫]CCF系列会议截稿时间订阅

关注{晓理紫|小李子}&#xff0c;每日更新CCF系列会议信息&#xff0c;如感兴趣&#xff0c;请转发给有需要的同学&#xff0c;谢谢支持&#xff01;&#xff01; 如果你感觉对你有所帮助&#xff0c;请关注我&#xff0c;每日准时为你推送最新会议信息。 SAC (CCF C) Select…