文章目录
- 前言
- 一、 安装依赖包
- 二、 设置数据库连接
- 三、 扫描数据库结构
- 四、 生成 SQL 查询
- 五、 执行 SQL 查询
- 六、 运行示例
- 七、 封装成类
- 总结
前言
前一篇文章中,我们一起写了一个agent,为了简化代码是直接传递sql的,这一篇文章我们将通过大模型根据我们的自然语言生成sql,然后再通过agent查询数据并交给大模型思考得出结果。
一、 安装依赖包
首先,我们需要安装必要的 Python 包。我们将使用 langchain
和 SQLAlchemy
进行数据库连接和查询生成。
pip install langchain sqlalchemy
二、 设置数据库连接
我们将以 SQLite 数据库为例,展示如何设置数据库连接并创建一个示例表。
from sqlalchemy import create_engine# 设置数据库连接
engine = create_engine('sqlite:///example.db')# 创建一个示例表并插入一些数据
with engine.connect() as connection:connection.execute("""CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY,name TEXT,age INTEGER)""")connection.execute("""INSERT INTO users (name, age) VALUES('Alice', 30),('Bob', 25),('Charlie', 35)""")
三、 扫描数据库结构
为了让语言模型生成正确的 SQL 查询,我们需要提供数据库的结构信息(表名和列名)。我们将使用 SQLAlchemy 的 inspect
模块来扫描数据库结构。
from sqlalchemy import inspectdef inspect_db_structure(engine):inspector = inspect(engine)structure = {}for table_name in inspector.get_table_names():columns = inspector.get_columns(table_name)structure[table_name] = [column['name'] for column in columns]return structuredb_structure = inspect_db_structure(engine)
四、 生成 SQL 查询
我们将使用 LangChain 的大语言模型(LLM)来生成 SQL 查询。为此,我们需要定义一个提示模板,并将用户的自然语言请求和数据库结构信息传递给模型。
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI# 初始化LangChain的LLM
llm = OpenAI(api_key="YOUR_OPENAI_API_KEY")# 定义生成SQL查询的提示模板
sql_generation_prompt = PromptTemplate(template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n""User request: {request}\n""Database structure: {db_structure}\n""SQL query:",input_variables=["request", "db_structure"]
)def generate_sql_query(request, db_structure):query = llm({"prompt": sql_generation_prompt.format(request=request,db_structure=db_structure)})return query["choices"][0]["text"].strip()
五、 执行 SQL 查询
定义执行 SQL 查询的工具函数,并使用 LangChain 初始化 Agent 来执行查询。
from langchain.agents import initialize_agent, Tool# 创建数据库会话
Session = sessionmaker(bind=engine)
session = Session()# 定义执行SQL查询的工具函数
def execute_sql_query(query):try:result = session.execute(query)return result.fetchall()except Exception as e:return str(e)# 定义LangChain的工具
sql_tool = Tool(name="SQL Executor",func=execute_sql_query,description="Executes SQL queries and returns the result"
)# 创建自定义Agent
agent = initialize_agent(tools=[sql_tool],llm=llm,agent_type="zero_shot",prompt_template=PromptTemplate(template="You are an SQL agent. Execute the following SQL query: {query}",input_variables=["query"])
)
六、 运行示例
结合以上所有步骤,使用自定义 Agent 自动生成 SQL 查询并执行:
def main():user_request = "Find all users older than 30"sql_query = generate_sql_query(user_request, db_structure)print(f"Generated SQL Query: {sql_query}")result = agent({"query": sql_query})print(result)if __name__ == "__main__":main()
七、 封装成类
我们将上述功能封装到一个类中。
封装到一个类中有许多好处,包括模块化、可重用性、扩展性、简化复杂性和增强可维护性。通过封装,我们可以将复杂的逻辑抽象出来,使得代码更容易理解和维护,并且可以在不同的项目或不同的部分中重复使用。以下是封装后的完整代码:
class SQLAgent:def __init__(self, database_url, api_key):# 设置数据库连接self.engine = create_engine(database_url)self.Session = sessionmaker(bind=self.engine)self.session = self.Session()# 扫描数据库结构self.db_structure = self.inspect_db_structure()# 初始化LangChain的LLMself.llm = OpenAI(api_key=api_key)# 定义执行SQL查询的工具函数def execute_sql_query(query):try:result = self.session.execute(query)return result.fetchall()except Exception as e:return str(e)# 定义LangChain的工具sql_tool = Tool(name="SQL Executor",func=execute_sql_query,description="Executes SQL queries and returns the result")# 定义生成SQL查询的提示模板self.sql_generation_prompt = PromptTemplate(template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n""User request: {request}\n""Database structure: {db_structure}\n""SQL query:",input_variables=["request", "db_structure"])# 创建自定义Agentself.agent = initialize_agent(tools=[sql_tool],llm=self.llm,agent_type="zero_shot",prompt_template=PromptTemplate(template="You are an SQL agent. Execute the following SQL query: {query}",input_variables=["query"]))def inspect_db_structure(self):inspector = inspect(self.engine)structure = {}for table_name in inspector.get_table_names():columns = inspector.get_columns(table_name)structure[table_name] = [column['name'] for column in columns]return structuredef generate_sql_query(self, request):query = self.llm({"prompt": self.sql_generation_prompt.format(request=request,db_structure=self.db_structure)})return query["choices"][0]["text"].strip()def execute(self, user_request):sql_query = self.generate_sql_query(user_request)print(f"Generated SQL Query: {sql_query}")result = self.agent({"query": sql_query})return result# 使用示例
if __name__ == "__main__":database_url = 'sqlite:///example.db'api_key = 'YOUR_OPENAI_API_KEY'sql_agent = SQLAgent(database_url, api_key)user_request = "Find all users older than 30"result = sql_agent.execute(user_request)print(result)
总结
通过本文的示例,我们展示了如何使用 LangChain 和 SQLAlchemy 创建一个自定义的 SQL 查询 Agent。生产的步骤主要就是:扫描数据库结构 -> 生成 SQL 查询
-> 执行 SQL 查询。
大家可以使用自己喜欢的大模型来测试和练习