225 lines
9.1 KiB
Python
225 lines
9.1 KiB
Python
import os
|
|
import sys
|
|
import logging
|
|
import time
|
|
import json
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, Depends, Request
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from .models import IncomingRequest, ProxyResponse
|
|
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data, convert_tool_calls_to_content
|
|
from .core.config import get_settings, Settings
|
|
from .database import init_db, log_request, update_request_log
|
|
|
|
# --- Environment & Debug Loading ---
|
|
# load_dotenv() # Uncomment if you run uvicorn directly and need to load .env
|
|
# ---
|
|
|
|
# --- Logging Configuration ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("llm_proxy.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
# --- End of Logging Configuration ---
|
|
|
|
app = FastAPI(
|
|
title="LLM Tool Proxy",
|
|
description="A proxy that intercepts LLM requests to inject and handle tool calls.",
|
|
version="1.0.0",
|
|
)
|
|
|
|
# --- Middleware for logging basic request/response info ---
|
|
@app.middleware("http")
|
|
async def logging_middleware(request: Request, call_next):
|
|
start_time = time.time()
|
|
logger.info(f"Request received: {request.method} {request.url.path} from {request.client.host}")
|
|
logger.info(f"Request Headers: {dict(request.headers)}")
|
|
|
|
response = await call_next(request)
|
|
|
|
process_time = (time.time() - start_time) * 1000
|
|
logger.info(f"Response sent: status_code={response.status_code} ({process_time:.2f}ms)")
|
|
return response
|
|
# --- End of Middleware ---
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
logger.info("Application startup complete.")
|
|
init_db()
|
|
logger.info("Database initialized.")
|
|
current_settings = get_settings()
|
|
logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}")
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(
|
|
request: Request,
|
|
settings: Settings = Depends(get_settings)
|
|
):
|
|
"""
|
|
This endpoint mimics the OpenAI Chat Completions API and supports both
|
|
streaming and non-streaming responses, with detailed logging.
|
|
"""
|
|
# Read raw request body
|
|
raw_body = await request.body()
|
|
body_str = raw_body.decode('utf-8')
|
|
|
|
# Log the raw client request
|
|
client_request = json.loads(body_str)
|
|
log_id = log_request(client_request=client_request)
|
|
logger.info(f"Request body logged with ID: {log_id}")
|
|
|
|
# Parse into IncomingRequest model for validation and type safety
|
|
try:
|
|
request_obj = IncomingRequest(**client_request)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse request: {e}")
|
|
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
|
|
|
|
if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL:
|
|
logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.")
|
|
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
|
|
|
|
messages_to_llm = request_obj.messages
|
|
|
|
# Convert assistant messages with tool_calls to content format
|
|
messages_to_llm = convert_tool_calls_to_content(messages_to_llm)
|
|
logger.info(f"Converted tool calls to content format for log ID: {log_id}")
|
|
|
|
if request_obj.tools:
|
|
messages_to_llm = inject_tools_into_prompt(messages_to_llm, request_obj.tools)
|
|
|
|
# Handle streaming request
|
|
if request_obj.stream:
|
|
logger.info(f"Initiating streaming request for log ID: {log_id}")
|
|
|
|
async def stream_and_log():
|
|
stream_content_buffer = []
|
|
raw_chunks = []
|
|
|
|
# First, collect all chunks to detect if there are tool calls
|
|
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
|
|
logger.info(f"sse_result: {chunk}")
|
|
raw_chunks.append(chunk)
|
|
# Extract content from SSE chunks
|
|
parsed = _parse_sse_data(chunk)
|
|
logger.info(f"sse_result_data: {parsed}")
|
|
if parsed and ( parsed.get("type") != "done" or parsed.get("choices").get("finish_reason") == "stop" ):
|
|
choices = parsed.get("choices")
|
|
if choices and len(choices) > 0:
|
|
delta = choices[0].get("delta")
|
|
if delta and "content" in delta:
|
|
stream_content_buffer.append(delta["content"])
|
|
|
|
# Parse the complete content
|
|
full_content = "".join(stream_content_buffer)
|
|
response_message = parse_llm_response_from_content(full_content)
|
|
|
|
# If tool_calls detected, send only OpenAI format tool_calls
|
|
if response_message.tool_calls:
|
|
logger.info(f"Tool calls detected in stream, sending OpenAI format for log ID {log_id}")
|
|
|
|
# Send tool_calls chunks
|
|
for tc in response_message.tool_calls:
|
|
# Send tool call start
|
|
chunk_data = {
|
|
"id": "chatcmpl-" + str(log_id),
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": "gpt-3.5-turbo",
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"tool_calls": [{
|
|
"index": 0,
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": ""
|
|
}
|
|
}]
|
|
},
|
|
"finish_reason": None
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
|
|
|
|
# Split arguments into smaller chunks to simulate streaming
|
|
args = tc.function.arguments
|
|
chunk_size = 20
|
|
for i in range(0, len(args), chunk_size):
|
|
chunk_data = {
|
|
"id": "chatcmpl-" + str(log_id),
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": "gpt-3.5-turbo",
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"tool_calls": [{
|
|
"index": 0,
|
|
"function": {
|
|
"arguments": args[i:i+chunk_size]
|
|
}
|
|
}]
|
|
},
|
|
"finish_reason": None
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
|
|
|
|
# Send final chunk
|
|
final_chunk = {
|
|
"id": "chatcmpl-" + str(log_id),
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": "gpt-3.5-turbo",
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "tool_calls"
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
|
|
|
|
else:
|
|
# No tool calls, yield original chunks
|
|
for chunk in raw_chunks:
|
|
yield chunk
|
|
|
|
# Log the response
|
|
proxy_response = ProxyResponse(message=response_message)
|
|
logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
|
|
update_request_log(log_id, client_response=proxy_response.model_dump())
|
|
|
|
return StreamingResponse(stream_and_log(), media_type="text/event-stream")
|
|
|
|
# Handle non-streaming request
|
|
try:
|
|
logger.info(f"Initiating non-streaming request for log ID: {log_id}")
|
|
response_message = await process_chat_request(messages_to_llm, settings, log_id)
|
|
|
|
proxy_response = ProxyResponse(message=response_message)
|
|
logger.info(f"Response body for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
|
|
|
|
# Log client response to DB
|
|
update_request_log(log_id, client_response=proxy_response.model_dump())
|
|
|
|
return proxy_response
|
|
except Exception as e:
|
|
logger.exception(f"An unexpected error occurred during non-streaming request for log ID: {log_id}")
|
|
# Log the error to the database
|
|
update_request_log(log_id, client_response={"error": str(e)})
|
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"message": "LLM Tool Proxy is running."}
|