Files
llmproxy/app/services.py
2026-01-12 14:12:15 +08:00

259 lines
10 KiB
Python

import json
import re
import httpx
import logging
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
from .database import update_request_log
from .response_parser import ResponseParser, parse_response
# Get a logger instance for this module
logger = logging.getLogger(__name__)
# --- Helper for parsing SSE ---
# Regex to extract data field from SSE
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
try:
lines = chunk.decode("utf-8").splitlines()
for line in lines:
if line.startswith("data:"):
match = SSE_DATA_RE.match(line)
if match:
data_str = match.group(1).strip()
if data_str == "[DONE]": # Handle OpenAI-style stream termination
return {"type": "done"}
try:
return json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
return None
except UnicodeDecodeError:
logger.warning("Failed to decode chunk as UTF-8.")
return None
# --- End Helper ---
def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]:
"""
Converts assistant messages with tool_calls into content format using XML tags.
This function processes the message history and converts any assistant messages
that have tool_calls into a format that LLMs can understand. The tool_calls
are converted to <invoke>...</invoke> tags in the content field.
Args:
messages: List of ChatMessage objects from the client
Returns:
Processed list of ChatMessage objects with tool_calls converted to content
Example:
Input: [{"role": "assistant", "tool_calls": [...]}]
Output: [{"role": "assistant", "content": "<invoke>{...}</invoke>"}]
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
processed_messages = []
for msg in messages:
# Check if this is an assistant message with tool_calls
if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0:
# Convert each tool call to XML tag format
tool_call_contents = []
for tc in msg.tool_calls:
tc_data = tc.get("function", {})
name = tc_data.get("name", "")
arguments_str = tc_data.get("arguments", "{}")
# Parse arguments JSON to ensure it's valid
try:
arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str
except json.JSONDecodeError:
arguments = {}
# Build the tool call JSON
tool_call_json = {"name": name, "arguments": arguments}
# Wrap in XML tags
tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}'
tool_call_contents.append(tool_call_content)
# Create new message with tool calls in content
# Preserve original content if it exists
content_parts = []
if msg.content:
content_parts.append(msg.content)
content_parts.extend(tool_call_contents)
new_content = "\n".join(content_parts)
processed_messages.append(
ChatMessage(role=msg.role, content=new_content)
)
else:
# Keep other messages as-is
processed_messages.append(msg)
return processed_messages
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects a system prompt with tool definitions at the beginning of the message list.
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
# Build the format example
json_example = '{"name": "search", "arguments": {"query": "example"}}'
full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}'
tool_prompt = f"""
You are a helpful assistant with access to a set of tools.
## TOOL CALL FORMAT (CRITICAL)
When you need to use a tool, you MUST follow this EXACT format:
{TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
### IMPORTANT RULES:
1. ALWAYS include BOTH the opening tag ({TOOL_CALL_START_TAG}) AND closing tag ({TOOL_CALL_END_TAG})
2. The JSON must be valid and properly formatted
3. Keep arguments concise to avoid truncation
4. Do not include any text between the tags except the JSON
### Examples:
Simple call:
{full_example}
Multiple arguments:
{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "example", "limit": 5}}}}{TOOL_CALL_END_TAG}
## AVAILABLE TOOLS:
{tool_defs}
## REMEMBER:
- If you decide to call a tool, output ONLY the tool call tags (you may add brief text before or after)
- ALWAYS close your tags properly with {TOOL_CALL_END_TAG}
- Keep your arguments concise and essential
"""
# Prepend the system prompt with tool definitions
return [ChatMessage(role="system", content=tool_prompt)] + messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
This function now delegates to the ResponseParser class for better maintainability.
"""
parser = ResponseParser()
return parser.parse(text)
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True }
# Log the request payload to the database
update_request_log(log_id, llm_request=payload)
try:
async with httpx.AsyncClient() as client:
logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
error_message = f"An error occurred during raw stream request to LLM API: {e}"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks.
"""
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
try:
logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}")
except UnicodeDecodeError:
logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}")
yield chunk
# Log the full LLM response after the stream is complete
update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)})
async def process_llm_stream_for_non_stream_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
Handles SSE parsing, delta accumulation, and logs the final aggregated message.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
break
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta and "content" in delta:
full_content_parts.append(delta["content"])
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
# Log the aggregated LLM response
logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}")
update_request_log(log_id, llm_response=final_message_dict)
return final_message_dict
async def process_chat_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
It calls the stream aggregation logic and then parses the result.
"""
llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id)
# Use the ResponseParser to handle both native and text-based tool calls
parser = ResponseParser()
return parser.parse_native_tool_calls(llm_message_dict)