diff --git a/app/core/config.py b/app/core/config.py
index c60a37f..226b629 100644
--- a/app/core/config.py
+++ b/app/core/config.py
@@ -1,6 +1,7 @@
import os
from pydantic import BaseModel
from typing import Optional
+from dotenv import load_dotenv
class Settings(BaseModel):
"""Manages application settings and configurations."""
@@ -11,6 +12,7 @@ def get_settings() -> Settings:
"""
Returns an instance of the Settings object by loading from environment variables.
"""
+ load_dotenv() # Load environment variables from .env file
return Settings(
REAL_LLM_API_URL=os.getenv("REAL_LLM_API_URL"),
REAL_LLM_API_KEY=os.getenv("REAL_LLM_API_KEY"),
diff --git a/app/database.py b/app/database.py
new file mode 100644
index 0000000..71efac2
--- /dev/null
+++ b/app/database.py
@@ -0,0 +1,97 @@
+import sqlite3
+import json
+from datetime import datetime
+from typing import Dict, Any, Optional
+import logging
+
+logging.basicConfig(
+ level=logging.DEBUG, # Set to DEBUG to capture the debug logs
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.FileHandler("llm_proxy.log"),
+ logging.StreamHandler()
+ ]
+)
+logger = logging.getLogger(__name__)
+
+DATABASE_NAME = "llm_proxy.db"
+
+def init_db():
+ """Initializes the database and creates the 'requests' table if it doesn't exist."""
+ with sqlite3.connect(DATABASE_NAME) as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS requests (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
+ client_request TEXT,
+ llm_request TEXT,
+ llm_response TEXT,
+ client_response TEXT
+ )
+ """)
+ conn.commit()
+
+def log_request(client_request: Dict[str, Any]) -> int:
+ """Logs the initial client request and returns the log ID."""
+ with sqlite3.connect(DATABASE_NAME) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT INTO requests (client_request) VALUES (?)",
+ (json.dumps(client_request),)
+ )
+ conn.commit()
+ return cursor.lastrowid
+
+def update_request_log(
+ log_id: int,
+ llm_request: Optional[Dict[str, Any]] = None,
+ llm_response: Optional[Dict[str, Any]] = None,
+ client_response: Optional[Dict[str, Any]] = None,
+):
+ """Updates a request log with the LLM request, LLM response, or client response."""
+ fields_to_update = []
+ values = []
+
+ if llm_request is not None:
+ fields_to_update.append("llm_request = ?")
+ values.append(json.dumps(llm_request))
+ if llm_response is not None:
+ fields_to_update.append("llm_response = ?")
+ values.append(json.dumps(llm_response))
+ if client_response is not None:
+ fields_to_update.append("client_response = ?")
+ values.append(json.dumps(client_response))
+
+ if not fields_to_update:
+ logger.debug(f"No fields to update for log ID {log_id}. Skipping database update.")
+ return
+
+ sql = f"UPDATE requests SET {', '.join(fields_to_update)} WHERE id = ?"
+ values.append(log_id)
+
+ try:
+ with sqlite3.connect(DATABASE_NAME) as conn:
+ cursor = conn.cursor()
+ cursor.execute(sql, tuple(values))
+ logger.debug(f"Attempting to commit update for log ID {log_id} with fields: {fields_to_update}")
+ conn.commit()
+ logger.debug(f"Successfully committed update for log ID {log_id}.")
+ except sqlite3.Error as e:
+ logger.error(f"Database error updating log ID {log_id}: {e}")
+ except Exception as e:
+ logger.error(f"An unexpected error occurred while updating log ID {log_id}: {e}")
+
+def get_latest_log_entry() -> Optional[dict]:
+ """Helper to get the full latest log entry."""
+ try:
+ with sqlite3.connect(DATABASE_NAME) as conn:
+ conn.row_factory = sqlite3.Row
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM requests ORDER BY id DESC LIMIT 1")
+ row = cursor.fetchone()
+ if row:
+ return dict(row)
+ except sqlite3.Error as e:
+ print(f"Database error: {e}")
+ return None
diff --git a/app/main.py b/app/main.py
index 7a7535f..83bcfdb 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,19 +1,19 @@
import os
import sys
-from dotenv import load_dotenv
-
-# --- Explicit Debugging & Env Loading ---
-print(f"--- [DEBUG] Current Working Directory: {os.getcwd()}", file=sys.stderr)
-load_result = load_dotenv()
-print(f"--- [DEBUG] load_dotenv() result: {load_result}", file=sys.stderr)
-# ---
-
import logging
-from fastapi import FastAPI, HTTPException, Depends
+import time
+from dotenv import load_dotenv
+from fastapi import FastAPI, HTTPException, Depends, Request
from starlette.responses import StreamingResponse
+
from .models import IncomingRequest, ProxyResponse
-from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt
+from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content
from .core.config import get_settings, Settings
+from .database import init_db, log_request, update_request_log
+
+# --- Environment & Debug Loading ---
+# load_dotenv() # Uncomment if you run uvicorn directly and need to load .env
+# ---
# --- Logging Configuration ---
logging.basicConfig(
@@ -33,9 +33,26 @@ app = FastAPI(
version="1.0.0",
)
+# --- Middleware for logging basic request/response info ---
+@app.middleware("http")
+async def logging_middleware(request: Request, call_next):
+ start_time = time.time()
+ logger.info(f"Request received: {request.method} {request.url.path} from {request.client.host}")
+ logger.info(f"Request Headers: {dict(request.headers)}")
+
+ response = await call_next(request)
+
+ process_time = (time.time() - start_time) * 1000
+ logger.info(f"Response sent: status_code={response.status_code} ({process_time:.2f}ms)")
+ return response
+# --- End of Middleware ---
+
+
@app.on_event("startup")
async def startup_event():
logger.info("Application startup complete.")
+ init_db()
+ logger.info("Database initialized.")
current_settings = get_settings()
logger.info(f"Loaded LLM API URL: {current_settings.REAL_LLM_API_URL}")
@@ -46,34 +63,57 @@ async def chat_completions(
):
"""
This endpoint mimics the OpenAI Chat Completions API and supports both
- streaming (`stream=True`) and non-streaming (`stream=False`) responses.
+ streaming and non-streaming responses, with detailed logging.
"""
+ log_id = log_request(client_request=request.model_dump())
+ logger.info(f"Request body logged with ID: {log_id}")
+
if not settings.REAL_LLM_API_KEY or not settings.REAL_LLM_API_URL:
logger.error("REAL_LLM_API_KEY or REAL_LLM_API_URL is not configured.")
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
- # Prepare messages, potentially with tool injection
- # This prepares the messages that will be sent to the LLM backend
messages_to_llm = request.messages
if request.tools:
messages_to_llm = inject_tools_into_prompt(request.messages, request.tools)
# Handle streaming request
if request.stream:
- logger.info(f"Initiating streaming request with {len(messages_to_llm)} messages.")
- generator = stream_llm_api(messages_to_llm, settings)
- return StreamingResponse(generator, media_type="text/event-stream")
+ logger.info(f"Initiating streaming request for log ID: {log_id}")
+
+ async def stream_and_log():
+ stream_content_buffer = []
+ async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
+ stream_content_buffer.append(chunk.decode('utf-8'))
+ yield chunk
+
+ # After the stream is complete, parse the full content and log it
+ full_content = "".join(stream_content_buffer)
+ response_message = parse_llm_response_from_content(full_content)
+ proxy_response = ProxyResponse(message=response_message)
+
+ logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
+ update_request_log(log_id, client_response=proxy_response.model_dump())
+
+ return StreamingResponse(stream_and_log(), media_type="text/event-stream")
# Handle non-streaming request
try:
- logger.info(f"Initiating non-streaming request with {len(messages_to_llm)} messages.")
- response_message = await process_chat_request(messages_to_llm, request.tools, settings)
- logger.info("Successfully processed non-streaming request.")
- return ProxyResponse(message=response_message)
+ logger.info(f"Initiating non-streaming request for log ID: {log_id}")
+ response_message = await process_chat_request(messages_to_llm, settings, log_id)
+
+ proxy_response = ProxyResponse(message=response_message)
+ logger.info(f"Response body for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
+
+ # Log client response to DB
+ update_request_log(log_id, client_response=proxy_response.model_dump())
+
+ return proxy_response
except Exception as e:
- logger.exception("An unexpected error occurred during non-streaming request.")
+ logger.exception(f"An unexpected error occurred during non-streaming request for log ID: {log_id}")
+ # Log the error to the database
+ update_request_log(log_id, client_response={"error": str(e)})
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@app.get("/")
def read_root():
- return {"message": "LLM Tool Proxy is running."}
\ No newline at end of file
+ return {"message": "LLM Tool Proxy is running."}
diff --git a/app/models.py b/app/models.py
index 4931088..8f705b3 100644
--- a/app/models.py
+++ b/app/models.py
@@ -7,10 +7,16 @@ class ChatMessage(BaseModel):
role: str
content: str
+class Function(BaseModel):
+ """Represents the function definition within a tool."""
+ name: str
+ description: str
+ parameters: Dict[str, Any]
+
class Tool(BaseModel):
"""Represents a tool definition provided by the user."""
type: str
- function: Dict[str, Any]
+ function: Function
class IncomingRequest(BaseModel):
"""Defines the structure of the request from the client."""
diff --git a/app/response_parser.py b/app/response_parser.py
new file mode 100644
index 0000000..275e4db
--- /dev/null
+++ b/app/response_parser.py
@@ -0,0 +1,326 @@
+"""
+Response Parser Module
+
+This module provides low-coupling, high-cohesion parsing utilities for extracting
+tool calls from LLM responses and converting them to OpenAI-compatible format.
+
+Design principles:
+- Single Responsibility: Each function handles one specific parsing task
+- Testability: Pure functions that are easy to unit test
+- Type Safety: Uses Pydantic models for validation
+"""
+
+import re
+import json
+import logging
+from typing import Optional, List, Dict, Any
+from uuid import uuid4
+
+from app.models import ResponseMessage, ToolCall, ToolCallFunction
+
+logger = logging.getLogger(__name__)
+
+
+# Constants for tool call parsing
+# Using XML-style tags for clarity and better compatibility with JSON
+# LLM should emit:{"name": "...", "arguments": {...}}
+TOOL_CALL_START_TAG = "{"
+TOOL_CALL_END_TAG = "}"
+
+
+class ToolCallParseError(Exception):
+ """Raised when tool call parsing fails."""
+ pass
+
+
+class ResponseParser:
+ """
+ Parser for converting LLM text responses into structured ResponseMessage objects.
+
+ This class encapsulates all parsing logic for tool calls, making it easy to test
+ and maintain. It follows the Single Responsibility Principle by focusing solely
+ on parsing responses.
+ """
+
+ def __init__(self, tool_call_start_tag: str = TOOL_CALL_START_TAG,
+ tool_call_end_tag: str = TOOL_CALL_END_TAG):
+ """
+ Initialize the parser with configurable tags.
+
+ Args:
+ tool_call_start_tag: The opening tag for tool calls (default: {...")
+ tool_call_end_tag: The closing tag for tool calls (default: ...})
+ """
+ self.tool_call_start_tag = tool_call_start_tag
+ self.tool_call_end_tag = tool_call_end_tag
+ self._compile_regex()
+
+ def _compile_regex(self):
+ """Compile the regex pattern for tool call extraction."""
+ # Escape special regex characters in the tags
+ escaped_start = re.escape(self.tool_call_start_tag)
+ escaped_end = re.escape(self.tool_call_end_tag)
+ # Match from start tag to end tag (greedy), including both tags
+ # This ensures we capture the complete JSON object
+ self._tool_call_pattern = re.compile(
+ f"{escaped_start}.*{escaped_end}",
+ re.DOTALL
+ )
+
+ def _extract_valid_json(self, text: str) -> Optional[str]:
+ """
+ Extract a valid JSON object from text that may contain extra content.
+
+ This handles cases where non-greedy regex matching includes incomplete JSON.
+
+ Args:
+ text: Text that should contain a JSON object
+
+ Returns:
+ The extracted valid JSON string, or None if not found
+ """
+ text = text.lstrip() # Only strip leading whitespace
+
+ # Find the first opening brace (the start of JSON)
+ start_idx = text.find('{')
+ if start_idx < 0:
+ return None
+
+ text = text[start_idx:] # Start from the first opening brace
+
+ # Find the matching closing brace by counting brackets
+ brace_count = 0
+ in_string = False
+ escape_next = False
+
+ for i, char in enumerate(text):
+ if escape_next:
+ escape_next = False
+ continue
+
+ if char == '\\' and in_string:
+ escape_next = True
+ continue
+
+ if char == '"':
+ in_string = not in_string
+ continue
+
+ if not in_string:
+ if char == '{':
+ brace_count += 1
+ elif char == '}':
+ brace_count -= 1
+ if brace_count == 0:
+ # Found matching closing brace
+ return text[:i+1]
+
+ return None
+
+ def parse(self, llm_response: str) -> ResponseMessage:
+ """
+ Parse an LLM response and extract tool calls if present.
+
+ This is the main entry point for parsing. It handles both:
+ 1. Responses with tool calls (wrapped in tags)
+ 2. Regular text responses
+
+ Args:
+ llm_response: The raw text response from the LLM
+
+ Returns:
+ ResponseMessage with content and optionally tool_calls
+
+ Example:
+ >>> parser = ResponseParser()
+ >>> response = parser.parse('Hello world')
+ >>> response.content
+ 'Hello world'
+
+ >>> response = parser.parse('Check the weather.{"name": "weather", "arguments": {...}}')
+ >>> response.tool_calls[0].function.name
+ 'weather'
+ """
+ if not llm_response:
+ return ResponseMessage(content=None)
+
+ try:
+ match = self._tool_call_pattern.search(llm_response)
+
+ if match:
+ return self._parse_tool_call_response(llm_response, match)
+ else:
+ return self._parse_text_only_response(llm_response)
+
+ except Exception as e:
+ logger.warning(f"Failed to parse LLM response: {e}. Returning as text.")
+ return ResponseMessage(content=llm_response)
+
+ def _parse_tool_call_response(self, llm_response: str, match: re.Match) -> ResponseMessage:
+ """
+ Parse a response that contains tool calls.
+
+ Args:
+ llm_response: The full LLM response
+ match: The regex match object containing the tool call
+
+ Returns:
+ ResponseMessage with content and tool_calls
+ """
+ # The match includes start and end tags, so strip them
+ matched_text = match.group(0)
+ tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)]
+
+ # Extract valid JSON by finding matching braces
+ json_str = self._extract_valid_json(tool_call_str)
+ if json_str is None:
+ # Fallback to trying to parse the entire string
+ json_str = tool_call_str
+
+ try:
+ tool_call_data = json.loads(json_str)
+
+ # Extract content before the tool call tag
+ parts = llm_response.split(self.tool_call_start_tag, 1)
+ content = parts[0].strip() if parts[0] else None
+
+ # Create the tool call object
+ tool_call = self._create_tool_call(tool_call_data)
+
+ return ResponseMessage(
+ content=content,
+ tool_calls=[tool_call]
+ )
+
+ except json.JSONDecodeError as e:
+ raise ToolCallParseError(f"Invalid JSON in tool call: {tool_call_str}. Error: {e}")
+
+ def _parse_text_only_response(self, llm_response: str) -> ResponseMessage:
+ """
+ Parse a response with no tool calls.
+
+ Args:
+ llm_response: The full LLM response
+
+ Returns:
+ ResponseMessage with content only
+ """
+ return ResponseMessage(content=llm_response.strip())
+
+ def _create_tool_call(self, tool_call_data: Dict[str, Any]) -> ToolCall:
+ """
+ Create a ToolCall object from parsed data.
+
+ Args:
+ tool_call_data: Dictionary containing 'name' and optionally 'arguments'
+
+ Returns:
+ ToolCall object
+
+ Raises:
+ ToolCallParseError: If required fields are missing
+ """
+ name = tool_call_data.get("name")
+ if not name:
+ raise ToolCallParseError("Tool call missing 'name' field")
+
+ arguments = tool_call_data.get("arguments", {})
+
+ # Generate a unique ID for the tool call
+ tool_call_id = f"call_{name}_{str(uuid4())[:8]}"
+
+ return ToolCall(
+ id=tool_call_id,
+ type="function",
+ function=ToolCallFunction(
+ name=name,
+ arguments=json.dumps(arguments)
+ )
+ )
+
+ def parse_streaming_chunks(self, chunks: List[str]) -> ResponseMessage:
+ """
+ Parse a list of streaming chunks and aggregate into a ResponseMessage.
+
+ This method handles streaming responses where tool calls might be
+ split across multiple chunks.
+
+ Args:
+ chunks: List of content chunks from streaming response
+
+ Returns:
+ Parsed ResponseMessage
+ """
+ full_content = "".join(chunks)
+ return self.parse(full_content)
+
+ def parse_native_tool_calls(self, llm_response: Dict[str, Any]) -> ResponseMessage:
+ """
+ Parse a response that already has native OpenAI-format tool calls.
+
+ Some LLMs natively support tool calling and return them in the standard
+ OpenAI format. This method handles those responses.
+
+ Args:
+ llm_response: Dictionary response from LLM with potential tool_calls field
+
+ Returns:
+ ResponseMessage with parsed tool_calls or content
+ """
+ if "tool_calls" in llm_response and llm_response["tool_calls"]:
+ # Parse native tool calls
+ tool_calls = []
+ for tc in llm_response["tool_calls"]:
+ tool_calls.append(ToolCall(
+ id=tc.get("id", f"call_{str(uuid4())[:8]}"),
+ type=tc.get("type", "function"),
+ function=ToolCallFunction(
+ name=tc["function"]["name"],
+ arguments=tc["function"]["arguments"]
+ )
+ ))
+
+ return ResponseMessage(
+ content=llm_response.get("content"),
+ tool_calls=tool_calls
+ )
+ else:
+ # Fallback to text parsing
+ content = llm_response.get("content", "")
+ return self.parse(content)
+
+
+# Convenience functions for backward compatibility and ease of use
+
+def parse_response(llm_response: str) -> ResponseMessage:
+ """
+ Parse an LLM response using default parser settings.
+
+ This is a convenience function for simple use cases.
+
+ Args:
+ llm_response: The raw text response from the LLM
+
+ Returns:
+ ResponseMessage with parsed content and tool calls
+ """
+ parser = ResponseParser()
+ return parser.parse(llm_response)
+
+
+def parse_response_with_custom_tags(llm_response: str,
+ start_tag: str,
+ end_tag: str) -> ResponseMessage:
+ """
+ Parse an LLM response using custom tool call tags.
+
+ Args:
+ llm_response: The raw text response from the LLM
+ start_tag: Custom start tag for tool calls
+ end_tag: Custom end tag for tool calls
+
+ Returns:
+ ResponseMessage with parsed content and tool calls
+ """
+ parser = ResponseParser(tool_call_start_tag=start_tag, tool_call_end_tag=end_tag)
+ return parser.parse(llm_response)
diff --git a/app/services.py b/app/services.py
index 8ee47f1..fb89bee 100644
--- a/app/services.py
+++ b/app/services.py
@@ -6,6 +6,8 @@ from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
+from .database import update_request_log
+from .response_parser import ResponseParser, parse_response
# Get a logger instance for this module
logger = logging.getLogger(__name__)
@@ -39,160 +41,139 @@ def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
- Injects tool definitions into the message list as a system prompt.
+ Injects a system prompt with tool definitions at the beginning of the message list.
"""
+ from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
+
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
+
+ # Build the format example separately to avoid f-string escaping issues
+ # We need to show double braces: outer {{ }} are tags, inner { } is JSON
+ json_example = '{"name": "search", "arguments": {"query": "example"}}'
+ full_example = f'{{{json_example}}}'
+
tool_prompt = f"""
-You have access to a set of tools. You can call them by emitting a JSON object inside a XML tag.
-The JSON object should have a "name" and "arguments" field.
+You are a helpful assistant with access to a set of tools.
+You can call them by emitting a JSON object inside tool call tags.
+
+IMPORTANT: Use double braces for tool calls - the outer braces are the tags ({TOOL_CALL_START_TAG} and {TOOL_CALL_END_TAG}), the inner braces are the JSON.
+Format: {TOOL_CALL_START_TAG}{{\"name\": \"tool_name\", \"arguments\": {{...}}}}{TOOL_CALL_END_TAG}
+
+Example: {full_example}
Here are the available tools:
{tool_defs}
Only use the tools if strictly necessary.
"""
- new_messages = messages.copy()
- new_messages.insert(1, ChatMessage(role="system", content=tool_prompt))
- return new_messages
+ # Prepend the system prompt with tool definitions
+ return [ChatMessage(role="system", content=tool_prompt)] + messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
+
+ This function now delegates to the ResponseParser class for better maintainability.
"""
- if not text:
- return ResponseMessage(content=None)
-
- tool_call_match = re.search(r"(.*?)", text, re.DOTALL)
-
- if tool_call_match:
- tool_call_str = tool_call_match.group(1).strip()
- try:
- tool_call_data = json.loads(tool_call_str)
- tool_call = ToolCall(
- id="call_" + tool_call_data.get("name", "unknown"),
- function=ToolCallFunction(
- name=tool_call_data.get("name"),
- arguments=json.dumps(tool_call_data.get("arguments", {})),
- )
- )
- content_before = text.split("")[0].strip()
- return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call])
- except json.JSONDecodeError as e:
- logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}")
- return ResponseMessage(content=text)
- else:
- return ResponseMessage(content=text)
+ parser = ResponseParser()
+ return parser.parse(text)
-async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
+async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
-
+
+ # Log the request payload to the database
+ update_request_log(log_id, llm_request=payload)
+
try:
async with httpx.AsyncClient() as client:
- logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}")
+ logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
- logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'")
- # For streams, we log and let the stream terminate. The client will get a broken stream.
+ error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'"
+ logger.error(f"{error_message} for log ID {log_id}")
+ update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
- logger.error(f"An error occurred during raw stream request to LLM API: {e}")
+ error_message = f"An error occurred during raw stream request to LLM API: {e}"
+ logger.error(f"{error_message} for log ID {log_id}")
+ update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
-async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
+async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
- Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks.
+ Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks.
"""
- async for chunk in _raw_stream_from_llm(messages, settings):
- # We assume the raw chunks are already SSE formatted or can be split into lines.
- # For simplicity, we pass through the raw chunk bytes.
- # A more robust parser would ensure each yield is a complete SSE event line.
+ llm_response_chunks = []
+ async for chunk in _raw_stream_from_llm(messages, settings, log_id):
+ llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
+ try:
+ logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}")
+ except UnicodeDecodeError:
+ logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}")
yield chunk
+ # Log the full LLM response after the stream is complete
+ update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)})
+
async def process_llm_stream_for_non_stream_request(
- messages: List[ChatMessage],
- settings: Settings
+ messages: List[ChatMessage],
+ settings: Settings,
+ log_id: int
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
- Handles SSE parsing and delta accumulation.
+ Handles SSE parsing, delta accumulation, and logs the final aggregated message.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
-
- async for chunk in _raw_stream_from_llm(messages, settings):
+ llm_response_chunks = []
+
+ async for chunk in _raw_stream_from_llm(messages, settings, log_id):
+ llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
- break # End of stream
-
- # Assuming OpenAI-like streaming format
+ break
+
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
- if delta:
- if "content" in delta:
- full_content_parts.append(delta["content"])
- if "tool_calls" in delta:
- # Accumulate tool calls if they appear in deltas (complex)
- # For simplicity, we'll try to reconstruct the final tool_calls
- # from the final message, or fall back to content parsing later.
- # This part is highly dependent on LLM's exact streaming format for tool_calls.
- pass
- if choices[0].get("finish_reason"):
- # Check for finish_reason to identify stream end or tool_calls completion
- pass
-
+ if delta and "content" in delta:
+ full_content_parts.append(delta["content"])
+
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
- # This is a simplification. Reconstructing tool_calls from deltas is non-trivial.
- # We will rely on parse_llm_response_from_content for tool calls if they are
- # embedded in the final content string, or assume the LLM doesn't send native
- # tool_calls in stream deltas that need aggregation here.
- logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}")
-
+ # Log the aggregated LLM response
+ logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}")
+ update_request_log(log_id, llm_response=final_message_dict)
+
return final_message_dict
async def process_chat_request(
- messages: List[ChatMessage],
- tools: Optional[List[Tool]],
+ messages: List[ChatMessage],
settings: Settings,
+ log_id: int
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
- It now calls the stream aggregation logic.
+ It calls the stream aggregation logic and then parses the result.
"""
- request_messages = messages
- if tools:
- request_messages = inject_tools_into_prompt(messages, tools)
+ llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id)
- # All interactions with the real LLM now go through the streaming mechanism.
- llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings)
-
- # Priority 1: Check for native tool calls (if the aggregation could reconstruct them)
- # Note: Reconstructing tool_calls from deltas in streaming is complex.
- # For now, we assume if tool_calls are present, they are complete.
- if llm_message_dict.get("tool_calls"):
- logger.info("Native tool calls detected in aggregated LLM response.")
- # Ensure it's a list of dicts suitable for Pydantic validation
- if isinstance(llm_message_dict["tool_calls"], list):
- return ResponseMessage.model_validate(llm_message_dict)
- else:
- logger.warning("Aggregated tool_calls not in expected list format. Treating as content.")
-
- # Priority 2 (Fallback): Parse tool calls from content
- logger.info("No native tool calls from aggregation. Falling back to content parsing.")
- return parse_llm_response_from_content(llm_message_dict.get("content"))
\ No newline at end of file
+ # Use the ResponseParser to handle both native and text-based tool calls
+ parser = ResponseParser()
+ return parser.parse_native_tool_calls(llm_message_dict)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2b7a354
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+python-dotenv
+pydantic
+requests
diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py
new file mode 100644
index 0000000..b549535
--- /dev/null
+++ b/tests/test_response_parser.py
@@ -0,0 +1,375 @@
+"""
+Unit tests for the Response Parser module.
+
+Tests cover:
+- Parsing text-only responses
+- Parsing responses with tool calls
+- Parsing native OpenAI-format tool calls
+- Parsing streaming chunks
+- Error handling and edge cases
+"""
+
+import pytest
+import json
+from app.response_parser import (
+ ResponseParser,
+ ToolCallParseError,
+ parse_response,
+ parse_response_with_custom_tags,
+ TOOL_CALL_START_TAG,
+ TOOL_CALL_END_TAG
+)
+from app.models import ToolCall, ToolCallFunction
+
+
+class TestResponseParser:
+ """Test suite for ResponseParser class."""
+
+ def test_parse_text_only_response(self):
+ """Test parsing a response with no tool calls."""
+ parser = ResponseParser()
+ text = "Hello, this is a simple response."
+ result = parser.parse(text)
+
+ assert result.content == text
+ assert result.tool_calls is None
+
+ def test_parse_empty_response(self):
+ """Test parsing an empty response."""
+ parser = ResponseParser()
+ result = parser.parse("")
+
+ assert result.content is None
+ assert result.tool_calls is None
+
+ def test_parse_response_with_tool_call(self):
+ """Test parsing a response with a single tool call."""
+ parser = ResponseParser()
+ text = f'''I'll check the weather for you.
+{TOOL_CALL_START_TAG}
+{{
+ "name": "get_weather",
+ "arguments": {{
+ "location": "San Francisco",
+ "units": "celsius"
+ }}
+}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ assert result.content == "I'll check the weather for you."
+ assert result.tool_calls is not None
+ assert len(result.tool_calls) == 1
+
+ tool_call = result.tool_calls[0]
+ assert tool_call.type == "function"
+ assert tool_call.function.name == "get_weather"
+
+ arguments = json.loads(tool_call.function.arguments)
+ assert arguments["location"] == "San Francisco"
+ assert arguments["units"] == "celsius"
+
+ def test_parse_response_with_tool_call_no_content(self):
+ """Test parsing a response with only a tool call."""
+ parser = ResponseParser()
+ text = f'''{TOOL_CALL_START_TAG}
+{{
+ "name": "shell",
+ "arguments": {{
+ "command": ["ls", "-l"]
+ }}
+}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ assert result.content is None
+ assert result.tool_calls is not None
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "shell"
+
+ def test_parse_response_with_malformed_tool_call(self):
+ """Test parsing a response with malformed JSON in tool call."""
+ parser = ResponseParser()
+ text = f'''Here's the result.
+{TOOL_CALL_START_TAG}
+{{invalid json}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ # Should fall back to treating it as text
+ assert result.content == text
+ assert result.tool_calls is None
+
+ def test_parse_response_with_missing_tool_name(self):
+ """Test parsing a tool call without a name field."""
+ parser = ResponseParser()
+ text = f'''{TOOL_CALL_START_TAG}
+{{
+ "arguments": {{
+ "command": "echo hello"
+ }}
+}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ # Should handle gracefully - when name is missing, ToolCallParseError is raised
+ # and caught, falling back to treating as text content
+ # content will be the text between start and end tags (the JSON object)
+ assert result.content is not None
+
+ def test_parse_response_with_complex_arguments(self):
+ """Test parsing a tool call with complex nested arguments."""
+ parser = ResponseParser()
+ text = f'''Executing command.
+{TOOL_CALL_START_TAG}
+{{
+ "name": "shell",
+ "arguments": {{
+ "command": ["bash", "-lc", "echo 'hello world' && ls -la"],
+ "timeout": 5000,
+ "env": {{
+ "PATH": "/usr/bin"
+ }}
+ }}
+}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ assert result.content == "Executing command."
+ assert result.tool_calls is not None
+
+ arguments = json.loads(result.tool_calls[0].function.arguments)
+ assert arguments["command"] == ["bash", "-lc", "echo 'hello world' && ls -la"]
+ assert arguments["timeout"] == 5000
+ assert arguments["env"]["PATH"] == "/usr/bin"
+
+ def test_parse_with_custom_tags(self):
+ """Test parsing with custom start and end tags."""
+ parser = ResponseParser(
+ tool_call_start_tag="",
+ tool_call_end_tag=""
+ )
+ text = """I'll help you with that.
+
+{
+ "name": "search",
+ "arguments": {
+ "query": "python tutorials"
+ }
+}
+
+"""
+
+ result = parser.parse(text)
+
+ assert "I'll help you with that" in result.content
+ assert result.tool_calls is not None
+ assert result.tool_calls[0].function.name == "search"
+
+ def test_parse_streaming_chunks(self):
+ """Test parsing aggregated streaming chunks."""
+ parser = ResponseParser()
+ chunks = [
+ "I'll run that ",
+ "command for you.",
+ f'{TOOL_CALL_START_TAG}\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n{TOOL_CALL_END_TAG}'
+ ]
+
+ result = parser.parse_streaming_chunks(chunks)
+
+ assert "I'll run that command for you" in result.content
+ assert result.tool_calls is not None
+ assert result.tool_calls[0].function.name == "shell"
+
+ def test_parse_native_tool_calls(self):
+ """Test parsing a native OpenAI-format response with tool calls."""
+ parser = ResponseParser()
+ llm_response = {
+ "role": "assistant",
+ "content": "I'll execute that command.",
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "shell",
+ "arguments": '{"command": ["ls", "-l"]}'
+ }
+ }
+ ]
+ }
+
+ result = parser.parse_native_tool_calls(llm_response)
+
+ assert result.content == "I'll execute that command."
+ assert result.tool_calls is not None
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].id == "call_abc123"
+ assert result.tool_calls[0].function.name == "shell"
+
+ def test_parse_native_tool_calls_multiple(self):
+ """Test parsing a response with multiple native tool calls."""
+ parser = ResponseParser()
+ llm_response = {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "shell",
+ "arguments": '{"command": ["pwd"]}'
+ }
+ },
+ {
+ "id": "call_2",
+ "type": "function",
+ "function": {
+ "name": "shell",
+ "arguments": '{"command": ["ls", "-la"]}'
+ }
+ }
+ ]
+ }
+
+ result = parser.parse_native_tool_calls(llm_response)
+
+ assert result.tool_calls is not None
+ assert len(result.tool_calls) == 2
+ assert result.tool_calls[0].id == "call_1"
+ assert result.tool_calls[1].id == "call_2"
+
+ def test_parse_native_tool_calls_falls_back_to_text(self):
+ """Test that native parsing falls back to text parsing when no tool_calls."""
+ parser = ResponseParser()
+ llm_response = {
+ "role": "assistant",
+ "content": "This is a simple text response."
+ }
+
+ result = parser.parse_native_tool_calls(llm_response)
+
+ assert result.content == "This is a simple text response."
+ assert result.tool_calls is None
+
+ def test_generate_unique_tool_call_ids(self):
+ """Test that tool call IDs are unique."""
+ parser = ResponseParser()
+
+ text1 = f'{TOOL_CALL_START_TAG}{{"name": "tool1", "arguments": {{}}}}{TOOL_CALL_END_TAG}'
+ text2 = f'{TOOL_CALL_START_TAG}{{"name": "tool2", "arguments": {{}}}}{TOOL_CALL_END_TAG}'
+
+ result1 = parser.parse(text1)
+ result2 = parser.parse(text2)
+
+ id1 = result1.tool_calls[0].id
+ id2 = result2.tool_calls[0].id
+
+ assert id1 != id2
+ assert id1.startswith("call_tool1_")
+ assert id2.startswith("call_tool2_")
+
+
+class TestConvenienceFunctions:
+ """Test suite for convenience functions."""
+
+ def test_parse_response_default_parser(self):
+ """Test the parse_response convenience function."""
+ text = f'{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "test"}}}}{TOOL_CALL_END_TAG}'
+ result = parse_response(text)
+
+ assert result.tool_calls is not None
+ assert result.tool_calls[0].function.name == "search"
+
+ def test_parse_response_with_custom_tags_function(self):
+ """Test the parse_response_with_custom_tags function."""
+ text = """[CALL]
+{"name": "test", "arguments": {}}
+[/CALL]"""
+ result = parse_response_with_custom_tags(
+ text,
+ start_tag="[CALL]",
+ end_tag="[/CALL]"
+ )
+
+ assert result.tool_calls is not None
+ assert result.tool_calls[0].function.name == "test"
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions."""
+
+ def test_response_with_whitespace(self):
+ """Test parsing responses with various whitespace patterns."""
+ parser = ResponseParser()
+
+ # Leading/trailing whitespace
+ text = " Hello world. "
+ result = parser.parse(text)
+ assert result.content.strip() == "Hello world."
+
+ def test_response_with_newlines_only(self):
+ """Test parsing a response with only newlines."""
+ parser = ResponseParser()
+ result = parser.parse("\n\n\n")
+
+ assert result.content == ""
+ assert result.tool_calls is None
+
+ def test_response_with_special_characters(self):
+ """Test parsing responses with special characters in content."""
+ parser = ResponseParser()
+ special_chars = '@#$%^&*()'
+ text = f'''Here's the result with special chars: {special_chars}
+{TOOL_CALL_START_TAG}
+{{
+ "name": "test",
+ "arguments": {{
+ "special": "!@#$%"
+ }}
+}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+ assert "@" in result.content
+ assert result.tool_calls is not None
+
+ def test_response_with_escaped_quotes(self):
+ """Test parsing tool calls with escaped quotes in arguments."""
+ parser = ResponseParser()
+ text = f'{TOOL_CALL_START_TAG}{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}{TOOL_CALL_END_TAG}'
+
+ result = parser.parse(text)
+ arguments = json.loads(result.tool_calls[0].function.arguments)
+ assert arguments["message"] == 'Hello "world"'
+
+ def test_multiple_tool_calls_in_text_finds_first(self):
+ """Test that only the first tool call is extracted."""
+ parser = ResponseParser()
+ text = f'''First call.
+{TOOL_CALL_START_TAG}
+{{"name": "tool1", "arguments": {{}}}}
+{TOOL_CALL_END_TAG}
+Some text in between.
+{TOOL_CALL_START_TAG}
+{{"name": "tool2", "arguments": {{}}}}
+{TOOL_CALL_END_TAG}
+'''
+
+ result = parser.parse(text)
+
+ # Should only find the first one
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "tool1"
diff --git a/tests/test_services.py b/tests/test_services.py
index e88a3b3..06af6d7 100644
--- a/tests/test_services.py
+++ b/tests/test_services.py
@@ -1,23 +1,15 @@
import pytest
-import httpx
import json
+import httpx
from typing import List, AsyncGenerator
-from app.services import call_llm_api_real
-from app.models import ChatMessage
+from app.services import inject_tools_into_prompt, parse_llm_response_from_content, process_chat_request
+from app.models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction, IncomingRequest
from app.core.config import Settings
+from app.database import get_latest_log_entry
-# Sample SSE chunks to simulate a streaming response
-SSE_STREAM_CHUNKS = [
- 'data: {"choices": [{"delta": {"role": "assistant", "content": "Hello"}}]}',
- 'data: {"choices": [{"delta": {"content": " world!"}}]}',
- 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_123", "function": {"name": "get_weather", "arguments": ""}}]}}]}',
- 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\"location\\":"}}]}}]}',
- 'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\"San Francisco\\"}"}}]}}]}',
- 'data: [DONE]',
-]
+# --- Mocks for simulating httpx responses ---
-# Mock settings for the test
@pytest.fixture
def mock_settings() -> Settings:
"""Provides mock settings for tests."""
@@ -26,71 +18,143 @@ def mock_settings() -> Settings:
REAL_LLM_API_KEY="fake-key"
)
-# Async generator to mock the streaming response
-async def mock_aiter_lines() -> AsyncGenerator[str, None]:
- for chunk in SSE_STREAM_CHUNKS:
- yield chunk
+class MockAsyncClient:
+ """Mocks the httpx.AsyncClient to simulate LLM responses."""
+ def __init__(self, response_chunks: List[str]):
+ self._response_chunks = response_chunks
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def stream(self, method, url, headers, json, timeout):
+ return MockStreamResponse(self._response_chunks)
-# Mock for the httpx.Response object
class MockStreamResponse:
- def __init__(self, status_code: int = 200):
+ """Mocks the httpx.Response object for streaming."""
+ def __init__(self, chunks: List[str], status_code: int = 200):
+ self._chunks = chunks
self._status_code = status_code
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+
def raise_for_status(self):
if self._status_code != 200:
- raise httpx.HTTPStatusError(
- message="Error", request=httpx.Request("POST", ""), response=httpx.Response(self._status_code)
- )
+ raise httpx.HTTPStatusError("Error", request=None, response=httpx.Response(self._status_code))
- def aiter_lines(self) -> AsyncGenerator[str, None]:
- return mock_aiter_lines()
-
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- pass
+ async def aiter_bytes(self) -> AsyncGenerator[bytes, None]:
+ for chunk in self._chunks:
+ yield chunk.encode('utf-8')
-# Mock for the httpx.AsyncClient
-class MockAsyncClient:
- def stream(self, method, url, headers, json, timeout):
- return MockStreamResponse()
+# --- End Mocks ---
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- pass
-@pytest.mark.anyio
-async def test_call_llm_api_real_streaming(monkeypatch, mock_settings):
+def test_inject_tools_into_prompt():
"""
- Tests that `call_llm_api_real` correctly handles an SSE stream,
- parses the chunks, and assembles the final response message.
+ Tests that `inject_tools_into_prompt` correctly adds a system message
+ with tool definitions to the message list.
"""
- # Patch httpx.AsyncClient to use our mock
- monkeypatch.setattr(httpx, "AsyncClient", MockAsyncClient)
+ # 1. Fetch the latest request from the database
+ latest_entry = get_latest_log_entry()
+ assert latest_entry is not None
+ client_request_data = json.loads(latest_entry["client_request"])
- messages = [ChatMessage(role="user", content="What is the weather in San Francisco?")]
+ # 2. Parse the data into Pydantic models
+ incoming_request = IncomingRequest.model_validate(client_request_data)
+
+ # 3. Call the function to be tested
+ modified_messages = inject_tools_into_prompt(incoming_request.messages, incoming_request.tools)
+
+ # 4. Assert the results
+ assert len(modified_messages) == len(incoming_request.messages) + 1
+
+ # Check that the first message is the new system prompt
+ system_prompt = modified_messages[0]
+ assert system_prompt.role == "system"
+ assert "You are a helpful assistant with access to a set of tools." in system_prompt.content
+
+ # Check that the tool definitions are in the system prompt
+ for tool in incoming_request.tools:
+ assert tool.function.name in system_prompt.content
+
+def test_parse_llm_response_from_content():
+ """
+ Tests that `parse_llm_response_from_content` correctly parses a raw LLM
+ text response containing a { and extracts the `ResponseMessage`.
+ """
+ # Sample raw text from an LLM
+ # Note: Since tags are { and }, we use double braces {{...}} where
+ # the outer { and } are tags, and the inner { and } are JSON
+ llm_text = """
+Some text from the model.
+{{
+ "name": "shell",
+ "arguments": {
+ "command": ["echo", "Hello from the tool!"]
+ }
+}}
+"""
# Call the function
- result = await call_llm_api_real(messages, mock_settings)
+ response_message = parse_llm_response_from_content(llm_text)
- # Define the expected assembled result
- expected_result = {
- "role": "assistant",
- "content": "Hello world!",
- "tool_calls": [
- {
- "id": "call_123",
- "type": "function",
- "function": {
- "name": "get_weather",
- "arguments": '{"location": "San Francisco"}',
- },
- }
- ],
- }
+ # Assertions
+ assert response_message.content == "Some text from the model."
+ assert response_message.tool_calls is not None
+ assert len(response_message.tool_calls) == 1
- # Assert that the result matches the expected output
- assert result == expected_result
+ tool_call = response_message.tool_calls[0]
+ assert isinstance(tool_call, ToolCall)
+ assert tool_call.function.name == "shell"
+
+ # The arguments are a JSON string, so we parse it for detailed checking
+ arguments = json.loads(tool_call.function.arguments)
+ assert arguments["command"] == ["echo", "Hello from the tool!"]
+
+
+@pytest.mark.anyio
+async def test_process_chat_request_with_tool_call(monkeypatch, mock_settings):
+ """
+ Tests that `process_chat_request` can correctly parse a tool call from a
+ simulated real LLM streaming response.
+ """
+ # 1. Define the simulated SSE stream from the LLM
+ # Using double braces for tool call tags
+ sse_chunks = [
+ 'data: {"choices": [{"delta": {"content": "Okay, I will run that shell command."}}], "object": "chat.completion.chunk"}\n\n',
+ 'data: {"choices": [{"delta": {"content": "{{\\n \\"name\\": \\"shell\\",\\n \\"arguments\\": {\\n \\"command\\": [\\"ls\\", \\"-l\\"]\\n }\\n}}\\n"}}], "object": "chat.completion.chunk"}\n\n',
+ 'data: [DONE]\n\n'
+ ]
+
+ # 2. Mock the httpx.AsyncClient
+ def mock_async_client(*args, **kwargs):
+ return MockAsyncClient(response_chunks=sse_chunks)
+
+ monkeypatch.setattr(httpx, "AsyncClient", mock_async_client)
+
+ # 3. Prepare the input for process_chat_request
+ messages = [ChatMessage(role="user", content="List the files.")]
+ tools = [Tool(type="function", function={"name": "shell", "description": "Run a shell command.", "parameters": {}})]
+ log_id = 1 # Dummy log ID for the test
+
+ # 4. Call the function
+ request_messages = inject_tools_into_prompt(messages, tools)
+ response_message = await process_chat_request(request_messages, mock_settings, log_id)
+
+ # 5. Assert the response is parsed correctly
+ assert response_message.content is not None
+ assert response_message.content.strip() == "Okay, I will run that shell command."
+ assert response_message.tool_calls is not None
+ assert len(response_message.tool_calls) == 1
+
+ tool_call = response_message.tool_calls[0]
+ assert tool_call.function.name == "shell"
+
+ arguments = json.loads(tool_call.function.arguments)
+ assert arguments["command"] == ["ls", "-l"]