Files
llmproxy/app/services.py
2025-12-31 06:35:08 +00:00

198 lines
8.7 KiB
Python

import json
import re
import httpx
import logging
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
# Get a logger instance for this module
logger = logging.getLogger(__name__)
# --- Helper for parsing SSE ---
# Regex to extract data field from SSE
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
try:
lines = chunk.decode("utf-8").splitlines()
for line in lines:
if line.startswith("data:"):
match = SSE_DATA_RE.match(line)
if match:
data_str = match.group(1).strip()
if data_str == "[DONE]": # Handle OpenAI-style stream termination
return {"type": "done"}
try:
return json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
return None
except UnicodeDecodeError:
logger.warning("Failed to decode chunk as UTF-8.")
return None
# --- End Helper ---
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects tool definitions into the message list as a system prompt.
"""
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
tool_prompt = f"""
You have access to a set of tools. You can call them by emitting a JSON object inside a <tool_call> XML tag.
The JSON object should have a "name" and "arguments" field.
Here are the available tools:
{tool_defs}
Only use the tools if strictly necessary.
"""
new_messages = messages.copy()
new_messages.insert(1, ChatMessage(role="system", content=tool_prompt))
return new_messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
"""
if not text:
return ResponseMessage(content=None)
tool_call_match = re.search(r"<tool_call>(.*?)</tool_call>", text, re.DOTALL)
if tool_call_match:
tool_call_str = tool_call_match.group(1).strip()
try:
tool_call_data = json.loads(tool_call_str)
tool_call = ToolCall(
id="call_" + tool_call_data.get("name", "unknown"),
function=ToolCallFunction(
name=tool_call_data.get("name"),
arguments=json.dumps(tool_call_data.get("arguments", {})),
)
)
content_before = text.split("<tool_call>")[0].strip()
return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call])
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}")
return ResponseMessage(content=text)
else:
return ResponseMessage(content=text)
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
try:
async with httpx.AsyncClient() as client:
logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'")
# For streams, we log and let the stream terminate. The client will get a broken stream.
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
logger.error(f"An error occurred during raw stream request to LLM API: {e}")
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
"""
Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks.
"""
async for chunk in _raw_stream_from_llm(messages, settings):
# We assume the raw chunks are already SSE formatted or can be split into lines.
# For simplicity, we pass through the raw chunk bytes.
# A more robust parser would ensure each yield is a complete SSE event line.
yield chunk
async def process_llm_stream_for_non_stream_request(
messages: List[ChatMessage],
settings: Settings
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
Handles SSE parsing and delta accumulation.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
async for chunk in _raw_stream_from_llm(messages, settings):
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
break # End of stream
# Assuming OpenAI-like streaming format
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta:
if "content" in delta:
full_content_parts.append(delta["content"])
if "tool_calls" in delta:
# Accumulate tool calls if they appear in deltas (complex)
# For simplicity, we'll try to reconstruct the final tool_calls
# from the final message, or fall back to content parsing later.
# This part is highly dependent on LLM's exact streaming format for tool_calls.
pass
if choices[0].get("finish_reason"):
# Check for finish_reason to identify stream end or tool_calls completion
pass
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
# This is a simplification. Reconstructing tool_calls from deltas is non-trivial.
# We will rely on parse_llm_response_from_content for tool calls if they are
# embedded in the final content string, or assume the LLM doesn't send native
# tool_calls in stream deltas that need aggregation here.
logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}")
return final_message_dict
async def process_chat_request(
messages: List[ChatMessage],
tools: Optional[List[Tool]],
settings: Settings,
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
It now calls the stream aggregation logic.
"""
request_messages = messages
if tools:
request_messages = inject_tools_into_prompt(messages, tools)
# All interactions with the real LLM now go through the streaming mechanism.
llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings)
# Priority 1: Check for native tool calls (if the aggregation could reconstruct them)
# Note: Reconstructing tool_calls from deltas in streaming is complex.
# For now, we assume if tool_calls are present, they are complete.
if llm_message_dict.get("tool_calls"):
logger.info("Native tool calls detected in aggregated LLM response.")
# Ensure it's a list of dicts suitable for Pydantic validation
if isinstance(llm_message_dict["tool_calls"], list):
return ResponseMessage.model_validate(llm_message_dict)
else:
logger.warning("Aggregated tool_calls not in expected list format. Treating as content.")
# Priority 2 (Fallback): Parse tool calls from content
logger.info("No native tool calls from aggregation. Falling back to content parsing.")
return parse_llm_response_from_content(llm_message_dict.get("content"))