diff --git a/app/main.py b/app/main.py index 83bcfdb..cac379f 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request from starlette.responses import StreamingResponse from .models import IncomingRequest, ProxyResponse -from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content +from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data from .core.config import get_settings, Settings from .database import init_db, log_request, update_request_log @@ -79,18 +79,104 @@ async def chat_completions( # Handle streaming request if request.stream: logger.info(f"Initiating streaming request for log ID: {log_id}") - + async def stream_and_log(): + import json + stream_content_buffer = [] + raw_chunks = [] + + # First, collect all chunks to detect if there are tool calls async for chunk in stream_llm_api(messages_to_llm, settings, log_id): - stream_content_buffer.append(chunk.decode('utf-8')) - yield chunk - - # After the stream is complete, parse the full content and log it + raw_chunks.append(chunk) + # Extract content from SSE chunks + parsed = _parse_sse_data(chunk) + if parsed and parsed.get("type") != "done": + choices = parsed.get("choices") + if choices and len(choices) > 0: + delta = choices[0].get("delta") + if delta and "content" in delta: + stream_content_buffer.append(delta["content"]) + + # Parse the complete content full_content = "".join(stream_content_buffer) response_message = parse_llm_response_from_content(full_content) + + # If tool_calls detected, send only OpenAI format tool_calls + if response_message.tool_calls: + logger.info(f"Tool calls detected in stream, sending OpenAI format for log ID {log_id}") + + # Send tool_calls chunks + for tc in response_message.tool_calls: + # Send tool call start + chunk_data = { + "id": "chatcmpl-" + str(log_id), + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": "" + } + }] + }, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8') + + # Split arguments into smaller chunks to simulate streaming + args = tc.function.arguments + chunk_size = 20 + for i in range(0, len(args), chunk_size): + chunk_data = { + "id": "chatcmpl-" + str(log_id), + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "function": { + "arguments": args[i:i+chunk_size] + } + }] + }, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8') + + # Send final chunk + final_chunk = { + "id": "chatcmpl-" + str(log_id), + "object": "chat.completion.chunk", + "created": 0, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "tool_calls" + }] + } + yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8') + + else: + # No tool calls, yield original chunks + for chunk in raw_chunks: + yield chunk + + # Log the response proxy_response = ProxyResponse(message=response_message) - logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}") update_request_log(log_id, client_response=proxy_response.model_dump()) diff --git a/app/response_parser.py b/app/response_parser.py index 275e4db..8d89ca0 100644 --- a/app/response_parser.py +++ b/app/response_parser.py @@ -22,10 +22,10 @@ logger = logging.getLogger(__name__) # Constants for tool call parsing -# Using XML-style tags for clarity and better compatibility with JSON -# LLM should emit:{"name": "...", "arguments": {...}} -TOOL_CALL_START_TAG = "{" -TOOL_CALL_END_TAG = "}" +# Using XML-style tags to avoid confusion with JSON braces +# LLM should emit: {"name": "...", "arguments": {...}} +TOOL_CALL_START_TAG = "" +TOOL_CALL_END_TAG = "" class ToolCallParseError(Exception): diff --git a/app/services.py b/app/services.py index fb89bee..4972b84 100644 --- a/app/services.py +++ b/app/services.py @@ -47,17 +47,16 @@ def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2) - # Build the format example separately to avoid f-string escaping issues - # We need to show double braces: outer {{ }} are tags, inner { } is JSON + # Build the format example json_example = '{"name": "search", "arguments": {"query": "example"}}' - full_example = f'{{{json_example}}}' + full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}' tool_prompt = f""" You are a helpful assistant with access to a set of tools. You can call them by emitting a JSON object inside tool call tags. -IMPORTANT: Use double braces for tool calls - the outer braces are the tags ({TOOL_CALL_START_TAG} and {TOOL_CALL_END_TAG}), the inner braces are the JSON. -Format: {TOOL_CALL_START_TAG}{{\"name\": \"tool_name\", \"arguments\": {{...}}}}{TOOL_CALL_END_TAG} +IMPORTANT: Use the following format for tool calls: +Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG} Example: {full_example} diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py index b549535..b9d3f3a 100644 --- a/tests/test_response_parser.py +++ b/tests/test_response_parser.py @@ -46,7 +46,7 @@ class TestResponseParser: """Test parsing a response with a single tool call.""" parser = ResponseParser() text = f'''I'll check the weather for you. -{TOOL_CALL_START_TAG} + {{ "name": "get_weather", "arguments": {{ @@ -54,7 +54,7 @@ class TestResponseParser: "units": "celsius" }} }} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -74,14 +74,14 @@ class TestResponseParser: def test_parse_response_with_tool_call_no_content(self): """Test parsing a response with only a tool call.""" parser = ResponseParser() - text = f'''{TOOL_CALL_START_TAG} + text = f''' {{ "name": "shell", "arguments": {{ "command": ["ls", "-l"] }} }} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -95,9 +95,9 @@ class TestResponseParser: """Test parsing a response with malformed JSON in tool call.""" parser = ResponseParser() text = f'''Here's the result. -{TOOL_CALL_START_TAG} + {{invalid json}} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -109,13 +109,13 @@ class TestResponseParser: def test_parse_response_with_missing_tool_name(self): """Test parsing a tool call without a name field.""" parser = ResponseParser() - text = f'''{TOOL_CALL_START_TAG} + text = f''' {{ "arguments": {{ "command": "echo hello" }} }} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -129,7 +129,7 @@ class TestResponseParser: """Test parsing a tool call with complex nested arguments.""" parser = ResponseParser() text = f'''Executing command. -{TOOL_CALL_START_TAG} + {{ "name": "shell", "arguments": {{ @@ -140,7 +140,7 @@ class TestResponseParser: }} }} }} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -182,7 +182,7 @@ class TestResponseParser: chunks = [ "I'll run that ", "command for you.", - f'{TOOL_CALL_START_TAG}\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n{TOOL_CALL_END_TAG}' + f'\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n' ] result = parser.parse_streaming_chunks(chunks) @@ -267,8 +267,8 @@ class TestResponseParser: """Test that tool call IDs are unique.""" parser = ResponseParser() - text1 = f'{TOOL_CALL_START_TAG}{{"name": "tool1", "arguments": {{}}}}{TOOL_CALL_END_TAG}' - text2 = f'{TOOL_CALL_START_TAG}{{"name": "tool2", "arguments": {{}}}}{TOOL_CALL_END_TAG}' + text1 = f'{{"name": "tool1", "arguments": {{}}}}' + text2 = f'{{"name": "tool2", "arguments": {{}}}}' result1 = parser.parse(text1) result2 = parser.parse(text2) @@ -286,7 +286,7 @@ class TestConvenienceFunctions: def test_parse_response_default_parser(self): """Test the parse_response convenience function.""" - text = f'{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "test"}}}}{TOOL_CALL_END_TAG}' + text = f'{{"name": "search", "arguments": {{"query": "test"}}}}' result = parse_response(text) assert result.tool_calls is not None @@ -332,14 +332,14 @@ class TestEdgeCases: parser = ResponseParser() special_chars = '@#$%^&*()' text = f'''Here's the result with special chars: {special_chars} -{TOOL_CALL_START_TAG} + {{ "name": "test", "arguments": {{ "special": "!@#$%" }} }} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text) @@ -349,7 +349,7 @@ class TestEdgeCases: def test_response_with_escaped_quotes(self): """Test parsing tool calls with escaped quotes in arguments.""" parser = ResponseParser() - text = f'{TOOL_CALL_START_TAG}{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}{TOOL_CALL_END_TAG}' + text = f'{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}' result = parser.parse(text) arguments = json.loads(result.tool_calls[0].function.arguments) @@ -359,13 +359,13 @@ class TestEdgeCases: """Test that only the first tool call is extracted.""" parser = ResponseParser() text = f'''First call. -{TOOL_CALL_START_TAG} + {{"name": "tool1", "arguments": {{}}}} -{TOOL_CALL_END_TAG} + Some text in between. -{TOOL_CALL_START_TAG} + {{"name": "tool2", "arguments": {{}}}} -{TOOL_CALL_END_TAG} + ''' result = parser.parse(text)