Compare commits

...

7 Commits

Author SHA1 Message Date
yhydev
9b0c32b6f2 支持ghcproxy 2026-01-12 14:12:15 +08:00
Vertex-AI-Step-Builder
03e216373f fix: 修复 Dockerfile 中的启动命令路径
将 CMD 从 'uvicorn main:app' 修改为 'uvicorn app.main:app',
以匹配实际的应用入口文件位置 (app/main.py)。

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 16:05:36 +00:00
Vertex-AI-Step-Builder
912b027864 feat: add requirements, Dockerfile and docker-compose 2025-12-31 15:30:00 +00:00
Vertex-AI-Step-Builder
fa419ccac4 多后端支持文档 2025-12-31 15:22:40 +00:00
Vertex-AI-Step-Builder
cecfc74a96 优化提示词及没有工具调用闭合标签 2025-12-31 14:20:20 +00:00
Vertex-AI-Step-Builder
6bcdbc2560 docs: 更新 README 文档,详细说明新增功能
主要更新:
- 新增消息历史转换功能说明 (4.1)
- 更新响应解析器特性说明 (4.3)
- 添加关键特性说明章节 (8)
- 补充带消息历史的 API 请求示例 (7.2)
- 新增更新日志章节 (10)
- 完善测试脚本说明

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 13:35:45 +00:00
Vertex-AI-Step-Builder
5c2904e010 feat: 增强工具调用代理功能,支持多工具调用和消息历史转换
主要改进:
- 新增 convert_tool_calls_to_content 函数,将消息历史中的 tool_calls 转换为 LLM 可理解的 XML 格式
- 修复 response_parser 支持同时解析多个 tool_calls
- 优化响应解析逻辑,支持 content 和 tool_calls 同时存在
- 添加完整的测试覆盖,包括多工具调用、消息转换和混合响应

技术细节:
- services.py: 实现工具调用历史到 content 的转换
- response_parser.py: 使用非贪婪匹配支持多个 tool_calls 解析
- main.py: 集成消息转换功能,确保消息历史正确传递给 LLM

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-31 13:33:25 +00:00
13 changed files with 2391 additions and 86 deletions

3
.gitignore vendored
View File

@@ -133,3 +133,6 @@ dmypy.json
# Cython debug symbols
cython_debug/
# logs
logs/

11
Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM hub.rat.dev/library/python:3.10-alpine
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
COPY . .
CMD ["uvicorn", "app.ghcproxy:app", "--host", "0.0.0.0", "--port", "8000"]

275
README.md
View File

@@ -2,68 +2,131 @@
## 1. 概述 (Overview)
本项目是一个基于 FastAPI 实现的智能LLM大语言模型代理服务。其核心功能是拦截发往LLMAPI请求动态地将客户端定义的`tools`工具信息注入到提示词Prompt然后将LLM返回的结果进行解析将其中可能包含的工具调用Tool Call指令提取出来最后以结构化的格式返回给调用者。
本项目是一个基于 FastAPI 实现的智能 LLM大语言模型代理服务。其核心功能是拦截发往 LLMAPI 请求,动态地将客户端定义的 `tools`工具信息注入到提示词Prompt然后将 LLM 返回的结果进行解析将其中可能包含的工具调用Tool Call指令提取出来最后以结构化的格式返回给调用者。
这使得即使底层LLM原生不支持工具调用参数我们也能通过提示工程的方式赋予其使用工具的能力。
这使得即使底层 LLM 原生不支持工具调用参数,我们也能通过提示工程的方式赋予其使用工具的能力。
## 2. 设计原则 (Design Principles)
本程序在设计上严格遵循了以下原则:
- **高内聚 (High Cohesion)**: 业务逻辑被集中在服务层 (`app/services.py`) 中与API路由和数据模型分离。
- **高内聚 (High Cohesion)**: 业务逻辑被集中在服务层 (`app/services.py`) 中,与 API 路由和数据模型分离。
- **低耦合 (Low Coupling)**:
- API层 (`app/main.py`) 只负责路由和请求校验,不关心业务实现细节。
- API 层 (`app/main.py`) 只负责路由和请求校验,不关心业务实现细节。
- 通过依赖注入 (`Depends`) 获取配置,避免了全局状态。
- LLM调用被抽象为独立的函数方便未来切换不同的LLM后端或在测试中使用模拟Mock实现。
- **可测试性 (Testability)**: 项目包含了完整的单元测试和集成测试 (`tests/`)使用 `pytest``TestClient`确保每个模块的正确性和整体流程的稳定性。
- LLM 调用被抽象为独立的函数,方便未来切换不同的 LLM 后端或在测试中使用模拟Mock实现。
- **可测试性 (Testability)**: 项目包含了完整的单元测试和集成测试 (`tests/`)以及功能测试脚本,确保每个模块的正确性和整体流程的稳定性。
## 3. 项目结构 (Project Structure)
```
.
├── app/ # 核心应用代码
│ ├── core/ # 配置管理
│ │ └── config.py
│ ├── main.py # FastAPI 应用实例和 API 路由
│ ├── models.py # Pydantic 数据模型
── services.py # 核心业务逻辑
├── tests/ # 测试代码
│ └── test_main.py
├── .env # 环境变量文件 (需手动创建)
├── .gitignore # Git 忽略文件
├── README.md # 本文档
└── .venv/ # Python 虚拟环境 (由 uv 创建)
├── app/ # 核心应用代码
│ ├── core/ # 配置管理
│ │ └── config.py # 环境变量配置
│ ├── main.py # FastAPI 应用实例和 API 路由
│ ├── models.py # Pydantic 数据模型
── services.py # 核心业务逻辑
├── response_parser.py # 响应解析器(工具调用提取)
│ └── database.py # 数据库操作(请求日志)
├── tests/ # 测试代码
│ ├── test_main.py
│ ├── test_services.py
│ └── test_response_parser.py
├── test_*.py # 功能测试脚本
│ ├── test_tool_call_conversion.py # 工具调用转换测试
│ ├── test_multiple_tool_calls.py # 多工具调用测试
│ └── test_content_with_tool_calls.py # 内容和工具调用混合测试
├── .env # 环境变量文件 (需手动创建)
├── .gitignore # Git 忽略文件
├── README.md # 本文档
└── .venv/ # Python 虚拟环境
```
## 4. 核心逻辑详解 (Core Logic)
### 4.1. 提示词注入 (Prompt Injection)
### 4.1. 消息历史转换 (Message History Conversion)
**新增功能** - 这是本次更新的核心功能之一。
- **实现函数**: `app.services.convert_tool_calls_to_content`
- **策略**:
1. 遍历消息历史,识别 `role``assistant` 且包含 `tool_calls` 的消息。
2. 将这些消息中的 `tool_calls` 转换为 LLM 可理解的 XML 格式 `{"name": "tool_name", "arguments": {...}}`
3. 保留消息原有的 `content` 字段(如果存在)。
4. 支持多个 tool_calls 的转换,用换行符连接。
- **目的**: 确保消息历史中的工具调用能够被底层 LLM 理解,保持对话上下文的连贯性。
**转换示例**:
```python
# 转换前
{
"role": "assistant",
"tool_calls": [
{"function": {"name": "get_weather", "arguments": '{"location": "北京"}'}}
]
}
# 转换后(发送给 LLM
{
"role": "assistant",
"content": "<invoke>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"北京\"}}</invoke>"
}
```
### 4.2. 提示词注入 (Prompt Injection)
- **实现函数**: `app.services.inject_tools_into_prompt`
- **策略**:
1. 将客户端请求中 `tools` 列表JSON数组序列化为格式化的JSON字符串。
1. 将客户端请求中 `tools` 列表JSON数组序列化为格式化的 JSON 字符串。
2. 创建一个新的、`role``system` 的独立消息。
3. 此消息包含明确的指令告诉LLM它拥有哪些工具以及如何通过特定的格式来调用它们。
4. **调用格式约定**: 指示LLM在需要调用工具时必须输出一个 `<tool_call>{...}</tool_call>` 的XML标签其中包含一个带有 `name``arguments` 字段的JSON对象
5. 这个系统消息被插入到原始消息列表的第二个位置索引1然后整个修改后的消息列表被发送到真实的LLM后端
- **目的**: 对调用者透明,将工具使用的契约通过上下文传递给LLM。
3. 此消息包含明确的指令,告诉 LLM 它拥有哪些工具以及如何通过特定的格式来调用它们。
4. **调用格式约定**: 指示 LLM 在需要调用工具时,必须输出一个 `{"name": "tool_name", "arguments": {...}}` 的 XML 标签
5. 这个系统消息被插入到消息列表的开头
- **目的**: 对调用者透明,将工具使用的"契约"通过上下文传递给 LLM。
### 4.2. 响应解析 (Response Parsing)
### 4.3. 响应解析 (Response Parsing)
- **实现函数**: `app.services.parse_llm_response`
- **实现**: `app.response_parser.ResponseParser`
- **策略**:
1. 使用正则表达式 (`re.search`) 在LLM返回的文本响应中查找 `<tool_call>...</tool_call>` 标签。
2. 如果找到它会提取标签内的JSON字符串并将其解析为一个结构化的 `ToolCall` 对象。此时,返回给客户端的 `ResponseMessage``tool_calls` 字段将被填充,而 `content` 字段可能为 `None`
3. 如果未找到标签则将LLM的全部响应视为常规的文本内容,填充 `content` 字段。
- **目的**: 将LLM的非结构化或半结构化输出转换为客户端可以轻松处理的、定义良好的结构化数据
1. 使用**非贪婪正则表达式**在 LLM 返回的文本响应中查找**所有** `...` 标签。
2. 支持同时解析**多个 tool_calls**
3. 提取工具调用前后的文本内容,合并到 `content` 字段。
4. 如果找到工具调用,将 `tool_calls` 字段填充为结构化的 `ToolCall` 对象列表
5. 如果未找到标签,则将 LLM 的全部响应视为常规的文本内容。
- **新特性**:
-**支持多个 tool_calls 同时解析** - 使用非贪婪匹配和 finditer
-**支持 content 和 tool_calls 同时存在** - 符合 OpenAI API 规范
-**支持文本在前、在后或前后都有文本的场景**
- **目的**: 将 LLM 的非结构化输出转换为标准的 OpenAI 格式响应。
**响应示例**:
```json
{
"message": {
"role": "assistant",
"content": "好的,我来帮你查询。",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"北京\"}"
}
}
]
}
}
```
## 5. 配置管理 (Configuration)
- 配置文件为根目录下的 `.env`
- `app/core/config.py` 中的 `get_settings` 函数通过依赖注入的方式在每次请求时加载环境变量,确保配置的实时性和在测试中的灵活性
- `app/core/config.py` 中的 `get_settings` 函数通过依赖注入的方式在每次请求时加载环境变量。
- **必需变量**:
- `REAL_LLM_API_URL`: 真实LLM后端的地址
- `REAL_LLM_API_KEY`: 用于访问真实LLMAPI密钥
- `REAL_LLM_API_URL`: 真实 LLM 后端的地址
- `REAL_LLM_API_KEY`: 用于访问真实 LLMAPI 密钥
## 6. 如何运行与测试 (Usage)
@@ -71,28 +134,49 @@
```bash
# 创建虚拟环境
uv venv
python -m venv .venv
# 激活虚拟环境
source .venv/bin/activate # Linux/Mac
# 或
.venv\Scripts\activate # Windows
# 安装依赖
uv pip install fastapi uvicorn httpx pytest
pip install fastapi uvicorn httpx pytest python-dotenv
```
### 6.2. 运行开发服务器
### 6.2. 配置环境变量
创建 `.env` 文件:
```bash
REAL_LLM_API_URL="https://api.example.com/v1/chat/completions"
REAL_LLM_API_KEY="your-api-key"
```
### 6.3. 运行开发服务器
```bash
uvicorn app.main:app --reload
```
服务将运行在 `http://127.0.0.1:8000`
### 6.3. 运行测试
### 6.4. 运行测试
```bash
# 使用 .venv 中的 python 解释器执行 pytest
.venv/bin/python -m pytest
# 运行所有单元测试
pytest
# 运行功能测试脚本
python test_tool_call_conversion.py # 测试工具调用转换
python test_multiple_tool_calls.py # 测试多工具调用
python test_content_with_tool_calls.py # 测试内容和工具调用混合
```
## 7. API 端点示例 (API Example)
### 7.1. 基本请求
**端点**: `POST /v1/chat/completions`
**请求示例 (带工具)**:
@@ -101,7 +185,7 @@ curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "What is the weather in San Francisco?"}
{"role": "user", "content": "What is the weather in Beijing?"}
],
"tools": [
{
@@ -109,16 +193,115 @@ curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {}
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}
}
]
}'
```
## 8. 未来升级方向 (Future Improvements)
### 7.2. 带消息历史的请求
- **支持多种LLM后端**: 修改 `call_llm_api_real` 函数使其能根据请求参数或配置选择不同的LLM提供商。
- **更灵活的工具调用格式**: 支持除XML标签外的其他格式例如纯JSON输出模式。
- **流式响应 (Streaming)**: 支持LLM的流式输出并实时解析和返回给客户端。
- **错误处理增强**: 针对不同的LLM API错误码和网络问题提供更精细的错误反馈。
```bash
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "What is the weather in Beijing?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"Beijing\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": "Temperature: 20°C, Sunny"
},
{"role": "user", "content": "What about Shanghai?"}
],
"tools": [...]
}'
```
注意:消息历史中的 `tool_calls` 会被自动转换为 XML 格式发送给 LLM。
## 8. 关键特性说明 (Key Features)
### 8.1. 多工具调用支持
系统现在支持在单次响应中返回多个工具调用:
```
LLM 输出:
<invoke>{"name": "get_weather", "arguments": {"location": "北京"}}</invoke>
<invoke>{"name": "get_weather", "arguments": {"location": "上海"}}</invoke>
解析为:
{
"tool_calls": [
{"function": {"name": "get_weather", ...}},
{"function": {"name": "get_weather", ...}}
]
}
```
### 8.2. 内容和工具调用混合
支持同时返回文本内容和工具调用:
```
LLM 输出:
好的,我来帮你查询。
<invoke>{"name": "search", "arguments": {"query": "..."}}</invoke>
请稍等片刻。
解析为:
{
"content": "好的,我来帮你查询。 请稍等片刻。",
"tool_calls": [...]
}
```
### 8.3. OpenAI 兼容性
- ✅ 完全兼容 OpenAI Chat Completions API 格式
- ✅ 支持流式和非流式响应
- ✅ 支持工具调用定义和执行
- ⚠️ 注意:虽然 OpenAI 的 GPT-4o 等模型通常只返回 `content``tool_calls` 中的一个,但本代理支持两者同时存在,以提供更大的灵活性并兼容不同的后端 LLM。
## 9. 未来升级方向 (Future Improvements)
- **支持多种 LLM 后端**: 修改调用函数,使其能根据请求参数或配置选择不同的 LLM 提供商。
- **更灵活的工具调用格式**: 支持除 XML 标签外的其他格式,例如纯 JSON 输出模式。
- **错误处理增强**: 针对不同的 LLM API 错误码和网络问题,提供更精细的错误反馈。
- **性能优化**: 添加缓存机制,减少重复请求的处理时间。
- **监控和日志**: 增强日志系统,添加性能监控和告警功能。
## 10. 更新日志 (Changelog)
### v1.1.0 (最新)
- ✨ 新增消息历史转换功能,支持 tool_calls 到 XML 格式的转换
- ✨ 优化响应解析器,支持多个 tool_calls 同时解析
- ✨ 支持内容和工具调用混合返回
- ✨ 添加完整的功能测试覆盖
- 🐝 修复流式工具调用解析的边界情况
### v1.0.0
- 🎉 初始版本
- ✨ 实现基本的工具调用代理功能
- ✨ OpenAI 兼容的 API 接口
- ✨ 流式和非流式响应支持

285
app/ghcproxy.py Normal file
View File

@@ -0,0 +1,285 @@
import os
import json
import random
import time
from typing import Optional, Dict, Any
from datetime import datetime
import httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
app = FastAPI()
TOKEN_EXPIRY_THRESHOLD = 60
GITHUB_TOKEN = "ghu_kpJkheogXW18PMY0Eu6D0sL4r5bDsD3aS3EA" # 注意:硬编码令牌存在安全风险
GITHUB_API_URL = "https://api.github.com/copilot_internal/v2/token"
cached_token: Optional[Dict[str, Any]] = None
def generate_uuid() -> str:
template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
return ''.join(
random.choice('0123456789abcdef') if c == 'x' else random.choice('89ab')
for c in template
)
def is_token_valid(token_data: Optional[Dict[str, Any]]) -> bool:
if not token_data or 'token' not in token_data or 'expires_at' not in token_data:
return False
current_time = int(time.time())
if current_time + TOKEN_EXPIRY_THRESHOLD >= token_data['expires_at']:
return False
return True
async def get_copilot_token() -> Dict[str, Any]:
global cached_token
if cached_token and is_token_valid(cached_token):
return cached_token
headers = {
"Authorization": f"Bearer {GITHUB_TOKEN}",
"Editor-Version": "JetBrains-IU/252.26830.84",
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
"Copilot-Language-Server-Version": "1.382.0",
"X-Github-Api-Version": "2024-12-15",
"User-Agent": "GithubCopilot/1.382.0",
"Accept": "*/*",
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(GITHUB_API_URL, headers=headers, timeout=10.0)
if response.status_code != 200:
if cached_token:
return cached_token
raise HTTPException(status_code=response.status_code, detail=f"Failed to get token: {response.text}")
data = response.json()
cached_token = data
expiry_ttl = data['expires_at'] - int(time.time()) - TOKEN_EXPIRY_THRESHOLD
if expiry_ttl > 0:
print(f"Token cached, will expire in {expiry_ttl} seconds")
else:
print("Warning: New token has short validity period")
return data
except httpx.RequestError as e:
if cached_token:
return cached_token
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
def get_headers_for_path(path: str) -> Dict[str, str]:
headers = {
"Editor-Version": "JetBrains-IU/252.26830.84",
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
"Copilot-Language-Server-Version": "1.382.0",
"X-Github-Api-Version": "2025-05-01",
"Copilot-Integration-Id": "jetbrains-chat",
"User-Agent": "GithubCopilot/1.382.0",
}
if path == "/agents" or path == "/models":
return headers
elif path == "/chat/completions":
interaction_id = generate_uuid()
request_id = generate_uuid()
headers.update({
"X-Initiator": "user",
"X-Interaction-Id": interaction_id,
"X-Interaction-Type": "conversation-panel",
"Openai-Organization": "github-copilot",
"X-Request-Id": request_id,
"Vscode-Sessionid": "427689f2-5dad-4b50-95d9-7cca977450061761839746260",
"Vscode-Machineid": "c9421c6ac240db1c5bc5117218aa21a73f3762bda7db1702d003ec2df103b812",
"Openai-Intent": "conversation-panel",
"Copilot-Vision-Request": "true",
})
print(f"/chat/completions path matched, Interaction-Id: {interaction_id}, Request-Id: {request_id}")
return headers
def has_non_empty_content(msg):
"""检查消息的 content 是否非空"""
content = msg.get('content')
if content is None:
return False
if isinstance(content, str):
return bool(content.strip()) # 字符串需要去除空格后判断
if isinstance(content, (list, dict)):
return bool(content) # 列表或字典,非空则为 True
# 其他类型 (数字, 布尔值等) 通常视为非空
return True
def filter_messages_logic(messages):
"""
优化后的过滤逻辑:
找到一个 role 为 assistant 且有 tool_calls 的消息 A
以及它后面紧接着的 role 为 tool 的消息 B。
删除 A 和 B 之间所有 content 非空的消息。
"""
if not messages or len(messages) < 3: # 至少需要 assistant, something, tool 才能操作
return
i = 0
while i < len(messages):
current_msg = messages[i]
# 检查当前消息是否为 assistant 且有 tool_calls (且 tool_calls 非空)
is_assistant_with_tool_calls = (
current_msg.get("role") == "assistant" and
isinstance(current_msg.get("tool_calls"), list) and
len(current_msg["tool_calls"]) > 0
)
if is_assistant_with_tool_calls:
# 从下一个消息开始查找第一个 role='tool' 的消息
j = i + 1
found_tool = False
indices_to_remove_between = []
while j < len(messages):
msg_to_check = messages[j]
if msg_to_check.get("role") == "tool":
found_tool = True
break # 找到第一个 tool 就停止,准备删除中间的
# 检查 j 位置的消息 (在 assistant 和 tool 之间) 是否有非空 content
if has_non_empty_content(msg_to_check):
indices_to_remove_between.append(j)
j += 1
if found_tool and indices_to_remove_between:
# 从后往前删除,避免因列表长度变化导致索引失效
for idx in sorted(indices_to_remove_between, reverse=True):
removed_msg = messages.pop(idx)
print(f"Removed intermediate message with non-empty content at index {idx}: {removed_msg}")
# 删除后,列表变短,下一次循环的 i 应该在当前位置,
# 因为原来的 i+1 位置的元素现在移动到了 i。
# 所以这里我们不增加 i让外层循环来处理。
continue
else:
# 如果找到了 assistant 但没有找到配对的 tool
# 或者找到了 tool 但中间没有需要删除的内容,
# 都正常检查下一条消息。
i += 1
else:
# 当前消息不符合条件,继续检查下一条
i += 1
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_request(request: Request, path: str):
# 创建时间戳目录用于存放日志
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
log_dir = os.path.join("logs", timestamp)
os.makedirs(log_dir, exist_ok=True)
# 记录原始请求数据
original_body = await request.body()
with open(os.path.join(log_dir, "original_request.txt"), "wb") as f:
f.write(original_body or b"")
token_data = await get_copilot_token()
print(token_data)
headers = get_headers_for_path(f"/{path}")
headers["Authorization"] = f"Bearer {token_data['token']}"
headers["hello"] = "world"
body = original_body
# 过滤 messages优化后的逻辑
if body:
try:
body_data = json.loads(body.decode('utf-8') if isinstance(body, bytes) else body)
if "messages" in body_data and isinstance(body_data["messages"], list):
messages = body_data["messages"]
initial_len = len(messages)
print(f"Processing messages, initial count: {initial_len}")
filter_messages_logic(messages)
final_len = len(messages)
if initial_len != final_len:
body = json.dumps(body_data).encode('utf-8')
print(f"Messages filtered from {initial_len} to {final_len}.")
# 记录修改后的请求体
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
f.write(body or b"")
else:
# 如果没有修改也记录原始内容作为modified_request
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
f.write(body or b"")
except json.JSONDecodeError:
# body 不是 JSON保持原样
print("Request body is not valid JSON, skipping message filtering.")
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
f.write(body or b"")
# target_url = f"https://qwapi.oopsapi.com/v1/{path}"
target_url = "https://api.business.githubcopilot.com/" + path
print(target_url, " ", str(body))
# request_headers = {k: v for k, v in request.headers.items()
# if k.lower() not in ['host', 'content-length']}
# request_headers.update(headers)
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method=request.method,
url=target_url,
headers=headers,
content=body if body else None,
timeout=120.0,
)
content = response.content
# 记录响应结果
with open(os.path.join(log_dir, "response.txt"), "wb") as f:
f.write(content or b"")
print("content: ", content)
if response.headers.get("content-type", "").startswith("text/event-stream"):
return StreamingResponse(
response.aiter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
)
return JSONResponse(
content=json.loads(content) if content else {},
status_code=response.status_code,
headers={k: v for k, v in response.headers.items()
if k.lower() not in ['content-length', 'transfer-encoding']}
)
except httpx.RequestError as e:
import backtrace
backtrace.print_exc()
raise HTTPException(status_code=500, detail=f"Proxy request failed: {str(e)}")
@app.get("/")
async def root():
return {"message": "GitHub Copilot Proxy API"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -8,7 +8,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request
from starlette.responses import StreamingResponse
from .models import IncomingRequest, ProxyResponse
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data, convert_tool_calls_to_content
from .core.config import get_settings, Settings
from .database import init_db, log_request, update_request_log
@@ -87,8 +87,13 @@ async def chat_completions(
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
messages_to_llm = request_obj.messages
# Convert assistant messages with tool_calls to content format
messages_to_llm = convert_tool_calls_to_content(messages_to_llm)
logger.info(f"Converted tool calls to content format for log ID: {log_id}")
if request_obj.tools:
messages_to_llm = inject_tools_into_prompt(request_obj.messages, request_obj.tools)
messages_to_llm = inject_tools_into_prompt(messages_to_llm, request_obj.tools)
# Handle streaming request
if request_obj.stream:
@@ -100,10 +105,12 @@ async def chat_completions(
# First, collect all chunks to detect if there are tool calls
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
logger.info(f"sse_result: {chunk}")
raw_chunks.append(chunk)
# Extract content from SSE chunks
parsed = _parse_sse_data(chunk)
if parsed and parsed.get("type") != "done":
logger.info(f"sse_result_data: {parsed}")
if parsed and ( parsed.get("type") != "done" or parsed.get("choices").get("finish_reason") == "stop" ):
choices = parsed.get("choices")
if choices and len(choices) > 0:
delta = choices[0].get("delta")

View File

@@ -60,10 +60,10 @@ class ResponseParser:
# Escape special regex characters in the tags
escaped_start = re.escape(self.tool_call_start_tag)
escaped_end = re.escape(self.tool_call_end_tag)
# Match from start tag to end tag (greedy), including both tags
# This ensures we capture the complete JSON object
# Use non-greedy matching to find all tool call occurrences
# This allows us to extract multiple tool calls from a single response
self._tool_call_pattern = re.compile(
f"{escaped_start}.*{escaped_end}",
f"{escaped_start}.*?{escaped_end}",
re.DOTALL
)
@@ -124,6 +124,8 @@ class ResponseParser:
This is the main entry point for parsing. It handles both:
1. Responses with tool calls (wrapped in tags)
2. Regular text responses
3. Multiple tool calls in a single response
4. Incomplete tool calls (missing closing tag) - fallback parsing
Args:
llm_response: The raw text response from the LLM
@@ -145,55 +147,129 @@ class ResponseParser:
return ResponseMessage(content=None)
try:
match = self._tool_call_pattern.search(llm_response)
# Find all tool call occurrences
matches = list(self._tool_call_pattern.finditer(llm_response))
if match:
return self._parse_tool_call_response(llm_response, match)
if matches:
return self._parse_tool_call_response(llm_response, matches)
else:
# Check for incomplete tool call (opening tag without closing tag)
if self.tool_call_start_tag in llm_response:
logger.warning("Detected incomplete tool call (missing closing tag). Attempting fallback parsing.")
return self._parse_incomplete_tool_call(llm_response)
return self._parse_text_only_response(llm_response)
except Exception as e:
logger.warning(f"Failed to parse LLM response: {e}. Returning as text.")
return ResponseMessage(content=llm_response)
def _parse_tool_call_response(self, llm_response: str, match: re.Match) -> ResponseMessage:
def _parse_tool_call_response(self, llm_response: str, matches: List[re.Match]) -> ResponseMessage:
"""
Parse a response that contains tool calls.
Args:
llm_response: The full LLM response
match: The regex match object containing the tool call
matches: List of regex match objects containing the tool calls
Returns:
ResponseMessage with content and tool_calls
"""
# The match includes start and end tags, so strip them
matched_text = match.group(0)
tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)]
tool_calls = []
last_end = 0 # Track the position of the last tool call
# Extract valid JSON by finding matching braces
json_str = self._extract_valid_json(tool_call_str)
if json_str is None:
# Fallback to trying to parse the entire string
json_str = tool_call_str
for match in matches:
# The match includes start and end tags, so strip them
matched_text = match.group(0)
tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)]
# Extract valid JSON by finding matching braces
json_str = self._extract_valid_json(tool_call_str)
if json_str is None:
# Fallback to trying to parse the entire string
json_str = tool_call_str
try:
tool_call_data = json.loads(json_str)
# Create the tool call object
tool_call = self._create_tool_call(tool_call_data)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse tool call JSON: {tool_call_str}. Error: {e}")
continue
# Update the last end position
last_end = match.end()
# Extract content before the first tool call tag
first_match_start = matches[0].start()
content_before = llm_response[:first_match_start].strip() if first_match_start > 0 else None
# Extract content between tool calls and after the last tool call
content_parts = []
if content_before:
content_parts.append(content_before)
# Check if there's content after the last tool call
content_after = llm_response[last_end:].strip() if last_end < len(llm_response) else None
if content_after:
content_parts.append(content_after)
# Combine all content parts
content = " ".join(content_parts) if content_parts else None
return ResponseMessage(
content=content,
tool_calls=tool_calls if tool_calls else None
)
def _parse_incomplete_tool_call(self, llm_response: str) -> ResponseMessage:
"""
Parse a response with an incomplete tool call (missing closing tag).
This is a fallback method when the LLM doesn't close the tag properly.
It attempts to extract the tool call JSON and complete it.
Args:
llm_response: The full LLM response with incomplete tool call
Returns:
ResponseMessage with content and optionally tool_calls
"""
try:
tool_call_data = json.loads(json_str)
# Find the opening tag
start_idx = llm_response.find(self.tool_call_start_tag)
if start_idx == -1:
return self._parse_text_only_response(llm_response)
# Extract content before the tool call tag
parts = llm_response.split(self.tool_call_start_tag, 1)
content = parts[0].strip() if parts[0] else None
# Extract content before the opening tag
content_before = llm_response[:start_idx].strip() if start_idx > 0 else None
# Create the tool call object
tool_call = self._create_tool_call(tool_call_data)
# Extract everything after the opening tag
after_tag = llm_response[start_idx + len(self.tool_call_start_tag):]
return ResponseMessage(
content=content,
tool_calls=[tool_call]
)
# Try to extract valid JSON
json_str = self._extract_valid_json(after_tag)
if json_str:
try:
tool_call_data = json.loads(json_str)
tool_call = self._create_tool_call(tool_call_data)
except json.JSONDecodeError as e:
raise ToolCallParseError(f"Invalid JSON in tool call: {tool_call_str}. Error: {e}")
logger.info(f"Successfully parsed incomplete tool call: {tool_call.function.name}")
return ResponseMessage(
content=content_before,
tool_calls=[tool_call]
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from incomplete tool call: {e}")
# If all else fails, return as text
logger.warning("Could not salvage incomplete tool call, returning as text")
return ResponseMessage(content=llm_response)
except Exception as e:
logger.warning(f"Error in _parse_incomplete_tool_call: {e}")
return ResponseMessage(content=llm_response)
def _parse_text_only_response(self, llm_response: str) -> ResponseMessage:
"""

View File

@@ -39,6 +39,70 @@ def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
# --- End Helper ---
def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]:
"""
Converts assistant messages with tool_calls into content format using XML tags.
This function processes the message history and converts any assistant messages
that have tool_calls into a format that LLMs can understand. The tool_calls
are converted to <invoke>...</invoke> tags in the content field.
Args:
messages: List of ChatMessage objects from the client
Returns:
Processed list of ChatMessage objects with tool_calls converted to content
Example:
Input: [{"role": "assistant", "tool_calls": [...]}]
Output: [{"role": "assistant", "content": "<invoke>{...}</invoke>"}]
"""
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
processed_messages = []
for msg in messages:
# Check if this is an assistant message with tool_calls
if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0:
# Convert each tool call to XML tag format
tool_call_contents = []
for tc in msg.tool_calls:
tc_data = tc.get("function", {})
name = tc_data.get("name", "")
arguments_str = tc_data.get("arguments", "{}")
# Parse arguments JSON to ensure it's valid
try:
arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str
except json.JSONDecodeError:
arguments = {}
# Build the tool call JSON
tool_call_json = {"name": name, "arguments": arguments}
# Wrap in XML tags
tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}'
tool_call_contents.append(tool_call_content)
# Create new message with tool calls in content
# Preserve original content if it exists
content_parts = []
if msg.content:
content_parts.append(msg.content)
content_parts.extend(tool_call_contents)
new_content = "\n".join(content_parts)
processed_messages.append(
ChatMessage(role=msg.role, content=new_content)
)
else:
# Keep other messages as-is
processed_messages.append(msg)
return processed_messages
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
"""
Injects a system prompt with tool definitions at the beginning of the message list.
@@ -53,17 +117,33 @@ def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) ->
tool_prompt = f"""
You are a helpful assistant with access to a set of tools.
You can call them by emitting a JSON object inside tool call tags.
IMPORTANT: Use the following format for tool calls:
Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
## TOOL CALL FORMAT (CRITICAL)
Example: {full_example}
When you need to use a tool, you MUST follow this EXACT format:
Here are the available tools:
{TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
### IMPORTANT RULES:
1. ALWAYS include BOTH the opening tag ({TOOL_CALL_START_TAG}) AND closing tag ({TOOL_CALL_END_TAG})
2. The JSON must be valid and properly formatted
3. Keep arguments concise to avoid truncation
4. Do not include any text between the tags except the JSON
### Examples:
Simple call:
{full_example}
Multiple arguments:
{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "example", "limit": 5}}}}{TOOL_CALL_END_TAG}
## AVAILABLE TOOLS:
{tool_defs}
Only use the tools if strictly necessary.
## REMEMBER:
- If you decide to call a tool, output ONLY the tool call tags (you may add brief text before or after)
- ALWAYS close your tags properly with {TOOL_CALL_END_TAG}
- Keep your arguments concise and essential
"""
# Prepend the system prompt with tool definitions
return [ChatMessage(role="system", content=tool_prompt)] + messages
@@ -86,7 +166,7 @@ async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings,
Yields raw byte chunks as received.
"""
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True }
# Log the request payload to the database
update_request_log(log_id, llm_request=payload)

18
docker-compose.yml Normal file
View File

@@ -0,0 +1,18 @@
version: '3.8'
services:
sqlite-web:
image: docker.1ms.run/coleifer/sqlite-web
volumes:
- .:/data
environment:
SQLITE_DATABASE: llm_proxy.db
ports:
- 8580:8080
llmproxy:
build: .
ports:
- "8000:8000"
volumes:
- .:/app
restart: unless-stopped

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
fastapi
uvicorn[standard]
httpx
python-dotenv
pydantic
requests

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""
测试 chat 接口同时返回文本内容和 tool_calls
"""
import sys
import os
import json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.response_parser import ResponseParser
from app.models import ResponseMessage, ToolCall, ToolCallFunction
def test_content_and_tool_calls():
"""测试同时返回文本内容和 tool_calls 的各种场景"""
parser = ResponseParser()
print("=" * 70)
print("测试:同时返回文本内容和 tool_calls")
print("=" * 70)
# 场景 1: 文本在前 + tool_calls
print("\n场景 1: 先说话,再调用工具")
print("-" * 70)
text1 = """好的,我来帮你查询北京的天气情况。
<invoke>{"name": "get_weather", "arguments": {"location": "北京", "unit": "celsius"}}</invoke>"""
result1 = parser.parse(text1)
print(f"输入文本:\n{text1}\n")
print(f"解析结果:")
print(f" - content: {result1.content}")
print(f" - tool_calls: {len(result1.tool_calls) if result1.tool_calls else 0}")
if result1.tool_calls:
for tc in result1.tool_calls:
print(f" * {tc.function.name}: {tc.function.arguments}")
# 验证
assert result1.content is not None, "Content should not be None"
assert result1.tool_calls is not None, "Tool calls should not be None"
assert len(result1.tool_calls) == 1, "Should have 1 tool call"
assert "北京" in result1.content or "查询" in result1.content, "Content should contain original text"
print(" ✓ 场景 1 通过")
# 场景 2: tool_calls + 文本在后
print("\n场景 2: 先调用工具,再说话")
print("-" * 70)
text2 = """<invoke>{"name": "search", "arguments": {"query": "今天天气"}}</invoke>
我已经帮你查询了,请稍等片刻。"""
result2 = parser.parse(text2)
print(f"输入文本:\n{text2}\n")
print(f"解析结果:")
print(f" - content: {result2.content}")
print(f" - tool_calls: {len(result2.tool_calls) if result2.tool_calls else 0}")
if result2.tool_calls:
for tc in result2.tool_calls:
print(f" * {tc.function.name}: {tc.function.arguments}")
assert result2.content is not None
assert result2.tool_calls is not None
assert "稍等" in result2.content or "查询" in result2.content
print(" ✓ 场景 2 通过")
# 场景 3: 文本 - tool_calls - 文本
print("\n场景 3: 文本 - 工具调用 - 文本(三明治结构)")
print("-" * 70)
text3 = """让我先查一下北京的温度。
<invoke>{"name": "get_weather", "arguments": {"location": "北京"}}</invoke>
查到了,我再查一下上海的。
<invoke>{"name": "get_weather", "arguments": {"location": "上海"}}</invoke>
好了,两个城市都查询完毕。"""
result3 = parser.parse(text3)
print(f"输入文本:\n{text3}\n")
print(f"解析结果:")
print(f" - content: {result3.content}")
print(f" - tool_calls: {len(result3.tool_calls) if result3.tool_calls else 0}")
if result3.tool_calls:
for i, tc in enumerate(result3.tool_calls, 1):
print(f" * {tc.function.name}: {tc.function.arguments}")
assert result3.content is not None
assert result3.tool_calls is not None
assert len(result3.tool_calls) == 2
assert "先查一下" in result3.content
assert "查询完毕" in result3.content
print(" ✓ 场景 3 通过")
# 场景 4: 测试 ResponseMessage 序列化
print("\n场景 4: 验证 ResponseMessage 可以正确序列化为 JSON")
print("-" * 70)
msg = ResponseMessage(
role="assistant",
content="好的,我来帮你查询。",
tool_calls=[
ToolCall(
id="call_123",
type="function",
function=ToolCallFunction(
name="get_weather",
arguments=json.dumps({"location": "北京"})
)
)
]
)
json_str = msg.model_dump_json(indent=2)
print("序列化的 JSON 响应:")
print(json_str)
parsed_back = ResponseMessage.model_validate_json(json_str)
assert parsed_back.content == msg.content
assert parsed_back.tool_calls is not None
assert len(parsed_back.tool_calls) == 1
print(" ✓ 场景 4 通过 - JSON 序列化/反序列化正常")
print("\n" + "=" * 70)
print("所有测试通过! ✓")
print("=" * 70)
print("\n总结:")
print("✓ chat 接口支持同时返回文本内容和 tool_calls")
print("✓ content 和 tool_calls 可以同时存在")
print("✓ 支持文本在前、在后、或前后都有文本的场景")
print("✓ 支持多个 tool_calls 与文本内容混合")
print("✓ JSON 序列化/反序列化正常")
print("\n实际应用场景示例:")
print("""
Assistant: "好的,我来帮你查询一下。"
[调用 get_weather 工具]
[收到工具结果]
Assistant: "北京今天晴天,气温 25°C。"
""")
if __name__ == "__main__":
test_content_and_tool_calls()

154
test_multiple_tool_calls.py Normal file
View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
测试多个 tool_calls 的完整流程
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services import convert_tool_calls_to_content
from app.response_parser import ResponseParser
from app.models import ChatMessage
def test_multiple_tool_calls():
"""测试多个 tool_calls 的完整流程"""
print("=" * 60)
print("测试场景:消息历史中有多个 tool_calls")
print("=" * 60)
# 模拟对话场景
# 用户问北京和上海的天气assistant 调用了两个工具
messages = [
ChatMessage(
role="user",
content="帮我查一下北京和上海的天气"
),
ChatMessage(
role="assistant",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "北京", "unit": "celsius"}'
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "上海", "unit": "celsius"}'
}
}
]
),
ChatMessage(
role="user",
content="结果怎么样?"
)
]
print("\n1. 原始消息:")
for i, msg in enumerate(messages):
print(f" 消息 {i+1}: {msg.role}")
if msg.content:
print(f" 内容: {msg.content}")
if msg.tool_calls:
print(f" 工具调用: {len(msg.tool_calls)}")
for j, tc in enumerate(msg.tool_calls):
print(f" {j+1}. {tc['function']['name']}")
# 转换 tool_calls 到 content
print("\n2. 转换后的消息(发送给 LLM:")
converted = convert_tool_calls_to_content(messages)
for i, msg in enumerate(converted):
print(f" 消息 {i+1}: {msg.role}")
if msg.content:
# 只显示前 150 个字符
content_preview = msg.content[:150] + "..." if len(msg.content) > 150 else msg.content
print(f" 内容: {content_preview}")
# 验证转换
assert "<invoke>" in converted[1].content
assert converted[1].content.count("<invoke>") == 2
print("\n ✓ 转换成功!两个 tool_calls 都被转换成 XML 标签格式")
# 模拟 LLM 返回新的响应(也包含多个 tool_calls
print("\n3. 模拟 LLM 响应(包含多个 tool_calls:")
llm_response = '''好的,我来帮你查一下其他城市的天气。
<invoke>{"name": "get_weather", "arguments": {"location": "广州"}}</invoke>
<invoke>{"name": "get_weather", "arguments": {"location": "深圳"}}</invoke>
请稍等。'''
print(f" {llm_response}")
# 解析 LLM 响应
print("\n4. 解析 LLM 响应:")
parser = ResponseParser()
parsed = parser.parse(llm_response)
print(f" Content: {parsed.content}")
print(f" Tool calls 数量: {len(parsed.tool_calls) if parsed.tool_calls else 0}")
if parsed.tool_calls:
for i, tc in enumerate(parsed.tool_calls):
import json
args = json.loads(tc.function.arguments)
print(f" {i+1}. {tc.function.name}(location={args['location']})")
# 验证解析
assert parsed.tool_calls is not None
assert len(parsed.tool_calls) == 2
assert parsed.tool_calls[0].function.name == "get_weather"
assert parsed.tool_calls[1].function.name == "get_weather"
print("\n ✓ 解析成功!两个 tool_calls 都被正确提取")
# 测试场景 2单个 tool_call向后兼容
print("\n" + "=" * 60)
print("测试场景:单个 tool_call向后兼容性")
print("=" * 60)
single_response = '''我来帮你查询。
<invoke>{"name": "search", "arguments": {"query": "今天天气"}}</invoke>'''
parsed_single = parser.parse(single_response)
print(f"Content: {parsed_single.content}")
print(f"Tool calls 数量: {len(parsed_single.tool_calls) if parsed_single.tool_calls else 0}")
assert parsed_single.tool_calls is not None
assert len(parsed_single.tool_calls) == 1
assert parsed_single.tool_calls[0].function.name == "search"
print("✓ 单个 tool_call 解析正常")
# 测试场景 3没有 tool_call
print("\n" + "=" * 60)
print("测试场景:没有 tool_call")
print("=" * 60)
no_tool_response = "你好!有什么可以帮助你的吗?"
parsed_no_tool = parser.parse(no_tool_response)
print(f"Content: {parsed_no_tool.content}")
print(f"Tool calls: {parsed_no_tool.tool_calls}")
assert parsed_no_tool.content == no_tool_response
assert parsed_no_tool.tool_calls is None
print("✓ 普通文本响应解析正常")
print("\n" + "=" * 60)
print("所有测试通过! ✓")
print("=" * 60)
print("\n总结:")
print("- 消息历史中的多个 tool_calls 可以正确转换为 XML 格式")
print("- LLM 响应中的多个 tool_calls 可以正确解析")
print("- 向后兼容单个 tool_call 和普通文本响应")
if __name__ == "__main__":
test_multiple_tool_calls()

View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
测试 tool_calls 到 content 的转换功能
"""
import sys
import os
# 添加项目路径到 sys.path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.services import convert_tool_calls_to_content
from app.models import ChatMessage
def test_convert_tool_calls_to_content():
"""测试工具调用转换功能"""
# 测试用例 1: 带有 tool_calls 的 assistant 消息
print("=" * 60)
print("测试用例 1: 带有 tool_calls 的 assistant 消息")
print("=" * 60)
messages = [
ChatMessage(
role="user",
content="帮我查询一下天气"
),
ChatMessage(
role="assistant",
tool_calls=[
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "北京", "unit": "celsius"}'
}
}
]
),
ChatMessage(
role="user",
content="那上海呢?"
)
]
print("\n原始消息:")
for i, msg in enumerate(messages):
print(f" 消息 {i+1}:")
print(f" 角色: {msg.role}")
if msg.content:
print(f" 内容: {msg.content}")
if msg.tool_calls:
print(f" 工具调用: {len(msg.tool_calls)}")
# 转换
converted = convert_tool_calls_to_content(messages)
print("\n转换后的消息:")
for i, msg in enumerate(converted):
print(f" 消息 {i+1}:")
print(f" 角色: {msg.role}")
if msg.content:
print(f" 内容: {msg.content[:100]}...") # 只显示前100个字符
# 验证第二个消息是否被正确转换
assert converted[1].role == "assistant"
assert "<invoke>" in converted[1].content
assert "get_weather" in converted[1].content
assert "北京" in converted[1].content
assert converted[1].tool_calls is None # tool_calls 应该被移除
print("\n✓ 测试用例 1 通过!")
# 测试用例 2: 带有 content 和 tool_calls 的 assistant 消息
print("\n" + "=" * 60)
print("测试用例 2: 带有 content 和 tool_calls 的 assistant 消息")
print("=" * 60)
messages2 = [
ChatMessage(
role="assistant",
content="好的,让我帮你查询天气。",
tool_calls=[
{
"id": "call_456",
"type": "function",
"function": {
"name": "search",
"arguments": '{"query": "今天天气"}'
}
}
]
)
]
print("\n原始消息:")
print(f" 角色: {messages2[0].role}")
print(f" 内容: {messages2[0].content}")
print(f" 工具调用: {messages2[0].tool_calls}")
converted2 = convert_tool_calls_to_content(messages2)
print("\n转换后的消息:")
print(f" 角色: {converted2[0].role}")
print(f" 内容: {converted2[0].content}")
# 验证
assert "好的,让我帮你查询天气。" in converted2[0].content
assert "<invoke>" in converted2[0].content
assert "search" in converted2[0].content
print("\n✓ 测试用例 2 通过!")
# 测试用例 3: 多个 tool_calls
print("\n" + "=" * 60)
print("测试用例 3: 多个 tool_calls")
print("=" * 60)
messages3 = [
ChatMessage(
role="assistant",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "北京"}'
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "上海"}'
}
}
]
)
]
print("\n原始消息:")
print(f" 角色: {messages3[0].role}")
print(f" 工具调用数量: {len(messages3[0].tool_calls)}")
converted3 = convert_tool_calls_to_content(messages3)
print("\n转换后的消息:")
print(f" 内容: {converted3[0].content}")
# 验证两个工具调用都被转换
assert converted3[0].content.count("<invoke>") == 2
assert "北京" in converted3[0].content
assert "上海" in converted3[0].content
print("\n✓ 测试用例 3 通过!")
# 测试用例 4: 没有 tool_calls 的消息(应该保持不变)
print("\n" + "=" * 60)
print("测试用例 4: 没有 tool_calls 的消息")
print("=" * 60)
messages4 = [
ChatMessage(role="user", content="你好"),
ChatMessage(role="assistant", content="你好,有什么可以帮助你的吗?"),
ChatMessage(role="user", content="再见")
]
print("\n原始消息:")
for i, msg in enumerate(messages4):
print(f" 消息 {i+1}: {msg.role} - {msg.content}")
converted4 = convert_tool_calls_to_content(messages4)
print("\n转换后的消息:")
for i, msg in enumerate(converted4):
print(f" 消息 {i+1}: {msg.role} - {msg.content}")
# 验证消息保持不变
assert len(converted4) == len(messages4)
assert converted4[0].content == "你好"
assert converted4[1].content == "你好,有什么可以帮助你的吗?"
assert converted4[2].content == "再见"
print("\n✓ 测试用例 4 通过!")
print("\n" + "=" * 60)
print("所有测试用例通过! ✓")
print("=" * 60)
if __name__ == "__main__":
test_convert_tool_calls_to_content()