Files
llmproxy/app/services.py
Vertex-AI-Step-Builder 5c2904e010 feat: 增强工具调用代理功能,支持多工具调用和消息历史转换
主要改进:
- 新增 convert_tool_calls_to_content 函数,将消息历史中的 tool_calls 转换为 LLM 可理解的 XML 格式
- 修复 response_parser 支持同时解析多个 tool_calls
- 优化响应解析逻辑,支持 content 和 tool_calls 同时存在
- 添加完整的测试覆盖,包括多工具调用、消息转换和混合响应

技术细节:
- services.py: 实现工具调用历史到 content 的转换
- response_parser.py: 使用非贪婪匹配支持多个 tool_calls 解析
- main.py: 集成消息转换功能,确保消息历史正确传递给 LLM

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 13:33:25 +00:00

243 lines
9.6 KiB
Python

import json
import re
import httpx
import logging
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
from .database import update_request_log
from .response_parser import ResponseParser, parse_response
# Get a logger instance for this module
logger = logging.getLogger(__name__)
# --- Helper for parsing SSE ---
# Regex to extract data field from SSE
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
try:
lines = chunk.decode("utf-8").splitlines()
for line in lines:
if line.startswith("data:"):
match = SSE_DATA_RE.match(line)
if match:
data_str = match.group(1).strip()
if data_str == "[DONE]": # Handle OpenAI-style stream termination
return {"type": "done"}
try:
return json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
return None
except UnicodeDecodeError:
logger.warning("Failed to decode chunk as UTF-8.")
return None
# --- End Helper ---
def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]:
"""
Converts assistant messages with tool_calls into content format using XML tags.
This function processes the message history and converts any assistant messages
that have tool_calls into a format that LLMs can understand. The tool_calls
are converted to <invoke>...</invoke> tags in the content field.
Args:
messages: List of ChatMessage objects from the client
Returns:
Processed list of ChatMessage objects with tool_calls converted to content
Example:
Input: [{"role": "assistant", "tool_calls": [...]}]
Output: [{"role": "assistant", "content": "<invoke>{...}</invoke>"}]
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
processed_messages = []
for msg in messages:
# Check if this is an assistant message with tool_calls
if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0:
# Convert each tool call to XML tag format
tool_call_contents = []
for tc in msg.tool_calls:
tc_data = tc.get("function", {})
name = tc_data.get("name", "")
arguments_str = tc_data.get("arguments", "{}")
# Parse arguments JSON to ensure it's valid
try:
arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str
except json.JSONDecodeError:
arguments = {}
# Build the tool call JSON
tool_call_json = {"name": name, "arguments": arguments}
# Wrap in XML tags
tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}'
tool_call_contents.append(tool_call_content)
# Create new message with tool calls in content
# Preserve original content if it exists
content_parts = []
if msg.content:
content_parts.append(msg.content)
content_parts.extend(tool_call_contents)
new_content = "\n".join(content_parts)
processed_messages.append(
ChatMessage(role=msg.role, content=new_content)
)
else:
# Keep other messages as-is
processed_messages.append(msg)
return processed_messages
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects a system prompt with tool definitions at the beginning of the message list.
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
# Build the format example
json_example = '{"name": "search", "arguments": {"query": "example"}}'
full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}'
tool_prompt = f"""
You are a helpful assistant with access to a set of tools.
You can call them by emitting a JSON object inside tool call tags.
IMPORTANT: Use the following format for tool calls:
Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
Example: {full_example}
Here are the available tools:
{tool_defs}
Only use the tools if strictly necessary.
"""
# Prepend the system prompt with tool definitions
return [ChatMessage(role="system", content=tool_prompt)] + messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
This function now delegates to the ResponseParser class for better maintainability.
"""
parser = ResponseParser()
return parser.parse(text)
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
# Log the request payload to the database
update_request_log(log_id, llm_request=payload)
try:
async with httpx.AsyncClient() as client:
logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
error_message = f"An error occurred during raw stream request to LLM API: {e}"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks.
"""
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
try:
logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}")
except UnicodeDecodeError:
logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}")
yield chunk
# Log the full LLM response after the stream is complete
update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)})
async def process_llm_stream_for_non_stream_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
Handles SSE parsing, delta accumulation, and logs the final aggregated message.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
break
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta and "content" in delta:
full_content_parts.append(delta["content"])
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
# Log the aggregated LLM response
logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}")
update_request_log(log_id, llm_response=final_message_dict)
return final_message_dict
async def process_chat_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
It calls the stream aggregation logic and then parses the result.
"""
llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id)
# Use the ResponseParser to handle both native and text-based tool calls
parser = ResponseParser()
return parser.parse_native_tool_calls(llm_message_dict)