import json import re import httpx import logging from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction from .core.config import Settings # Get a logger instance for this module logger = logging.getLogger(__name__) # --- Helper for parsing SSE --- # Regex to extract data field from SSE SSE_DATA_RE = re.compile(r"data:\s*(.*)") def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]: """Parses a chunk of bytes as SSE and extracts the JSON data.""" try: lines = chunk.decode("utf-8").splitlines() for line in lines: if line.startswith("data:"): match = SSE_DATA_RE.match(line) if match: data_str = match.group(1).strip() if data_str == "[DONE]": # Handle OpenAI-style stream termination return {"type": "done"} try: return json.loads(data_str) except json.JSONDecodeError: logger.warning(f"Failed to decode JSON from SSE data: {data_str}") return None except UnicodeDecodeError: logger.warning("Failed to decode chunk as UTF-8.") return None # --- End Helper --- def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]: """ Injects tool definitions into the message list as a system prompt. """ tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2) tool_prompt = f""" You have access to a set of tools. You can call them by emitting a JSON object inside a XML tag. The JSON object should have a "name" and "arguments" field. Here are the available tools: {tool_defs} Only use the tools if strictly necessary. """ new_messages = messages.copy() new_messages.insert(1, ChatMessage(role="system", content=tool_prompt)) return new_messages def parse_llm_response_from_content(text: str) -> ResponseMessage: """ (Fallback) Parses the raw LLM text response to extract a message and any tool calls. This is used when the LLM does not support native tool calling. """ if not text: return ResponseMessage(content=None) tool_call_match = re.search(r"(.*?)", text, re.DOTALL) if tool_call_match: tool_call_str = tool_call_match.group(1).strip() try: tool_call_data = json.loads(tool_call_str) tool_call = ToolCall( id="call_" + tool_call_data.get("name", "unknown"), function=ToolCallFunction( name=tool_call_data.get("name"), arguments=json.dumps(tool_call_data.get("arguments", {})), ) ) content_before = text.split("")[0].strip() return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call]) except json.JSONDecodeError as e: logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}") return ResponseMessage(content=text) else: return ResponseMessage(content=text) async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]: """ Makes the raw HTTP streaming call to the LLM backend. Yields raw byte chunks as received. """ headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" } payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True } try: async with httpx.AsyncClient() as client: logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}") async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk except httpx.HTTPStatusError as e: logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'") # For streams, we log and let the stream terminate. The client will get a broken stream. yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n' except httpx.RequestError as e: logger.error(f"An error occurred during raw stream request to LLM API: {e}") yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n' async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]: """ Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks. """ async for chunk in _raw_stream_from_llm(messages, settings): # We assume the raw chunks are already SSE formatted or can be split into lines. # For simplicity, we pass through the raw chunk bytes. # A more robust parser would ensure each yield is a complete SSE event line. yield chunk async def process_llm_stream_for_non_stream_request( messages: List[ChatMessage], settings: Settings ) -> Dict[str, Any]: """ Aggregates a streaming LLM response into a single, non-streaming message. Handles SSE parsing and delta accumulation. """ full_content_parts = [] final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None} async for chunk in _raw_stream_from_llm(messages, settings): parsed_data = _parse_sse_data(chunk) if parsed_data: if parsed_data.get("type") == "done": break # End of stream # Assuming OpenAI-like streaming format choices = parsed_data.get("choices") if choices and len(choices) > 0: delta = choices[0].get("delta") if delta: if "content" in delta: full_content_parts.append(delta["content"]) if "tool_calls" in delta: # Accumulate tool calls if they appear in deltas (complex) # For simplicity, we'll try to reconstruct the final tool_calls # from the final message, or fall back to content parsing later. # This part is highly dependent on LLM's exact streaming format for tool_calls. pass if choices[0].get("finish_reason"): # Check for finish_reason to identify stream end or tool_calls completion pass final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None # This is a simplification. Reconstructing tool_calls from deltas is non-trivial. # We will rely on parse_llm_response_from_content for tool calls if they are # embedded in the final content string, or assume the LLM doesn't send native # tool_calls in stream deltas that need aggregation here. logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}") return final_message_dict async def process_chat_request( messages: List[ChatMessage], tools: Optional[List[Tool]], settings: Settings, ) -> ResponseMessage: """ Main service function for non-streaming requests. It now calls the stream aggregation logic. """ request_messages = messages if tools: request_messages = inject_tools_into_prompt(messages, tools) # All interactions with the real LLM now go through the streaming mechanism. llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings) # Priority 1: Check for native tool calls (if the aggregation could reconstruct them) # Note: Reconstructing tool_calls from deltas in streaming is complex. # For now, we assume if tool_calls are present, they are complete. if llm_message_dict.get("tool_calls"): logger.info("Native tool calls detected in aggregated LLM response.") # Ensure it's a list of dicts suitable for Pydantic validation if isinstance(llm_message_dict["tool_calls"], list): return ResponseMessage.model_validate(llm_message_dict) else: logger.warning("Aggregated tool_calls not in expected list format. Treating as content.") # Priority 2 (Fallback): Parse tool calls from content logger.info("No native tool calls from aggregation. Falling back to content parsing.") return parse_llm_response_from_content(llm_message_dict.get("content"))