LLM 可以从简单数据中学习吗?

在 10 月份的一次周会结束后,我提到 SFT 训练后的 Loss 曲线呈现阶梯状,至于为什么,并没有人有合理的解释,加上当时的重心是提升次日留存率,Loss 曲线呈现阶梯状与次日留存率的关系还太远,即使有问题,起码次日留存率是逐渐在提升。

幸运的是,在一次逛论坛时发现了一篇博客 Can LLMs learn from a single example?,也是我这篇博客的标题名称由来,在其基础上结合了公司业务的一些现状和我个人的思考。
图:fastchat 框架在 VTuber 数据集上训练全 BOT 回复,在 3 epoch 上的 Loss 曲线。

可以清楚地看到每个 epoch 的终点——loss 突然向下跳。我们以前也见过类似的损失曲线,但都是由于错误造成的。例如,在评估验证集时,很容易意外地让模型继续学习——这样在验证之后,模型就会突然变得更好。因此,开始寻找训练过程中的错误。

发现该“问题”的时间,恰好与单句重复问题同一时期(9 月份),于是推测是不是 context length 从 2k 变到 4k 所致,以及 Transformers 库和 RoPE 位置编码的问题。在开始逐步修改代码的同时,在 Alignment Lab AI Discord 上看到他人反馈的类似的奇怪 loss 曲线,并且每个回复的人也都在使用 Trainer,这在当时加深了我认为 Transformers 库存在问题的猜测,甚至我还去询问了同事李老师是否有同样的问题,以及 load model 时的 warning。

9 月中旬,老板要求我们加上验证 loss,于是出现了如下图所示的 eval loss 曲线。

eval loss 曲线

该问题在 Discord 上讨论得越来越激烈,也有人反映在不使用 Trainer 的情况下,也会出现阶梯状的 loss 曲线。

查阅资料,看到一种假设:即这些训练曲线实际上显示了过拟合。起初,这似乎是不可能的。这意味着模型正在学习识别来自一个或两个示例的输入。如果回过头来看我们展示的第一条曲线,就会发现 loss 在第二和第三个 epoch 期间,它根本没有学习到任何新东西。因此,除了在第一个 epoch 开始时的初始学习(学习了多轮对话的对齐方式)外,几乎所有表面上的学习都是(根据这一理论)对训练集的记忆。此外,对于每个问题,它只能获得极少量的信号:它对答案的猜测与真实标签的比较。

资料提到了一项实验:使用以下学习率计划,对 Kaggle 模型进行了两个 epoch 的训练:

如今,这种 schedule 并不常见,但莱斯利-史密斯(Leslie Smith)在 2015 年发表的论文《训练神经网络的循环学习率》(Cyclical Learning Rates for Training Neural Networks)中讨论了这种方法,并取得了很大成功。

下面就是我们因此而看到的看起来很疯狂的训练和验证损失曲线:

到目前为止,我们唯一能完全解释这种情况的方法就是假设是正确的:模型正在快速学习识别实例,即使只看到一次。让我们依次查看 loss 曲线的各个部分:

  • 从第一个 epoch 来看,这是一条非常标准的 loss 曲线。在第一个 10% 的 epoch 中,学习率开始升温,一旦达到温度后,训练和验证 loss 就会迅速降低;然后按照余弦曲线逐渐下降,两者都会放缓。
  • 第二个 epoch 才是我们感兴趣的地方。我们并没有在 epoch 开始时重新 shuffle 数据集,因此第二个 epoch 的第一批数据是学习率仍在预热的时候。这就是为什么在我们展示的第一条 loss 曲线中,没有看到像从 epoch 2 到 epoch 3 那样的直接阶跃变化——这些批次只有在学习率较低时才会出现,所以它学不到太多东西。在 epoch 2 开始 10% 时,训练 loss 急剧下降,因为在第一个 epoch 中看到这些批次时,学习率很高,模型已经知道了它们的样子,因此它可以非常自信地猜出正确答案。但在此期间,验证 loss 会受到影响。这是因为虽然模型变得非常自信,但实际上它的预测能力并没有提高。它只是记住了数据集(早期没有清洗掉训练数据中的保底回复以及一些涉及到公司信息的关键词,模型会输出这些内容,甚至会将原样的超时保底回复输出),但并没有提高泛化能力。过于自信的预测会导致验证损失变大,因为损失函数会对更自信的错误进行更高的惩罚。
  • 曲线的末端是特别有趣的地方。训练 loss 开始变得越来越大,而这是绝对不应该发生的!事实上,我还从未在使用合理的学习率时遇到过这种情况。根据记忆假说,这完全说得通:这些批次是模型在学习率再次下降时看到的,因此它无法有效地记忆这些批次。但模型仍然过于自信,因为它刚刚得到了一大堆几乎完全正确的批次,还没有适应现在看到的批次没有机会学得那么好这一事实。它会逐渐重新校准到一个更合理的置信度水平,但这需要一段时间,因为学习率越来越低。在重新校准的过程中,验证 loss 会再次下降。
    记忆假说很有可能是真的。按照先前小模型时代的训练经验,我们往往需要大量的数据来让模型学习输入分布和模式。使用随机梯度下降法(SGD)导航的损失面太崎岖,无法一下子跳得很远。不过,有些东西可以让损失面变得更平滑,比如使用残差连接,如经典论文《可视化神经网络的损失景观》(Li et al,2018)中所示。

很可能的情况是,预训练的大语言模型在接近最小损失的区域具有极其平滑的损失面,而开源社区所做的大量微调工作都是在这一区域。这是基于最初开发微调通用语言模型的基本前提。简单来说,我们的训练数据并不能够让模型跳出该平滑的损失面,只是让模型记住了 BOT 的回复、以及通过几个数据就让模型学到了说话风格。

如果以上猜测都属实,这不是什么糟糕的事情,拥有一个学习速度非常快、且能够举一反三的模型是一件非常棒的事情。同时,这也佐证了《LIMA:Less Is More for Alignment》、《A Few More Examples May Be Worth Billions of Parameters》、《Maybe only 0.5% Data is Needed: A Preliminary Exploration of Low Training Data Instruction Tuning》等一系列证明少量优质、多样性丰富的指令数据就能让模型有很强指令遵循的论文的有效性。以及最近出现的一系列关于指令数据集子集选择的论文,例如《Smaller Language Models are capable of selecting Instruction-Tuning Training Data for Larger Language Models》、《LESS: Selecting Influential Data for Targeted Instruction Tuning》。这些论文提到经过他们方法挑选出来的子集,在该子集上训练出来的模型比在全量数据集上微调的模型效果要更好。

我统计了从 7 月 到 11 月份所训练模型的 Loss 曲线是否呈现阶梯状,正常表示平滑下降,不正常表示阶梯下降(在每个 epoch 交界处骤降)。早期训练的模型的 loss 曲线都是正常,可惜的是早期的训练数据被删了,无法准确地判断是数据质量的因素,还是基底模型的因素。

早期训练遵循多阶段的方式,即先在 continual pretrain 得到的 base 模型上用 GPT4all 数据集以及一个闲聊场景的对话集进行训练,然后再用高质量的对话数据集再次微调。以此得到的模型表现平常,虽不会犯错,但也没有新意,不能提升平均对话轮数,因此后续我们不再进行 base model -> GPT4all + 闲聊数据集 -> 高质量对话数据集的多段式 SFT,而是直接在 base model 上用高质量对话数据集进行 SFT。在这之后训练的模型的 loss 曲线都是阶梯状,按照记忆假说和先前分析的内容来看,llama2、vicuna-13b-v1.5 等模型的对话、闲聊能力得到了提升(也有可能是 GPT4all 数据集让模型闲聊能力下降),在我们所认为的“高质量”数据集上进行训练,模型只是记住了对话内容,而非真正意义上地学习(训练数据集对于模型来说非常简单)。

PS:我没有否认和贬低这种方式,当模型的“脑容量”(记忆力)大到能够将我们提供的优质回复都记住,并且在合适的场景输出,这在业务上完全没有问题。在复读机问题上,将高质量数据集从 4k 扩充至 26k 后,的确减少了该问题的频次。

一个猜想:当模型的学习速度如此之快时,灾难性遗忘问题可能会突然变得明显得多。例如,如果一个模型看到了十个非常常见关系的示例,然后又看到了一个不太常见的反例,那么它很可能会记住这个反例,而不仅仅是稍微降低它对原来十个示例的记忆权重。

在 6 月下旬时,老板询问我为什么模型的效果不太好时,我想了想说是灾难性遗忘(找的理由)。现在看来,似乎的确大概率是这个原因。沿着 base model -> GPT4all + 闲聊数据集 -> 高质量数据集训练的路径,希望模型能够不断地进化,但实际上 base model 原先的知识和 GPT4all 数据集中的内容都遗忘得差不多。因此,不要多阶段 SFT,而是将每个阶段的训练数据进行混合,可以减少灾难性遗忘的影响,这或许就是后来尝试数据混合方案后,能够提升次日留存率的一个原因?

此外,我们还需要审视,对于模型来说,什么是高质量的数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

torch.searchsorted

torch.searchsorted 官方文档链接:torch.searchsorted — PyTorch 2.3 documentation 该函数用于在已排序的序列中查找要插入的值的位置,以保持序列的顺序, torch.searchsorted(sorted_sequence, values, *, out_int32False, rightFalse, s…

Python - 金三银四心路历程 之 数据结构与算法 刷题

目录 一.引言 二.心路历程 三.刷题经历 四.刷题历程 五.总结 一.引言 <夜深人静写算法> 是 23 年 12 月底博主打算跳槽时开始做刷题准备做的专栏&#xff0c;前后准备了大约一个月&#xff0c;刷题完毕后简单准备了项目和简历后就开始加入找工作大军了&#xff0c;最…

【机器学习】逻辑化讲清PCA主成分分析

碎碎念&#xff1a;小编去年数学建模比赛的时候真的理解不了主成分分析中的“主成分”的概念&#xff01;&#xff01;但是&#xff0c;时隔两年&#xff0c;在机器学习领域我又行了&#xff0c;终于搞明白了&#xff01;且看正文&#xff01;再分享一个今天听到的播客中非常触…

Web3 Tools - Base58

Base58编码 Base58编码是一种用于表示数字的非常见的编码方法。它通常用于加密货币领域&#xff0c;例如比特币和其他加密货币的地址表示。 什么是Base58编码&#xff1f; Base58编码是一种将数字转换为人类可读形式的编码方法。与常见的Base64编码不同&#xff0c;Base58编码…

Double 4 VR智能互动教学系统在轨道交通客运服务课堂上的应用

一、模拟真实场景&#xff0c;提高教学效果 Double 4 VR智能互动教学系统能够模拟真实的轨道交通客运场景&#xff0c;让学生身临其境地感受客运服务的全过程。通过虚拟现实技术&#xff0c;学生可以在系统内扮演不同的角色&#xff0c;如列车员、站务员、乘客等&#xff0c;亲…

JCR一区 | Matlab实现1D-2D-GASF-CNN-GRU-MATT的多通道输入数据分类预测

JCR一区 | Matlab实现1D-2D-GASF-CNN-GRU-MATT的多通道输入数据分类预测 目录 JCR一区 | Matlab实现1D-2D-GASF-CNN-GRU-MATT的多通道输入数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 基本介绍 Matlab实现1D-2D-GASF-CNN-GRU-MATT的多通道输入数据分类预…

android 蓝牙技术 学习记录 二

android 蓝牙连接 关键类 BluetoothDevice--------------------蓝牙设备 BluetoothGattCallback--------------连接回调 BluetoothGatt----------------------gatt 1. public BluetoothGatt connectGatt(Context context, boolean autoConnect, BluetoothGattCallback call…

Ascent DMS AE电源说明书和设备连接调试教程

Ascent DMS AE电源说明书和设备连接调试教程

世上最全前端开发教程(HTMLCSS)

HTML介绍 HTML&#xff0c;全称为HyperText Markup Language&#xff0c;即超文本标记语言&#xff0c;是一种用来创建网页的标准标记语言。HTML使用一系列的标签&#xff08;Tags&#xff09;来定义网页的不同部分和它们的行为&#xff0c;比如段落、链接、图片等。 CSS介绍 …

《这就是ChatGPT》读书笔记

书名&#xff1a;这就是ChatGPT 作者&#xff1a;[美] 斯蒂芬沃尔弗拉姆&#xff08;Stephen Wolfram&#xff09; ChatGPT在做什么&#xff1f; ChatGPT可以生成类似于人类书写的文本&#xff0c;它基本任务是弄清楚如何针对它得到的任何文本产生“合理的延续”。当ChatGPT写…

如何对Java对象数组多个属性值进行汇总

1、问题 最近在做报表统计相关的任务&#xff0c;中间涉及到很多的地方&#xff0c;需要同时使用SQL进行数据汇总或者在内存进行数据汇总&#xff0c;在内存汇总的时候&#xff0c;遇到一个场景&#xff0c;就是对Java对象数组多个属性值进行汇总。 最后认为方法三使用反射进…

数据库基础语法二

一、数据库 1、登陆数据库 2、创建数据库zoo 3、修改数据库zoo字符集为gbk 4、选择当前数据库为zoo 5、查看创建数据库zoo信息 6、删除数据库zoo mysql -uroot -p #登陆数据库 create database zoo; #创建数据库zoo alter database zoo character set gbk collate gbk_…

Android 12.0 TvSettings系统设置wifi连接密码框点击Enter键失去焦点

1.前言 在12.0的系统box产品开发中,在TvSettings中,在wifi连接的时候,在用遥控器输入wifi密码框的时候,会发现在按遥控器Enter键的时候, 发现EditText焦点失去了,导致输入法消失了,为了解决这个问题就需要拦截Enter键保证正常输入wifi密码,接下来就来实现这个功能 如图…

实用的Chrome命令 帮你打开Chrome浏览器的隐藏功能

前言 Chrome作为主力浏览器&#xff0c;支持相当丰富的第三方扩展&#xff0c;其实浏览器本身也内置了大量实用的命令。许多实用的功能并没有直接显示在Chrome的菜单上。在这篇文章中&#xff0c;我们将介绍几个实用的chrome:// commands。 通过下面整理的 Chrome 命令&#x…

uboot Ethernet初始化

目录 initr_net eth_initialize eth_common_init phy_init genphy_init uclass_first_device_check eth_write_hwaddr initr_net

Liunx-Tcp Wrapper-访问控制工具

目录 主要功能 工作原理 配置示例 允许所有来自特定 IP 范围的访问 拒绝所有其他访问 允许特定服务来自特定域名的访问 拒绝特定 IP 地址的访问 注意事项 应用范围 查看某一个程序是否支持 TCP wrapper Telnet 用途 1. 远程登录 2. 网络诊断 3. 测试网络服务 安…

LVS(Linux Virtual Server)知识点详解

引言 LVS&#xff08;Linux Virtual Server&#xff09;是一种虚拟服务器技术&#xff0c;它主要用于实现多台服务器的负载均衡。 LVS的核心作用是通过负载均衡技术将客户端的请求分发到不同的服务器上进行处理&#xff0c;以此来提高整体的服务能力和资源的利用效率。在构建高…

什么是Unreal Engine游戏引擎?它有什么优势?

大家好&#xff0c;我是咕噜土豆&#xff0c;很高兴又和大家见面了。在游戏开发行业中&#xff0c;选择合适的游戏引擎是非常重要的。其中&#xff0c;Unreal Engine作为一款功能强大的游戏引擎&#xff0c;在业界非常受欢迎。今天我带大家简单的了解一下。 什么是Unreal Engi…

基于STM32移植lvgl(V8.2)(SPI接口的LCD)

目录 概述 1 认识LVGL 1.1 LVGL官网 1.2 LVGL库文件下载 2 认识SPI接口型LCD 2.1 PIN引脚定义 2.2 MCU IO与LCD PIN对应关系 3 实现LCD驱动 3.1 使用STM32Cube配置Project 3.2 STM32Cube生成工程 4 移植LVGL 4.1 准备移植文件 4.2 添加lvgl库文件到项目 4.2.1 src下…

【Pytorch】6.torch.nn.functional.conv2d的使用

阅读之前应该先了解基础的CNN网络的逻辑 conv2d的作用 是PyTorch中用于执行二维卷积操作的函数。它的作用是对输入数据进行二维卷积操作&#xff0c;通常用于图像处理和深度学习中的卷积神经网络&#xff08;CNN&#xff09;模型。 conv2d的使用 我们先查看一下官方文档 inpu…