Text-to-SQL将自然语言转换为数据库查询语句

有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究

直接与数据库对话-text2sql

Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。

    使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。

这里简单做一下记录,以供后续学习使用。

基本思路

1、数据库DDL语句,SQL-Question,Dcoument信息获取

2、基于用户提问question和数据库Document锁定要分析的数据库

3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。

这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。

4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容

语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性

文本召回匹配:BM25算法文本召回,找到与问题最相关的内容

rerank结果重排序:对搜索结果进行排序。

5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。

实现逻辑图

实现架构图:

具体实现方式如下所示:

1.数据库的选择

class DataBaseSearch(object):def __init__(self, _model):self.name = 'DataBaseSearch'self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.textdata = []self.subdata = {}self.i2key = {}self.id2ddls = {}self.id2sqlques = {}self.id2docs = {}self.strtexts = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()         # 加载text数据self.load_textdata_vec()     # text数据向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]for textdata in textdatas:                                 # 提取每一个数据库内容cid = textdata["dataSetID"]cddls = textdata["ddl"]csql_ques = textdata["exp"]cdocuments = textdata["Intro"]self.textdata.append((cid, cddls, csql_ques, cdocuments))   # 整合所有数据except Exception as e:print(e)# print("load textdata ", self.textdata)def load_textdata_vec(self):num0 = 0for recode in self.textdata:_id = recode[0]_ddls = recode[1]_sql_ques = recode[2]_documents = recode[3]# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)_strtexts = str(_sql_ques) + str(_documents)text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)self.index.add(text_embeddings)self.i2key[num0] = _idself.strtexts[_id] = _strtextsself.id2ddls[_id] = _ddlsself.id2sqlques[_id] = _sql_quesself.id2docs[_id] = _documentsnum0 += 1# print("init instruction vec", num0)def calculate_score(self, score, question, kws):passdef find_vec_database(self, question, k, theata):# print(question)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]uuid = self.i2key.get(sim_i, "none")sim_v = D[0][i]database_texts = self.strtexts.get(uuid, "none")# score = self.calculate_score(sim_v, question, database_texts) # wait implementscore = int(sim_v*1000)if score < theata:doc = {}doc["score"] = scoredoc["dataSetID"] = uuidresult.append(doc)# print(result)return resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DataBaseSearch(model)result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000)print(result)

2.sql_ques:sql问题训练

class SqlQuesSearch(object):def __init__(self, _model):self.name = "SqlQuesSearch"self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.sqlquedata = []self.i2dbid = {}self.i2sqlid = {}self.id2sqlque = {}self.id2que = {}self.id2sql = {}self.dbid2sqlques = {}## self.sqlques = {}## self.i2key = {}## self.id2sqlques = {}## self.num2sqlque = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()  # 加载text数据self.load_textdata_vec()  # text数据向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]datadatas = jsonobj["data"]for datadata in datadatas:  # 提取每一个数据库sql-ques内容dbid = datadata["dataSetID"]sql_ques = datadata["exp"]self.sqlquedata.append((dbid, sql_ques))  # 整合sql数据except Exception as e:print(e)# print("load textdata ", self.sqlquedata)def load_textdata_vec(self):num0 = 0for recode in self.sqlquedata:db_id = recode[0]sql_ques = recode[1]for sql_que in sql_ques:sql_id = sql_que["sql_id"]question = sql_que["question"]sql = sql_que["sql"]ddl_embeddings = self.model.encode([question], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2sqlid[num0] = sql_idself.id2que[sql_id] = questionself.id2sql[sql_id] = sqlnum0 += 1print("init sql-que vec", num0)def calculate_score(sim_v, question, sql_ques):passdef find_vec_sqlque(self, question, k, theta, dataSetID, number):q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")  # 获取数据库idsqlid = self.i2sqlid.get(sim_i, "none")question = self.id2que.get(sqlid, "none")sql = self.id2sql.get(sqlid, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theta:doc = {}doc["score"] = scoredoc["question"] = questiondoc["sql"] = sqlresult.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = SqlQuesSearch(model)result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111)print(result)

3.数据库DDL训练

class DdlQuesSearch(object):def __init__(self, _model):self.name = "DdlQuesSearch"self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.ddldata = []self.sqlques = {}self.i2dbid = {}self.i2ddlid = {}self.dbid2ddls = {}self.id2ddl = {}self.ddlid2dbid = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_ddldata()  # 加载text数据self.load_ddl_vec()  # text数据向量化def load_ddldata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)for database in databases:db_id = database["dataSetID"]ddls = database["ddl"]self.ddldata.append((db_id, ddls))# print(db_id)# for ddl in database["ddl"]:#     ddl_id = ddl["ddl_id"]#     ddl = ddl['ddl']##     self.id2ddl[ddl_id] = ddl# self.dbid2ddls[db_id] = self.id2ddlexcept Exception as e:print(e)# print("load textdata ", self.ddldata)def load_ddl_vec(self):num0 = 0for recode in self.ddldata:db_id = recode[0]ddls = recode[1]for ddl in ddls:ddl_id = ddl["ddl_id"]ddl_name = ddl["TABLE"]ddl = ddl['ddl']ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2ddlid[num0] = ddl_idself.id2ddl[ddl_id] = ddlself.ddlid2dbid[ddl_id] = db_idnum0 += 1self.dbid2ddls[db_id] = self.id2ddlprint("init ddl vec", num0)def find_vec_ddl(self, question, k, theata, dataSetID, number):       # dataSetID:数据库id# self.id2ddls.get(action_id)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")         # 获取数据库idddlid = self.i2ddlid.get(sim_i, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theata:doc = {}doc["score"] = scoredoc["ddl"] = self.id2ddl.get(ddlid, "none")result.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111)print(ss)

4.数据库document训练

class DocQuesSearch(object):def __init__(self):self.name = "TestDataSearch"self.docdata = []self.load_doc_data()def load_doc_data(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)databases = jsonobj["data"]for database in databases:db_id = database["dataSetID"]doc = database["Intro"]self.docdata.append((db_id, doc))except Exception as e:print(e)# print("load ddldata ", self.docdata)def find_similar_doc(self, dataSetID):result = []for recode in self.docdata:dbid = recode[0]doc = recode[1]if dbid == dataSetID:result.append(doc)return resultif __name__ == '__main__':docques_search = DocQuesSearch()result = docques_search.find_similar_doc(222)print(result)

5.生成sql语句,这里使用的qwen-max模型

import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformerclass Genarate(object):def __init__(self):self.api_key = os.environ.get('api_key')self.model_name = os.environ.get('model')def system_message(self, message):return {'role': 'system', 'content': message}def user_message(self, message):return {'role': 'user', 'content': message}def assistant_message(self, message):return {'role': 'assistant', 'content': message}def submit_prompt(self, prompt):resp = Generation.call(model=self.model_name,messages=prompt,seed=random.randint(1, 10000),result_format='message',api_key=self.api_key)if resp["status_code"] == 200:answer = resp.output.choices[0].message.contentglobal DEBUG_INFODEBUG_INFO = (prompt, answer)return answerelse:answer = Nonereturn answerdef generate_sql(self, question, sql_result, ddl_result, doc_result):prompt = self.get_sql_prompt(question = question,sql_result = sql_result,ddl_result = ddl_result,doc_result = doc_result)print("SQL Prompt:",prompt)llm_response = self.submit_prompt(prompt)sql = self.extrat_sql(llm_response)return sqldef extrat_sql(self, llm_response):sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlreturn llm_responsedef get_sql_prompt(self, question, sql_result, ddl_result, doc_result):initial_prompt = "You are a SQL expert. " + \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)initial_prompt += ("===Response Guidelines \n""1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n""2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n""3. If the provided context is insufficient, please explain why it can't be generated. \n""4. Please use the most relevant table(s). \n""5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n")message_log = [self.system_message(initial_prompt)]message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)return message_logdef add_ddl_to_prompt(self, initial_prompt, ddl_result):""":param initial_prompt::param ddl_result::return:"""ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]if len(ddl_list) > 0:initial_prompt += "\n===Tables \n"for ddl in ddl_list:initial_prompt += f"{ddl}\n\n"return initial_promptdef add_sqlques_to_prompt(self, question, sql_result, message_log):""":param sql_result::return:"""if len(sql_result) > 0:for example in sql_result:if example is not None and "question" in example and "sql" in example:message_log.append(self.user_message(example["question"]))message_log.append(self.assistant_message(example["sql"]))message_log.append(self.user_message(question))return message_logdef add_documentation_to_prompt(self, initial_prompt, doc_result):if len(doc_result) > 0:initial_prompt += "\n===Additional Context \n\n"for doc in doc_result:initial_prompt += f"{doc}\n\n"return initial_promptif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111)print(ss)

6.执行结果显示

如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。

需要源码的同学,可以留言。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899649.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

golang 的strconv包常用方法

目录 1. 字符串与整数的转换 2. 字符串与浮点数的转换 3. 布尔值的转换 4. 字符串的转义 5. 补充&#xff1a;rune 类型的使用 方法功能详解 代码示例&#xff1a; 1. 字符串与整数的转换 方法名称功能描述示例Atoi将字符串转换为十进制整数。strconv.Atoi("123&q…

MATLAB详细图文安装教程(附安装包)

前言 MATLAB&#xff08;Matrix Laboratory&#xff09;是由MathWorks公司开发的一款高性能的编程语言和交互式环境&#xff0c;主要用于数值计算、数据分析和算法开发。内置数学函数和工具箱丰富&#xff0c;开发效率高&#xff0c;特别适合矩阵运算和领域特定问题。接下来就…

ShapeCrawler:.NET开发者的PPTX操控魔法

引言 在当今的软件开发领域&#xff0c;随着数据可视化和信息展示需求的不断增长&#xff0c;处理 PPTX 文件的场景日益频繁。无论是自动化生成报告、批量制作演示文稿&#xff0c;还是对现有 PPT 进行内容更新与格式调整&#xff0c;开发者都需要高效的工具来完成这些任务。传…

HTML5贪吃蛇游戏开发经验分享

HTML5贪吃蛇游戏开发经验分享 这里写目录标题 HTML5贪吃蛇游戏开发经验分享项目介绍技术栈核心功能实现1. 游戏初始化2. 蛇的移动控制3. 碰撞检测4. 食物生成 开发心得项目收获后续优化方向结语 项目介绍 在这个项目中&#xff0c;我使用HTML5 Canvas和原生JavaScript实现了一…

有关pip与conda的介绍

Conda vs. Pip vs. Virtualenv 命令对比 任务Conda 命令Pip 命令Virtualenv 命令安装包conda install $PACKAGE_NAMEpip install $PACKAGE_NAMEX更新包conda update --name $ENVIRONMENT_NAME $PACKAGE_NAMEpip install --upgrade $PACKAGE_NAMEX更新包管理器conda update con…

【Linux】调试器——gdb使用

目录 一、预备知识 二、常用指令 三、调试技巧 &#xff08;一&#xff09;监视变量的变化指令 watch &#xff08;二&#xff09;更改指定变量的值 set var 正文 一、预备知识 程序的发布形式有两种&#xff0c;debug和release模式&#xff0c;Linux gcc/g出来的二进制…

【Ubuntu常用命令】

1.将本地服务器文件或文件夹传输到远程服务器 文件 scp /data/a.txt administrator10.60.51.20:/home/administrator/ 文件夹 scp -r /data/ administrator10.60.51.20:/home/administrator/ 2.从远程服务器传输文件到本地服务器 scp administrator10.60.51.20:/data/a.txt /h…

golang 的time包的常用方法

目录 time 包方法总结 类型 time.Time 的方法 库函数 代码示例&#xff1a; time 包方法总结 类型 time.Time 的方法 方法名描述示例               ẵNow()获取当前时间和日期time.Now()Format()格式化时间为字符串time.Now().Format("2006-01-02 15…

Elasticsearch:使用 Azure AI 文档智能解析 PDF 文本和表格数据

作者&#xff1a;来自 Elastic James Williams 了解如何使用 Azure AI 文档智能解析包含文本和表格数据的 PDF 文档。 Azure AI 文档智能是一个强大的工具&#xff0c;用于从 PDF 中提取结构化数据。它可以有效地提取文本和表格数据。提取的数据可以索引到 Elastic Cloud Serve…

【ArcGIS操作】ArcGIS 进行空间聚类分析

ArcGIS 是一个强大的地理信息系统&#xff08;GIS&#xff09;软件&#xff0c;主要用于地理数据的存储、分析、可视化和制图 启动 ArcMap 在 Windows 中&#xff0c;点击“开始”菜单&#xff0c;找到 ArcGIS文件夹&#xff0c;然后点击 ArcMap 添加数据 添加数据 - 点击工具…

RabbitMQ消息相关

MQ的模式&#xff1a; 基本消息模式&#xff1a;一个生产者&#xff0c;一个消费者work模式&#xff1a;一个生产者&#xff0c;多个消费者订阅模式&#xff1a; fanout广播模式&#xff1a;在Fanout模式中&#xff0c;一条消息&#xff0c;会被所有订阅的队列都消费。 在广播…

缓存使用纪要

一、本地缓存&#xff1a;Caffeine 1、简介 Caffeine是一种高性能、高命中率、内存占用低的本地缓存库&#xff0c;简单来说它是 Guava Cache 的优化加强版&#xff0c;是当下最流行、最佳&#xff08;最优&#xff09;缓存框架。 Spring5 即将放弃掉 Guava Cache 作为缓存机…

2025年3月29日笔记

问题&#xff1a;创建一个长度为99的整数数组&#xff0c;输出数组的每个位置数字是几&#xff1f; 解题思路&#xff1a; 1.因为题中没有明确要求需要输入,所以所有类型的答案都需要写出 解法1&#xff1a; #include<iostream> #include<bits/stdc.h> using n…

hadoop相关面试题以及答案

什么是Hadoop&#xff1f;它的主要组件是什么&#xff1f; Hadoop是一个开源的分布式计算框架&#xff0c;用于处理大规模数据的存储和计算。其主要组件包括Hadoop Distributed File System&#xff08;HDFS&#xff09;和MapReduce。 解释HDFS的工作原理。 HDFS采用主从架构&…

微信小程序:数据拼接方法

1. 使用 concat() 方法拼接数组 // 在原有数组基础上拼接新数组 Page({data: {originalArray: [1, 2, 3]},appendData() {const newData [4, 5, 6];const combinedArray this.data.originalArray.concat(newData);this.setData({originalArray: combinedArray});} }) 2. 使…

Python之贪心算法

Python实现贪心算法(Greedy Algorithm) 概念 贪心算法是一种在每一步选择中都采取当前状态下最优的选择&#xff0c;从而希望导致结果是全局最优的算法策略。 基本特点 局部最优选择&#xff1a;每一步都做出当前看起来最佳的选择不可回退&#xff1a;一旦做出选择&#xf…

【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的 AOP:实现日志记录与性能监控

<前文回顾> 点击此处查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、开篇整…

TCP/IP协议簇

文章目录 应用层http/httpsDNS补充 传输层TCP1. 序列号与确认机制2. 超时重传3. 流量控制&#xff08;滑动窗口机制&#xff09;4. 拥塞控制5. 错误检测与校验6. 连接管理总结 网络层ARP**ARP 的核心功能**ARP 的工作流程1. ARP 请求&#xff08;Broadcast&#xff09;2. ARP 缓…

SpringBoot分布式项目订单管理实战:Mybatis最佳实践全解

一、架构设计与技术选型 典型分布式订单系统架构&#xff1a; [网关层] → [订单服务] ←→ [分布式缓存]↑ ↓ [用户服务] [支付服务]↓ ↓ [MySQL集群] ← [分库分表中间件]技术栈组合&#xff1a; Spring Boot 3.xMybatis-Plus 3.5.xShardingSpher…

微服务架构中的精妙设计:环境和工程搭建

一.前期准备 1.1开发环境安装 Oracle从JDK9开始每半年发布⼀个新版本, 新版本发布后, ⽼版本就不再进⾏维护. 但是会有⼏个⻓期维护的版本. ⽬前⻓期维护的版本有: JDK8, JDK11, JDK17, JDK21 在 JDK版本的选择上&#xff0c;尽量选择⻓期维护的版本. 为什么选择JDK17? S…