【Text2SQL】WikiSQL 数据集与 Seq2SQL 模型

论文:Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning

⭐⭐⭐⭐⭐

ICLR 2018

Dataset: github.com/salesforce/WikiSQL

Code:Seq2SQL 模型实现

一、论文速读

本文提出了 Text2SQL 方向的一个经典数据集 —— WikiSQL,同时提出了一个模型 Seq2SQL,用于把自然语言问句转为 SQL。

WikiSQL 数据集中的 SQL 形式较为简单,不包括排序(order by)、分组(group by)、子查询等其他复杂操作。根据这种简单的形式,本文的 Seq2SQL 模型针对一个 table 和一个 question,预测出 SELECT 部分、Aggregation 部分和 WHERE 部分,并将其构造成一个 SQL 语句。下图展示了一个示例:

在这里插入图片描述

Seq2SQL 基于 Augmented Pointer Network 来实现,下面先介绍一下这个网络结构,然后再介绍基于此来实现 Seq2SQL 模型。

二、Augmented Pointer Network(增广指针网络)

Augmented Pointer Network 能够从输入序列中选择 token 并逐个 token 生成输出序列。

对于一个 example,输入序列 x x x 是由"table 的列名"、“SQL 词汇表”、"question"三者用特殊分隔符拼接起来的序列:

在这里插入图片描述

比如在前面图片的示例中,列名 token 包括 “Pick”、“#”、“CFL” 等等组成,question token 包括 “How”、“many”、“CFL” 等等,SQL 词汇表包括 “SELECT”、“WHERE”、“COUNT”、“MIN” 等等。

这个网络首先对 input sequence x x x 做 word embedding,然后输入给两层的 Bi-LSTM 做编码得到 h e n c h^{enc} henc,其中 input 的第 i 个 token 的编码是 h t e n c h_t^{enc} htenc,这样每个 token 经过编码都变成了一个 vector。

解码器部分使用双层的单向 LSTM,每一步生成一个 token。具体生成方式是:使用上一步生成的 token y s − 1 y_{s-1} ys1 作为输入,输出一个 state g s g_s gs,然后拿 g s g_s gs 与 input sequence 的每个位置 t 的 h t h_t ht 做计算得到一个标量的注意力分数 α s , t p t r \alpha_{s,t}^{ptr} αs,tptr,选择分数最高的对应的输入 token 作为生成的下一个 token。其中注意力分数的计算公式如下:

20240518155338

三、Seq2SQL 模型

虽然可以直接训练 Augmented Pointer Network 让他生成 SQL 序列作为结果,但是这没有利用 SQL 本身固有的结构。本论文固定 SQL 的结构由三部分组成:SELECT、WHERE 和 Aggregation,并训练三个组件来分别生成这三部分:

在这里插入图片描述

3.1 Aggregation Classifier

他就是一个 classifier,最终输出一个 softmax 计算后的分布,从 NULLMAXMINCOUNTSUMAVG 中做分类,NULL 表示没有 aggregation 操作。其 loss L a g g L^{agg} Lagg 使用 cross entropy 来计算。

比如,“How many” 类型的 question 往往被分类为 COUNT

3.2 SELECT column prediction

SELECT column prediction 是一个匹配问题,这里使用指针网络的思想来解决:输入列名序列和 question 的拼接,输出与 question 最匹配的一个 column。

首先使用 LSTM 对每一列进行编码,column j j j 对应一个 vector e j c e_j^c ejc,然后对 input x x x 编码出一个 vector κ s e l \kappa^{sel} κsel,然后使用 MLP,计算 input representation κ s e l \kappa^{sel} κsel 与每一个 column j 的分数 α j s e l \alpha^{sel}_{j} αjsel,之后使用 softmax 对分数进行归一化:

  • 训练时,使用交叉熵损失 L s e l L^{sel} Lsel 来训练该模块
  • 预测时,选分数最大的 column 作为预测结果

对于输入 x x x 编码为 input representation 和计算分数的详细信息可以参考论文和代码实现

3.3 WHERE Clause

这里使用类似于 Augmented Pointer Network 的 pointer decoder 来训练这一模块。但是使用 cross entropy 有一个限制:两个 WHERE 条件可以被交换并产生相同结果。但两个顺序不同的 WHERE 会被 cross entropy 错误地惩罚,比如 year>18 and male=1male=1 and year>18 是等价的,但由于 cross entropy 是精确匹配 tokens,导致这个结果会被计算损失。

这里使用强化学习(RL)来训练, q ( y ) q(y) q(y) 是生成的查询, q g q_g qg 是真实查询,奖励函数的定义如下:

20240518171120

并根据此奖励函数计算出 loss L w h e L^{whe} Lwhe

3.4 Seq2SQL 的训练

设置一个混合损失函数 L = L a g g + L s e l + L w h e L = L^{agg} + L^{sel} + L^{whe} L=Lagg+Lsel+Lwhe,并使用梯度下降来最小化该 loss 从而训练模型。

四、WikiSQL 数据集

该文更重要的一个贡献是提供了一个 WikiSQL 数据集,包含 80654 条样本和 24241 个 schema。这些数据被随机划分为 train、dev 和 test 三个 split。

下面是一个 example:

20240518173309

解释如下:

  • phase: the phase in which the dataset was collected. We collected WikiSQL in two phases.
  • question: the natural language question written by the worker.
  • table_id: the ID of the table to which this question is addressed.
  • sql: the SQL query corresponding to the question. This has the following subfields:
    • sel: the numerical index of the column that is being selected. You can find the actual column from the table.
    • agg: the numerical index of the aggregation operator that is being used. You can find the actual operator from Query.agg_ops in lib/query.py.
    • conds: a list of triplets (column_index, operator_index, condition) where:
      • column_index: the numerical index of the condition column that is being used. You can find the actual column from the table.
      • operator_index: the numerical index of the condition operator that is being used. You can find the actual operator from Query.cond_ops in lib/query.py.
      • condition: the comparison value for the condition, in either string or float type.

同时还给出了每个 table 的 schema 和数据部分。

五、评估指标

  • N N N:数据集的样本总数
  • N e x N_{ex} Nex:运行生成的 SQL 后,得到正确结果的样本数
  • N l f N_{lf} Nlf:生成的 SQL 与 ground-truth SQL 字符串完全精确匹配的样本数

由此提出两个指标:

  • A C C e x = N e x / N ACC_{ex} = N_{ex} / N ACCex=Nex/N执行精度指标,如果生成的 SQL 与 ground-truth SQL 的执行结果相同,那就算作正确。存在一个缺点:如果构造一个错误的 SQL 但执行结果正确,依然被算作正确
  • A C C l f = N l f / N ACC_{lf} = N_{lf} / N ACClf=Nlf/N逻辑形式的精确指标,如果生成的 SQL 与 ground-truth SQL 完全匹配,才被算作正确。存在一个缺点:两个等价但写法不同的 SQL 会被算作错误

六、总结

这篇论文给出了一个 WikiSQL 数据集,并提出了 Text2SQL 的一个解决方案以及评价指标。

但是很明显,该方案存在不少缺点,之后的方案会继续改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/840410.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux--10---安装JDK、MySQL

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 安装JDK[Linux命令--03----JDK .Nginx. 数据库](https://blog.csdn.net/weixin_48052161/article/details/108997148) 第一步 查询系统中自带的JDK第二步 卸载系统中…

刀片式服务器是什么?

什么是刀片式服务器? 刀片式服务器是服务器的一种,能够在标准高度的机架式机箱中插装多个卡式的服务器单元,是专门为特殊应用行业和高密度计算环境专门设计的,主要的结构是一大型主体机箱,内部可以插入许多“刀片”。 …

Unity Physics入门

概述 在unity中物理属性是非常重要的,它可以模拟真实物理的效果在unity中,其中的组件是非常多的,让我们来学习一下这部分的内容吧。 Unity组件入门篇总目录----------点击导航 Character Controller(角色控制) 说明:组件是Unity提…

华为编程题目(实时更新)

1.大小端整数 计算机中对整型数据的表示有两种方式:大端序和小端序,大端序的高位字节在低地址,小端序的高位字节在高地址。例如:对数字 65538,其4字节表示的大端序内容为00 01 00 02,小端序内容为02 00 01…

Java数据结构与算法(平衡二叉树)

前言 平衡二叉树是为了提高二叉树的查询速度,通过满足特定的条件来保持其平衡性。平衡二叉树具有以下特点: 左子树和右子树的高度差不会大于1,这是为了确保树的高度不会过大,从而减少查询时的磁盘I/O开销,提高查询速…

【开源】史上最全的JAVA面试题总结

史上最全的JAVA面试题总结 为什么要做这件事情前言JAVA基础开发框架springSpringMVCmybatisdubbospringbootspringcloudnacos 数据库mysqloracle 缓存redismongodbElasticSearch 消息队列rabbitmqrocketmqkafka 监控prometheusgraylogzabbix 工具篇tcpdumpgitjenkins 容器docke…

【案例分享】医疗布草数字化管理系统:聚通宝赋能仟溪信息科技

内容概要 本文介绍了北京聚通宝科技有限公司与河南仟溪信息科技有限公司合作开发的医疗布草数字化管理系统。该系统利用物联网技术实现了医疗布草生产过程的实时监控和数据分析,解决了医疗布草洗涤厂面临的诸多挑战,包括人工记录、生产低效率和缺乏实时…

SpringBoot RPM制作

安装依赖 [root20240423-instance4 ~]# yum install rpmdevtools2.初始化目录 [root20240423-instance4 ~]# rpmdev-setuptree [root20240423-instance4 ~]# tree rpmbuild/ rpmbuild/ ├── BUILD ├── RPMS ├── SOURCES ├── SPECS └── SRPMS5 directories, 0 …

DNF手游攻略:角色培养与技能搭配!游戏辅助!

角色培养和技能搭配是《地下城与勇士》中提升战斗力的关键环节。每个职业都有独特的技能和发展路线,合理的属性加点和技能搭配可以最大化角色的潜力,帮助玩家在各种战斗中立于不败之地。接下来,我们将探讨如何有效地培养角色并搭配技能。 角色…

进程通信,队列,管道

【一】进程通信 1.多个进程之间的信息交换过程 2.如何实现(1)消息队列:把信息从一端放入队列中,另一个进程从另一端将消息取出非阻塞的,即发送进程不需要等待接收进程的响应即可继续执行。(2)管道:半双工的通信机制,同…

架构每日一学 11:快手高级副总裁给年轻人的几点建议

文章首发于公众平台:腐烂的橘子 于冰毕业于清华大学,从 05 年开始接触音视频领域,到现在已经在垂直行业深耕将近 20 年。先后经历了两次创业,曾在 Hulu、FreeWheel 等公司专攻音视频领域,现任快手高级副总裁。 作为一…

JavaEE之线程(9) _定时器的实现代码

前言 定时器也是软件开发中的一个重要组件. 类似于一个 “闹钟”。 达到一个设定的时间之后,就执行某个指定好的代码,比如: 在受上述场景中,当客户端发出去请求之后, 就要等待响应,如果服务器迟迟没有响应&…

新能源汽车结构与原理

第一章 新能源汽车概述 1.1 电动汽车及新能源汽车定义 新能源汽车是指采用非常规的车用燃料作为动力来源(或使用常规车用燃料、采用新型车载动力装置),综合车辆的动力控制和驱动方面的先进技术,形成的技术原理先进、具有新技术、…

大小字符判断

//函数int my_isalpha(char c)的功能是返回字符种类 //大写字母返回1,小写字母返回-1.其它字符返回0 //void a 调用my_isalpha(),返回大写,输出*;返回小写,输出#;其它,输出? #inclu…

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器(可能需要花钱) 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 ,使用 Ubuntu 也可以 。这里提供几个安装方法: 电脑安装双系统(不…

深入解析力扣162题:寻找峰值(线性扫描与二分查找详解)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

virtual box ubuntu20 全屏展示

virtual box 虚拟机 ubuntu20 系统 全屏展示 ubuntu20.04 视图-自动调整窗口大小 视图-自动调整显示尺寸 系统黑屏解决 ##设备-安装增强功能 ##进入终端 ##终端打不开,解决方案-传送门ubuntu Open in Terminal打不开终端解决方案-CSDN博客 ##点击cd盘按钮进入文…

YoloV8改进策略:蒸馏改进|MGDLoss|使用蒸馏模型实现YoloV8无损涨点|特征蒸馏

摘要 在本文中,我们成功应用蒸馏策略以实现YoloV8小模型的无损性能提升。我们采用了MGDLoss作为蒸馏方法的核心,通过对比在线和离线两种蒸馏方式,我们发现离线蒸馏在效果上更为出色。因此,为了方便广大读者和研究者应用&#xff…

【RabbitMQ】使用SpringAMQP的Publish/Subscribe(发布/订阅)

Publish/Subscribe **发布(Publish)、订阅(Subscribe):**允许将同一个消息发送给多个消费者 **注意:**exchange负责消息路由,而不是存储,路由失败则消息丢失 常见的**X(exchange–交换机)***类型: Fanout 广播Direc…

【设计模式】JAVA Design Patterns——Callback(回调模式)

🔍目的 回调是一部分被当为参数来传递给其他代码的可执行代码,接收方的代码可以在一些方便的时候来调用它。 🔍解释 真实世界例子 我们需要被通知当执行的任务结束时。我们为调用者传递一个回调方法然后等它调用通知我们。 通俗描述 回调是一…