diff --git a/app/core/config.py b/app/core/config.py index c60a37f..226b629 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,6 +1,7 @@ import os from pydantic import BaseModel from typing import Optional +from dotenv import load_dotenv class Settings(BaseModel): """Manages application settings and configurations.""" @@ -11,6 +12,7 @@ def get_settings() -> Settings: """ Returns an instance of the Settings object by loading from environment variables. """ + load_dotenv() # Load environment variables from .env file return Settings( REAL_LLM_API_URL=os.getenv("REAL_LLM_API_URL"), REAL_LLM_API_KEY=os.getenv("REAL_LLM_API_KEY"), diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..71efac2 --- /dev/null +++ b/app/database.py @@ -0,0 +1,97 @@ +import sqlite3 +import json +from datetime import datetime +from typing import Dict, Any, Optional +import logging + +logging.basicConfig( + level=logging.DEBUG, # Set to DEBUG to capture the debug logs + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("llm_proxy.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +DATABASE_NAME = "llm_proxy.db" + +def init_db(): + """Initializes the database and creates the 'requests' table if it doesn't exist.""" + with sqlite3.connect(DATABASE_NAME) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + client_request TEXT, + llm_request TEXT, + llm_response TEXT, + client_response TEXT + ) + """) + conn.commit() + +def log_request(client_request: Dict[str, Any]) -> int: + """Logs the initial client request and returns the log ID.""" + with sqlite3.connect(DATABASE_NAME) as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT INTO requests (client_request) VALUES (?)", + (json.dumps(client_request),) + ) + conn.commit() + return cursor.lastrowid + +def update_request_log( + log_id: int, + llm_request: Optional[Dict[str, Any]] = None, + llm_response: Optional[Dict[str, Any]] = None, + client_response: Optional[Dict[str, Any]] = None, +): + """Updates a request log with the LLM request, LLM response, or client response.""" + fields_to_update = [] + values = [] + + if llm_request is not None: + fields_to_update.append("llm_request = ?") + values.append(json.dumps(llm_request)) + if llm_response is not None: + fields_to_update.append("llm_response = ?") + values.append(json.dumps(llm_response)) + if client_response is not None: + fields_to_update.append("client_response = ?") + values.append(json.dumps(client_response)) + + if not fields_to_update: + logger.debug(f"No fields to update for log ID {log_id}. Skipping database update.") + return + + sql = f"UPDATE requests SET {', '.join(fields_to_update)} WHERE id = ?" + values.append(log_id) + + try: + with sqlite3.connect(DATABASE_NAME) as conn: + cursor = conn.cursor() + cursor.execute(sql, tuple(values)) + logger.debug(f"Attempting to commit update for log ID {log_id} with fields: {fields_to_update}") + conn.commit() + logger.debug(f"Successfully committed update for log ID {log_id}.") + except sqlite3.Error as e: + logger.error(f"Database error updating log ID {log_id}: {e}") + except Exception as e: + logger.error(f"An unexpected error occurred while updating log ID {log_id}: {e}") + +def get_latest_log_entry() -> Optional[dict]: + """Helper to get the full latest log entry.""" + try: + with sqlite3.connect(DATABASE_NAME) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute("SELECT * FROM requests ORDER BY id DESC LIMIT 1") + row = cursor.fetchone() + if row: + return dict(row) + except sqlite3.Error as e: + print(f"Database error: {e}") + return None diff --git a/app/main.py b/app/main.py index 7a7535f..83bcfdb 100644 --- a/app/main.py +++ b/app/main.py @@ -1,19 +1,19 @@ import os import sys -from dotenv import load_dotenv - -# --- Explicit Debugging & Env Loading --- -print(f"--- [DEBUG] Current Working Directory: {os.getcwd()}", file=sys.stderr) -load_result = load_dotenv() -print(f"--- [DEBUG] load_dotenv() result: {load_result}", file=sys.stderr) -# --- - import logging -from fastapi import FastAPI, HTTPException, Depends +import time +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException, Depends, Request from starlette.responses import StreamingResponse + from .models import IncomingRequest, ProxyResponse -from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt +from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content from .core.config import get_settings, Settings +from .database import init_db, log_request, update_request_log + +# --- Environment & Debug Loading --- +# load_dotenv() # Uncomment if you run uvicorn directly and need to load .env +# --- # --- Logging Configuration --- logging.basicConfig( @@ -33,9 +33,26 @@ app = FastAPI( version="1.0.0", ) +# --- Middleware for logging basic request/response info --- +@app.middleware("http") +async def logging_middleware(request: Request, call_next): + start_time = time.time() + logger.info(f"Request received: {request.method} {request.url.path} from {request.client.host}") + logger.info(f"Request Headers: {dict(request.headers)}") + + response = await call_next(request) + + process_time = (time.time() - start_time) * 1000 + logger.info(f"Response sent: status_code={response.status_code} ({process_time:.2f}ms)") + return response +# --- End of Middleware --- + + @app.on_event("startup") async def startup_event(): logger.info("Application startup complete.") + init_db() + logger.info("Database initialized.") current_settings = get_settings() logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}") @@ -46,34 +63,57 @@ async def chat_completions( ): """ This endpoint mimics the OpenAI Chat Completions API and supports both - streaming (`stream=True`) and non-streaming (`stream=False`) responses. + streaming and non-streaming responses, with detailed logging. """ + log_id = log_request(client_request=request.model_dump()) + logger.info(f"Request body logged with ID: {log_id}") + if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL: logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.") raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.") - # Prepare messages, potentially with tool injection - # This prepares the messages that will be sent to the LLM backend messages_to_llm = request.messages if request.tools: messages_to_llm = inject_tools_into_prompt(request.messages, request.tools) # Handle streaming request if request.stream: - logger.info(f"Initiating streaming request with {len(messages_to_llm)} messages.") - generator = stream_llm_api(messages_to_llm, settings) - return StreamingResponse(generator, media_type="text/event-stream") + logger.info(f"Initiating streaming request for log ID: {log_id}") + + async def stream_and_log(): + stream_content_buffer = [] + async for chunk in stream_llm_api(messages_to_llm, settings, log_id): + stream_content_buffer.append(chunk.decode('utf-8')) + yield chunk + + # After the stream is complete, parse the full content and log it + full_content = "".join(stream_content_buffer) + response_message = parse_llm_response_from_content(full_content) + proxy_response = ProxyResponse(message=response_message) + + logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}") + update_request_log(log_id, client_response=proxy_response.model_dump()) + + return StreamingResponse(stream_and_log(), media_type="text/event-stream") # Handle non-streaming request try: - logger.info(f"Initiating non-streaming request with {len(messages_to_llm)} messages.") - response_message = await process_chat_request(messages_to_llm, request.tools, settings) - logger.info("Successfully processed non-streaming request.") - return ProxyResponse(message=response_message) + logger.info(f"Initiating non-streaming request for log ID: {log_id}") + response_message = await process_chat_request(messages_to_llm, settings, log_id) + + proxy_response = ProxyResponse(message=response_message) + logger.info(f"Response body for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}") + + # Log client response to DB + update_request_log(log_id, client_response=proxy_response.model_dump()) + + return proxy_response except Exception as e: - logger.exception("An unexpected error occurred during non-streaming request.") + logger.exception(f"An unexpected error occurred during non-streaming request for log ID: {log_id}") + # Log the error to the database + update_request_log(log_id, client_response={"error": str(e)}) raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @app.get("/") def read_root(): - return {"message": "LLM Tool Proxy is running."} \ No newline at end of file + return {"message": "LLM Tool Proxy is running."} diff --git a/app/models.py b/app/models.py index 4931088..8f705b3 100644 --- a/app/models.py +++ b/app/models.py @@ -7,10 +7,16 @@ class ChatMessage(BaseModel): role: str content: str +class Function(BaseModel): + """Represents the function definition within a tool.""" + name: str + description: str + parameters: Dict[str, Any] + class Tool(BaseModel): """Represents a tool definition provided by the user.""" type: str - function: Dict[str, Any] + function: Function class IncomingRequest(BaseModel): """Defines the structure of the request from the client.""" diff --git a/app/response_parser.py b/app/response_parser.py new file mode 100644 index 0000000..275e4db --- /dev/null +++ b/app/response_parser.py @@ -0,0 +1,326 @@ +""" +Response Parser Module + +This module provides low-coupling, high-cohesion parsing utilities for extracting +tool calls from LLM responses and converting them to OpenAI-compatible format. + +Design principles: +- Single Responsibility: Each function handles one specific parsing task +- Testability: Pure functions that are easy to unit test +- Type Safety: Uses Pydantic models for validation +""" + +import re +import json +import logging +from typing import Optional, List, Dict, Any +from uuid import uuid4 + +from app.models import ResponseMessage, ToolCall, ToolCallFunction + +logger = logging.getLogger(__name__) + + +# Constants for tool call parsing +# Using XML-style tags for clarity and better compatibility with JSON +# LLM should emit:{"name": "...", "arguments": {...}} +TOOL_CALL_START_TAG = "{" +TOOL_CALL_END_TAG = "}" + + +class ToolCallParseError(Exception): + """Raised when tool call parsing fails.""" + pass + + +class ResponseParser: + """ + Parser for converting LLM text responses into structured ResponseMessage objects. + + This class encapsulates all parsing logic for tool calls, making it easy to test + and maintain. It follows the Single Responsibility Principle by focusing solely + on parsing responses. + """ + + def __init__(self, tool_call_start_tag: str = TOOL_CALL_START_TAG, + tool_call_end_tag: str = TOOL_CALL_END_TAG): + """ + Initialize the parser with configurable tags. + + Args: + tool_call_start_tag: The opening tag for tool calls (default: {...") + tool_call_end_tag: The closing tag for tool calls (default: ...}) + """ + self.tool_call_start_tag = tool_call_start_tag + self.tool_call_end_tag = tool_call_end_tag + self._compile_regex() + + def _compile_regex(self): + """Compile the regex pattern for tool call extraction.""" + # Escape special regex characters in the tags + escaped_start = re.escape(self.tool_call_start_tag) + escaped_end = re.escape(self.tool_call_end_tag) + # Match from start tag to end tag (greedy), including both tags + # This ensures we capture the complete JSON object + self._tool_call_pattern = re.compile( + f"{escaped_start}.*{escaped_end}", + re.DOTALL + ) + + def _extract_valid_json(self, text: str) -> Optional[str]: + """ + Extract a valid JSON object from text that may contain extra content. + + This handles cases where non-greedy regex matching includes incomplete JSON. + + Args: + text: Text that should contain a JSON object + + Returns: + The extracted valid JSON string, or None if not found + """ + text = text.lstrip() # Only strip leading whitespace + + # Find the first opening brace (the start of JSON) + start_idx = text.find('{') + if start_idx < 0: + return None + + text = text[start_idx:] # Start from the first opening brace + + # Find the matching closing brace by counting brackets + brace_count = 0 + in_string = False + escape_next = False + + for i, char in enumerate(text): + if escape_next: + escape_next = False + continue + + if char == '\\' and in_string: + escape_next = True + continue + + if char == '"': + in_string = not in_string + continue + + if not in_string: + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0: + # Found matching closing brace + return text[:i+1] + + return None + + def parse(self, llm_response: str) -> ResponseMessage: + """ + Parse an LLM response and extract tool calls if present. + + This is the main entry point for parsing. It handles both: + 1. Responses with tool calls (wrapped in tags) + 2. Regular text responses + + Args: + llm_response: The raw text response from the LLM + + Returns: + ResponseMessage with content and optionally tool_calls + + Example: + >>> parser = ResponseParser() + >>> response = parser.parse('Hello world') + >>> response.content + 'Hello world' + + >>> response = parser.parse('Check the weather.{"name": "weather", "arguments": {...}}') + >>> response.tool_calls[0].function.name + 'weather' + """ + if not llm_response: + return ResponseMessage(content=None) + + try: + match = self._tool_call_pattern.search(llm_response) + + if match: + return self._parse_tool_call_response(llm_response, match) + else: + return self._parse_text_only_response(llm_response) + + except Exception as e: + logger.warning(f"Failed to parse LLM response: {e}. Returning as text.") + return ResponseMessage(content=llm_response) + + def _parse_tool_call_response(self, llm_response: str, match: re.Match) -> ResponseMessage: + """ + Parse a response that contains tool calls. + + Args: + llm_response: The full LLM response + match: The regex match object containing the tool call + + Returns: + ResponseMessage with content and tool_calls + """ + # The match includes start and end tags, so strip them + matched_text = match.group(0) + tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)] + + # Extract valid JSON by finding matching braces + json_str = self._extract_valid_json(tool_call_str) + if json_str is None: + # Fallback to trying to parse the entire string + json_str = tool_call_str + + try: + tool_call_data = json.loads(json_str) + + # Extract content before the tool call tag + parts = llm_response.split(self.tool_call_start_tag, 1) + content = parts[0].strip() if parts[0] else None + + # Create the tool call object + tool_call = self._create_tool_call(tool_call_data) + + return ResponseMessage( + content=content, + tool_calls=[tool_call] + ) + + except json.JSONDecodeError as e: + raise ToolCallParseError(f"Invalid JSON in tool call: {tool_call_str}. Error: {e}") + + def _parse_text_only_response(self, llm_response: str) -> ResponseMessage: + """ + Parse a response with no tool calls. + + Args: + llm_response: The full LLM response + + Returns: + ResponseMessage with content only + """ + return ResponseMessage(content=llm_response.strip()) + + def _create_tool_call(self, tool_call_data: Dict[str, Any]) -> ToolCall: + """ + Create a ToolCall object from parsed data. + + Args: + tool_call_data: Dictionary containing 'name' and optionally 'arguments' + + Returns: + ToolCall object + + Raises: + ToolCallParseError: If required fields are missing + """ + name = tool_call_data.get("name") + if not name: + raise ToolCallParseError("Tool call missing 'name' field") + + arguments = tool_call_data.get("arguments", {}) + + # Generate a unique ID for the tool call + tool_call_id = f"call_{name}_{str(uuid4())[:8]}" + + return ToolCall( + id=tool_call_id, + type="function", + function=ToolCallFunction( + name=name, + arguments=json.dumps(arguments) + ) + ) + + def parse_streaming_chunks(self, chunks: List[str]) -> ResponseMessage: + """ + Parse a list of streaming chunks and aggregate into a ResponseMessage. + + This method handles streaming responses where tool calls might be + split across multiple chunks. + + Args: + chunks: List of content chunks from streaming response + + Returns: + Parsed ResponseMessage + """ + full_content = "".join(chunks) + return self.parse(full_content) + + def parse_native_tool_calls(self, llm_response: Dict[str, Any]) -> ResponseMessage: + """ + Parse a response that already has native OpenAI-format tool calls. + + Some LLMs natively support tool calling and return them in the standard + OpenAI format. This method handles those responses. + + Args: + llm_response: Dictionary response from LLM with potential tool_calls field + + Returns: + ResponseMessage with parsed tool_calls or content + """ + if "tool_calls" in llm_response and llm_response["tool_calls"]: + # Parse native tool calls + tool_calls = [] + for tc in llm_response["tool_calls"]: + tool_calls.append(ToolCall( + id=tc.get("id", f"call_{str(uuid4())[:8]}"), + type=tc.get("type", "function"), + function=ToolCallFunction( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"] + ) + )) + + return ResponseMessage( + content=llm_response.get("content"), + tool_calls=tool_calls + ) + else: + # Fallback to text parsing + content = llm_response.get("content", "") + return self.parse(content) + + +# Convenience functions for backward compatibility and ease of use + +def parse_response(llm_response: str) -> ResponseMessage: + """ + Parse an LLM response using default parser settings. + + This is a convenience function for simple use cases. + + Args: + llm_response: The raw text response from the LLM + + Returns: + ResponseMessage with parsed content and tool calls + """ + parser = ResponseParser() + return parser.parse(llm_response) + + +def parse_response_with_custom_tags(llm_response: str, + start_tag: str, + end_tag: str) -> ResponseMessage: + """ + Parse an LLM response using custom tool call tags. + + Args: + llm_response: The raw text response from the LLM + start_tag: Custom start tag for tool calls + end_tag: Custom end tag for tool calls + + Returns: + ResponseMessage with parsed content and tool calls + """ + parser = ResponseParser(tool_call_start_tag=start_tag, tool_call_end_tag=end_tag) + return parser.parse(llm_response) diff --git a/app/services.py b/app/services.py index 8ee47f1..fb89bee 100644 --- a/app/services.py +++ b/app/services.py @@ -6,6 +6,8 @@ from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction from .core.config import Settings +from .database import update_request_log +from .response_parser import ResponseParser, parse_response # Get a logger instance for this module logger = logging.getLogger(__name__) @@ -39,160 +41,139 @@ def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]: def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]: """ - Injects tool definitions into the message list as a system prompt. + Injects a system prompt with tool definitions at the beginning of the message list. """ + from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG + tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2) + + # Build the format example separately to avoid f-string escaping issues + # We need to show double braces: outer {{ }} are tags, inner { } is JSON + json_example = '{"name": "search", "arguments": {"query": "example"}}' + full_example = f'{{{json_example}}}' + tool_prompt = f""" -You have access to a set of tools. You can call them by emitting a JSON object inside a XML tag. -The JSON object should have a "name" and "arguments" field. +You are a helpful assistant with access to a set of tools. +You can call them by emitting a JSON object inside tool call tags. + +IMPORTANT: Use double braces for tool calls - the outer braces are the tags ({TOOL_CALL_START_TAG} and {TOOL_CALL_END_TAG}), the inner braces are the JSON. +Format: {TOOL_CALL_START_TAG}{{\"name\": \"tool_name\", \"arguments\": {{...}}}}{TOOL_CALL_END_TAG} + +Example: {full_example} Here are the available tools: {tool_defs} Only use the tools if strictly necessary. """ - new_messages = messages.copy() - new_messages.insert(1, ChatMessage(role="system", content=tool_prompt)) - return new_messages + # Prepend the system prompt with tool definitions + return [ChatMessage(role="system", content=tool_prompt)] + messages def parse_llm_response_from_content(text: str) -> ResponseMessage: """ (Fallback) Parses the raw LLM text response to extract a message and any tool calls. This is used when the LLM does not support native tool calling. + + This function now delegates to the ResponseParser class for better maintainability. """ - if not text: - return ResponseMessage(content=None) - - tool_call_match = re.search(r"(.*?)", text, re.DOTALL) - - if tool_call_match: - tool_call_str = tool_call_match.group(1).strip() - try: - tool_call_data = json.loads(tool_call_str) - tool_call = ToolCall( - id="call_" + tool_call_data.get("name", "unknown"), - function=ToolCallFunction( - name=tool_call_data.get("name"), - arguments=json.dumps(tool_call_data.get("arguments", {})), - ) - ) - content_before = text.split("")[0].strip() - return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call]) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}") - return ResponseMessage(content=text) - else: - return ResponseMessage(content=text) + parser = ResponseParser() + return parser.parse(text) -async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]: +async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]: """ Makes the raw HTTP streaming call to the LLM backend. Yields raw byte chunks as received. """ headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" } payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True } - + + # Log the request payload to the database + update_request_log(log_id, llm_request=payload) + try: async with httpx.AsyncClient() as client: - logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}") + logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}") async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk except httpx.HTTPStatusError as e: - logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'") - # For streams, we log and let the stream terminate. The client will get a broken stream. + error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'" + logger.error(f"{error_message} for log ID {log_id}") + update_request_log(log_id, llm_response={"error": error_message}) yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n' except httpx.RequestError as e: - logger.error(f"An error occurred during raw stream request to LLM API: {e}") + error_message = f"An error occurred during raw stream request to LLM API: {e}" + logger.error(f"{error_message} for log ID {log_id}") + update_request_log(log_id, llm_response={"error": error_message}) yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n' -async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]: +async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]: """ - Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks. + Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks. """ - async for chunk in _raw_stream_from_llm(messages, settings): - # We assume the raw chunks are already SSE formatted or can be split into lines. - # For simplicity, we pass through the raw chunk bytes. - # A more robust parser would ensure each yield is a complete SSE event line. + llm_response_chunks = [] + async for chunk in _raw_stream_from_llm(messages, settings, log_id): + llm_response_chunks.append(chunk.decode('utf-8', errors='ignore')) + try: + logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}") + except UnicodeDecodeError: + logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}") yield chunk + # Log the full LLM response after the stream is complete + update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)}) + async def process_llm_stream_for_non_stream_request( - messages: List[ChatMessage], - settings: Settings + messages: List[ChatMessage], + settings: Settings, + log_id: int ) -> Dict[str, Any]: """ Aggregates a streaming LLM response into a single, non-streaming message. - Handles SSE parsing and delta accumulation. + Handles SSE parsing, delta accumulation, and logs the final aggregated message. """ full_content_parts = [] final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None} - - async for chunk in _raw_stream_from_llm(messages, settings): + llm_response_chunks = [] + + async for chunk in _raw_stream_from_llm(messages, settings, log_id): + llm_response_chunks.append(chunk.decode('utf-8', errors='ignore')) parsed_data = _parse_sse_data(chunk) if parsed_data: if parsed_data.get("type") == "done": - break # End of stream - - # Assuming OpenAI-like streaming format + break + choices = parsed_data.get("choices") if choices and len(choices) > 0: delta = choices[0].get("delta") - if delta: - if "content" in delta: - full_content_parts.append(delta["content"]) - if "tool_calls" in delta: - # Accumulate tool calls if they appear in deltas (complex) - # For simplicity, we'll try to reconstruct the final tool_calls - # from the final message, or fall back to content parsing later. - # This part is highly dependent on LLM's exact streaming format for tool_calls. - pass - if choices[0].get("finish_reason"): - # Check for finish_reason to identify stream end or tool_calls completion - pass - + if delta and "content" in delta: + full_content_parts.append(delta["content"]) + final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None - # This is a simplification. Reconstructing tool_calls from deltas is non-trivial. - # We will rely on parse_llm_response_from_content for tool calls if they are - # embedded in the final content string, or assume the LLM doesn't send native - # tool_calls in stream deltas that need aggregation here. - logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}") - + # Log the aggregated LLM response + logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}") + update_request_log(log_id, llm_response=final_message_dict) + return final_message_dict async def process_chat_request( - messages: List[ChatMessage], - tools: Optional[List[Tool]], + messages: List[ChatMessage], settings: Settings, + log_id: int ) -> ResponseMessage: """ Main service function for non-streaming requests. - It now calls the stream aggregation logic. + It calls the stream aggregation logic and then parses the result. """ - request_messages = messages - if tools: - request_messages = inject_tools_into_prompt(messages, tools) + llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id) - # All interactions with the real LLM now go through the streaming mechanism. - llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings) - - # Priority 1: Check for native tool calls (if the aggregation could reconstruct them) - # Note: Reconstructing tool_calls from deltas in streaming is complex. - # For now, we assume if tool_calls are present, they are complete. - if llm_message_dict.get("tool_calls"): - logger.info("Native tool calls detected in aggregated LLM response.") - # Ensure it's a list of dicts suitable for Pydantic validation - if isinstance(llm_message_dict["tool_calls"], list): - return ResponseMessage.model_validate(llm_message_dict) - else: - logger.warning("Aggregated tool_calls not in expected list format. Treating as content.") - - # Priority 2 (Fallback): Parse tool calls from content - logger.info("No native tool calls from aggregation. Falling back to content parsing.") - return parse_llm_response_from_content(llm_message_dict.get("content")) \ No newline at end of file + # Use the ResponseParser to handle both native and text-based tool calls + parser = ResponseParser() + return parser.parse_native_tool_calls(llm_message_dict) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2b7a354 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +python-dotenv +pydantic +requests diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py new file mode 100644 index 0000000..b549535 --- /dev/null +++ b/tests/test_response_parser.py @@ -0,0 +1,375 @@ +""" +Unit tests for the Response Parser module. + +Tests cover: +- Parsing text-only responses +- Parsing responses with tool calls +- Parsing native OpenAI-format tool calls +- Parsing streaming chunks +- Error handling and edge cases +""" + +import pytest +import json +from app.response_parser import ( + ResponseParser, + ToolCallParseError, + parse_response, + parse_response_with_custom_tags, + TOOL_CALL_START_TAG, + TOOL_CALL_END_TAG +) +from app.models import ToolCall, ToolCallFunction + + +class TestResponseParser: + """Test suite for ResponseParser class.""" + + def test_parse_text_only_response(self): + """Test parsing a response with no tool calls.""" + parser = ResponseParser() + text = "Hello, this is a simple response." + result = parser.parse(text) + + assert result.content == text + assert result.tool_calls is None + + def test_parse_empty_response(self): + """Test parsing an empty response.""" + parser = ResponseParser() + result = parser.parse("") + + assert result.content is None + assert result.tool_calls is None + + def test_parse_response_with_tool_call(self): + """Test parsing a response with a single tool call.""" + parser = ResponseParser() + text = f'''I'll check the weather for you. +{TOOL_CALL_START_TAG} +{{ + "name": "get_weather", + "arguments": {{ + "location": "San Francisco", + "units": "celsius" + }} +}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + assert result.content == "I'll check the weather for you." + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + + tool_call = result.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + + arguments = json.loads(tool_call.function.arguments) + assert arguments["location"] == "San Francisco" + assert arguments["units"] == "celsius" + + def test_parse_response_with_tool_call_no_content(self): + """Test parsing a response with only a tool call.""" + parser = ResponseParser() + text = f'''{TOOL_CALL_START_TAG} +{{ + "name": "shell", + "arguments": {{ + "command": ["ls", "-l"] + }} +}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + assert result.content is None + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "shell" + + def test_parse_response_with_malformed_tool_call(self): + """Test parsing a response with malformed JSON in tool call.""" + parser = ResponseParser() + text = f'''Here's the result. +{TOOL_CALL_START_TAG} +{{invalid json}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + # Should fall back to treating it as text + assert result.content == text + assert result.tool_calls is None + + def test_parse_response_with_missing_tool_name(self): + """Test parsing a tool call without a name field.""" + parser = ResponseParser() + text = f'''{TOOL_CALL_START_TAG} +{{ + "arguments": {{ + "command": "echo hello" + }} +}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + # Should handle gracefully - when name is missing, ToolCallParseError is raised + # and caught, falling back to treating as text content + # content will be the text between start and end tags (the JSON object) + assert result.content is not None + + def test_parse_response_with_complex_arguments(self): + """Test parsing a tool call with complex nested arguments.""" + parser = ResponseParser() + text = f'''Executing command. +{TOOL_CALL_START_TAG} +{{ + "name": "shell", + "arguments": {{ + "command": ["bash", "-lc", "echo 'hello world' && ls -la"], + "timeout": 5000, + "env": {{ + "PATH": "/usr/bin" + }} + }} +}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + assert result.content == "Executing command." + assert result.tool_calls is not None + + arguments = json.loads(result.tool_calls[0].function.arguments) + assert arguments["command"] == ["bash", "-lc", "echo 'hello world' && ls -la"] + assert arguments["timeout"] == 5000 + assert arguments["env"]["PATH"] == "/usr/bin" + + def test_parse_with_custom_tags(self): + """Test parsing with custom start and end tags.""" + parser = ResponseParser( + tool_call_start_tag="", + tool_call_end_tag="" + ) + text = """I'll help you with that. + +{ + "name": "search", + "arguments": { + "query": "python tutorials" + } +} + +""" + + result = parser.parse(text) + + assert "I'll help you with that" in result.content + assert result.tool_calls is not None + assert result.tool_calls[0].function.name == "search" + + def test_parse_streaming_chunks(self): + """Test parsing aggregated streaming chunks.""" + parser = ResponseParser() + chunks = [ + "I'll run that ", + "command for you.", + f'{TOOL_CALL_START_TAG}\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n{TOOL_CALL_END_TAG}' + ] + + result = parser.parse_streaming_chunks(chunks) + + assert "I'll run that command for you" in result.content + assert result.tool_calls is not None + assert result.tool_calls[0].function.name == "shell" + + def test_parse_native_tool_calls(self): + """Test parsing a native OpenAI-format response with tool calls.""" + parser = ResponseParser() + llm_response = { + "role": "assistant", + "content": "I'll execute that command.", + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "shell", + "arguments": '{"command": ["ls", "-l"]}' + } + } + ] + } + + result = parser.parse_native_tool_calls(llm_response) + + assert result.content == "I'll execute that command." + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id == "call_abc123" + assert result.tool_calls[0].function.name == "shell" + + def test_parse_native_tool_calls_multiple(self): + """Test parsing a response with multiple native tool calls.""" + parser = ResponseParser() + llm_response = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "shell", + "arguments": '{"command": ["pwd"]}' + } + }, + { + "id": "call_2", + "type": "function", + "function": { + "name": "shell", + "arguments": '{"command": ["ls", "-la"]}' + } + } + ] + } + + result = parser.parse_native_tool_calls(llm_response) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "call_1" + assert result.tool_calls[1].id == "call_2" + + def test_parse_native_tool_calls_falls_back_to_text(self): + """Test that native parsing falls back to text parsing when no tool_calls.""" + parser = ResponseParser() + llm_response = { + "role": "assistant", + "content": "This is a simple text response." + } + + result = parser.parse_native_tool_calls(llm_response) + + assert result.content == "This is a simple text response." + assert result.tool_calls is None + + def test_generate_unique_tool_call_ids(self): + """Test that tool call IDs are unique.""" + parser = ResponseParser() + + text1 = f'{TOOL_CALL_START_TAG}{{"name": "tool1", "arguments": {{}}}}{TOOL_CALL_END_TAG}' + text2 = f'{TOOL_CALL_START_TAG}{{"name": "tool2", "arguments": {{}}}}{TOOL_CALL_END_TAG}' + + result1 = parser.parse(text1) + result2 = parser.parse(text2) + + id1 = result1.tool_calls[0].id + id2 = result2.tool_calls[0].id + + assert id1 != id2 + assert id1.startswith("call_tool1_") + assert id2.startswith("call_tool2_") + + +class TestConvenienceFunctions: + """Test suite for convenience functions.""" + + def test_parse_response_default_parser(self): + """Test the parse_response convenience function.""" + text = f'{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "test"}}}}{TOOL_CALL_END_TAG}' + result = parse_response(text) + + assert result.tool_calls is not None + assert result.tool_calls[0].function.name == "search" + + def test_parse_response_with_custom_tags_function(self): + """Test the parse_response_with_custom_tags function.""" + text = """[CALL] +{"name": "test", "arguments": {}} +[/CALL]""" + result = parse_response_with_custom_tags( + text, + start_tag="[CALL]", + end_tag="[/CALL]" + ) + + assert result.tool_calls is not None + assert result.tool_calls[0].function.name == "test" + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_response_with_whitespace(self): + """Test parsing responses with various whitespace patterns.""" + parser = ResponseParser() + + # Leading/trailing whitespace + text = " Hello world. " + result = parser.parse(text) + assert result.content.strip() == "Hello world." + + def test_response_with_newlines_only(self): + """Test parsing a response with only newlines.""" + parser = ResponseParser() + result = parser.parse("\n\n\n") + + assert result.content == "" + assert result.tool_calls is None + + def test_response_with_special_characters(self): + """Test parsing responses with special characters in content.""" + parser = ResponseParser() + special_chars = '@#$%^&*()' + text = f'''Here's the result with special chars: {special_chars} +{TOOL_CALL_START_TAG} +{{ + "name": "test", + "arguments": {{ + "special": "!@#$%" + }} +}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + assert "@" in result.content + assert result.tool_calls is not None + + def test_response_with_escaped_quotes(self): + """Test parsing tool calls with escaped quotes in arguments.""" + parser = ResponseParser() + text = f'{TOOL_CALL_START_TAG}{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}{TOOL_CALL_END_TAG}' + + result = parser.parse(text) + arguments = json.loads(result.tool_calls[0].function.arguments) + assert arguments["message"] == 'Hello "world"' + + def test_multiple_tool_calls_in_text_finds_first(self): + """Test that only the first tool call is extracted.""" + parser = ResponseParser() + text = f'''First call. +{TOOL_CALL_START_TAG} +{{"name": "tool1", "arguments": {{}}}} +{TOOL_CALL_END_TAG} +Some text in between. +{TOOL_CALL_START_TAG} +{{"name": "tool2", "arguments": {{}}}} +{TOOL_CALL_END_TAG} +''' + + result = parser.parse(text) + + # Should only find the first one + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "tool1" diff --git a/tests/test_services.py b/tests/test_services.py index e88a3b3..06af6d7 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,23 +1,15 @@ import pytest -import httpx import json +import httpx from typing import List, AsyncGenerator -from app.services import call_llm_api_real -from app.models import ChatMessage +from app.services import inject_tools_into_prompt, parse_llm_response_from_content, process_chat_request +from app.models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction, IncomingRequest from app.core.config import Settings +from app.database import get_latest_log_entry -# Sample SSE chunks to simulate a streaming response -SSE_STREAM_CHUNKS = [ - 'data: {"choices": [{"delta": {"role": "assistant", "content": "Hello"}}]}', - 'data: {"choices": [{"delta": {"content": " world!"}}]}', - 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_123", "function": {"name": "get_weather", "arguments": ""}}]}}]}', - 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\"location\\":"}}]}}]}', - 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\"San Francisco\\"}"}}]}}]}', - 'data: [DONE]', -] +# --- Mocks for simulating httpx responses --- -# Mock settings for the test @pytest.fixture def mock_settings() -> Settings: """Provides mock settings for tests.""" @@ -26,71 +18,143 @@ def mock_settings() -> Settings: REAL_LLM_API_KEY="fake-key" ) -# Async generator to mock the streaming response -async def mock_aiter_lines() -> AsyncGenerator[str, None]: - for chunk in SSE_STREAM_CHUNKS: - yield chunk +class MockAsyncClient: + """Mocks the httpx.AsyncClient to simulate LLM responses.""" + def __init__(self, response_chunks: List[str]): + self._response_chunks = response_chunks + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + def stream(self, method, url, headers, json, timeout): + return MockStreamResponse(self._response_chunks) -# Mock for the httpx.Response object class MockStreamResponse: - def __init__(self, status_code: int = 200): + """Mocks the httpx.Response object for streaming.""" + def __init__(self, chunks: List[str], status_code: int = 200): + self._chunks = chunks self._status_code = status_code + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + def raise_for_status(self): if self._status_code != 200: - raise httpx.HTTPStatusError( - message="Error", request=httpx.Request("POST", ""), response=httpx.Response(self._status_code) - ) + raise httpx.HTTPStatusError("Error", request=None, response=httpx.Response(self._status_code)) - def aiter_lines(self) -> AsyncGenerator[str, None]: - return mock_aiter_lines() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - pass + async def aiter_bytes(self) -> AsyncGenerator[bytes, None]: + for chunk in self._chunks: + yield chunk.encode('utf-8') -# Mock for the httpx.AsyncClient -class MockAsyncClient: - def stream(self, method, url, headers, json, timeout): - return MockStreamResponse() +# --- End Mocks --- - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - pass -@pytest.mark.anyio -async def test_call_llm_api_real_streaming(monkeypatch, mock_settings): +def test_inject_tools_into_prompt(): """ - Tests that `call_llm_api_real` correctly handles an SSE stream, - parses the chunks, and assembles the final response message. + Tests that `inject_tools_into_prompt` correctly adds a system message + with tool definitions to the message list. """ - # Patch httpx.AsyncClient to use our mock - monkeypatch.setattr(httpx, "AsyncClient", MockAsyncClient) + # 1. Fetch the latest request from the database + latest_entry = get_latest_log_entry() + assert latest_entry is not None + client_request_data = json.loads(latest_entry["client_request"]) - messages = [ChatMessage(role="user", content="What is the weather in San Francisco?")] + # 2. Parse the data into Pydantic models + incoming_request = IncomingRequest.model_validate(client_request_data) + + # 3. Call the function to be tested + modified_messages = inject_tools_into_prompt(incoming_request.messages, incoming_request.tools) + + # 4. Assert the results + assert len(modified_messages) == len(incoming_request.messages) + 1 + + # Check that the first message is the new system prompt + system_prompt = modified_messages[0] + assert system_prompt.role == "system" + assert "You are a helpful assistant with access to a set of tools." in system_prompt.content + + # Check that the tool definitions are in the system prompt + for tool in incoming_request.tools: + assert tool.function.name in system_prompt.content + +def test_parse_llm_response_from_content(): + """ + Tests that `parse_llm_response_from_content` correctly parses a raw LLM + text response containing a { and extracts the `ResponseMessage`. + """ + # Sample raw text from an LLM + # Note: Since tags are { and }, we use double braces {{...}} where + # the outer { and } are tags, and the inner { and } are JSON + llm_text = """ +Some text from the model. +{{ + "name": "shell", + "arguments": { + "command": ["echo", "Hello from the tool!"] + } +}} +""" # Call the function - result = await call_llm_api_real(messages, mock_settings) + response_message = parse_llm_response_from_content(llm_text) - # Define the expected assembled result - expected_result = { - "role": "assistant", - "content": "Hello world!", - "tool_calls": [ - { - "id": "call_123", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"location": "San Francisco"}', - }, - } - ], - } + # Assertions + assert response_message.content == "Some text from the model." + assert response_message.tool_calls is not None + assert len(response_message.tool_calls) == 1 - # Assert that the result matches the expected output - assert result == expected_result + tool_call = response_message.tool_calls[0] + assert isinstance(tool_call, ToolCall) + assert tool_call.function.name == "shell" + + # The arguments are a JSON string, so we parse it for detailed checking + arguments = json.loads(tool_call.function.arguments) + assert arguments["command"] == ["echo", "Hello from the tool!"] + + +@pytest.mark.anyio +async def test_process_chat_request_with_tool_call(monkeypatch, mock_settings): + """ + Tests that `process_chat_request` can correctly parse a tool call from a + simulated real LLM streaming response. + """ + # 1. Define the simulated SSE stream from the LLM + # Using double braces for tool call tags + sse_chunks = [ + 'data: {"choices": [{"delta": {"content": "Okay, I will run that shell command."}}], "object": "chat.completion.chunk"}\n\n', + 'data: {"choices": [{"delta": {"content": "{{\\n \\"name\\": \\"shell\\",\\n \\"arguments\\": {\\n \\"command\\": [\\"ls\\", \\"-l\\"]\\n }\\n}}\\n"}}], "object": "chat.completion.chunk"}\n\n', + 'data: [DONE]\n\n' + ] + + # 2. Mock the httpx.AsyncClient + def mock_async_client(*args, **kwargs): + return MockAsyncClient(response_chunks=sse_chunks) + + monkeypatch.setattr(httpx, "AsyncClient", mock_async_client) + + # 3. Prepare the input for process_chat_request + messages = [ChatMessage(role="user", content="List the files.")] + tools = [Tool(type="function", function={"name": "shell", "description": "Run a shell command.", "parameters": {}})] + log_id = 1 # Dummy log ID for the test + + # 4. Call the function + request_messages = inject_tools_into_prompt(messages, tools) + response_message = await process_chat_request(request_messages, mock_settings, log_id) + + # 5. Assert the response is parsed correctly + assert response_message.content is not None + assert response_message.content.strip() == "Okay, I will run that shell command." + assert response_message.tool_calls is not None + assert len(response_message.tool_calls) == 1 + + tool_call = response_message.tool_calls[0] + assert tool_call.function.name == "shell" + + arguments = json.loads(tool_call.function.arguments) + assert arguments["command"] == ["ls", "-l"]