259 lines
10 KiB
Python
259 lines
10 KiB
Python
import json
|
|
import re
|
|
import httpx
|
|
import logging
|
|
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
|
|
|
|
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
|
|
from .core.config import Settings
|
|
from .database import update_request_log
|
|
from .response_parser import ResponseParser, parse_response
|
|
|
|
# Get a logger instance for this module
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Helper for parsing SSE ---
|
|
# Regex to extract data field from SSE
|
|
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
|
|
|
|
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
|
|
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
|
|
try:
|
|
lines = chunk.decode("utf-8").splitlines()
|
|
for line in lines:
|
|
if line.startswith("data:"):
|
|
match = SSE_DATA_RE.match(line)
|
|
if match:
|
|
data_str = match.group(1).strip()
|
|
if data_str == "[DONE]": # Handle OpenAI-style stream termination
|
|
return {"type": "done"}
|
|
try:
|
|
return json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
|
|
return None
|
|
except UnicodeDecodeError:
|
|
logger.warning("Failed to decode chunk as UTF-8.")
|
|
return None
|
|
|
|
# --- End Helper ---
|
|
|
|
|
|
def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|
"""
|
|
Converts assistant messages with tool_calls into content format using XML tags.
|
|
|
|
This function processes the message history and converts any assistant messages
|
|
that have tool_calls into a format that LLMs can understand. The tool_calls
|
|
are converted to <invoke>...</invoke> tags in the content field.
|
|
|
|
Args:
|
|
messages: List of ChatMessage objects from the client
|
|
|
|
Returns:
|
|
Processed list of ChatMessage objects with tool_calls converted to content
|
|
|
|
Example:
|
|
Input: [{"role": "assistant", "tool_calls": [...]}]
|
|
Output: [{"role": "assistant", "content": "<invoke>{...}</invoke>"}]
|
|
"""
|
|
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
|
|
|
|
processed_messages = []
|
|
|
|
for msg in messages:
|
|
# Check if this is an assistant message with tool_calls
|
|
if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0:
|
|
# Convert each tool call to XML tag format
|
|
tool_call_contents = []
|
|
for tc in msg.tool_calls:
|
|
tc_data = tc.get("function", {})
|
|
name = tc_data.get("name", "")
|
|
arguments_str = tc_data.get("arguments", "{}")
|
|
|
|
# Parse arguments JSON to ensure it's valid
|
|
try:
|
|
arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str
|
|
except json.JSONDecodeError:
|
|
arguments = {}
|
|
|
|
# Build the tool call JSON
|
|
tool_call_json = {"name": name, "arguments": arguments}
|
|
|
|
# Wrap in XML tags
|
|
tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}'
|
|
tool_call_contents.append(tool_call_content)
|
|
|
|
# Create new message with tool calls in content
|
|
# Preserve original content if it exists
|
|
content_parts = []
|
|
if msg.content:
|
|
content_parts.append(msg.content)
|
|
|
|
content_parts.extend(tool_call_contents)
|
|
new_content = "\n".join(content_parts)
|
|
|
|
processed_messages.append(
|
|
ChatMessage(role=msg.role, content=new_content)
|
|
)
|
|
else:
|
|
# Keep other messages as-is
|
|
processed_messages.append(msg)
|
|
|
|
return processed_messages
|
|
|
|
|
|
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
|
|
"""
|
|
Injects a system prompt with tool definitions at the beginning of the message list.
|
|
"""
|
|
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
|
|
|
|
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
|
|
|
|
# Build the format example
|
|
json_example = '{"name": "search", "arguments": {"query": "example"}}'
|
|
full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}'
|
|
|
|
tool_prompt = f"""
|
|
You are a helpful assistant with access to a set of tools.
|
|
|
|
## TOOL CALL FORMAT (CRITICAL)
|
|
|
|
When you need to use a tool, you MUST follow this EXACT format:
|
|
|
|
{TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
|
|
|
|
### IMPORTANT RULES:
|
|
1. ALWAYS include BOTH the opening tag ({TOOL_CALL_START_TAG}) AND closing tag ({TOOL_CALL_END_TAG})
|
|
2. The JSON must be valid and properly formatted
|
|
3. Keep arguments concise to avoid truncation
|
|
4. Do not include any text between the tags except the JSON
|
|
|
|
### Examples:
|
|
Simple call:
|
|
{full_example}
|
|
|
|
Multiple arguments:
|
|
{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "example", "limit": 5}}}}{TOOL_CALL_END_TAG}
|
|
|
|
## AVAILABLE TOOLS:
|
|
{tool_defs}
|
|
|
|
## REMEMBER:
|
|
- If you decide to call a tool, output ONLY the tool call tags (you may add brief text before or after)
|
|
- ALWAYS close your tags properly with {TOOL_CALL_END_TAG}
|
|
- Keep your arguments concise and essential
|
|
"""
|
|
# Prepend the system prompt with tool definitions
|
|
return [ChatMessage(role="system", content=tool_prompt)] + messages
|
|
|
|
|
|
def parse_llm_response_from_content(text: str) -> ResponseMessage:
|
|
"""
|
|
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
|
|
This is used when the LLM does not support native tool calling.
|
|
|
|
This function now delegates to the ResponseParser class for better maintainability.
|
|
"""
|
|
parser = ResponseParser()
|
|
return parser.parse(text)
|
|
|
|
|
|
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
|
|
"""
|
|
Makes the raw HTTP streaming call to the LLM backend.
|
|
Yields raw byte chunks as received.
|
|
"""
|
|
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
|
|
payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
|
|
|
# Log the request payload to the database
|
|
update_request_log(log_id, llm_request=payload)
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}")
|
|
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
|
|
response.raise_for_status()
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
except httpx.HTTPStatusError as e:
|
|
error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'"
|
|
logger.error(f"{error_message} for log ID {log_id}")
|
|
update_request_log(log_id, llm_response={"error": error_message})
|
|
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
|
|
except httpx.RequestError as e:
|
|
error_message = f"An error occurred during raw stream request to LLM API: {e}"
|
|
logger.error(f"{error_message} for log ID {log_id}")
|
|
update_request_log(log_id, llm_response={"error": error_message})
|
|
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
|
|
|
|
|
|
async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
|
|
"""
|
|
Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks.
|
|
"""
|
|
llm_response_chunks = []
|
|
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
|
|
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
|
|
try:
|
|
logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}")
|
|
except UnicodeDecodeError:
|
|
logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}")
|
|
yield chunk
|
|
|
|
# Log the full LLM response after the stream is complete
|
|
update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)})
|
|
|
|
|
|
async def process_llm_stream_for_non_stream_request(
|
|
messages: List[ChatMessage],
|
|
settings: Settings,
|
|
log_id: int
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Aggregates a streaming LLM response into a single, non-streaming message.
|
|
Handles SSE parsing, delta accumulation, and logs the final aggregated message.
|
|
"""
|
|
full_content_parts = []
|
|
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
|
|
llm_response_chunks = []
|
|
|
|
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
|
|
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
|
|
parsed_data = _parse_sse_data(chunk)
|
|
if parsed_data:
|
|
if parsed_data.get("type") == "done":
|
|
break
|
|
|
|
choices = parsed_data.get("choices")
|
|
if choices and len(choices) > 0:
|
|
delta = choices[0].get("delta")
|
|
if delta and "content" in delta:
|
|
full_content_parts.append(delta["content"])
|
|
|
|
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
|
|
|
|
# Log the aggregated LLM response
|
|
logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}")
|
|
update_request_log(log_id, llm_response=final_message_dict)
|
|
|
|
return final_message_dict
|
|
|
|
|
|
async def process_chat_request(
|
|
messages: List[ChatMessage],
|
|
settings: Settings,
|
|
log_id: int
|
|
) -> ResponseMessage:
|
|
"""
|
|
Main service function for non-streaming requests.
|
|
It calls the stream aggregation logic and then parses the result.
|
|
"""
|
|
llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id)
|
|
|
|
# Use the ResponseParser to handle both native and text-based tool calls
|
|
parser = ResponseParser()
|
|
return parser.parse_native_tool_calls(llm_message_dict)
|