TabR:检索增强能否让深度学习在表格数据上超过梯度增强模型?

这是一篇7月新发布的论文,他提出了使用自然语言处理的检索增强Retrieval Augmented技术,目的是让深度学习在表格数据上超过梯度增强模型。

检索增强一直是NLP中研究的一个方向,但是引入了检索增强的表格深度学习模型在当前实现与非基于检索的模型相比几乎没有改进。所以论文作者提出了一个新的TabR模型,模型通过增加一个类似注意力的检索组件来改进现有模型。据说,这种注意力机制的细节可以显著提高表格数据任务的性能。TabR模型在表格数据上的平均性能优于其他DL模型,在几个数据集上设置了新的标准,在某些情况下甚至超过了GBDT模型,特别是在通常被视为GBDT友好的数据集上。

TabR

表格数据集通常被表示为特征和标签对{(xi, yi)},其中xi和yi分别是第i个对象的特征和标签。一般有三种类型的主要任务:二元分类、多类分类和回归。

对于表格数据我们会将数据集分为训练部分、验证部分和测试部分,模型对“输入”或“目标”对象进行预测。当使用检索技术时,检索是在一组“上下文候选”或“候选”中完成的,被检索的对象称为“上下文对象”或简称为“上下文”。同一组候选对象用于所有输入对象。

论文的实验设置涉及调优和评估协议,其中需要超参数调优和基于验证集性能的早期停止。然后在15个随机种子的平均测试集上测试最佳超参数,并在算法比较中考虑标准偏差。

论文作者的目标是将检索功能集成到传统的前馈网络中。该过程包括通过编码器传递目标对象及其上下文候选者,然后检索组件会对目标对象进行的表示,最后预测器进行预测。

编码器和预测器模块很简单简单,因为它们不是工作的重点。检索模块对目标对象的表示以及候选对象的表示和标签进行操作。这个模块可以看作是注意力机制的一般化版本。

这个过程包括几个步骤:

  • 如果编码器包含至少一个块,则将表示进行规范化;
  • 根据与目标对象的相似性定义上下文对象;
  • 基于softmax函数对上下文对象的相似性分配权重;
  • 定义上下文对象的值;
  • 使用值和权重输出加权聚合。

上下文大小设置为一个较大的值96,softmax函数会自动选择有效的上下文大小。

检索模块是最重要的部分

作者探讨了检索模块的不同实现,特别是相似度模块和值模块。并且说明了是通过一下几个步骤得到最终的模型。

1、作者评估了传统注意力的相似性和值模块,发现该配置与多层感知器(MLP)相似,因此不能证明使用检索组件是合理的。

2、然后他们将上下文标签添加到值模块中,但发现这并没有改进,这表明传统注意力的相似性模块可能是瓶颈。

3、为了改进相似度模块,作者删除了查询的概念,并用L2距离替换点积。这种调整使得几个数据集上性能的显著跃升。

4、值模块也进行改进,灵感来自最近提出的DNNR(用于回归问题的kNN算法的广义版本)。新的值模块带来了进一步的性能改进。

5、最后,作者创建模型TabR。在相似性模块中省略缩放项,不包括目标对象在其自身的上下文中(使用交叉注意),平均而言会得到更好的结果。

生成的TabR模型为基于检索的表格深度学习问题提供了一种健壮的方法。

作者也强调了TabR模型的两个主要局限性:

与所有检索增强模型一样,从应用程序的角度来看,使用真实的训练对象进行预测可能会带来一些问题,例如隐私和道德问题。

TabR的检索组件虽然比以前的工作更有效,但会产生明显的开销。所以它可能无法有效地扩展以处理真正的大型数据集。

实验结果

作者将TabR与现有的检索增强解决方案和最先进的参数模型进行比较。除了完全配置的TabR,他们还使用了一个简化版本,TabR- s,它不使用特征嵌入,只有一个线性编码器和一个块预测器。

与全参数深度学习模型的比较表明,TabR在几个数据集上优于大多数模型,除了MI数据集,在其他数据集也很有竞争力。在许多数据集上,它比多层感知器(MLP)提供了显著的提升。

与GBDT模型相比,调整后的TabR在几个数据集上也有明显的改进,并且在其他数据集上保持竞争力(除了MI数据集),并且TabR的平均表现也优于GBDT模型。

总之,TabR将自己确立为表格数据问题的强大深度学习解决方案,展示了强大的平均性能,并在几个数据集上设置了新的基准。它的基于检索的方法具有良好的潜力,并且在某些数据集上可以明显优于梯度增强的决策树。

一些研究

1、冻结上下文以更快地训练TabR

在TabR的原始实现中,由于需要对所有候选对象进行编码并计算每个训练批次的相似度,因此在大型数据集上的训练可能很慢。作者提到在完整的“Weather prediction”数据集上训练一个TabR需要18个多小时,该数据集有300多万个对象。

作者注意到在训练过程中,平均训练对象的上下文(即,根据相似度模块S,前m个候选对象及其分布)趋于稳定,这为优化提供了机会。在一定数量的epoch之后,他们提出了一个“上下文冻结”,即最后一次计算所有训练对象的最新上下文,然后在其余的训练中重用。

这种简单的技术可以加速TabR的训练,并且不会在指标上造成重大损失。在上面提到的完整的“Weather prediction”数据集上,它使速度提高了近7倍(将训练时间从18小时9分钟减少到3小时15分钟),同时仍然保持有竞争力的均方根误差(RMSE)值。

2、用新的训练数据更新TabR不需要再训练(初步探索)

在现实世界的场景中,在机器学习模型已经训练完之后,通常会收到新的、看不见的训练数据。作者测试了TabR在不需要再训练的情况下合并新数据的能力,方法是将新数据添加到候选检索集中。

他们使用完整的“Weather prediction”数据集进行了这个测试。结果表明在线更新可以有效地将新数据整合到训练好的TabR模型中。这种方法可以通过在数据子集上训练模型并从完整数据集中检索模型来将TabR扩展到更大的数据集。

3、使用检索组件增强XGBoost

作者试图通过结合类似于TabR中的检索组件来提高XGBoost的性能。这种方法涉及在原始特征空间中找到与给定输入对象最接近的96个训练对象(匹配TabR的上下文大小)。然后对这些最近邻的特征和标签进行平均,将标签按原样用于回归任务,并将其转换为用于分类任务的单一编码。

将这些平均数据与目标对象的特征和标签连接起来,形成XGBoost的新输入向量。但是该策略并没有显著提高XGBoost的性能。试图改变邻居的数量也没有产生任何显著的改善。

总结

深度学习模型在表格类数据上一直没有超越梯度增强模型,TabR还在这个方向继续努力。

如果你对他感兴趣,一下是论文和源代码:

https://avoid.overfit.cn/post/9e8cc5f506af4b368516876e108a62c7

作者:Andrew Lukyanenko

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/19736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL的使用——【初识MySQL】第二节

MySQL的使用——【初识MySQL】第二节 文章目录 MySQL环境变量的配置(如使用Navicat可忽略)使用命令行连接MySQL(如使用Navicat可忽略)步骤注意 NavicatNavicat的下载Navicat的使用连接MySQL新建表 总结总结 MySQL环境变量的配置&a…

SpringCloudAlibaba之Nacos服务的发现与注册中心(二)配置

在nacos中的雪崩保护 和阈值&#xff08;0~1&#xff09;与spring.cloud.nacos.discovery.ephemeral这个配置有关 ephemeral为false&#xff0c;永久实例 epheme为true&#xff0c;临时实例&#xff08;默认&#xff09; 健康实例数/总实例数<保护阈值的时候&#xff0c;…

【秋招】算法岗的八股文之机器学习

目录 机器学习特征工程常见的计算模型总览线性回归模型与逻辑回归模型线性回归模型逻辑回归模型区别 朴素贝叶斯分类器模型 (Naive Bayes)决策树模型随机森林模型支持向量机模型 (Support Vector Machine)K近邻模型神经网络模型卷积神经网络&#xff08;CNN&#xff09;循环神经…

Redis中缓存穿透、击穿、雪崩以及解决方案

Redis中缓存穿透、击穿、雪崩以及解决方案 Redis作为一个高效的内存数据库&#xff0c;提供了缓存能力使得我们能够快速访问数据。然而&#xff0c;在使用Redis作为缓存时&#xff0c;我们可能会面临缓存穿透、缓存击穿和缓存雪崩的问题。接下来&#xff0c;我将详细解释这些现…

【雕爷学编程】MicroPython动手做(28)——物联网之Yeelight 3

知识点&#xff1a;什么是掌控板&#xff1f; 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片&#xff0c;支持WiFi和蓝牙双模通信&#xff0c;可作为物联网节点&#xff0c;实现物联网应用。同时掌控板上集成了OLED…

Stable Diffusion:网页版 体验 / AI 绘图

一、官网地址 Stable Diffusion Online 二、Stable Diffusion AI 能做什么 Stable Diffusion AI绘图是一种基于Stable Diffusion模型的生成式AI技术&#xff0c;能够生成各种类型的图像&#xff0c;包括数字艺术、照片增强和图像修复等。以下是一些可能的应用&#xff1a; …

AX7A200教程(8): HDMI输入和输出显示1080p视频

文章目录 本章节主要将hdmi输入的1080p视频通过ddr3缓存&#xff0c;然后通过hdmi输出口输出到显示屏上显示 一&#xff0c; 突发读写命令 设置读写突发长度为64 //parameter defineparameter WRITE_LENGTH 64;parameter READ_LENGTH 64;parameter IDLE 3d0; …

C#实现结构体与字节流的相互转化

unity项目中&#xff0c;涉及到与C的相互通信&#xff0c;而通信接口为C封好的动态库。所以&#xff0c;传输信息时&#xff0c;需要向C端发送字节流信息。 对此&#xff0c;需将结构体数据转为字节流&#xff0c;其代码如下&#xff1a; public static byte[] StructToBytes(…

如何用DHTMLX组件为Web应用创建甘特图?(二)

dhtmlxGantt是用于跨浏览器和跨平台应用程序的功能齐全的Gantt图表。可满足项目管理应用程序的所有需求&#xff0c;是最完善的甘特图图表库。甘特图仍然是项目管理应用程序中最需要的工具之一&#xff0c;DHTMLX Gantt组件提供了能提升研发甘特图功能所需的重要工具。 在这篇…

Kotlin基础(十):函数进阶

前言 本文主要讲解kotlin函数&#xff0c;之前系列文章中提到过函数&#xff0c;本文是kotlin函数的进阶内容。 Kotlin文章列表 Kotlin文章列表: 点击此处跳转查看 目录 1.1 函数基本用法 Kotlin 是一种现代的静态类型编程语言&#xff0c;它在函数的定义和使用上有一些特点…

什么是互斥锁?怎么运用互斥锁?

1、什么是互斥锁&#xff1f; 互斥锁是&#xff08;Mutex&#xff09;是一种用于多线程编程的同步原语&#xff0c;用于确保在多个线程访问共享资源时的互斥性。 在多线程环境中&#xff0c;当多个线程同时访问共享资源时&#xff0c;可能会导致数据的竞争和不一致问题。为了…

国家之间的标准的语言代码

当涉及到标准的语言代码时&#xff0c;以下是一些常见的国家/地区与其对应的语言代码&#xff1a; 美国英语&#xff1a;en-US英国英语&#xff1a;en-GB加拿大英语&#xff1a;en-CA澳大利亚英语&#xff1a;en-AU法国法语&#xff1a;fr-FR德国德语&#xff1a;de-DE中国普通…

改造 ChatGPT-Next-Web 项目重新生成 Docker 镜像

改造 ChatGPT-Next-Web 项目重新生成 Docker 镜像 0.背景1. 修改代码2. 生成 Docker 镜像3. 上传 Docker 镜像4. 运行 Docker 镜像 0.背景 需要通过 ChatGPT-Next-Web 使用自己搭建的 OpenAI API 兼容的服务器&#xff0c;需要对 ChatGPT-Next-Web 项目的少量代码进行改造。 …

无人机自动返航的关键技术有哪些

无人机的广泛应用使得无人机自动返航技术变得至关重要。在各种应对意外情况的背景下&#xff0c;无人机自动返航技术的发展对确保无人机的安全&#xff0c;以及提高其应用范围具有重要意义。接下来&#xff0c;便为大家详细介绍无人机自动返航所运用到的关键技术。 一、定位与导…

20230802-下载jdk1.8、jre

搜索oracle oracle官网 https://www.oracle.com/cn/

ulog记录(RTTulog部分)

ulog.h int ulog_init(void); int ulog_async_init(void); void ulog_output_lock_enabled(rt_bool_t enabled); void ulog_deinit(void); log初始化、异步初始化、输出锁初始化、log反初始化&#xff1b; #define LOG_E(...) ulog_e(LOG_TAG, __VA_ARGS…

13-1_Qt 5.9 C++开发指南_多线程及QThread 创建多线程程序_ThreadSignal

一个应用程序一般只有一个线程&#xff0c;一个线程内的操作是顺序执行的&#xff0c;如果有某个比较消耗时间的计算或操作&#xff0c;比如网络通信中的文件传输&#xff0c;在一个线程内操作时&#xff0c;用户界面就可能会冻结而不能及时响应。这种情况下&#xff0c;可以创…

如何看待低级爬虫与高级爬虫?

爬虫之所以分为高级和低级&#xff0c;主要是基于其功能、复杂性和灵活性的差异。根据我总结大概有下面几点原因&#xff1a; 功能和复杂性&#xff1a;高级爬虫通常提供更多功能和扩展性&#xff0c;包括处理复杂页面结构、模拟用户操作、解析和清洗数据等。它们解决了开发者…

C++ 通过time.windows.com获取时间

C++ 通过time.windows.com获取时间 在C++中,你可以使用 <ctime>头文件中的 time()函数来获取当前的系统时间。然后,你可以使用 <ctime>头文件中的 localtime()函数将时间转换为本地时间,并从中获取小时、分钟和秒。 以下是一个示例代码,演示如何通过time.windo…

关于hardcoded账号和密码的问题的想法

在编写应用程序时&#xff0c;都会访问一些存储系统获取相关的信息。最简单的例子就是用户登录&#xff0c;需要访问存储的用户和密码&#xff0c;进而可以验证用户是否可以正常登录。为了访问数据库各种框架结构也都提供了相应的方法和接口支持。但是&#xff0c;程序员在实现…