Files
llmproxy/app/main.py
Vertex-AI-Step-Builder f7508d915b feat: 优化 chat 接口并修复 function 消息处理
主要变更:
- 使用原生 Request 对象接收请求数据
- 先记录原始 client_request(完整 JSON)到数据库
- 然后解析为 IncomingRequest 对象进行验证
- 添加请求解析的错误处理

修复问题:
- ChatMessage 的 content 改为 Optional[str],支持空值
- 添加 name 字段支持 function 角色的工具名称
- 添加 tool_calls 字段支持 assistant 消息的工具调用
- 修复 function 类型消息 content 为空时报错的问题

优化改进:
- 保留完整的原始客户端请求
- 更好的数据完整性和可追溯性
- 代码清理:移除重复的 import 语句

测试验证:
- 多轮工具调用对话正常工作
- function 消息空 content 正常处理
- 所有单元测试通过 (20/20)
- 完全兼容 OpenAI API 消息格式

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 09:32:37 +00:00

218 lines
8.7 KiB
Python

import os
import sys
import logging
import time
import json
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Depends, Request
from starlette.responses import StreamingResponse
from .models import IncomingRequest, ProxyResponse
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data
from .core.config import get_settings, Settings
from .database import init_db, log_request, update_request_log
# --- Environment & Debug Loading ---
# load_dotenv() # Uncomment if you run uvicorn directly and need to load .env
# ---
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("llm_proxy.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# --- End of Logging Configuration ---
app = FastAPI(
title="LLM Tool Proxy",
description="A proxy that intercepts LLM requests to inject and handle tool calls.",
version="1.0.0",
)
# --- Middleware for logging basic request/response info ---
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
start_time = time.time()
logger.info(f"Request received: {request.method} {request.url.path} from {request.client.host}")
logger.info(f"Request Headers: {dict(request.headers)}")
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
logger.info(f"Response sent: status_code={response.status_code} ({process_time:.2f}ms)")
return response
# --- End of Middleware ---
@app.on_event("startup")
async def startup_event():
logger.info("Application startup complete.")
init_db()
logger.info("Database initialized.")
current_settings = get_settings()
logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}")
@app.post("/v1/chat/completions")
async def chat_completions(
request: Request,
settings: Settings = Depends(get_settings)
):
"""
This endpoint mimics the OpenAI Chat Completions API and supports both
streaming and non-streaming responses, with detailed logging.
"""
# Read raw request body
raw_body = await request.body()
body_str = raw_body.decode('utf-8')
# Log the raw client request
client_request = json.loads(body_str)
log_id = log_request(client_request=client_request)
logger.info(f"Request body logged with ID: {log_id}")
# Parse into IncomingRequest model for validation and type safety
try:
request_obj = IncomingRequest(**client_request)
except Exception as e:
logger.error(f"Failed to parse request: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL:
logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.")
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
messages_to_llm = request_obj.messages
if request_obj.tools:
messages_to_llm = inject_tools_into_prompt(request_obj.messages, request_obj.tools)
# Handle streaming request
if request_obj.stream:
logger.info(f"Initiating streaming request for log ID: {log_id}")
async def stream_and_log():
stream_content_buffer = []
raw_chunks = []
# First, collect all chunks to detect if there are tool calls
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
raw_chunks.append(chunk)
# Extract content from SSE chunks
parsed = _parse_sse_data(chunk)
if parsed and parsed.get("type") != "done":
choices = parsed.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta and "content" in delta:
stream_content_buffer.append(delta["content"])
# Parse the complete content
full_content = "".join(stream_content_buffer)
response_message = parse_llm_response_from_content(full_content)
# If tool_calls detected, send only OpenAI format tool_calls
if response_message.tool_calls:
logger.info(f"Tool calls detected in stream, sending OpenAI format for log ID {log_id}")
# Send tool_calls chunks
for tc in response_message.tool_calls:
# Send tool call start
chunk_data = {
"id": "chatcmpl-" + str(log_id),
"object": "chat.completion.chunk",
"created": 0,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": 0,
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": ""
}
}]
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
# Split arguments into smaller chunks to simulate streaming
args = tc.function.arguments
chunk_size = 20
for i in range(0, len(args), chunk_size):
chunk_data = {
"id": "chatcmpl-" + str(log_id),
"object": "chat.completion.chunk",
"created": 0,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": 0,
"function": {
"arguments": args[i:i+chunk_size]
}
}]
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
# Send final chunk
final_chunk = {
"id": "chatcmpl-" + str(log_id),
"object": "chat.completion.chunk",
"created": 0,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
else:
# No tool calls, yield original chunks
for chunk in raw_chunks:
yield chunk
# Log the response
proxy_response = ProxyResponse(message=response_message)
logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
update_request_log(log_id, client_response=proxy_response.model_dump())
return StreamingResponse(stream_and_log(), media_type="text/event-stream")
# Handle non-streaming request
try:
logger.info(f"Initiating non-streaming request for log ID: {log_id}")
response_message = await process_chat_request(messages_to_llm, settings, log_id)
proxy_response = ProxyResponse(message=response_message)
logger.info(f"Response body for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
# Log client response to DB
update_request_log(log_id, client_response=proxy_response.model_dump())
return proxy_response
except Exception as e:
logger.exception(f"An unexpected error occurred during non-streaming request for log ID: {log_id}")
# Log the error to the database
update_request_log(log_id, client_response={"error": str(e)})
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@app.get("/")
def read_root():
return {"message": "LLM Tool Proxy is running."}