import json import re import httpx import logging from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction from .core.config import Settings from .database import update_request_log from .response_parser import ResponseParser, parse_response # Get a logger instance for this module logger = logging.getLogger(__name__) # --- Helper for parsing SSE --- # Regex to extract data field from SSE SSE_DATA_RE = re.compile(r"data:\s*(.*)") def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]: """Parses a chunk of bytes as SSE and extracts the JSON data.""" try: lines = chunk.decode("utf-8").splitlines() for line in lines: if line.startswith("data:"): match = SSE_DATA_RE.match(line) if match: data_str = match.group(1).strip() if data_str == "[DONE]": # Handle OpenAI-style stream termination return {"type": "done"} try: return json.loads(data_str) except json.JSONDecodeError: logger.warning(f"Failed to decode JSON from SSE data: {data_str}") return None except UnicodeDecodeError: logger.warning("Failed to decode chunk as UTF-8.") return None # --- End Helper --- def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]: """ Converts assistant messages with tool_calls into content format using XML tags. This function processes the message history and converts any assistant messages that have tool_calls into a format that LLMs can understand. The tool_calls are converted to ... tags in the content field. Args: messages: List of ChatMessage objects from the client Returns: Processed list of ChatMessage objects with tool_calls converted to content Example: Input: [{"role": "assistant", "tool_calls": [...]}] Output: [{"role": "assistant", "content": "{...}"}] """ from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG processed_messages = [] for msg in messages: # Check if this is an assistant message with tool_calls if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0: # Convert each tool call to XML tag format tool_call_contents = [] for tc in msg.tool_calls: tc_data = tc.get("function", {}) name = tc_data.get("name", "") arguments_str = tc_data.get("arguments", "{}") # Parse arguments JSON to ensure it's valid try: arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str except json.JSONDecodeError: arguments = {} # Build the tool call JSON tool_call_json = {"name": name, "arguments": arguments} # Wrap in XML tags tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}' tool_call_contents.append(tool_call_content) # Create new message with tool calls in content # Preserve original content if it exists content_parts = [] if msg.content: content_parts.append(msg.content) content_parts.extend(tool_call_contents) new_content = "\n".join(content_parts) processed_messages.append( ChatMessage(role=msg.role, content=new_content) ) else: # Keep other messages as-is processed_messages.append(msg) return processed_messages def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]: """ Injects a system prompt with tool definitions at the beginning of the message list. """ from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2) # Build the format example json_example = '{"name": "search", "arguments": {"query": "example"}}' full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}' tool_prompt = f""" You are a helpful assistant with access to a set of tools. ## TOOL CALL FORMAT (CRITICAL) When you need to use a tool, you MUST follow this EXACT format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG} ### IMPORTANT RULES: 1. ALWAYS include BOTH the opening tag ({TOOL_CALL_START_TAG}) AND closing tag ({TOOL_CALL_END_TAG}) 2. The JSON must be valid and properly formatted 3. Keep arguments concise to avoid truncation 4. Do not include any text between the tags except the JSON ### Examples: Simple call: {full_example} Multiple arguments: {TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "example", "limit": 5}}}}{TOOL_CALL_END_TAG} ## AVAILABLE TOOLS: {tool_defs} ## REMEMBER: - If you decide to call a tool, output ONLY the tool call tags (you may add brief text before or after) - ALWAYS close your tags properly with {TOOL_CALL_END_TAG} - Keep your arguments concise and essential """ # Prepend the system prompt with tool definitions return [ChatMessage(role="system", content=tool_prompt)] + messages def parse_llm_response_from_content(text: str) -> ResponseMessage: """ (Fallback) Parses the raw LLM text response to extract a message and any tool calls. This is used when the LLM does not support native tool calling. This function now delegates to the ResponseParser class for better maintainability. """ parser = ResponseParser() return parser.parse(text) async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]: """ Makes the raw HTTP streaming call to the LLM backend. Yields raw byte chunks as received. """ headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" } payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True } # Log the request payload to the database update_request_log(log_id, llm_request=payload) try: async with httpx.AsyncClient() as client: logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}") async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk except httpx.HTTPStatusError as e: error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'" logger.error(f"{error_message} for log ID {log_id}") update_request_log(log_id, llm_response={"error": error_message}) yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n' except httpx.RequestError as e: error_message = f"An error occurred during raw stream request to LLM API: {e}" logger.error(f"{error_message} for log ID {log_id}") update_request_log(log_id, llm_response={"error": error_message}) yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n' async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]: """ Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks. """ llm_response_chunks = [] async for chunk in _raw_stream_from_llm(messages, settings, log_id): llm_response_chunks.append(chunk.decode('utf-8', errors='ignore')) try: logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}") except UnicodeDecodeError: logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}") yield chunk # Log the full LLM response after the stream is complete update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)}) async def process_llm_stream_for_non_stream_request( messages: List[ChatMessage], settings: Settings, log_id: int ) -> Dict[str, Any]: """ Aggregates a streaming LLM response into a single, non-streaming message. Handles SSE parsing, delta accumulation, and logs the final aggregated message. """ full_content_parts = [] final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None} llm_response_chunks = [] async for chunk in _raw_stream_from_llm(messages, settings, log_id): llm_response_chunks.append(chunk.decode('utf-8', errors='ignore')) parsed_data = _parse_sse_data(chunk) if parsed_data: if parsed_data.get("type") == "done": break choices = parsed_data.get("choices") if choices and len(choices) > 0: delta = choices[0].get("delta") if delta and "content" in delta: full_content_parts.append(delta["content"]) final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None # Log the aggregated LLM response logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}") update_request_log(log_id, llm_response=final_message_dict) return final_message_dict async def process_chat_request( messages: List[ChatMessage], settings: Settings, log_id: int ) -> ResponseMessage: """ Main service function for non-streaming requests. It calls the stream aggregation logic and then parses the result. """ llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id) # Use the ResponseParser to handle both native and text-based tool calls parser = ResponseParser() return parser.parse_native_tool_calls(llm_message_dict)