import os import sys import logging import time import json from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Depends, Request from starlette.responses import StreamingResponse from .models import IncomingRequest, ProxyResponse from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data, convert_tool_calls_to_content from .core.config import get_settings, Settings from .database import init_db, log_request, update_request_log # --- Environment & Debug Loading --- # load_dotenv() # Uncomment if you run uvicorn directly and need to load .env # --- # --- Logging Configuration --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("llm_proxy.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # --- End of Logging Configuration --- app = FastAPI( title="LLM Tool Proxy", description="A proxy that intercepts LLM requests to inject and handle tool calls.", version="1.0.0", ) # --- Middleware for logging basic request/response info --- @app.middleware("http") async def logging_middleware(request: Request, call_next): start_time = time.time() logger.info(f"Request received: {request.method} {request.url.path} from {request.client.host}") logger.info(f"Request Headers: {dict(request.headers)}") response = await call_next(request) process_time = (time.time() - start_time) * 1000 logger.info(f"Response sent: status_code={response.status_code} ({process_time:.2f}ms)") return response # --- End of Middleware --- @app.on_event("startup") async def startup_event(): logger.info("Application startup complete.") init_db() logger.info("Database initialized.") current_settings = get_settings() logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}") @app.post("/v1/chat/completions") async def chat_completions( request: Request, settings: Settings = Depends(get_settings) ): """ This endpoint mimics the OpenAI Chat Completions API and supports both streaming and non-streaming responses, with detailed logging. """ # Read raw request body raw_body = await request.body() body_str = raw_body.decode('utf-8') # Log the raw client request client_request = json.loads(body_str) log_id = log_request(client_request=client_request) logger.info(f"Request body logged with ID: {log_id}") # Parse into IncomingRequest model for validation and type safety try: request_obj = IncomingRequest(**client_request) except Exception as e: logger.error(f"Failed to parse request: {e}") raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}") if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL: logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.") raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.") messages_to_llm = request_obj.messages # Convert assistant messages with tool_calls to content format messages_to_llm = convert_tool_calls_to_content(messages_to_llm) logger.info(f"Converted tool calls to content format for log ID: {log_id}") if request_obj.tools: messages_to_llm = inject_tools_into_prompt(messages_to_llm, request_obj.tools) # Handle streaming request if request_obj.stream: logger.info(f"Initiating streaming request for log ID: {log_id}") async def stream_and_log(): stream_content_buffer = [] raw_chunks = [] # First, collect all chunks to detect if there are tool calls async for chunk in stream_llm_api(messages_to_llm, settings, log_id): raw_chunks.append(chunk) # Extract content from SSE chunks parsed = _parse_sse_data(chunk) if parsed and parsed.get("type") != "done": choices = parsed.get("choices") if choices and len(choices) > 0: delta = choices[0].get("delta") if delta and "content" in delta: stream_content_buffer.append(delta["content"]) # Parse the complete content full_content = "".join(stream_content_buffer) response_message = parse_llm_response_from_content(full_content) # If tool_calls detected, send only OpenAI format tool_calls if response_message.tool_calls: logger.info(f"Tool calls detected in stream, sending OpenAI format for log ID {log_id}") # Send tool_calls chunks for tc in response_message.tool_calls: # Send tool call start chunk_data = { "id": "chatcmpl-" + str(log_id), "object": "chat.completion.chunk", "created": 0, "model": "gpt-3.5-turbo", "choices": [{ "index": 0, "delta": { "tool_calls": [{ "index": 0, "id": tc.id, "type": tc.type, "function": { "name": tc.function.name, "arguments": "" } }] }, "finish_reason": None }] } yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8') # Split arguments into smaller chunks to simulate streaming args = tc.function.arguments chunk_size = 20 for i in range(0, len(args), chunk_size): chunk_data = { "id": "chatcmpl-" + str(log_id), "object": "chat.completion.chunk", "created": 0, "model": "gpt-3.5-turbo", "choices": [{ "index": 0, "delta": { "tool_calls": [{ "index": 0, "function": { "arguments": args[i:i+chunk_size] } }] }, "finish_reason": None }] } yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8') # Send final chunk final_chunk = { "id": "chatcmpl-" + str(log_id), "object": "chat.completion.chunk", "created": 0, "model": "gpt-3.5-turbo", "choices": [{ "index": 0, "delta": {}, "finish_reason": "tool_calls" }] } yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') else: # No tool calls, yield original chunks for chunk in raw_chunks: yield chunk # Log the response proxy_response = ProxyResponse(message=response_message) logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}") update_request_log(log_id, client_response=proxy_response.model_dump()) return StreamingResponse(stream_and_log(), media_type="text/event-stream") # Handle non-streaming request try: logger.info(f"Initiating non-streaming request for log ID: {log_id}") response_message = await process_chat_request(messages_to_llm, settings, log_id) proxy_response = ProxyResponse(message=response_message) logger.info(f"Response body for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}") # Log client response to DB update_request_log(log_id, client_response=proxy_response.model_dump()) return proxy_response except Exception as e: logger.exception(f"An unexpected error occurred during non-streaming request for log ID: {log_id}") # Log the error to the database update_request_log(log_id, client_response={"error": str(e)}) raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @app.get("/") def read_root(): return {"message": "LLM Tool Proxy is running."}