【Text2SQL】WikiSQL 数据集与 Seq2SQL 模型

论文:Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning

⭐⭐⭐⭐⭐

ICLR 2018

Dataset: github.com/salesforce/WikiSQL

Code:Seq2SQL 模型实现

一、论文速读

本文提出了 Text2SQL 方向的一个经典数据集 —— WikiSQL,同时提出了一个模型 Seq2SQL,用于把自然语言问句转为 SQL。

WikiSQL 数据集中的 SQL 形式较为简单,不包括排序(order by)、分组(group by)、子查询等其他复杂操作。根据这种简单的形式,本文的 Seq2SQL 模型针对一个 table 和一个 question,预测出 SELECT 部分、Aggregation 部分和 WHERE 部分,并将其构造成一个 SQL 语句。下图展示了一个示例:

在这里插入图片描述

Seq2SQL 基于 Augmented Pointer Network 来实现,下面先介绍一下这个网络结构,然后再介绍基于此来实现 Seq2SQL 模型。

二、Augmented Pointer Network(增广指针网络)

Augmented Pointer Network 能够从输入序列中选择 token 并逐个 token 生成输出序列。

对于一个 example,输入序列 x x x 是由"table 的列名"、“SQL 词汇表”、"question"三者用特殊分隔符拼接起来的序列:

在这里插入图片描述

比如在前面图片的示例中,列名 token 包括 “Pick”、“#”、“CFL” 等等组成,question token 包括 “How”、“many”、“CFL” 等等,SQL 词汇表包括 “SELECT”、“WHERE”、“COUNT”、“MIN” 等等。

这个网络首先对 input sequence x x x 做 word embedding,然后输入给两层的 Bi-LSTM 做编码得到 h e n c h^{enc} henc,其中 input 的第 i 个 token 的编码是 h t e n c h_t^{enc} htenc,这样每个 token 经过编码都变成了一个 vector。

解码器部分使用双层的单向 LSTM,每一步生成一个 token。具体生成方式是:使用上一步生成的 token y s − 1 y_{s-1} ys1 作为输入,输出一个 state g s g_s gs,然后拿 g s g_s gs 与 input sequence 的每个位置 t 的 h t h_t ht 做计算得到一个标量的注意力分数 α s , t p t r \alpha_{s,t}^{ptr} αs,tptr,选择分数最高的对应的输入 token 作为生成的下一个 token。其中注意力分数的计算公式如下:

20240518155338

三、Seq2SQL 模型

虽然可以直接训练 Augmented Pointer Network 让他生成 SQL 序列作为结果,但是这没有利用 SQL 本身固有的结构。本论文固定 SQL 的结构由三部分组成:SELECT、WHERE 和 Aggregation,并训练三个组件来分别生成这三部分:

在这里插入图片描述

3.1 Aggregation Classifier

他就是一个 classifier,最终输出一个 softmax 计算后的分布,从 NULLMAXMINCOUNTSUMAVG 中做分类,NULL 表示没有 aggregation 操作。其 loss L a g g L^{agg} Lagg 使用 cross entropy 来计算。

比如,“How many” 类型的 question 往往被分类为 COUNT

3.2 SELECT column prediction

SELECT column prediction 是一个匹配问题,这里使用指针网络的思想来解决:输入列名序列和 question 的拼接,输出与 question 最匹配的一个 column。

首先使用 LSTM 对每一列进行编码,column j j j 对应一个 vector e j c e_j^c ejc,然后对 input x x x 编码出一个 vector κ s e l \kappa^{sel} κsel,然后使用 MLP,计算 input representation κ s e l \kappa^{sel} κsel 与每一个 column j 的分数 α j s e l \alpha^{sel}_{j} αjsel,之后使用 softmax 对分数进行归一化:

  • 训练时,使用交叉熵损失 L s e l L^{sel} Lsel 来训练该模块
  • 预测时,选分数最大的 column 作为预测结果

对于输入 x x x 编码为 input representation 和计算分数的详细信息可以参考论文和代码实现

3.3 WHERE Clause

这里使用类似于 Augmented Pointer Network 的 pointer decoder 来训练这一模块。但是使用 cross entropy 有一个限制:两个 WHERE 条件可以被交换并产生相同结果。但两个顺序不同的 WHERE 会被 cross entropy 错误地惩罚,比如 year>18 and male=1male=1 and year>18 是等价的,但由于 cross entropy 是精确匹配 tokens,导致这个结果会被计算损失。

这里使用强化学习(RL)来训练, q ( y ) q(y) q(y) 是生成的查询, q g q_g qg 是真实查询,奖励函数的定义如下:

20240518171120

并根据此奖励函数计算出 loss L w h e L^{whe} Lwhe

3.4 Seq2SQL 的训练

设置一个混合损失函数 L = L a g g + L s e l + L w h e L = L^{agg} + L^{sel} + L^{whe} L=Lagg+Lsel+Lwhe,并使用梯度下降来最小化该 loss 从而训练模型。

四、WikiSQL 数据集

该文更重要的一个贡献是提供了一个 WikiSQL 数据集,包含 80654 条样本和 24241 个 schema。这些数据被随机划分为 train、dev 和 test 三个 split。

下面是一个 example:

20240518173309

解释如下:

  • phase: the phase in which the dataset was collected. We collected WikiSQL in two phases.
  • question: the natural language question written by the worker.
  • table_id: the ID of the table to which this question is addressed.
  • sql: the SQL query corresponding to the question. This has the following subfields:
    • sel: the numerical index of the column that is being selected. You can find the actual column from the table.
    • agg: the numerical index of the aggregation operator that is being used. You can find the actual operator from Query.agg_ops in lib/query.py.
    • conds: a list of triplets (column_index, operator_index, condition) where:
      • column_index: the numerical index of the condition column that is being used. You can find the actual column from the table.
      • operator_index: the numerical index of the condition operator that is being used. You can find the actual operator from Query.cond_ops in lib/query.py.
      • condition: the comparison value for the condition, in either string or float type.

同时还给出了每个 table 的 schema 和数据部分。

五、评估指标

  • N N N:数据集的样本总数
  • N e x N_{ex} Nex:运行生成的 SQL 后,得到正确结果的样本数
  • N l f N_{lf} Nlf:生成的 SQL 与 ground-truth SQL 字符串完全精确匹配的样本数

由此提出两个指标:

  • A C C e x = N e x / N ACC_{ex} = N_{ex} / N ACCex=Nex/N执行精度指标,如果生成的 SQL 与 ground-truth SQL 的执行结果相同,那就算作正确。存在一个缺点:如果构造一个错误的 SQL 但执行结果正确,依然被算作正确
  • A C C l f = N l f / N ACC_{lf} = N_{lf} / N ACClf=Nlf/N逻辑形式的精确指标,如果生成的 SQL 与 ground-truth SQL 完全匹配,才被算作正确。存在一个缺点:两个等价但写法不同的 SQL 会被算作错误

六、总结

这篇论文给出了一个 WikiSQL 数据集,并提出了 Text2SQL 的一个解决方案以及评价指标。

但是很明显,该方案存在不少缺点,之后的方案会继续改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/840410.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux--10---安装JDK、MySQL

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 安装JDK[Linux命令--03----JDK .Nginx. 数据库](https://blog.csdn.net/weixin_48052161/article/details/108997148) 第一步 查询系统中自带的JDK第二步 卸载系统中…

Unity Physics入门

概述 在unity中物理属性是非常重要的,它可以模拟真实物理的效果在unity中,其中的组件是非常多的,让我们来学习一下这部分的内容吧。 Unity组件入门篇总目录----------点击导航 Character Controller(角色控制) 说明:组件是Unity提…

华为编程题目(实时更新)

1.大小端整数 计算机中对整型数据的表示有两种方式:大端序和小端序,大端序的高位字节在低地址,小端序的高位字节在高地址。例如:对数字 65538,其4字节表示的大端序内容为00 01 00 02,小端序内容为02 00 01…

【案例分享】医疗布草数字化管理系统:聚通宝赋能仟溪信息科技

内容概要 本文介绍了北京聚通宝科技有限公司与河南仟溪信息科技有限公司合作开发的医疗布草数字化管理系统。该系统利用物联网技术实现了医疗布草生产过程的实时监控和数据分析,解决了医疗布草洗涤厂面临的诸多挑战,包括人工记录、生产低效率和缺乏实时…

DNF手游攻略:角色培养与技能搭配!游戏辅助!

角色培养和技能搭配是《地下城与勇士》中提升战斗力的关键环节。每个职业都有独特的技能和发展路线,合理的属性加点和技能搭配可以最大化角色的潜力,帮助玩家在各种战斗中立于不败之地。接下来,我们将探讨如何有效地培养角色并搭配技能。 角色…

JavaEE之线程(9) _定时器的实现代码

前言 定时器也是软件开发中的一个重要组件. 类似于一个 “闹钟”。 达到一个设定的时间之后,就执行某个指定好的代码,比如: 在受上述场景中,当客户端发出去请求之后, 就要等待响应,如果服务器迟迟没有响应&…

大小字符判断

//函数int my_isalpha(char c)的功能是返回字符种类 //大写字母返回1,小写字母返回-1.其它字符返回0 //void a 调用my_isalpha(),返回大写,输出*;返回小写,输出#;其它,输出? #inclu…

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器(可能需要花钱) 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 ,使用 Ubuntu 也可以 。这里提供几个安装方法: 电脑安装双系统(不…

深入解析力扣162题:寻找峰值(线性扫描与二分查找详解)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

virtual box ubuntu20 全屏展示

virtual box 虚拟机 ubuntu20 系统 全屏展示 ubuntu20.04 视图-自动调整窗口大小 视图-自动调整显示尺寸 系统黑屏解决 ##设备-安装增强功能 ##进入终端 ##终端打不开,解决方案-传送门ubuntu Open in Terminal打不开终端解决方案-CSDN博客 ##点击cd盘按钮进入文…

【RabbitMQ】使用SpringAMQP的Publish/Subscribe(发布/订阅)

Publish/Subscribe **发布(Publish)、订阅(Subscribe):**允许将同一个消息发送给多个消费者 **注意:**exchange负责消息路由,而不是存储,路由失败则消息丢失 常见的**X(exchange–交换机)***类型: Fanout 广播Direc…

【设计模式】JAVA Design Patterns——Callback(回调模式)

🔍目的 回调是一部分被当为参数来传递给其他代码的可执行代码,接收方的代码可以在一些方便的时候来调用它。 🔍解释 真实世界例子 我们需要被通知当执行的任务结束时。我们为调用者传递一个回调方法然后等它调用通知我们。 通俗描述 回调是一…

试用百川智能的百小应-说的太多,做的太少

“百小应”的品牌标识(logo)上有一缕黄色,这是王小川特意设计的。他说,其他AI应用都在强调科技感,更愿意用蓝色或者冷色调。但他觉得这一代科技与上个时代不一样,现代科技应该像人,所以选择使用…

Java进阶学习笔记5——Static应用知识:单例设计模式

设计模式: 架构师会使用到设计模式,开发框架,就需要掌握很多设计模式。 在Java基础阶段学习设计模式,将来面试笔试的时候,笔试题目会经常靠到设计模式。 将来会用到设计模式。框架代码中会用到设计模式。 什么是设计…

linux常用软件源码安装-2

jdk、tomcat、Apache、nginx、mysql、redis、maven、nexus安装文档:linux常用软件源码安装 9.sonarqube安装 前置条件:mysql5.6和jdk8 1.下载 官网 2.安装unzip并解压sonarqube,然后移动到/usr/local yum install -y unzip unzip sonarq…

基于Matlab完整版孤立词识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 孤立词识别是语音识别领域的一个重要分支,其目标是将输入的语音信号转换为计算机可…

收入极高的副业兼职,单价179元的养生爆款卖出3000份

做养生赛道切忌不要只靠接广告变现,换个思路以产品为核心,走低粉高变现才是变现效率最高的。 周周近财:让网络小白少花冤枉钱,赚取第一桶金 今天,我们要分析的这位养生博主,仅凭一款售价为179元的温通膏&a…

excel poi的titleRows 和 headRows含义

titleRows 这个参数的意思是:excel标题占多少行,而不是第几行headRows 这个参数的意思是:excel表头占几行,而不是第几行(多行的意思是合并的行数) 比如有一个excel如下,1-2行是标题&#xff0c…

将 MOV 转换为 MP4 的 10 个最佳工具

在当今的数字时代,内容创作和消费正处于巅峰,对多功能和兼容媒体格式的需求从未如此之高。在众多可用的视频格式中,MOV 和 MP4 因其在各种设备和平台中的广泛使用而脱颖而出。然而,将 MOV 文件转换为更通用兼容的 MP4 格式的需求已…

运算符重载(上)

目录 运算符重载日期类的比较判断日期是否相等判断日期大小 赋值运算符重载赋值运算符重载格式赋值运算符只能重载成类的成员函数不能重载成全局函数用户没有显式实现时,编译器会生成一个默认赋值运算符重载,以值的方式逐字节拷贝 感谢各位大佬对我的支持…