使用 FastAPI 实现聊天完成 API 详解
- 简介
- 基础概念
- FastAPI
- Pydantic
- PyTorch
- 代码详解
- 1. 定义 API 端点
- 2. 请求验证
- 3. 生成参数字典
- 4. 处理流式响应
- 5. 工具调用处理
- 6. 非流式响应处理
- 7. 处理使用信息和工具调用
- 8. 构建聊天消息
- 9. 构建响应选择
- 10. 更新使用信息
- 11. 返回最终响应
- 总结
- 示例代码
简介
在这篇博客中,我们将详细解释一段使用 FastAPI 构建的聊天完成 API 代码。这段代码实现了一个 POST 请求的 API 端点,用于处理聊天消息并生成响应。我们将逐行解析代码,并提供必要的背景知识和示例代码。
基础概念
FastAPI
FastAPI 是一个用于构建 API 的现代、快速(高性能)的 Web 框架,基于 Python 3.6+。它使用类型提示来自动生成文档和验证请求数据。
Pydantic
Pydantic 是一个用于数据验证和设置管理的库,常与 FastAPI 一起使用。它通过 Python 的类型提示来定义数据模型,并自动验证输入数据的类型和格式。
PyTorch
PyTorch 是一个开源的深度学习框架,广泛用于研究和生产环境。它提供了灵活的张量计算和自动求导功能。
代码详解
下面是完整的代码,我们将逐段进行解释。
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):if len(request.messages) < 1 or request.messages[-1].role == "assistant":raise HTTPException(status_code=400, detail="Invalid request")gen_params = dict(messages=request.messages,temperature=request.temperature,top_p=request.top_p,max_tokens=request.max_tokens or 1024,echo=False,stream=request.stream,repetition_penalty=request.repetition_penalty,tools=request.tools,tool_choice=request.tool_choice,)logger.debug(f"==== request ====\n{gen_params}")if request.stream:predict_stream_generator = predict_stream(request.model, gen_params)output = await anext(predict_stream_generator)if output:return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")logger.debug(f"First result output:\n{output}")function_call = Noneif output and request.tools:try:function_call = process_response(output, request.tools, use_tool=True)except:logger.warning("Failed to parse tool call")if isinstance(function_call, dict):function_call = ChoiceDeltaToolCallFunction(**function_call)generate = parse_output_text(request.model, output, function_call=function_call)return EventSourceResponse(generate, media_type="text/event-stream")else:return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")response = ""async for response in generate_stream_glm4(gen_params):passif response["text"].startswith("\n"):response["text"] = response["text"][1:]response["text"] = response["text"].strip()usage = UsageInfo()function_call, finish_reason = None, "stop"tool_calls = Noneif request.tools:try:function_call = process_response(response["text"], request.tools, use_tool=True)except Exception as e:logger.warning(f"Failed to parse tool call: {e}")if isinstance(function_call, dict):finish_reason = "tool_calls"function_call_response = ChoiceDeltaToolCallFunction(**function_call)function_call_instance = FunctionCall(name=function_call_response.name,arguments=function_call_response.arguments)tool_calls = [ChatCompletionMessageToolCall(id=generate_id('call_', 24),function=function_call_instance,type="function")]message = ChatMessage(role="assistant",content=None if tool_calls else response["text"],function_call=None,tool_calls=tool_calls,)logger.debug(f"==== message ====\n{message}")choice_data = ChatCompletionResponseChoice(index=0,message=message,finish_reason=finish_reason,)task_usage = UsageInfo.model_validate(response["usage"])for usage_key, usage_value in task_usage.model_dump().items():setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)return ChatCompletionResponse(model=request.model,choices=[choice_data],object="chat.completion",usage=usage)
1. 定义 API 端点
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
这行代码定义了一个 POST 请求的 API 端点 /v1/chat/completions
,并指定了请求的响应模型为 ChatCompletionResponse
。create_chat_completion
函数将处理传入的 ChatCompletionRequest
请求。
2. 请求验证
if len(request.messages) < 1 or request.messages[-1].role == "assistant":raise HTTPException(status_code=400, detail="Invalid request")
这里我们进行请求验证:
- 检查
messages
列表的长度是否小于 1。 - 检查最后一条消息的角色是否为 “assistant”。
如果以上任一条件为真,则抛出 HTTP 400 错误。
3. 生成参数字典
gen_params = dict(messages=request.messages,temperature=request.temperature,top_p=request.top_p,max_tokens=request.max_tokens or 1024,echo=False,stream=request.stream,repetition_penalty=request.repetition_penalty,tools=request.tools,tool_choice=request.tool_choice,)logger.debug(f"==== request ====\n{gen_params}")
将请求中的参数转化为一个字典 gen_params
,用于后续的生成操作。同时,记录调试信息。
4. 处理流式响应
if request.stream:predict_stream_generator = predict_stream(request.model, gen_params)output = await anext(predict_stream_generator)if output:return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")logger.debug(f"First result output:\n{output}")
如果请求中指定了流式响应,则调用 predict_stream
函数生成流式响应生成器,并返回 EventSourceResponse
。如果第一个输出存在,则直接返回生成器作为事件流响应。
5. 工具调用处理
function_call = Noneif output and request.tools:try:function_call = process_response(output, request.tools, use_tool=True)except:logger.warning("Failed to parse tool call")if isinstance(function_call, dict):function_call = ChoiceDeltaToolCallFunction(**function_call)generate = parse_output_text(request.model, output, function_call=function_call)return EventSourceResponse(generate, media_type="text/event-stream")else:return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
如果输出存在且请求包含工具调用,则尝试解析工具调用。如果解析成功,则处理工具调用并生成新的事件流响应;否则,继续返回原始的事件流生成器。
6. 非流式响应处理
response = ""async for response in generate_stream_glm4(gen_params):passif response["text"].startswith("\n"):response["text"] = response["text"][1:]response["text"] = response["text"].strip()
如果请求未指定流式响应,则调用 generate_stream_glm4
生成响应。在生成响应后,去掉开头的换行符并修剪两端空白。
7. 处理使用信息和工具调用
usage = UsageInfo()function_call, finish_reason = None, "stop"tool_calls = Noneif request.tools:try:function_call = process_response(response["text"], request.tools, use_tool=True)except Exception as e:logger.warning(f"Failed to parse tool call: {e}")if isinstance(function_call, dict):finish_reason = "tool_calls"function_call_response = ChoiceDeltaToolCallFunction(**function_call)function_call_instance = FunctionCall(name=function_call_response.name,arguments=function_call_response.arguments)tool_calls = [ChatCompletionMessageToolCall(id=generate_id('call_', 24),function=function_call_instance,type="function")]
在处理响应后,创建 UsageInfo
实例并检查是否有工具调用。如果有工具调用,则解析并生成工具调用响应。
8. 构建聊天消息
message = ChatMessage(role="assistant",content=None if tool_calls else response["text"],function_call=None,tool_calls=tool_calls,)
根据生成的响应和工具调用信息,创建一个 ChatMessage
实例。
9. 构建响应选择
logger.debug(f"==== message ====\n{message}")choice_data = ChatCompletionResponseChoice(index=0,message=message,finish_reason=finish_reason,)
这段代码将生成的 ChatMessage
实例记录到日志中,并且创建一个 ChatCompletionResponseChoice
实例,其中包含了消息的索引、消息内容和完成原因。
10. 更新使用信息
task_usage = UsageInfo.model_validate(response["usage"])for usage_key, usage_value in task_usage.model_dump().items():setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
从响应中提取使用信息,并将其添加到 usage
实例中。UsageInfo.model_validate
方法用于验证并创建一个包含使用信息的实例。
11. 返回最终响应
return ChatCompletionResponse(model=request.model,choices=[choice_data],object="chat.completion",usage=usage)
最后,创建并返回一个 ChatCompletionResponse
实例,其中包含了模型名称、选项列表和使用信息。
总结
通过这篇博客,我们详细解析了一个基于 FastAPI 实现的聊天完成 API 的代码。我们逐行解释了代码的功能,并介绍了相关的基础概念和库。
示例代码
为了帮助理解,我们提供一个简化版的示例代码,用于实现类似的聊天完成 API:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Listapp = FastAPI()class ChatMessage(BaseModel):role: strcontent: strclass ChatCompletionRequest(BaseModel):messages: List[ChatMessage]temperature: floatmax_tokens: intclass ChatCompletionResponse(BaseModel):message: ChatMessage@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):if len(request.messages) < 1 or request.messages[-1].role == "assistant":raise HTTPException(status_code=400, detail="Invalid request")# Simplified response generation logicresponse_text = "This is a response."response_message = ChatMessage(role="assistant", content=response_text)return ChatCompletionResponse(message=response_message)
这段简化代码定义了一个基本的聊天完成 API 端点,处理请求并返回简单的响应。通过这个示例,可以更好地理解完整代码的工作原理。