feat: Initial commit of LLM Tool Proxy
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
17
app/core/config.py
Normal file
17
app/core/config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""Manages application settings and configurations."""
|
||||
REAL_LLM_API_URL: Optional[str] = None
|
||||
REAL_LLM_API_KEY: Optional[str] = None
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""
|
||||
Returns an instance of the Settings object by loading from environment variables.
|
||||
"""
|
||||
return Settings(
|
||||
REAL_LLM_API_URL=os.getenv("REAL_LLM_API_URL"),
|
||||
REAL_LLM_API_KEY=os.getenv("REAL_LLM_API_KEY"),
|
||||
)
|
||||
79
app/main.py
Normal file
79
app/main.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# --- Explicit Debugging & Env Loading ---
|
||||
print(f"--- [DEBUG] Current Working Directory: {os.getcwd()}", file=sys.stderr)
|
||||
load_result = load_dotenv()
|
||||
print(f"--- [DEBUG] load_dotenv() result: {load_result}", file=sys.stderr)
|
||||
# ---
|
||||
|
||||
import logging
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from starlette.responses import StreamingResponse
|
||||
from .models import IncomingRequest, ProxyResponse
|
||||
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt
|
||||
from .core.config import get_settings, Settings
|
||||
|
||||
# --- Logging Configuration ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("llm_proxy.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
# --- End of Logging Configuration ---
|
||||
|
||||
app = FastAPI(
|
||||
title="LLM Tool Proxy",
|
||||
description="A proxy that intercepts LLM requests to inject and handle tool calls.",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Application startup complete.")
|
||||
current_settings = get_settings()
|
||||
logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}")
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(
|
||||
request: IncomingRequest,
|
||||
settings: Settings = Depends(get_settings)
|
||||
):
|
||||
"""
|
||||
This endpoint mimics the OpenAI Chat Completions API and supports both
|
||||
streaming (`stream=True`) and non-streaming (`stream=False`) responses.
|
||||
"""
|
||||
if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL:
|
||||
logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.")
|
||||
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
|
||||
|
||||
# Prepare messages, potentially with tool injection
|
||||
# This prepares the messages that will be sent to the LLM backend
|
||||
messages_to_llm = request.messages
|
||||
if request.tools:
|
||||
messages_to_llm = inject_tools_into_prompt(request.messages, request.tools)
|
||||
|
||||
# Handle streaming request
|
||||
if request.stream:
|
||||
logger.info(f"Initiating streaming request with {len(messages_to_llm)} messages.")
|
||||
generator = stream_llm_api(messages_to_llm, settings)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
|
||||
# Handle non-streaming request
|
||||
try:
|
||||
logger.info(f"Initiating non-streaming request with {len(messages_to_llm)} messages.")
|
||||
response_message = await process_chat_request(messages_to_llm, request.tools, settings)
|
||||
logger.info("Successfully processed non-streaming request.")
|
||||
return ProxyResponse(message=response_message)
|
||||
except Exception as e:
|
||||
logger.exception("An unexpected error occurred during non-streaming request.")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "LLM Tool Proxy is running."}
|
||||
41
app/models.py
Normal file
41
app/models.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# Models for incoming requests
|
||||
class ChatMessage(BaseModel):
|
||||
"""Represents a single message in the chat history."""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Represents a tool definition provided by the user."""
|
||||
type: str
|
||||
function: Dict[str, Any]
|
||||
|
||||
class IncomingRequest(BaseModel):
|
||||
"""Defines the structure of the request from the client."""
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[List[Tool]] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
# Models for outgoing responses
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""Function call details within a tool call."""
|
||||
name: str
|
||||
arguments: str # JSON string of arguments
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Represents a tool call requested by the LLM."""
|
||||
id: str
|
||||
type: str = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
class ResponseMessage(BaseModel):
|
||||
"""The message part of the response from the proxy."""
|
||||
role: str = "assistant"
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
class ProxyResponse(BaseModel):
|
||||
"""Defines the final structured response sent back to the client."""
|
||||
message: ResponseMessage
|
||||
198
app/services.py
Normal file
198
app/services.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import json
|
||||
import re
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
|
||||
|
||||
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
|
||||
from .core.config import Settings
|
||||
|
||||
# Get a logger instance for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Helper for parsing SSE ---
|
||||
# Regex to extract data field from SSE
|
||||
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
|
||||
|
||||
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
|
||||
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
|
||||
try:
|
||||
lines = chunk.decode("utf-8").splitlines()
|
||||
for line in lines:
|
||||
if line.startswith("data:"):
|
||||
match = SSE_DATA_RE.match(line)
|
||||
if match:
|
||||
data_str = match.group(1).strip()
|
||||
if data_str == "[DONE]": # Handle OpenAI-style stream termination
|
||||
return {"type": "done"}
|
||||
try:
|
||||
return json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
|
||||
return None
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("Failed to decode chunk as UTF-8.")
|
||||
return None
|
||||
|
||||
# --- End Helper ---
|
||||
|
||||
|
||||
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
|
||||
"""
|
||||
Injects tool definitions into the message list as a system prompt.
|
||||
"""
|
||||
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
|
||||
tool_prompt = f"""
|
||||
You have access to a set of tools. You can call them by emitting a JSON object inside a <tool_call> XML tag.
|
||||
The JSON object should have a "name" and "arguments" field.
|
||||
|
||||
Here are the available tools:
|
||||
{tool_defs}
|
||||
|
||||
Only use the tools if strictly necessary.
|
||||
"""
|
||||
new_messages = messages.copy()
|
||||
new_messages.insert(1, ChatMessage(role="system", content=tool_prompt))
|
||||
return new_messages
|
||||
|
||||
|
||||
def parse_llm_response_from_content(text: str) -> ResponseMessage:
|
||||
"""
|
||||
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
|
||||
This is used when the LLM does not support native tool calling.
|
||||
"""
|
||||
if not text:
|
||||
return ResponseMessage(content=None)
|
||||
|
||||
tool_call_match = re.search(r"<tool_call>(.*?)</tool_call>", text, re.DOTALL)
|
||||
|
||||
if tool_call_match:
|
||||
tool_call_str = tool_call_match.group(1).strip()
|
||||
try:
|
||||
tool_call_data = json.loads(tool_call_str)
|
||||
tool_call = ToolCall(
|
||||
id="call_" + tool_call_data.get("name", "unknown"),
|
||||
function=ToolCallFunction(
|
||||
name=tool_call_data.get("name"),
|
||||
arguments=json.dumps(tool_call_data.get("arguments", {})),
|
||||
)
|
||||
)
|
||||
content_before = text.split("<tool_call>")[0].strip()
|
||||
return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call])
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}")
|
||||
return ResponseMessage(content=text)
|
||||
else:
|
||||
return ResponseMessage(content=text)
|
||||
|
||||
|
||||
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Makes the raw HTTP streaming call to the LLM backend.
|
||||
Yields raw byte chunks as received.
|
||||
"""
|
||||
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
|
||||
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}")
|
||||
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'")
|
||||
# For streams, we log and let the stream terminate. The client will get a broken stream.
|
||||
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"An error occurred during raw stream request to LLM API: {e}")
|
||||
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
|
||||
|
||||
|
||||
async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks.
|
||||
"""
|
||||
async for chunk in _raw_stream_from_llm(messages, settings):
|
||||
# We assume the raw chunks are already SSE formatted or can be split into lines.
|
||||
# For simplicity, we pass through the raw chunk bytes.
|
||||
# A more robust parser would ensure each yield is a complete SSE event line.
|
||||
yield chunk
|
||||
|
||||
|
||||
async def process_llm_stream_for_non_stream_request(
|
||||
messages: List[ChatMessage],
|
||||
settings: Settings
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregates a streaming LLM response into a single, non-streaming message.
|
||||
Handles SSE parsing and delta accumulation.
|
||||
"""
|
||||
full_content_parts = []
|
||||
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
|
||||
|
||||
async for chunk in _raw_stream_from_llm(messages, settings):
|
||||
parsed_data = _parse_sse_data(chunk)
|
||||
if parsed_data:
|
||||
if parsed_data.get("type") == "done":
|
||||
break # End of stream
|
||||
|
||||
# Assuming OpenAI-like streaming format
|
||||
choices = parsed_data.get("choices")
|
||||
if choices and len(choices) > 0:
|
||||
delta = choices[0].get("delta")
|
||||
if delta:
|
||||
if "content" in delta:
|
||||
full_content_parts.append(delta["content"])
|
||||
if "tool_calls" in delta:
|
||||
# Accumulate tool calls if they appear in deltas (complex)
|
||||
# For simplicity, we'll try to reconstruct the final tool_calls
|
||||
# from the final message, or fall back to content parsing later.
|
||||
# This part is highly dependent on LLM's exact streaming format for tool_calls.
|
||||
pass
|
||||
if choices[0].get("finish_reason"):
|
||||
# Check for finish_reason to identify stream end or tool_calls completion
|
||||
pass
|
||||
|
||||
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
|
||||
|
||||
# This is a simplification. Reconstructing tool_calls from deltas is non-trivial.
|
||||
# We will rely on parse_llm_response_from_content for tool calls if they are
|
||||
# embedded in the final content string, or assume the LLM doesn't send native
|
||||
# tool_calls in stream deltas that need aggregation here.
|
||||
logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}")
|
||||
|
||||
return final_message_dict
|
||||
|
||||
|
||||
async def process_chat_request(
|
||||
messages: List[ChatMessage],
|
||||
tools: Optional[List[Tool]],
|
||||
settings: Settings,
|
||||
) -> ResponseMessage:
|
||||
"""
|
||||
Main service function for non-streaming requests.
|
||||
It now calls the stream aggregation logic.
|
||||
"""
|
||||
request_messages = messages
|
||||
if tools:
|
||||
request_messages = inject_tools_into_prompt(messages, tools)
|
||||
|
||||
# All interactions with the real LLM now go through the streaming mechanism.
|
||||
llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings)
|
||||
|
||||
# Priority 1: Check for native tool calls (if the aggregation could reconstruct them)
|
||||
# Note: Reconstructing tool_calls from deltas in streaming is complex.
|
||||
# For now, we assume if tool_calls are present, they are complete.
|
||||
if llm_message_dict.get("tool_calls"):
|
||||
logger.info("Native tool calls detected in aggregated LLM response.")
|
||||
# Ensure it's a list of dicts suitable for Pydantic validation
|
||||
if isinstance(llm_message_dict["tool_calls"], list):
|
||||
return ResponseMessage.model_validate(llm_message_dict)
|
||||
else:
|
||||
logger.warning("Aggregated tool_calls not in expected list format. Treating as content.")
|
||||
|
||||
# Priority 2 (Fallback): Parse tool calls from content
|
||||
logger.info("No native tool calls from aggregation. Falling back to content parsing.")
|
||||
return parse_llm_response_from_content(llm_message_dict.get("content"))
|
||||
Reference in New Issue
Block a user