feat: Initial commit of LLM Tool Proxy

This commit is contained in:
Vertex-AI-Step-Builder
2025-12-31 06:35:08 +00:00
commit 0d14c98cf4
11 changed files with 775 additions and 0 deletions

198
app/services.py Normal file
View File

@@ -0,0 +1,198 @@
import json
import re
import httpx
import logging
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
# Get a logger instance for this module
logger = logging.getLogger(__name__)
# --- Helper for parsing SSE ---
# Regex to extract data field from SSE
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
try:
lines = chunk.decode("utf-8").splitlines()
for line in lines:
if line.startswith("data:"):
match = SSE_DATA_RE.match(line)
if match:
data_str = match.group(1).strip()
if data_str == "[DONE]": # Handle OpenAI-style stream termination
return {"type": "done"}
try:
return json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
return None
except UnicodeDecodeError:
logger.warning("Failed to decode chunk as UTF-8.")
return None
# --- End Helper ---
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects tool definitions into the message list as a system prompt.
"""
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
tool_prompt = f"""
You have access to a set of tools. You can call them by emitting a JSON object inside a <tool_call> XML tag.
The JSON object should have a "name" and "arguments" field.
Here are the available tools:
{tool_defs}
Only use the tools if strictly necessary.
"""
new_messages = messages.copy()
new_messages.insert(1, ChatMessage(role="system", content=tool_prompt))
return new_messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
"""
if not text:
return ResponseMessage(content=None)
tool_call_match = re.search(r"<tool_call>(.*?)</tool_call>", text, re.DOTALL)
if tool_call_match:
tool_call_str = tool_call_match.group(1).strip()
try:
tool_call_data = json.loads(tool_call_str)
tool_call = ToolCall(
id="call_" + tool_call_data.get("name", "unknown"),
function=ToolCallFunction(
name=tool_call_data.get("name"),
arguments=json.dumps(tool_call_data.get("arguments", {})),
)
)
content_before = text.split("<tool_call>")[0].strip()
return ResponseMessage(content=content_before if content_before else None, tool_calls=[tool_call])
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse tool call JSON from content: {tool_call_str}. Error: {e}")
return ResponseMessage(content=text)
else:
return ResponseMessage(content=text)
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
try:
async with httpx.AsyncClient() as client:
logger.info(f"Initiating raw stream to LLM API at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
logger.error(f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'")
# For streams, we log and let the stream terminate. The client will get a broken stream.
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
logger.error(f"An error occurred during raw stream request to LLM API: {e}")
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
async def stream_llm_api(messages: List[ChatMessage], settings: Settings) -> AsyncGenerator[bytes, None]:
"""
Public interface for streaming. Calls the raw stream, parses SSE, and yields SSE data chunks.
"""
async for chunk in _raw_stream_from_llm(messages, settings):
# We assume the raw chunks are already SSE formatted or can be split into lines.
# For simplicity, we pass through the raw chunk bytes.
# A more robust parser would ensure each yield is a complete SSE event line.
yield chunk
async def process_llm_stream_for_non_stream_request(
messages: List[ChatMessage],
settings: Settings
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
Handles SSE parsing and delta accumulation.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
async for chunk in _raw_stream_from_llm(messages, settings):
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
break # End of stream
# Assuming OpenAI-like streaming format
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta:
if "content" in delta:
full_content_parts.append(delta["content"])
if "tool_calls" in delta:
# Accumulate tool calls if they appear in deltas (complex)
# For simplicity, we'll try to reconstruct the final tool_calls
# from the final message, or fall back to content parsing later.
# This part is highly dependent on LLM's exact streaming format for tool_calls.
pass
if choices[0].get("finish_reason"):
# Check for finish_reason to identify stream end or tool_calls completion
pass
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
# This is a simplification. Reconstructing tool_calls from deltas is non-trivial.
# We will rely on parse_llm_response_from_content for tool calls if they are
# embedded in the final content string, or assume the LLM doesn't send native
# tool_calls in stream deltas that need aggregation here.
logger.info(f"Aggregated non-streaming response content: {final_message_dict.get('content')}")
return final_message_dict
async def process_chat_request(
messages: List[ChatMessage],
tools: Optional[List[Tool]],
settings: Settings,
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
It now calls the stream aggregation logic.
"""
request_messages = messages
if tools:
request_messages = inject_tools_into_prompt(messages, tools)
# All interactions with the real LLM now go through the streaming mechanism.
llm_message_dict = await process_llm_stream_for_non_stream_request(request_messages, settings)
# Priority 1: Check for native tool calls (if the aggregation could reconstruct them)
# Note: Reconstructing tool_calls from deltas in streaming is complex.
# For now, we assume if tool_calls are present, they are complete.
if llm_message_dict.get("tool_calls"):
logger.info("Native tool calls detected in aggregated LLM response.")
# Ensure it's a list of dicts suitable for Pydantic validation
if isinstance(llm_message_dict["tool_calls"], list):
return ResponseMessage.model_validate(llm_message_dict)
else:
logger.warning("Aggregated tool_calls not in expected list format. Treating as content.")
# Priority 2 (Fallback): Parse tool calls from content
logger.info("No native tool calls from aggregation. Falling back to content parsing.")
return parse_llm_response_from_content(llm_message_dict.get("content"))