Spider 数据集上实现nlp2sql训练任务

NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。

任务目标
  • 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
  • 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。

例如

输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”

输出: 对应的 SQL 查询,SELECT population FROM districts WHERE name = 'Gelderland';

Spider 是一个难度最大数据集

耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。

下载查看spider数据集内容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要转换为Spider的标准格式(参考tables.jsontrain.json):

{"db_id": "concert_singer","question": "Show name, country, age...","query": "SELECT name, country, age FROM singer ORDER BY age DESC","schema": {"table_names": ["singer"],"column_names": [[0, "name", "text"],[0, "country", "text"],[0, "age", "int"]]}
}

拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。

column_names 表示数据库表中每一列的详细信息。具体来说,column_names 是一个列表,其中每个元素都是一个包含三个部分的子列表:

  1. 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
  2. 列名("name"、"country"、"age"):表示列的名称。
  3. 数据类型("text"、"int"):表示该列的数据类型,例如文本(text)或整数(int)。

实现下面逻辑转换原始数据

def extract_columns_from_sql(sql):# 使用正则表达式匹配 SELECT 语句中的列名match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)if match:# 提取列名columns = match.group(1).split(",")# 构建 column_names 列表column_names = []for index, column in enumerate(columns):column = column.strip()  # 去除多余的空格data_type = "text"  # 默认数据类型为 text,可以根据需要修改# 添加到 column_names 列表,假设所有列类型为 textcolumn_names.append([0, column, data_type])return column_namesreturn []# 从 dev.sql 文件读取数据
def load_sql_data(file_path):data_list = []with open(file_path, 'r', encoding='utf-8') as f:  # 指定编码为 UTF-8lines = f.readlines()for i in range(0, len(lines), 3):  # 每三行一组question_line = lines[i].strip()sql_line = lines[i + 1].strip()if not question_line or not sql_line:continue# 提取问题和 SQLquestion = question_line.split(': ', 1)[1].strip()  # 获取问题内容sql = sql_line.split(': ', 1)[1].strip()  # 获取 SQL 查询# 提取表名db_id = question_line.split('|||')[-1].strip()  # 从问题行获取表名question = question.split('|||')[0].strip()target_sql = preprocess(question, db_id, sql)data_list.append({"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}","target_sql": json.dumps(target_sql)  # 将目标 SQL 转换为 JSON 格式字符串})return data_list

选择Tokenizer.from_pretrained("t5-base") 是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。

from transformers import T5Tokenizertokenizer = T5Tokenizer.from_pretrained("t5-base")def preprocess(question, db_id, sql):# 提取列名column_names = extract_columns_from_sql(sql)# 构建目标格式target_sql = {"db_id": db_id,"question": question,"query": sql,"schema": {"table_names": [db_id],"column_names": column_names}}return target_sql# 示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任务都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):model_inputs = tokenizer(examples["input_text"],max_length=512,truncation=True,padding="max_length")with tokenizer.as_target_tokenizer():labels = tokenizer(examples["target_sql"],max_length=512,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"] 中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。

最后组织下训练代码

tokenized_datasets = dataset.map(tokenize_function, batched=True)# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")# 训练参数
training_args = Seq2SeqTrainingArguments(output_dir="./results",evaluation_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=8,per_device_eval_batch_size=8,num_train_epochs=100,predict_with_generate=True,run_name="spider"
)# 开始训练
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,data_collator=DataCollatorForSeq2Seq(tokenizer)
)trainer.train()

这里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer 的主要功能和特点:

  1. 简化训练流程Seq2SeqTrainer 封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。

  2. 支持多种训练参数: 通过 Seq2SeqTrainingArguments 类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。

  3. 自动处理填充和截断: 在处理输入和输出序列时,Seq2SeqTrainer 可以自动填充和截断序列,以确保它们适应模型的输入要求。

  4. 集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集

开始训练,进行100次epoch

训练监控在 Weights & Biases ,Seq2SeqTrainer 能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:

  1. 自动集成:当你使用 Seq2SeqTrainer 时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。

  2. 回调功能Trainer 类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。

  3. 配置管理training_args 中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。

  4. 训练循环:在每个训练和评估周期结束时,Trainer 会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。

  5. 可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。

多次试验还可以比较训练性能

训练结束, 损失收敛到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00,  2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

测试下预测能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")def generate_sql(question, db_id):input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~Ooutput = model.generate(input_ids,max_length=512,num_beams=5,  # 或者尝试其他解码策略early_stopping=True)print('output', output)generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)return generated_sqlquestion = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

输出结果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/69260.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IDEA+DeepSeek让Java开发起飞

1.获取DeepSeek秘钥 登录DeepSeek官网 : https://www.deepseek.com/ 进入API开放平台&#xff0c;第一次需要注册一个账号 进去之后需要创建一个API KEY&#xff0c;然后把APIkey记录保存下来 接着我们获取DeepSeek的API对话接口地址&#xff0c;点击左边的&#xff1a;接口…

intra-mart实现简易登录页面笔记

一、前言 最近在学习intra-mart框架&#xff0c;在此总结下笔记。 intra-mart是一个前后端不分离的框架&#xff0c;开发时主要用的就是xml、html、js这几个文件&#xff1b; xml文件当做配置文件&#xff0c;html当做前端页面文件&#xff0c;js当做后端文件&#xff08;js里…

Linux+Docer 容器化部署之 Shell 语法入门篇 【Shell 替代】

&#x1f380;&#x1f380;Shell语法入门篇 系列篇 &#x1f380;&#x1f380; LinuxDocer 容器化部署之 Shell 语法入门篇 【准备阶段】LinuxDocer 容器化部署之 Shell 语法入门篇 【Shell变量】LinuxDocer 容器化部署之 Shell 语法入门篇 【Shell数组与函数】LinuxDocer 容…

Intellij IDEA如何查看当前文件的类

快捷键&#xff1a;CtrlF12&#xff0c;我个人感觉记快捷键很麻烦&#xff0c;知道具体的位置更简单&#xff0c;如果忘了快捷键&#xff08;KeyMap&#xff09;看一下就记起来了&#xff0c;不需要再Google or Baidu or GPT啥的&#xff0c;位置&#xff1a;Navigate > Fi…

C++----继承

一、继承的基本概念 本质&#xff1a;代码复用类关系建模&#xff08;是多态的基础&#xff09; class Person { /*...*/ }; class Student : public Person { /*...*/ }; // public继承 派生类继承基类成员&#xff08;数据方法&#xff09;&#xff0c;可以通过监视窗口检…

2025.2.5——五、[网鼎杯 2020 青龙组]AreUSerialz 代码审计|反序列化

题目来源&#xff1a;BUUCTF [网鼎杯 2020 青龙组]AreUSerialz 目录 一、打开靶机&#xff0c;整理信息 二、解题思路 step 1&#xff1a;代码审计 step 2&#xff1a;开始解题 突破protected访问修饰符限制 三、小结 一、打开靶机&#xff0c;整理信息 直接得到一串ph…

Docker深度解析:安装各大环境

安装 Nginx 实现负载均衡&#xff1a; 挂载 nginx html 文件&#xff1a; 创建过载目录&#xff1a; mkdir -p /data/nginx/{conf,conf.d,html,logs} 注意&#xff1a;在挂载前需要对 conf/nginx.conf 文件进行编写 worker_processes 1;events {worker_connections 1024; …

基于SpringBoot养老院平台系统功能实现五

一、前言介绍&#xff1a; 1.1 项目摘要 随着全球人口老龄化的不断加剧&#xff0c;养老服务需求日益增长。特别是在中国&#xff0c;随着经济的快速发展和人民生活水平的提高&#xff0c;老年人口数量不断增加&#xff0c;对养老服务的质量和效率提出了更高的要求。传统的养…

【AIGC魔童】DeepSeek v3推理部署:vLLM/SGLang/LMDeploy

【AIGC魔童】DeepSeek v3推理部署&#xff1a;vLLM/SGLang/LMDeploy &#xff08;1&#xff09;使用vLLM推理部署DeepSeek&#xff08;2&#xff09;使用SGLang推理部署DeepSeek&#xff08;3&#xff09;使用LMDeploy推理部署DeepSeek &#xff08;1&#xff09;使用vLLM推理部…

C语言的灵魂——指针(2)

前言&#xff1a;上期我们介绍了如何理解地址&#xff0c;内存&#xff0c;以及指针的一些基础知识和运算&#xff1b;这期我们来介绍一下const修饰指针&#xff0c;野指针&#xff0c;assert断言&#xff0c;指针的传址调用。 上一篇指针&#xff08;1&#xff09; 文章目录 一…

Android studio 创建aar包给Unity使用

1、aar 是什么&#xff1f; 和 Jar有什么区别 aar 和 jar包 都是压缩包&#xff0c;可以使用压缩软件打开 jar包 用于封装 Java 类及其相关资源 aar 文件是专门为 Android 平台设计的 &#xff0c;可以包含Android的专有内容&#xff0c;比如AndroidManifest.xml 文件 &#…

ASP.NET Core中Filter与Middleware的区别

中间件是ASP.NET Core这个基础提供的功能&#xff0c;而Filter是ASP.NET Core MVC中提供的功能。ASP.NET Core MVC是由MVC中间件提供的框架&#xff0c;而Filter属于MVC中间件提供的功能。 区别 中间件可以处理所有的请求&#xff0c;而Filter只能处理对控制器的请求&#x…

基础篇05-图像直方图操作

本节将简要介绍Halcon中有关图像直方图操作的算子&#xff0c;重点介绍直方图获取和显示两类算子&#xff0c;以及直方图均衡化处理算子。 目录 1. 引言 2. 获取并显示直方图 2.1 获取&#xff08;灰度&#xff09;直方图 (1) gray_histo算子 (2) gray_histo_abs算子 (3…

MySQL | Navicat安装教程

MySQL | Navicat安装教程 &#x1fa84;个人博客&#xff1a;https://vite.xingji.fun 简介 Navicat 是一款流行的 图形化数据库管理工具&#xff0c;由 PremiumSoft 公司开发&#xff0c;支持多种主流数据库系统&#xff08;如 MySQL、MariaDB、SQL Server、Oracle、Postgre…

硬件实现I2C案例(寄存器实现)

一、需求分析 二、硬件电路设计 本次案例需求与前面软件模拟案例一致&#xff0c;这里不再赘述&#xff0c;不清楚可参见下面文章&#xff1a;软件模拟I2C案例&#xff08;寄存器实现&#xff09;-CSDN博客 值得注意的是&#xff0c;前面是软件模拟I2C&#xff0c;所以并没有…

基于SpringBoot养老院平台系统功能实现六

一、前言介绍&#xff1a; 1.1 项目摘要 随着全球人口老龄化的不断加剧&#xff0c;养老服务需求日益增长。特别是在中国&#xff0c;随着经济的快速发展和人民生活水平的提高&#xff0c;老年人口数量不断增加&#xff0c;对养老服务的质量和效率提出了更高的要求。传统的养…

matlab simulink 汽车四分之一模型轮胎带阻尼

1、内容简介 略 matlab simulink121-汽车四分之一模型轮胎带阻尼 可以交流、咨询、答疑 2、内容说明 略 3、仿真分析 略 4、参考论文 略

w196Spring Boot高校教师科研管理系统设计与实现

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;原创团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339;赠送计算机毕业设计600个选题excel文…

数据分析:企业数字化转型的金钥匙

引言&#xff1a;数字化浪潮下的数据金矿 在数字化浪潮席卷全球的背景下&#xff0c;有研究表明&#xff0c;只有不到30%的企业能够充分利用手中掌握的数据&#xff0c;这是否让人深思&#xff1f;数据已然成为企业最为宝贵的资产之一。然而&#xff0c;企业是否真正准备好从数…

Vue 入门到实战 八

第8章 组合API与响应性 目录 8.1 响应性 8.1.1 什么是响应性 8.1.2 响应性原理 8.2 为什么使用组合API 8.3 setup组件选项 8.3.1 setup函数的参数 8.3.2 setup函数的返回值 8.3.3 使用ref创建响应式引用 8.3.4 setup内部调用生命周期钩子函数 8.4 提供/注入 8.4.1 …