Files
llmproxy/app/services.py
Vertex-AI-Step-Builder 42548108ba fix: 修复流式工具调用解析并更新标签格式
主要变更:
- 将工具调用标签从 {} 改为 <invoke></invoke>,避免与 JSON 括号冲突
- 修复流式请求未解析工具调用的问题,现在返回 OpenAI 格式的 tool_calls
- 从 SSE 响应中正确提取 content 并解析工具调用
- 更新提示词格式以使用新标签
- 更新所有相关测试用例

问题修复:
- 流式请求现在正确返回 OpenAI 格式的 tool_calls
- 标签冲突导致的解析失败问题已解决
- 所有单元测试通过 (20/20)
- API 完全兼容 OpenAI REST API tools 字段行为

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 09:20:08 +00:00

179 lines
7.2 KiB
Python

import json
import re
import httpx
import logging
from typing import List, Dict, Any, Tuple, Optional, AsyncGenerator
from .models import ChatMessage, Tool, ResponseMessage, ToolCall, ToolCallFunction
from .core.config import Settings
from .database import update_request_log
from .response_parser import ResponseParser, parse_response
# Get a logger instance for this module
logger = logging.getLogger(__name__)
# --- Helper for parsing SSE ---
# Regex to extract data field from SSE
SSE_DATA_RE = re.compile(r"data:\s*(.*)")
def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
"""Parses a chunk of bytes as SSE and extracts the JSON data."""
try:
lines = chunk.decode("utf-8").splitlines()
for line in lines:
if line.startswith("data:"):
match = SSE_DATA_RE.match(line)
if match:
data_str = match.group(1).strip()
if data_str == "[DONE]": # Handle OpenAI-style stream termination
return {"type": "done"}
try:
return json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from SSE data: {data_str}")
return None
except UnicodeDecodeError:
logger.warning("Failed to decode chunk as UTF-8.")
return None
# --- End Helper ---
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects a system prompt with tool definitions at the beginning of the message list.
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
# Build the format example
json_example = '{"name": "search", "arguments": {"query": "example"}}'
full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}'
tool_prompt = f"""
You are a helpful assistant with access to a set of tools.
You can call them by emitting a JSON object inside tool call tags.
IMPORTANT: Use the following format for tool calls:
Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
Example: {full_example}
Here are the available tools:
{tool_defs}
Only use the tools if strictly necessary.
"""
# Prepend the system prompt with tool definitions
return [ChatMessage(role="system", content=tool_prompt)] + messages
def parse_llm_response_from_content(text: str) -> ResponseMessage:
"""
(Fallback) Parses the raw LLM text response to extract a message and any tool calls.
This is used when the LLM does not support native tool calling.
This function now delegates to the ResponseParser class for better maintainability.
"""
parser = ResponseParser()
return parser.parse(text)
async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Makes the raw HTTP streaming call to the LLM backend.
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
# Log the request payload to the database
update_request_log(log_id, llm_request=payload)
try:
async with httpx.AsyncClient() as client:
logger.info(f"Initiating raw stream to LLM API for log ID {log_id} at {settings.REAL_LLM_API_URL}")
async with client.stream("POST", settings.REAL_LLM_API_URL, headers=headers, json=payload, timeout=60.0) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
except httpx.HTTPStatusError as e:
error_message = f"LLM API returned an error during raw stream: {e.response.status_code}, response: '{e.response.text}'"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "LLM API Error", "status_code": ' + str(e.response.status_code).encode() + b'}\n\n'
except httpx.RequestError as e:
error_message = f"An error occurred during raw stream request to LLM API: {e}"
logger.error(f"{error_message} for log ID {log_id}")
update_request_log(log_id, llm_response={"error": error_message})
yield b'data: {"error": "Network Error", "details": "' + str(e).encode() + b'"}\n\n'
async def stream_llm_api(messages: List[ChatMessage], settings: Settings, log_id: int) -> AsyncGenerator[bytes, None]:
"""
Public interface for streaming. Calls the raw stream, logs the full response, and yields chunks.
"""
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
try:
logger.info(f"Streaming chunk for log ID {log_id}: {chunk.decode('utf-8').strip()}")
except UnicodeDecodeError:
logger.info(f"Streaming chunk (undecodable) for log ID {log_id}: {chunk}")
yield chunk
# Log the full LLM response after the stream is complete
update_request_log(log_id, llm_response={"content": "".join(llm_response_chunks)})
async def process_llm_stream_for_non_stream_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> Dict[str, Any]:
"""
Aggregates a streaming LLM response into a single, non-streaming message.
Handles SSE parsing, delta accumulation, and logs the final aggregated message.
"""
full_content_parts = []
final_message_dict: Dict[str, Any] = {"role": "assistant", "content": None}
llm_response_chunks = []
async for chunk in _raw_stream_from_llm(messages, settings, log_id):
llm_response_chunks.append(chunk.decode('utf-8', errors='ignore'))
parsed_data = _parse_sse_data(chunk)
if parsed_data:
if parsed_data.get("type") == "done":
break
choices = parsed_data.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")
if delta and "content" in delta:
full_content_parts.append(delta["content"])
final_message_dict["content"] = "".join(full_content_parts) if full_content_parts else None
# Log the aggregated LLM response
logger.info(f"Aggregated non-streaming response content for log ID {log_id}: {final_message_dict.get('content')}")
update_request_log(log_id, llm_response=final_message_dict)
return final_message_dict
async def process_chat_request(
messages: List[ChatMessage],
settings: Settings,
log_id: int
) -> ResponseMessage:
"""
Main service function for non-streaming requests.
It calls the stream aggregation logic and then parses the result.
"""
llm_message_dict = await process_llm_stream_for_non_stream_request(messages, settings, log_id)
# Use the ResponseParser to handle both native and text-based tool calls
parser = ResponseParser()
return parser.parse_native_tool_calls(llm_message_dict)