diff --git a/app/main.py b/app/main.py
index 83bcfdb..cac379f 100644
--- a/app/main.py
+++ b/app/main.py
@@ -7,7 +7,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request
from starlette.responses import StreamingResponse
from .models import IncomingRequest, ProxyResponse
-from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content
+from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data
from .core.config import get_settings, Settings
from .database import init_db, log_request, update_request_log
@@ -79,18 +79,104 @@ async def chat_completions(
# Handle streaming request
if request.stream:
logger.info(f"Initiating streaming request for log ID: {log_id}")
-
+
async def stream_and_log():
+ import json
+
stream_content_buffer = []
+ raw_chunks = []
+
+ # First, collect all chunks to detect if there are tool calls
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
- stream_content_buffer.append(chunk.decode('utf-8'))
- yield chunk
-
- # After the stream is complete, parse the full content and log it
+ raw_chunks.append(chunk)
+ # Extract content from SSE chunks
+ parsed = _parse_sse_data(chunk)
+ if parsed and parsed.get("type") != "done":
+ choices = parsed.get("choices")
+ if choices and len(choices) > 0:
+ delta = choices[0].get("delta")
+ if delta and "content" in delta:
+ stream_content_buffer.append(delta["content"])
+
+ # Parse the complete content
full_content = "".join(stream_content_buffer)
response_message = parse_llm_response_from_content(full_content)
+
+ # If tool_calls detected, send only OpenAI format tool_calls
+ if response_message.tool_calls:
+ logger.info(f"Tool calls detected in stream, sending OpenAI format for log ID {log_id}")
+
+ # Send tool_calls chunks
+ for tc in response_message.tool_calls:
+ # Send tool call start
+ chunk_data = {
+ "id": "chatcmpl-" + str(log_id),
+ "object": "chat.completion.chunk",
+ "created": 0,
+ "model": "gpt-3.5-turbo",
+ "choices": [{
+ "index": 0,
+ "delta": {
+ "tool_calls": [{
+ "index": 0,
+ "id": tc.id,
+ "type": tc.type,
+ "function": {
+ "name": tc.function.name,
+ "arguments": ""
+ }
+ }]
+ },
+ "finish_reason": None
+ }]
+ }
+ yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
+
+ # Split arguments into smaller chunks to simulate streaming
+ args = tc.function.arguments
+ chunk_size = 20
+ for i in range(0, len(args), chunk_size):
+ chunk_data = {
+ "id": "chatcmpl-" + str(log_id),
+ "object": "chat.completion.chunk",
+ "created": 0,
+ "model": "gpt-3.5-turbo",
+ "choices": [{
+ "index": 0,
+ "delta": {
+ "tool_calls": [{
+ "index": 0,
+ "function": {
+ "arguments": args[i:i+chunk_size]
+ }
+ }]
+ },
+ "finish_reason": None
+ }]
+ }
+ yield f"data: {json.dumps(chunk_data)}\n\n".encode('utf-8')
+
+ # Send final chunk
+ final_chunk = {
+ "id": "chatcmpl-" + str(log_id),
+ "object": "chat.completion.chunk",
+ "created": 0,
+ "model": "gpt-3.5-turbo",
+ "choices": [{
+ "index": 0,
+ "delta": {},
+ "finish_reason": "tool_calls"
+ }]
+ }
+ yield f"data: {json.dumps(final_chunk)}\n\n".encode('utf-8')
+
+ else:
+ # No tool calls, yield original chunks
+ for chunk in raw_chunks:
+ yield chunk
+
+ # Log the response
proxy_response = ProxyResponse(message=response_message)
-
logger.info(f"Streaming client response for log ID {log_id}:\n{proxy_response.model_dump_json(indent=2)}")
update_request_log(log_id, client_response=proxy_response.model_dump())
diff --git a/app/response_parser.py b/app/response_parser.py
index 275e4db..8d89ca0 100644
--- a/app/response_parser.py
+++ b/app/response_parser.py
@@ -22,10 +22,10 @@ logger = logging.getLogger(__name__)
# Constants for tool call parsing
-# Using XML-style tags for clarity and better compatibility with JSON
-# LLM should emit:{"name": "...", "arguments": {...}}
-TOOL_CALL_START_TAG = "{"
-TOOL_CALL_END_TAG = "}"
+# Using XML-style tags to avoid confusion with JSON braces
+# LLM should emit: {"name": "...", "arguments": {...}}
+TOOL_CALL_START_TAG = ""
+TOOL_CALL_END_TAG = ""
class ToolCallParseError(Exception):
diff --git a/app/services.py b/app/services.py
index fb89bee..4972b84 100644
--- a/app/services.py
+++ b/app/services.py
@@ -47,17 +47,16 @@ def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) ->
tool_defs = json.dumps([tool.model_dump() for tool in tools], indent=2)
- # Build the format example separately to avoid f-string escaping issues
- # We need to show double braces: outer {{ }} are tags, inner { } is JSON
+ # Build the format example
json_example = '{"name": "search", "arguments": {"query": "example"}}'
- full_example = f'{{{json_example}}}'
+ full_example = f'{TOOL_CALL_START_TAG}{json_example}{TOOL_CALL_END_TAG}'
tool_prompt = f"""
You are a helpful assistant with access to a set of tools.
You can call them by emitting a JSON object inside tool call tags.
-IMPORTANT: Use double braces for tool calls - the outer braces are the tags ({TOOL_CALL_START_TAG} and {TOOL_CALL_END_TAG}), the inner braces are the JSON.
-Format: {TOOL_CALL_START_TAG}{{\"name\": \"tool_name\", \"arguments\": {{...}}}}{TOOL_CALL_END_TAG}
+IMPORTANT: Use the following format for tool calls:
+Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
Example: {full_example}
diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py
index b549535..b9d3f3a 100644
--- a/tests/test_response_parser.py
+++ b/tests/test_response_parser.py
@@ -46,7 +46,7 @@ class TestResponseParser:
"""Test parsing a response with a single tool call."""
parser = ResponseParser()
text = f'''I'll check the weather for you.
-{TOOL_CALL_START_TAG}
+
{{
"name": "get_weather",
"arguments": {{
@@ -54,7 +54,7 @@ class TestResponseParser:
"units": "celsius"
}}
}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -74,14 +74,14 @@ class TestResponseParser:
def test_parse_response_with_tool_call_no_content(self):
"""Test parsing a response with only a tool call."""
parser = ResponseParser()
- text = f'''{TOOL_CALL_START_TAG}
+ text = f'''
{{
"name": "shell",
"arguments": {{
"command": ["ls", "-l"]
}}
}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -95,9 +95,9 @@ class TestResponseParser:
"""Test parsing a response with malformed JSON in tool call."""
parser = ResponseParser()
text = f'''Here's the result.
-{TOOL_CALL_START_TAG}
+
{{invalid json}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -109,13 +109,13 @@ class TestResponseParser:
def test_parse_response_with_missing_tool_name(self):
"""Test parsing a tool call without a name field."""
parser = ResponseParser()
- text = f'''{TOOL_CALL_START_TAG}
+ text = f'''
{{
"arguments": {{
"command": "echo hello"
}}
}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -129,7 +129,7 @@ class TestResponseParser:
"""Test parsing a tool call with complex nested arguments."""
parser = ResponseParser()
text = f'''Executing command.
-{TOOL_CALL_START_TAG}
+
{{
"name": "shell",
"arguments": {{
@@ -140,7 +140,7 @@ class TestResponseParser:
}}
}}
}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -182,7 +182,7 @@ class TestResponseParser:
chunks = [
"I'll run that ",
"command for you.",
- f'{TOOL_CALL_START_TAG}\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n{TOOL_CALL_END_TAG}'
+ f'\n{{"name": "shell", "arguments": {{"command": ["echo", "hello"]}}}}\n'
]
result = parser.parse_streaming_chunks(chunks)
@@ -267,8 +267,8 @@ class TestResponseParser:
"""Test that tool call IDs are unique."""
parser = ResponseParser()
- text1 = f'{TOOL_CALL_START_TAG}{{"name": "tool1", "arguments": {{}}}}{TOOL_CALL_END_TAG}'
- text2 = f'{TOOL_CALL_START_TAG}{{"name": "tool2", "arguments": {{}}}}{TOOL_CALL_END_TAG}'
+ text1 = f'{{"name": "tool1", "arguments": {{}}}}'
+ text2 = f'{{"name": "tool2", "arguments": {{}}}}'
result1 = parser.parse(text1)
result2 = parser.parse(text2)
@@ -286,7 +286,7 @@ class TestConvenienceFunctions:
def test_parse_response_default_parser(self):
"""Test the parse_response convenience function."""
- text = f'{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "test"}}}}{TOOL_CALL_END_TAG}'
+ text = f'{{"name": "search", "arguments": {{"query": "test"}}}}'
result = parse_response(text)
assert result.tool_calls is not None
@@ -332,14 +332,14 @@ class TestEdgeCases:
parser = ResponseParser()
special_chars = '@#$%^&*()'
text = f'''Here's the result with special chars: {special_chars}
-{TOOL_CALL_START_TAG}
+
{{
"name": "test",
"arguments": {{
"special": "!@#$%"
}}
}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)
@@ -349,7 +349,7 @@ class TestEdgeCases:
def test_response_with_escaped_quotes(self):
"""Test parsing tool calls with escaped quotes in arguments."""
parser = ResponseParser()
- text = f'{TOOL_CALL_START_TAG}{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}{TOOL_CALL_END_TAG}'
+ text = f'{{"name": "echo", "arguments": {{"message": "Hello \\"world\\""}}}}'
result = parser.parse(text)
arguments = json.loads(result.tool_calls[0].function.arguments)
@@ -359,13 +359,13 @@ class TestEdgeCases:
"""Test that only the first tool call is extracted."""
parser = ResponseParser()
text = f'''First call.
-{TOOL_CALL_START_TAG}
+
{{"name": "tool1", "arguments": {{}}}}
-{TOOL_CALL_END_TAG}
+
Some text in between.
-{TOOL_CALL_START_TAG}
+
{{"name": "tool2", "arguments": {{}}}}
-{TOOL_CALL_END_TAG}
+
'''
result = parser.parse(text)