Compare commits
7 Commits
f7508d915b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b0c32b6f2 | ||
|
|
03e216373f | ||
|
|
912b027864 | ||
|
|
fa419ccac4 | ||
|
|
cecfc74a96 | ||
|
|
6bcdbc2560 | ||
|
|
5c2904e010 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -133,3 +133,6 @@ dmypy.json
|
|||||||
|
|
||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
|
# logs
|
||||||
|
logs/
|
||||||
|
|||||||
11
Dockerfile
Normal file
11
Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
FROM hub.rat.dev/library/python:3.10-alpine
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.ghcproxy:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
275
README.md
275
README.md
@@ -2,68 +2,131 @@
|
|||||||
|
|
||||||
## 1. 概述 (Overview)
|
## 1. 概述 (Overview)
|
||||||
|
|
||||||
本项目是一个基于 FastAPI 实现的智能LLM(大语言模型)代理服务。其核心功能是拦截发往LLM的API请求,动态地将客户端定义的`tools`(工具)信息注入到提示词(Prompt)中,然后将LLM返回的结果进行解析,将其中可能包含的工具调用(Tool Call)指令提取出来,最后以结构化的格式返回给调用者。
|
本项目是一个基于 FastAPI 实现的智能 LLM(大语言模型)代理服务。其核心功能是拦截发往 LLM 的 API 请求,动态地将客户端定义的 `tools`(工具)信息注入到提示词(Prompt)中,然后将 LLM 返回的结果进行解析,将其中可能包含的工具调用(Tool Call)指令提取出来,最后以结构化的格式返回给调用者。
|
||||||
|
|
||||||
这使得即使底层LLM原生不支持工具调用参数,我们也能通过提示工程的方式赋予其使用工具的能力。
|
这使得即使底层 LLM 原生不支持工具调用参数,我们也能通过提示工程的方式赋予其使用工具的能力。
|
||||||
|
|
||||||
## 2. 设计原则 (Design Principles)
|
## 2. 设计原则 (Design Principles)
|
||||||
|
|
||||||
本程序在设计上严格遵循了以下原则:
|
本程序在设计上严格遵循了以下原则:
|
||||||
|
|
||||||
- **高内聚 (High Cohesion)**: 业务逻辑被集中在服务层 (`app/services.py`) 中,与API路由和数据模型分离。
|
- **高内聚 (High Cohesion)**: 业务逻辑被集中在服务层 (`app/services.py`) 中,与 API 路由和数据模型分离。
|
||||||
- **低耦合 (Low Coupling)**:
|
- **低耦合 (Low Coupling)**:
|
||||||
- API层 (`app/main.py`) 只负责路由和请求校验,不关心业务实现细节。
|
- API 层 (`app/main.py`) 只负责路由和请求校验,不关心业务实现细节。
|
||||||
- 通过依赖注入 (`Depends`) 获取配置,避免了全局状态。
|
- 通过依赖注入 (`Depends`) 获取配置,避免了全局状态。
|
||||||
- LLM调用被抽象为独立的函数,方便未来切换不同的LLM后端或在测试中使用模拟(Mock)实现。
|
- LLM 调用被抽象为独立的函数,方便未来切换不同的 LLM 后端或在测试中使用模拟(Mock)实现。
|
||||||
- **可测试性 (Testability)**: 项目包含了完整的单元测试和集成测试 (`tests/`),使用 `pytest` 和 `TestClient` 来确保每个模块的正确性和整体流程的稳定性。
|
- **可测试性 (Testability)**: 项目包含了完整的单元测试和集成测试 (`tests/`),以及功能测试脚本,确保每个模块的正确性和整体流程的稳定性。
|
||||||
|
|
||||||
## 3. 项目结构 (Project Structure)
|
## 3. 项目结构 (Project Structure)
|
||||||
|
|
||||||
```
|
```
|
||||||
.
|
.
|
||||||
├── app/ # 核心应用代码
|
├── app/ # 核心应用代码
|
||||||
│ ├── core/ # 配置管理
|
│ ├── core/ # 配置管理
|
||||||
│ │ └── config.py
|
│ │ └── config.py # 环境变量配置
|
||||||
│ ├── main.py # FastAPI 应用实例和 API 路由
|
│ ├── main.py # FastAPI 应用实例和 API 路由
|
||||||
│ ├── models.py # Pydantic 数据模型
|
│ ├── models.py # Pydantic 数据模型
|
||||||
│ └── services.py # 核心业务逻辑
|
│ ├── services.py # 核心业务逻辑
|
||||||
├── tests/ # 测试代码
|
│ ├── response_parser.py # 响应解析器(工具调用提取)
|
||||||
│ └── test_main.py
|
│ └── database.py # 数据库操作(请求日志)
|
||||||
├── .env # 环境变量文件 (需手动创建)
|
├── tests/ # 测试代码
|
||||||
├── .gitignore # Git 忽略文件
|
│ ├── test_main.py
|
||||||
├── README.md # 本文档
|
│ ├── test_services.py
|
||||||
└── .venv/ # Python 虚拟环境 (由 uv 创建)
|
│ └── test_response_parser.py
|
||||||
|
├── test_*.py # 功能测试脚本
|
||||||
|
│ ├── test_tool_call_conversion.py # 工具调用转换测试
|
||||||
|
│ ├── test_multiple_tool_calls.py # 多工具调用测试
|
||||||
|
│ └── test_content_with_tool_calls.py # 内容和工具调用混合测试
|
||||||
|
├── .env # 环境变量文件 (需手动创建)
|
||||||
|
├── .gitignore # Git 忽略文件
|
||||||
|
├── README.md # 本文档
|
||||||
|
└── .venv/ # Python 虚拟环境
|
||||||
```
|
```
|
||||||
|
|
||||||
## 4. 核心逻辑详解 (Core Logic)
|
## 4. 核心逻辑详解 (Core Logic)
|
||||||
|
|
||||||
### 4.1. 提示词注入 (Prompt Injection)
|
### 4.1. 消息历史转换 (Message History Conversion)
|
||||||
|
|
||||||
|
**新增功能** - 这是本次更新的核心功能之一。
|
||||||
|
|
||||||
|
- **实现函数**: `app.services.convert_tool_calls_to_content`
|
||||||
|
- **策略**:
|
||||||
|
1. 遍历消息历史,识别 `role` 为 `assistant` 且包含 `tool_calls` 的消息。
|
||||||
|
2. 将这些消息中的 `tool_calls` 转换为 LLM 可理解的 XML 格式 `{"name": "tool_name", "arguments": {...}}`。
|
||||||
|
3. 保留消息原有的 `content` 字段(如果存在)。
|
||||||
|
4. 支持多个 tool_calls 的转换,用换行符连接。
|
||||||
|
- **目的**: 确保消息历史中的工具调用能够被底层 LLM 理解,保持对话上下文的连贯性。
|
||||||
|
|
||||||
|
**转换示例**:
|
||||||
|
```python
|
||||||
|
# 转换前
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"function": {"name": "get_weather", "arguments": '{"location": "北京"}'}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 转换后(发送给 LLM)
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<invoke>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"北京\"}}</invoke>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2. 提示词注入 (Prompt Injection)
|
||||||
|
|
||||||
- **实现函数**: `app.services.inject_tools_into_prompt`
|
- **实现函数**: `app.services.inject_tools_into_prompt`
|
||||||
- **策略**:
|
- **策略**:
|
||||||
1. 将客户端请求中 `tools` 列表(JSON数组)序列化为格式化的JSON字符串。
|
1. 将客户端请求中 `tools` 列表(JSON数组)序列化为格式化的 JSON 字符串。
|
||||||
2. 创建一个新的、`role` 为 `system` 的独立消息。
|
2. 创建一个新的、`role` 为 `system` 的独立消息。
|
||||||
3. 此消息包含明确的指令,告诉LLM它拥有哪些工具以及如何通过特定的格式来调用它们。
|
3. 此消息包含明确的指令,告诉 LLM 它拥有哪些工具以及如何通过特定的格式来调用它们。
|
||||||
4. **调用格式约定**: 指示LLM在需要调用工具时,必须输出一个 `<tool_call>{...}</tool_call>` 的XML标签,其中包含一个带有 `name` 和 `arguments` 字段的JSON对象。
|
4. **调用格式约定**: 指示 LLM 在需要调用工具时,必须输出一个 `{"name": "tool_name", "arguments": {...}}` 的 XML 标签。
|
||||||
5. 这个系统消息被插入到原始消息列表的第二个位置(索引1),然后整个修改后的消息列表被发送到真实的LLM后端。
|
5. 这个系统消息被插入到消息列表的开头。
|
||||||
- **目的**: 对调用者透明,将工具使用的“契约”通过上下文传递给LLM。
|
- **目的**: 对调用者透明,将工具使用的"契约"通过上下文传递给 LLM。
|
||||||
|
|
||||||
### 4.2. 响应解析 (Response Parsing)
|
### 4.3. 响应解析 (Response Parsing)
|
||||||
|
|
||||||
- **实现函数**: `app.services.parse_llm_response`
|
- **实现类**: `app.response_parser.ResponseParser`
|
||||||
- **策略**:
|
- **策略**:
|
||||||
1. 使用正则表达式 (`re.search`) 在LLM返回的纯文本响应中查找 `<tool_call>...</tool_call>` 标签。
|
1. 使用**非贪婪正则表达式**在 LLM 返回的文本响应中查找**所有** `...` 标签。
|
||||||
2. 如果找到,它会提取标签内的JSON字符串,并将其解析为一个结构化的 `ToolCall` 对象。此时,返回给客户端的 `ResponseMessage` 中 `tool_calls` 字段将被填充,而 `content` 字段可能为 `None`。
|
2. 支持同时解析**多个 tool_calls**。
|
||||||
3. 如果未找到标签,则将LLM的全部响应视为常规的文本内容,填充 `content` 字段。
|
3. 提取工具调用前后的文本内容,合并到 `content` 字段。
|
||||||
- **目的**: 将LLM的非结构化(或半结构化)输出,转换为客户端可以轻松处理的、定义良好的结构化数据。
|
4. 如果找到工具调用,将 `tool_calls` 字段填充为结构化的 `ToolCall` 对象列表。
|
||||||
|
5. 如果未找到标签,则将 LLM 的全部响应视为常规的文本内容。
|
||||||
|
- **新特性**:
|
||||||
|
- ✅ **支持多个 tool_calls 同时解析** - 使用非贪婪匹配和 finditer
|
||||||
|
- ✅ **支持 content 和 tool_calls 同时存在** - 符合 OpenAI API 规范
|
||||||
|
- ✅ **支持文本在前、在后或前后都有文本的场景**
|
||||||
|
- **目的**: 将 LLM 的非结构化输出转换为标准的 OpenAI 格式响应。
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "好的,我来帮你查询。",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"location\": \"北京\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## 5. 配置管理 (Configuration)
|
## 5. 配置管理 (Configuration)
|
||||||
|
|
||||||
- 配置文件为根目录下的 `.env`。
|
- 配置文件为根目录下的 `.env`。
|
||||||
- `app/core/config.py` 中的 `get_settings` 函数通过依赖注入的方式在每次请求时加载环境变量,确保配置的实时性和在测试中的灵活性。
|
- `app/core/config.py` 中的 `get_settings` 函数通过依赖注入的方式在每次请求时加载环境变量。
|
||||||
- **必需变量**:
|
- **必需变量**:
|
||||||
- `REAL_LLM_API_URL`: 真实LLM后端的地址。
|
- `REAL_LLM_API_URL`: 真实 LLM 后端的地址
|
||||||
- `REAL_LLM_API_KEY`: 用于访问真实LLM的API密钥。
|
- `REAL_LLM_API_KEY`: 用于访问真实 LLM 的 API 密钥
|
||||||
|
|
||||||
## 6. 如何运行与测试 (Usage)
|
## 6. 如何运行与测试 (Usage)
|
||||||
|
|
||||||
@@ -71,28 +134,49 @@
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 创建虚拟环境
|
# 创建虚拟环境
|
||||||
uv venv
|
python -m venv .venv
|
||||||
|
|
||||||
|
# 激活虚拟环境
|
||||||
|
source .venv/bin/activate # Linux/Mac
|
||||||
|
# 或
|
||||||
|
.venv\Scripts\activate # Windows
|
||||||
|
|
||||||
# 安装依赖
|
# 安装依赖
|
||||||
uv pip install fastapi uvicorn httpx pytest
|
pip install fastapi uvicorn httpx pytest python-dotenv
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6.2. 运行开发服务器
|
### 6.2. 配置环境变量
|
||||||
|
|
||||||
|
创建 `.env` 文件:
|
||||||
|
```bash
|
||||||
|
REAL_LLM_API_URL="https://api.example.com/v1/chat/completions"
|
||||||
|
REAL_LLM_API_KEY="your-api-key"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3. 运行开发服务器
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvicorn app.main:app --reload
|
uvicorn app.main:app --reload
|
||||||
```
|
```
|
||||||
|
|
||||||
服务将运行在 `http://127.0.0.1:8000`。
|
服务将运行在 `http://127.0.0.1:8000`。
|
||||||
|
|
||||||
### 6.3. 运行测试
|
### 6.4. 运行测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 使用 .venv 中的 python 解释器执行 pytest
|
# 运行所有单元测试
|
||||||
.venv/bin/python -m pytest
|
pytest
|
||||||
|
|
||||||
|
# 运行功能测试脚本
|
||||||
|
python test_tool_call_conversion.py # 测试工具调用转换
|
||||||
|
python test_multiple_tool_calls.py # 测试多工具调用
|
||||||
|
python test_content_with_tool_calls.py # 测试内容和工具调用混合
|
||||||
```
|
```
|
||||||
|
|
||||||
## 7. API 端点示例 (API Example)
|
## 7. API 端点示例 (API Example)
|
||||||
|
|
||||||
|
### 7.1. 基本请求
|
||||||
|
|
||||||
**端点**: `POST /v1/chat/completions`
|
**端点**: `POST /v1/chat/completions`
|
||||||
|
|
||||||
**请求示例 (带工具)**:
|
**请求示例 (带工具)**:
|
||||||
@@ -101,7 +185,7 @@ curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
|||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "What is the weather in San Francisco?"}
|
{"role": "user", "content": "What is the weather in Beijing?"}
|
||||||
],
|
],
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
@@ -109,16 +193,115 @@ curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
"description": "Get weather for a city",
|
"description": "Get weather for a city",
|
||||||
"parameters": {}
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## 8. 未来升级方向 (Future Improvements)
|
### 7.2. 带消息历史的请求
|
||||||
|
|
||||||
- **支持多种LLM后端**: 修改 `call_llm_api_real` 函数,使其能根据请求参数或配置选择不同的LLM提供商。
|
```bash
|
||||||
- **更灵活的工具调用格式**: 支持除XML标签外的其他格式,例如纯JSON输出模式。
|
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
|
||||||
- **流式响应 (Streaming)**: 支持LLM的流式输出,并实时解析和返回给客户端。
|
-H "Content-Type: application/json" \
|
||||||
- **错误处理增强**: 针对不同的LLM API错误码和网络问题,提供更精细的错误反馈。
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is the weather in Beijing?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"location\": \"Beijing\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "Temperature: 20°C, Sunny"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What about Shanghai?"}
|
||||||
|
],
|
||||||
|
"tools": [...]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:消息历史中的 `tool_calls` 会被自动转换为 XML 格式发送给 LLM。
|
||||||
|
|
||||||
|
## 8. 关键特性说明 (Key Features)
|
||||||
|
|
||||||
|
### 8.1. 多工具调用支持
|
||||||
|
|
||||||
|
系统现在支持在单次响应中返回多个工具调用:
|
||||||
|
|
||||||
|
```
|
||||||
|
LLM 输出:
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "北京"}}</invoke>
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "上海"}}</invoke>
|
||||||
|
|
||||||
|
解析为:
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{"function": {"name": "get_weather", ...}},
|
||||||
|
{"function": {"name": "get_weather", ...}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.2. 内容和工具调用混合
|
||||||
|
|
||||||
|
支持同时返回文本内容和工具调用:
|
||||||
|
|
||||||
|
```
|
||||||
|
LLM 输出:
|
||||||
|
好的,我来帮你查询。
|
||||||
|
<invoke>{"name": "search", "arguments": {"query": "..."}}</invoke>
|
||||||
|
请稍等片刻。
|
||||||
|
|
||||||
|
解析为:
|
||||||
|
{
|
||||||
|
"content": "好的,我来帮你查询。 请稍等片刻。",
|
||||||
|
"tool_calls": [...]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.3. OpenAI 兼容性
|
||||||
|
|
||||||
|
- ✅ 完全兼容 OpenAI Chat Completions API 格式
|
||||||
|
- ✅ 支持流式和非流式响应
|
||||||
|
- ✅ 支持工具调用定义和执行
|
||||||
|
- ⚠️ 注意:虽然 OpenAI 的 GPT-4o 等模型通常只返回 `content` 或 `tool_calls` 中的一个,但本代理支持两者同时存在,以提供更大的灵活性并兼容不同的后端 LLM。
|
||||||
|
|
||||||
|
## 9. 未来升级方向 (Future Improvements)
|
||||||
|
|
||||||
|
- **支持多种 LLM 后端**: 修改调用函数,使其能根据请求参数或配置选择不同的 LLM 提供商。
|
||||||
|
- **更灵活的工具调用格式**: 支持除 XML 标签外的其他格式,例如纯 JSON 输出模式。
|
||||||
|
- **错误处理增强**: 针对不同的 LLM API 错误码和网络问题,提供更精细的错误反馈。
|
||||||
|
- **性能优化**: 添加缓存机制,减少重复请求的处理时间。
|
||||||
|
- **监控和日志**: 增强日志系统,添加性能监控和告警功能。
|
||||||
|
|
||||||
|
## 10. 更新日志 (Changelog)
|
||||||
|
|
||||||
|
### v1.1.0 (最新)
|
||||||
|
- ✨ 新增消息历史转换功能,支持 tool_calls 到 XML 格式的转换
|
||||||
|
- ✨ 优化响应解析器,支持多个 tool_calls 同时解析
|
||||||
|
- ✨ 支持内容和工具调用混合返回
|
||||||
|
- ✨ 添加完整的功能测试覆盖
|
||||||
|
- 🐝 修复流式工具调用解析的边界情况
|
||||||
|
|
||||||
|
### v1.0.0
|
||||||
|
- 🎉 初始版本
|
||||||
|
- ✨ 实现基本的工具调用代理功能
|
||||||
|
- ✨ OpenAI 兼容的 API 接口
|
||||||
|
- ✨ 流式和非流式响应支持
|
||||||
|
|||||||
285
app/ghcproxy.py
Normal file
285
app/ghcproxy.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
TOKEN_EXPIRY_THRESHOLD = 60
|
||||||
|
GITHUB_TOKEN = "ghu_kpJkheogXW18PMY0Eu6D0sL4r5bDsD3aS3EA" # 注意:硬编码令牌存在安全风险
|
||||||
|
GITHUB_API_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||||
|
|
||||||
|
cached_token: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid() -> str:
|
||||||
|
template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
|
||||||
|
return ''.join(
|
||||||
|
random.choice('0123456789abcdef') if c == 'x' else random.choice('89ab')
|
||||||
|
for c in template
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_token_valid(token_data: Optional[Dict[str, Any]]) -> bool:
|
||||||
|
if not token_data or 'token' not in token_data or 'expires_at' not in token_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
if current_time + TOKEN_EXPIRY_THRESHOLD >= token_data['expires_at']:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def get_copilot_token() -> Dict[str, Any]:
|
||||||
|
global cached_token
|
||||||
|
|
||||||
|
if cached_token and is_token_valid(cached_token):
|
||||||
|
return cached_token
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {GITHUB_TOKEN}",
|
||||||
|
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||||||
|
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||||||
|
"Copilot-Language-Server-Version": "1.382.0",
|
||||||
|
"X-Github-Api-Version": "2024-12-15",
|
||||||
|
"User-Agent": "GithubCopilot/1.382.0",
|
||||||
|
"Accept": "*/*",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.get(GITHUB_API_URL, headers=headers, timeout=10.0)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
if cached_token:
|
||||||
|
return cached_token
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=f"Failed to get token: {response.text}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
cached_token = data
|
||||||
|
|
||||||
|
expiry_ttl = data['expires_at'] - int(time.time()) - TOKEN_EXPIRY_THRESHOLD
|
||||||
|
if expiry_ttl > 0:
|
||||||
|
print(f"Token cached, will expire in {expiry_ttl} seconds")
|
||||||
|
else:
|
||||||
|
print("Warning: New token has short validity period")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if cached_token:
|
||||||
|
return cached_token
|
||||||
|
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_headers_for_path(path: str) -> Dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||||||
|
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||||||
|
"Copilot-Language-Server-Version": "1.382.0",
|
||||||
|
"X-Github-Api-Version": "2025-05-01",
|
||||||
|
"Copilot-Integration-Id": "jetbrains-chat",
|
||||||
|
"User-Agent": "GithubCopilot/1.382.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
if path == "/agents" or path == "/models":
|
||||||
|
return headers
|
||||||
|
|
||||||
|
elif path == "/chat/completions":
|
||||||
|
interaction_id = generate_uuid()
|
||||||
|
request_id = generate_uuid()
|
||||||
|
|
||||||
|
headers.update({
|
||||||
|
"X-Initiator": "user",
|
||||||
|
"X-Interaction-Id": interaction_id,
|
||||||
|
"X-Interaction-Type": "conversation-panel",
|
||||||
|
"Openai-Organization": "github-copilot",
|
||||||
|
"X-Request-Id": request_id,
|
||||||
|
"Vscode-Sessionid": "427689f2-5dad-4b50-95d9-7cca977450061761839746260",
|
||||||
|
"Vscode-Machineid": "c9421c6ac240db1c5bc5117218aa21a73f3762bda7db1702d003ec2df103b812",
|
||||||
|
"Openai-Intent": "conversation-panel",
|
||||||
|
"Copilot-Vision-Request": "true",
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"/chat/completions path matched, Interaction-Id: {interaction_id}, Request-Id: {request_id}")
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def has_non_empty_content(msg):
|
||||||
|
"""检查消息的 content 是否非空"""
|
||||||
|
content = msg.get('content')
|
||||||
|
if content is None:
|
||||||
|
return False
|
||||||
|
if isinstance(content, str):
|
||||||
|
return bool(content.strip()) # 字符串需要去除空格后判断
|
||||||
|
if isinstance(content, (list, dict)):
|
||||||
|
return bool(content) # 列表或字典,非空则为 True
|
||||||
|
# 其他类型 (数字, 布尔值等) 通常视为非空
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def filter_messages_logic(messages):
|
||||||
|
"""
|
||||||
|
优化后的过滤逻辑:
|
||||||
|
找到一个 role 为 assistant 且有 tool_calls 的消息 A,
|
||||||
|
以及它后面紧接着的 role 为 tool 的消息 B。
|
||||||
|
删除 A 和 B 之间所有 content 非空的消息。
|
||||||
|
"""
|
||||||
|
if not messages or len(messages) < 3: # 至少需要 assistant, something, tool 才能操作
|
||||||
|
return
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(messages):
|
||||||
|
current_msg = messages[i]
|
||||||
|
|
||||||
|
# 检查当前消息是否为 assistant 且有 tool_calls (且 tool_calls 非空)
|
||||||
|
is_assistant_with_tool_calls = (
|
||||||
|
current_msg.get("role") == "assistant" and
|
||||||
|
isinstance(current_msg.get("tool_calls"), list) and
|
||||||
|
len(current_msg["tool_calls"]) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_assistant_with_tool_calls:
|
||||||
|
# 从下一个消息开始查找第一个 role='tool' 的消息
|
||||||
|
j = i + 1
|
||||||
|
found_tool = False
|
||||||
|
indices_to_remove_between = []
|
||||||
|
|
||||||
|
while j < len(messages):
|
||||||
|
msg_to_check = messages[j]
|
||||||
|
|
||||||
|
if msg_to_check.get("role") == "tool":
|
||||||
|
found_tool = True
|
||||||
|
break # 找到第一个 tool 就停止,准备删除中间的
|
||||||
|
# 检查 j 位置的消息 (在 assistant 和 tool 之间) 是否有非空 content
|
||||||
|
if has_non_empty_content(msg_to_check):
|
||||||
|
indices_to_remove_between.append(j)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
if found_tool and indices_to_remove_between:
|
||||||
|
# 从后往前删除,避免因列表长度变化导致索引失效
|
||||||
|
for idx in sorted(indices_to_remove_between, reverse=True):
|
||||||
|
removed_msg = messages.pop(idx)
|
||||||
|
print(f"Removed intermediate message with non-empty content at index {idx}: {removed_msg}")
|
||||||
|
# 删除后,列表变短,下一次循环的 i 应该在当前位置,
|
||||||
|
# 因为原来的 i+1 位置的元素现在移动到了 i。
|
||||||
|
# 所以这里我们不增加 i,让外层循环来处理。
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 如果找到了 assistant 但没有找到配对的 tool,
|
||||||
|
# 或者找到了 tool 但中间没有需要删除的内容,
|
||||||
|
# 都正常检查下一条消息。
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
# 当前消息不符合条件,继续检查下一条
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||||
|
async def proxy_request(request: Request, path: str):
|
||||||
|
# 创建时间戳目录用于存放日志
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
log_dir = os.path.join("logs", timestamp)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 记录原始请求数据
|
||||||
|
original_body = await request.body()
|
||||||
|
with open(os.path.join(log_dir, "original_request.txt"), "wb") as f:
|
||||||
|
f.write(original_body or b"")
|
||||||
|
|
||||||
|
token_data = await get_copilot_token()
|
||||||
|
print(token_data)
|
||||||
|
headers = get_headers_for_path(f"/{path}")
|
||||||
|
headers["Authorization"] = f"Bearer {token_data['token']}"
|
||||||
|
headers["hello"] = "world"
|
||||||
|
|
||||||
|
body = original_body
|
||||||
|
|
||||||
|
# 过滤 messages:优化后的逻辑
|
||||||
|
if body:
|
||||||
|
try:
|
||||||
|
body_data = json.loads(body.decode('utf-8') if isinstance(body, bytes) else body)
|
||||||
|
if "messages" in body_data and isinstance(body_data["messages"], list):
|
||||||
|
messages = body_data["messages"]
|
||||||
|
initial_len = len(messages)
|
||||||
|
print(f"Processing messages, initial count: {initial_len}")
|
||||||
|
filter_messages_logic(messages)
|
||||||
|
final_len = len(messages)
|
||||||
|
if initial_len != final_len:
|
||||||
|
body = json.dumps(body_data).encode('utf-8')
|
||||||
|
print(f"Messages filtered from {initial_len} to {final_len}.")
|
||||||
|
# 记录修改后的请求体
|
||||||
|
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||||
|
f.write(body or b"")
|
||||||
|
else:
|
||||||
|
# 如果没有修改,也记录原始内容作为modified_request
|
||||||
|
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||||
|
f.write(body or b"")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# body 不是 JSON,保持原样
|
||||||
|
print("Request body is not valid JSON, skipping message filtering.")
|
||||||
|
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||||
|
f.write(body or b"")
|
||||||
|
|
||||||
|
# target_url = f"https://qwapi.oopsapi.com/v1/{path}"
|
||||||
|
target_url = "https://api.business.githubcopilot.com/" + path
|
||||||
|
|
||||||
|
print(target_url, " ", str(body))
|
||||||
|
|
||||||
|
# request_headers = {k: v for k, v in request.headers.items()
|
||||||
|
# if k.lower() not in ['host', 'content-length']}
|
||||||
|
# request_headers.update(headers)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=target_url,
|
||||||
|
headers=headers,
|
||||||
|
content=body if body else None,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
# 记录响应结果
|
||||||
|
with open(os.path.join(log_dir, "response.txt"), "wb") as f:
|
||||||
|
f.write(content or b"")
|
||||||
|
|
||||||
|
print("content: ", content)
|
||||||
|
if response.headers.get("content-type", "").startswith("text/event-stream"):
|
||||||
|
return StreamingResponse(
|
||||||
|
response.aiter_bytes(),
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content=json.loads(content) if content else {},
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers={k: v for k, v in response.headers.items()
|
||||||
|
if k.lower() not in ['content-length', 'transfer-encoding']}
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
import backtrace
|
||||||
|
backtrace.print_exc()
|
||||||
|
raise HTTPException(status_code=500, detail=f"Proxy request failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "GitHub Copilot Proxy API"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
13
app/main.py
13
app/main.py
@@ -8,7 +8,7 @@ from fastapi import FastAPI, HTTPException, Depends, Request
|
|||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from .models import IncomingRequest, ProxyResponse
|
from .models import IncomingRequest, ProxyResponse
|
||||||
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data
|
from .services import process_chat_request, stream_llm_api, inject_tools_into_prompt, parse_llm_response_from_content, _parse_sse_data, convert_tool_calls_to_content
|
||||||
from .core.config import get_settings, Settings
|
from .core.config import get_settings, Settings
|
||||||
from .database import init_db, log_request, update_request_log
|
from .database import init_db, log_request, update_request_log
|
||||||
|
|
||||||
@@ -87,8 +87,13 @@ async def chat_completions(
|
|||||||
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
|
raise HTTPException(status_code=500, detail="LLM API Key or URL is not configured.")
|
||||||
|
|
||||||
messages_to_llm = request_obj.messages
|
messages_to_llm = request_obj.messages
|
||||||
|
|
||||||
|
# Convert assistant messages with tool_calls to content format
|
||||||
|
messages_to_llm = convert_tool_calls_to_content(messages_to_llm)
|
||||||
|
logger.info(f"Converted tool calls to content format for log ID: {log_id}")
|
||||||
|
|
||||||
if request_obj.tools:
|
if request_obj.tools:
|
||||||
messages_to_llm = inject_tools_into_prompt(request_obj.messages, request_obj.tools)
|
messages_to_llm = inject_tools_into_prompt(messages_to_llm, request_obj.tools)
|
||||||
|
|
||||||
# Handle streaming request
|
# Handle streaming request
|
||||||
if request_obj.stream:
|
if request_obj.stream:
|
||||||
@@ -100,10 +105,12 @@ async def chat_completions(
|
|||||||
|
|
||||||
# First, collect all chunks to detect if there are tool calls
|
# First, collect all chunks to detect if there are tool calls
|
||||||
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
|
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
|
||||||
|
logger.info(f"sse_result: {chunk}")
|
||||||
raw_chunks.append(chunk)
|
raw_chunks.append(chunk)
|
||||||
# Extract content from SSE chunks
|
# Extract content from SSE chunks
|
||||||
parsed = _parse_sse_data(chunk)
|
parsed = _parse_sse_data(chunk)
|
||||||
if parsed and parsed.get("type") != "done":
|
logger.info(f"sse_result_data: {parsed}")
|
||||||
|
if parsed and ( parsed.get("type") != "done" or parsed.get("choices").get("finish_reason") == "stop" ):
|
||||||
choices = parsed.get("choices")
|
choices = parsed.get("choices")
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
delta = choices[0].get("delta")
|
delta = choices[0].get("delta")
|
||||||
|
|||||||
@@ -60,10 +60,10 @@ class ResponseParser:
|
|||||||
# Escape special regex characters in the tags
|
# Escape special regex characters in the tags
|
||||||
escaped_start = re.escape(self.tool_call_start_tag)
|
escaped_start = re.escape(self.tool_call_start_tag)
|
||||||
escaped_end = re.escape(self.tool_call_end_tag)
|
escaped_end = re.escape(self.tool_call_end_tag)
|
||||||
# Match from start tag to end tag (greedy), including both tags
|
# Use non-greedy matching to find all tool call occurrences
|
||||||
# This ensures we capture the complete JSON object
|
# This allows us to extract multiple tool calls from a single response
|
||||||
self._tool_call_pattern = re.compile(
|
self._tool_call_pattern = re.compile(
|
||||||
f"{escaped_start}.*{escaped_end}",
|
f"{escaped_start}.*?{escaped_end}",
|
||||||
re.DOTALL
|
re.DOTALL
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,6 +124,8 @@ class ResponseParser:
|
|||||||
This is the main entry point for parsing. It handles both:
|
This is the main entry point for parsing. It handles both:
|
||||||
1. Responses with tool calls (wrapped in tags)
|
1. Responses with tool calls (wrapped in tags)
|
||||||
2. Regular text responses
|
2. Regular text responses
|
||||||
|
3. Multiple tool calls in a single response
|
||||||
|
4. Incomplete tool calls (missing closing tag) - fallback parsing
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_response: The raw text response from the LLM
|
llm_response: The raw text response from the LLM
|
||||||
@@ -145,55 +147,129 @@ class ResponseParser:
|
|||||||
return ResponseMessage(content=None)
|
return ResponseMessage(content=None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
match = self._tool_call_pattern.search(llm_response)
|
# Find all tool call occurrences
|
||||||
|
matches = list(self._tool_call_pattern.finditer(llm_response))
|
||||||
|
|
||||||
if match:
|
if matches:
|
||||||
return self._parse_tool_call_response(llm_response, match)
|
return self._parse_tool_call_response(llm_response, matches)
|
||||||
else:
|
else:
|
||||||
|
# Check for incomplete tool call (opening tag without closing tag)
|
||||||
|
if self.tool_call_start_tag in llm_response:
|
||||||
|
logger.warning("Detected incomplete tool call (missing closing tag). Attempting fallback parsing.")
|
||||||
|
return self._parse_incomplete_tool_call(llm_response)
|
||||||
return self._parse_text_only_response(llm_response)
|
return self._parse_text_only_response(llm_response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to parse LLM response: {e}. Returning as text.")
|
logger.warning(f"Failed to parse LLM response: {e}. Returning as text.")
|
||||||
return ResponseMessage(content=llm_response)
|
return ResponseMessage(content=llm_response)
|
||||||
|
|
||||||
def _parse_tool_call_response(self, llm_response: str, match: re.Match) -> ResponseMessage:
|
def _parse_tool_call_response(self, llm_response: str, matches: List[re.Match]) -> ResponseMessage:
|
||||||
"""
|
"""
|
||||||
Parse a response that contains tool calls.
|
Parse a response that contains tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_response: The full LLM response
|
llm_response: The full LLM response
|
||||||
match: The regex match object containing the tool call
|
matches: List of regex match objects containing the tool calls
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ResponseMessage with content and tool_calls
|
ResponseMessage with content and tool_calls
|
||||||
"""
|
"""
|
||||||
# The match includes start and end tags, so strip them
|
tool_calls = []
|
||||||
matched_text = match.group(0)
|
last_end = 0 # Track the position of the last tool call
|
||||||
tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)]
|
|
||||||
|
|
||||||
# Extract valid JSON by finding matching braces
|
for match in matches:
|
||||||
json_str = self._extract_valid_json(tool_call_str)
|
# The match includes start and end tags, so strip them
|
||||||
if json_str is None:
|
matched_text = match.group(0)
|
||||||
# Fallback to trying to parse the entire string
|
tool_call_str = matched_text[len(self.tool_call_start_tag):-len(self.tool_call_end_tag)]
|
||||||
json_str = tool_call_str
|
|
||||||
|
|
||||||
|
# Extract valid JSON by finding matching braces
|
||||||
|
json_str = self._extract_valid_json(tool_call_str)
|
||||||
|
if json_str is None:
|
||||||
|
# Fallback to trying to parse the entire string
|
||||||
|
json_str = tool_call_str
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call_data = json.loads(json_str)
|
||||||
|
# Create the tool call object
|
||||||
|
tool_call = self._create_tool_call(tool_call_data)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse tool call JSON: {tool_call_str}. Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update the last end position
|
||||||
|
last_end = match.end()
|
||||||
|
|
||||||
|
# Extract content before the first tool call tag
|
||||||
|
first_match_start = matches[0].start()
|
||||||
|
content_before = llm_response[:first_match_start].strip() if first_match_start > 0 else None
|
||||||
|
|
||||||
|
# Extract content between tool calls and after the last tool call
|
||||||
|
content_parts = []
|
||||||
|
if content_before:
|
||||||
|
content_parts.append(content_before)
|
||||||
|
|
||||||
|
# Check if there's content after the last tool call
|
||||||
|
content_after = llm_response[last_end:].strip() if last_end < len(llm_response) else None
|
||||||
|
if content_after:
|
||||||
|
content_parts.append(content_after)
|
||||||
|
|
||||||
|
# Combine all content parts
|
||||||
|
content = " ".join(content_parts) if content_parts else None
|
||||||
|
|
||||||
|
return ResponseMessage(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls if tool_calls else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_incomplete_tool_call(self, llm_response: str) -> ResponseMessage:
|
||||||
|
"""
|
||||||
|
Parse a response with an incomplete tool call (missing closing tag).
|
||||||
|
|
||||||
|
This is a fallback method when the LLM doesn't close the tag properly.
|
||||||
|
It attempts to extract the tool call JSON and complete it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_response: The full LLM response with incomplete tool call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResponseMessage with content and optionally tool_calls
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
tool_call_data = json.loads(json_str)
|
# Find the opening tag
|
||||||
|
start_idx = llm_response.find(self.tool_call_start_tag)
|
||||||
|
if start_idx == -1:
|
||||||
|
return self._parse_text_only_response(llm_response)
|
||||||
|
|
||||||
# Extract content before the tool call tag
|
# Extract content before the opening tag
|
||||||
parts = llm_response.split(self.tool_call_start_tag, 1)
|
content_before = llm_response[:start_idx].strip() if start_idx > 0 else None
|
||||||
content = parts[0].strip() if parts[0] else None
|
|
||||||
|
|
||||||
# Create the tool call object
|
# Extract everything after the opening tag
|
||||||
tool_call = self._create_tool_call(tool_call_data)
|
after_tag = llm_response[start_idx + len(self.tool_call_start_tag):]
|
||||||
|
|
||||||
return ResponseMessage(
|
# Try to extract valid JSON
|
||||||
content=content,
|
json_str = self._extract_valid_json(after_tag)
|
||||||
tool_calls=[tool_call]
|
if json_str:
|
||||||
)
|
try:
|
||||||
|
tool_call_data = json.loads(json_str)
|
||||||
|
tool_call = self._create_tool_call(tool_call_data)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
logger.info(f"Successfully parsed incomplete tool call: {tool_call.function.name}")
|
||||||
raise ToolCallParseError(f"Invalid JSON in tool call: {tool_call_str}. Error: {e}")
|
|
||||||
|
return ResponseMessage(
|
||||||
|
content=content_before,
|
||||||
|
tool_calls=[tool_call]
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse JSON from incomplete tool call: {e}")
|
||||||
|
|
||||||
|
# If all else fails, return as text
|
||||||
|
logger.warning("Could not salvage incomplete tool call, returning as text")
|
||||||
|
return ResponseMessage(content=llm_response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in _parse_incomplete_tool_call: {e}")
|
||||||
|
return ResponseMessage(content=llm_response)
|
||||||
|
|
||||||
def _parse_text_only_response(self, llm_response: str) -> ResponseMessage:
|
def _parse_text_only_response(self, llm_response: str) -> ResponseMessage:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -39,6 +39,70 @@ def _parse_sse_data(chunk: bytes) -> Optional[Dict[str, Any]]:
|
|||||||
# --- End Helper ---
|
# --- End Helper ---
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tool_calls_to_content(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||||
|
"""
|
||||||
|
Converts assistant messages with tool_calls into content format using XML tags.
|
||||||
|
|
||||||
|
This function processes the message history and converts any assistant messages
|
||||||
|
that have tool_calls into a format that LLMs can understand. The tool_calls
|
||||||
|
are converted to <invoke>...</invoke> tags in the content field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of ChatMessage objects from the client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed list of ChatMessage objects with tool_calls converted to content
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Input: [{"role": "assistant", "tool_calls": [...]}]
|
||||||
|
Output: [{"role": "assistant", "content": "<invoke>{...}</invoke>"}]
|
||||||
|
"""
|
||||||
|
from .response_parser import TOOL_CALL_START_TAG, TOOL_CALL_END_TAG
|
||||||
|
|
||||||
|
processed_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
# Check if this is an assistant message with tool_calls
|
||||||
|
if msg.role == "assistant" and msg.tool_calls and len(msg.tool_calls) > 0:
|
||||||
|
# Convert each tool call to XML tag format
|
||||||
|
tool_call_contents = []
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
tc_data = tc.get("function", {})
|
||||||
|
name = tc_data.get("name", "")
|
||||||
|
arguments_str = tc_data.get("arguments", "{}")
|
||||||
|
|
||||||
|
# Parse arguments JSON to ensure it's valid
|
||||||
|
try:
|
||||||
|
arguments = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
# Build the tool call JSON
|
||||||
|
tool_call_json = {"name": name, "arguments": arguments}
|
||||||
|
|
||||||
|
# Wrap in XML tags
|
||||||
|
tool_call_content = f'{TOOL_CALL_START_TAG}{json.dumps(tool_call_json, ensure_ascii=False)}{TOOL_CALL_END_TAG}'
|
||||||
|
tool_call_contents.append(tool_call_content)
|
||||||
|
|
||||||
|
# Create new message with tool calls in content
|
||||||
|
# Preserve original content if it exists
|
||||||
|
content_parts = []
|
||||||
|
if msg.content:
|
||||||
|
content_parts.append(msg.content)
|
||||||
|
|
||||||
|
content_parts.extend(tool_call_contents)
|
||||||
|
new_content = "\n".join(content_parts)
|
||||||
|
|
||||||
|
processed_messages.append(
|
||||||
|
ChatMessage(role=msg.role, content=new_content)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Keep other messages as-is
|
||||||
|
processed_messages.append(msg)
|
||||||
|
|
||||||
|
return processed_messages
|
||||||
|
|
||||||
|
|
||||||
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
|
def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) -> List[ChatMessage]:
|
||||||
"""
|
"""
|
||||||
Injects a system prompt with tool definitions at the beginning of the message list.
|
Injects a system prompt with tool definitions at the beginning of the message list.
|
||||||
@@ -53,17 +117,33 @@ def inject_tools_into_prompt(messages: List[ChatMessage], tools: List[Tool]) ->
|
|||||||
|
|
||||||
tool_prompt = f"""
|
tool_prompt = f"""
|
||||||
You are a helpful assistant with access to a set of tools.
|
You are a helpful assistant with access to a set of tools.
|
||||||
You can call them by emitting a JSON object inside tool call tags.
|
|
||||||
|
|
||||||
IMPORTANT: Use the following format for tool calls:
|
## TOOL CALL FORMAT (CRITICAL)
|
||||||
Format: {TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
|
|
||||||
|
|
||||||
Example: {full_example}
|
When you need to use a tool, you MUST follow this EXACT format:
|
||||||
|
|
||||||
Here are the available tools:
|
{TOOL_CALL_START_TAG}{{"name": "tool_name", "arguments": {{...}}}}{TOOL_CALL_END_TAG}
|
||||||
|
|
||||||
|
### IMPORTANT RULES:
|
||||||
|
1. ALWAYS include BOTH the opening tag ({TOOL_CALL_START_TAG}) AND closing tag ({TOOL_CALL_END_TAG})
|
||||||
|
2. The JSON must be valid and properly formatted
|
||||||
|
3. Keep arguments concise to avoid truncation
|
||||||
|
4. Do not include any text between the tags except the JSON
|
||||||
|
|
||||||
|
### Examples:
|
||||||
|
Simple call:
|
||||||
|
{full_example}
|
||||||
|
|
||||||
|
Multiple arguments:
|
||||||
|
{TOOL_CALL_START_TAG}{{"name": "search", "arguments": {{"query": "example", "limit": 5}}}}{TOOL_CALL_END_TAG}
|
||||||
|
|
||||||
|
## AVAILABLE TOOLS:
|
||||||
{tool_defs}
|
{tool_defs}
|
||||||
|
|
||||||
Only use the tools if strictly necessary.
|
## REMEMBER:
|
||||||
|
- If you decide to call a tool, output ONLY the tool call tags (you may add brief text before or after)
|
||||||
|
- ALWAYS close your tags properly with {TOOL_CALL_END_TAG}
|
||||||
|
- Keep your arguments concise and essential
|
||||||
"""
|
"""
|
||||||
# Prepend the system prompt with tool definitions
|
# Prepend the system prompt with tool definitions
|
||||||
return [ChatMessage(role="system", content=tool_prompt)] + messages
|
return [ChatMessage(role="system", content=tool_prompt)] + messages
|
||||||
@@ -86,7 +166,7 @@ async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings,
|
|||||||
Yields raw byte chunks as received.
|
Yields raw byte chunks as received.
|
||||||
"""
|
"""
|
||||||
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
|
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
|
||||||
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
||||||
|
|
||||||
# Log the request payload to the database
|
# Log the request payload to the database
|
||||||
update_request_log(log_id, llm_request=payload)
|
update_request_log(log_id, llm_request=payload)
|
||||||
|
|||||||
18
docker-compose.yml
Normal file
18
docker-compose.yml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
sqlite-web:
|
||||||
|
image: docker.1ms.run/coleifer/sqlite-web
|
||||||
|
volumes:
|
||||||
|
- .:/data
|
||||||
|
environment:
|
||||||
|
SQLITE_DATABASE: llm_proxy.db
|
||||||
|
ports:
|
||||||
|
- 8580:8080
|
||||||
|
llmproxy:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- .:/app
|
||||||
|
restart: unless-stopped
|
||||||
1139
docs/multi_backend_final_design.md
Normal file
1139
docs/multi_backend_final_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
httpx
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pydantic
|
|
||||||
requests
|
|
||||||
|
|||||||
155
test_content_with_tool_calls.py
Normal file
155
test_content_with_tool_calls.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试 chat 接口同时返回文本内容和 tool_calls
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from app.response_parser import ResponseParser
|
||||||
|
from app.models import ResponseMessage, ToolCall, ToolCallFunction
|
||||||
|
|
||||||
|
def test_content_and_tool_calls():
|
||||||
|
"""测试同时返回文本内容和 tool_calls 的各种场景"""
|
||||||
|
|
||||||
|
parser = ResponseParser()
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("测试:同时返回文本内容和 tool_calls")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# 场景 1: 文本在前 + tool_calls
|
||||||
|
print("\n场景 1: 先说话,再调用工具")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
text1 = """好的,我来帮你查询北京的天气情况。
|
||||||
|
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "北京", "unit": "celsius"}}</invoke>"""
|
||||||
|
|
||||||
|
result1 = parser.parse(text1)
|
||||||
|
|
||||||
|
print(f"输入文本:\n{text1}\n")
|
||||||
|
print(f"解析结果:")
|
||||||
|
print(f" - content: {result1.content}")
|
||||||
|
print(f" - tool_calls: {len(result1.tool_calls) if result1.tool_calls else 0} 个")
|
||||||
|
if result1.tool_calls:
|
||||||
|
for tc in result1.tool_calls:
|
||||||
|
print(f" * {tc.function.name}: {tc.function.arguments}")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
assert result1.content is not None, "Content should not be None"
|
||||||
|
assert result1.tool_calls is not None, "Tool calls should not be None"
|
||||||
|
assert len(result1.tool_calls) == 1, "Should have 1 tool call"
|
||||||
|
assert "北京" in result1.content or "查询" in result1.content, "Content should contain original text"
|
||||||
|
|
||||||
|
print(" ✓ 场景 1 通过")
|
||||||
|
|
||||||
|
# 场景 2: tool_calls + 文本在后
|
||||||
|
print("\n场景 2: 先调用工具,再说话")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
text2 = """<invoke>{"name": "search", "arguments": {"query": "今天天气"}}</invoke>
|
||||||
|
|
||||||
|
我已经帮你查询了,请稍等片刻。"""
|
||||||
|
|
||||||
|
result2 = parser.parse(text2)
|
||||||
|
|
||||||
|
print(f"输入文本:\n{text2}\n")
|
||||||
|
print(f"解析结果:")
|
||||||
|
print(f" - content: {result2.content}")
|
||||||
|
print(f" - tool_calls: {len(result2.tool_calls) if result2.tool_calls else 0} 个")
|
||||||
|
if result2.tool_calls:
|
||||||
|
for tc in result2.tool_calls:
|
||||||
|
print(f" * {tc.function.name}: {tc.function.arguments}")
|
||||||
|
|
||||||
|
assert result2.content is not None
|
||||||
|
assert result2.tool_calls is not None
|
||||||
|
assert "稍等" in result2.content or "查询" in result2.content
|
||||||
|
|
||||||
|
print(" ✓ 场景 2 通过")
|
||||||
|
|
||||||
|
# 场景 3: 文本 - tool_calls - 文本
|
||||||
|
print("\n场景 3: 文本 - 工具调用 - 文本(三明治结构)")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
text3 = """让我先查一下北京的温度。
|
||||||
|
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "北京"}}</invoke>
|
||||||
|
|
||||||
|
查到了,我再查一下上海的。
|
||||||
|
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "上海"}}</invoke>
|
||||||
|
|
||||||
|
好了,两个城市都查询完毕。"""
|
||||||
|
|
||||||
|
result3 = parser.parse(text3)
|
||||||
|
|
||||||
|
print(f"输入文本:\n{text3}\n")
|
||||||
|
print(f"解析结果:")
|
||||||
|
print(f" - content: {result3.content}")
|
||||||
|
print(f" - tool_calls: {len(result3.tool_calls) if result3.tool_calls else 0} 个")
|
||||||
|
if result3.tool_calls:
|
||||||
|
for i, tc in enumerate(result3.tool_calls, 1):
|
||||||
|
print(f" * {tc.function.name}: {tc.function.arguments}")
|
||||||
|
|
||||||
|
assert result3.content is not None
|
||||||
|
assert result3.tool_calls is not None
|
||||||
|
assert len(result3.tool_calls) == 2
|
||||||
|
assert "先查一下" in result3.content
|
||||||
|
assert "查询完毕" in result3.content
|
||||||
|
|
||||||
|
print(" ✓ 场景 3 通过")
|
||||||
|
|
||||||
|
# 场景 4: 测试 ResponseMessage 序列化
|
||||||
|
print("\n场景 4: 验证 ResponseMessage 可以正确序列化为 JSON")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
msg = ResponseMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="好的,我来帮你查询。",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
type="function",
|
||||||
|
function=ToolCallFunction(
|
||||||
|
name="get_weather",
|
||||||
|
arguments=json.dumps({"location": "北京"})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = msg.model_dump_json(indent=2)
|
||||||
|
print("序列化的 JSON 响应:")
|
||||||
|
print(json_str)
|
||||||
|
|
||||||
|
parsed_back = ResponseMessage.model_validate_json(json_str)
|
||||||
|
assert parsed_back.content == msg.content
|
||||||
|
assert parsed_back.tool_calls is not None
|
||||||
|
assert len(parsed_back.tool_calls) == 1
|
||||||
|
|
||||||
|
print(" ✓ 场景 4 通过 - JSON 序列化/反序列化正常")
|
||||||
|
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("所有测试通过! ✓")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
print("\n总结:")
|
||||||
|
print("✓ chat 接口支持同时返回文本内容和 tool_calls")
|
||||||
|
print("✓ content 和 tool_calls 可以同时存在")
|
||||||
|
print("✓ 支持文本在前、在后、或前后都有文本的场景")
|
||||||
|
print("✓ 支持多个 tool_calls 与文本内容混合")
|
||||||
|
print("✓ JSON 序列化/反序列化正常")
|
||||||
|
|
||||||
|
print("\n实际应用场景示例:")
|
||||||
|
print("""
|
||||||
|
Assistant: "好的,我来帮你查询一下。"
|
||||||
|
[调用 get_weather 工具]
|
||||||
|
[收到工具结果]
|
||||||
|
Assistant: "北京今天晴天,气温 25°C。"
|
||||||
|
""")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_content_and_tool_calls()
|
||||||
154
test_multiple_tool_calls.py
Normal file
154
test_multiple_tool_calls.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试多个 tool_calls 的完整流程
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from app.services import convert_tool_calls_to_content
|
||||||
|
from app.response_parser import ResponseParser
|
||||||
|
from app.models import ChatMessage
|
||||||
|
|
||||||
|
def test_multiple_tool_calls():
|
||||||
|
"""测试多个 tool_calls 的完整流程"""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试场景:消息历史中有多个 tool_calls")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 模拟对话场景
|
||||||
|
# 用户问北京和上海的天气,assistant 调用了两个工具
|
||||||
|
messages = [
|
||||||
|
ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content="帮我查一下北京和上海的天气"
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "北京", "unit": "celsius"}'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "上海", "unit": "celsius"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content="结果怎么样?"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n1. 原始消息:")
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
print(f" 消息 {i+1}: {msg.role}")
|
||||||
|
if msg.content:
|
||||||
|
print(f" 内容: {msg.content}")
|
||||||
|
if msg.tool_calls:
|
||||||
|
print(f" 工具调用: {len(msg.tool_calls)} 个")
|
||||||
|
for j, tc in enumerate(msg.tool_calls):
|
||||||
|
print(f" {j+1}. {tc['function']['name']}")
|
||||||
|
|
||||||
|
# 转换 tool_calls 到 content
|
||||||
|
print("\n2. 转换后的消息(发送给 LLM):")
|
||||||
|
converted = convert_tool_calls_to_content(messages)
|
||||||
|
for i, msg in enumerate(converted):
|
||||||
|
print(f" 消息 {i+1}: {msg.role}")
|
||||||
|
if msg.content:
|
||||||
|
# 只显示前 150 个字符
|
||||||
|
content_preview = msg.content[:150] + "..." if len(msg.content) > 150 else msg.content
|
||||||
|
print(f" 内容: {content_preview}")
|
||||||
|
|
||||||
|
# 验证转换
|
||||||
|
assert "<invoke>" in converted[1].content
|
||||||
|
assert converted[1].content.count("<invoke>") == 2
|
||||||
|
print("\n ✓ 转换成功!两个 tool_calls 都被转换成 XML 标签格式")
|
||||||
|
|
||||||
|
# 模拟 LLM 返回新的响应(也包含多个 tool_calls)
|
||||||
|
print("\n3. 模拟 LLM 响应(包含多个 tool_calls):")
|
||||||
|
llm_response = '''好的,我来帮你查一下其他城市的天气。
|
||||||
|
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "广州"}}</invoke>
|
||||||
|
<invoke>{"name": "get_weather", "arguments": {"location": "深圳"}}</invoke>
|
||||||
|
|
||||||
|
请稍等。'''
|
||||||
|
|
||||||
|
print(f" {llm_response}")
|
||||||
|
|
||||||
|
# 解析 LLM 响应
|
||||||
|
print("\n4. 解析 LLM 响应:")
|
||||||
|
parser = ResponseParser()
|
||||||
|
parsed = parser.parse(llm_response)
|
||||||
|
|
||||||
|
print(f" Content: {parsed.content}")
|
||||||
|
print(f" Tool calls 数量: {len(parsed.tool_calls) if parsed.tool_calls else 0}")
|
||||||
|
|
||||||
|
if parsed.tool_calls:
|
||||||
|
for i, tc in enumerate(parsed.tool_calls):
|
||||||
|
import json
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
print(f" {i+1}. {tc.function.name}(location={args['location']})")
|
||||||
|
|
||||||
|
# 验证解析
|
||||||
|
assert parsed.tool_calls is not None
|
||||||
|
assert len(parsed.tool_calls) == 2
|
||||||
|
assert parsed.tool_calls[0].function.name == "get_weather"
|
||||||
|
assert parsed.tool_calls[1].function.name == "get_weather"
|
||||||
|
print("\n ✓ 解析成功!两个 tool_calls 都被正确提取")
|
||||||
|
|
||||||
|
# 测试场景 2:单个 tool_call(向后兼容)
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试场景:单个 tool_call(向后兼容性)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
single_response = '''我来帮你查询。
|
||||||
|
|
||||||
|
<invoke>{"name": "search", "arguments": {"query": "今天天气"}}</invoke>'''
|
||||||
|
|
||||||
|
parsed_single = parser.parse(single_response)
|
||||||
|
print(f"Content: {parsed_single.content}")
|
||||||
|
print(f"Tool calls 数量: {len(parsed_single.tool_calls) if parsed_single.tool_calls else 0}")
|
||||||
|
|
||||||
|
assert parsed_single.tool_calls is not None
|
||||||
|
assert len(parsed_single.tool_calls) == 1
|
||||||
|
assert parsed_single.tool_calls[0].function.name == "search"
|
||||||
|
print("✓ 单个 tool_call 解析正常")
|
||||||
|
|
||||||
|
# 测试场景 3:没有 tool_call
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试场景:没有 tool_call")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
no_tool_response = "你好!有什么可以帮助你的吗?"
|
||||||
|
|
||||||
|
parsed_no_tool = parser.parse(no_tool_response)
|
||||||
|
print(f"Content: {parsed_no_tool.content}")
|
||||||
|
print(f"Tool calls: {parsed_no_tool.tool_calls}")
|
||||||
|
|
||||||
|
assert parsed_no_tool.content == no_tool_response
|
||||||
|
assert parsed_no_tool.tool_calls is None
|
||||||
|
print("✓ 普通文本响应解析正常")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("所有测试通过! ✓")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\n总结:")
|
||||||
|
print("- 消息历史中的多个 tool_calls 可以正确转换为 XML 格式")
|
||||||
|
print("- LLM 响应中的多个 tool_calls 可以正确解析")
|
||||||
|
print("- 向后兼容单个 tool_call 和普通文本响应")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_multiple_tool_calls()
|
||||||
193
test_tool_call_conversion.py
Normal file
193
test_tool_call_conversion.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试 tool_calls 到 content 的转换功能
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目路径到 sys.path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from app.services import convert_tool_calls_to_content
|
||||||
|
from app.models import ChatMessage
|
||||||
|
|
||||||
|
def test_convert_tool_calls_to_content():
|
||||||
|
"""测试工具调用转换功能"""
|
||||||
|
|
||||||
|
# 测试用例 1: 带有 tool_calls 的 assistant 消息
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试用例 1: 带有 tool_calls 的 assistant 消息")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content="帮我查询一下天气"
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "北京", "unit": "celsius"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
role="user",
|
||||||
|
content="那上海呢?"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n原始消息:")
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
print(f" 消息 {i+1}:")
|
||||||
|
print(f" 角色: {msg.role}")
|
||||||
|
if msg.content:
|
||||||
|
print(f" 内容: {msg.content}")
|
||||||
|
if msg.tool_calls:
|
||||||
|
print(f" 工具调用: {len(msg.tool_calls)} 个")
|
||||||
|
|
||||||
|
# 转换
|
||||||
|
converted = convert_tool_calls_to_content(messages)
|
||||||
|
|
||||||
|
print("\n转换后的消息:")
|
||||||
|
for i, msg in enumerate(converted):
|
||||||
|
print(f" 消息 {i+1}:")
|
||||||
|
print(f" 角色: {msg.role}")
|
||||||
|
if msg.content:
|
||||||
|
print(f" 内容: {msg.content[:100]}...") # 只显示前100个字符
|
||||||
|
|
||||||
|
# 验证第二个消息是否被正确转换
|
||||||
|
assert converted[1].role == "assistant"
|
||||||
|
assert "<invoke>" in converted[1].content
|
||||||
|
assert "get_weather" in converted[1].content
|
||||||
|
assert "北京" in converted[1].content
|
||||||
|
assert converted[1].tool_calls is None # tool_calls 应该被移除
|
||||||
|
|
||||||
|
print("\n✓ 测试用例 1 通过!")
|
||||||
|
|
||||||
|
# 测试用例 2: 带有 content 和 tool_calls 的 assistant 消息
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试用例 2: 带有 content 和 tool_calls 的 assistant 消息")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
messages2 = [
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="好的,让我帮你查询天气。",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call_456",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search",
|
||||||
|
"arguments": '{"query": "今天天气"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n原始消息:")
|
||||||
|
print(f" 角色: {messages2[0].role}")
|
||||||
|
print(f" 内容: {messages2[0].content}")
|
||||||
|
print(f" 工具调用: {messages2[0].tool_calls}")
|
||||||
|
|
||||||
|
converted2 = convert_tool_calls_to_content(messages2)
|
||||||
|
|
||||||
|
print("\n转换后的消息:")
|
||||||
|
print(f" 角色: {converted2[0].role}")
|
||||||
|
print(f" 内容: {converted2[0].content}")
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
assert "好的,让我帮你查询天气。" in converted2[0].content
|
||||||
|
assert "<invoke>" in converted2[0].content
|
||||||
|
assert "search" in converted2[0].content
|
||||||
|
|
||||||
|
print("\n✓ 测试用例 2 通过!")
|
||||||
|
|
||||||
|
# 测试用例 3: 多个 tool_calls
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试用例 3: 多个 tool_calls")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
messages3 = [
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "北京"}'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_2",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "上海"}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n原始消息:")
|
||||||
|
print(f" 角色: {messages3[0].role}")
|
||||||
|
print(f" 工具调用数量: {len(messages3[0].tool_calls)}")
|
||||||
|
|
||||||
|
converted3 = convert_tool_calls_to_content(messages3)
|
||||||
|
|
||||||
|
print("\n转换后的消息:")
|
||||||
|
print(f" 内容: {converted3[0].content}")
|
||||||
|
|
||||||
|
# 验证两个工具调用都被转换
|
||||||
|
assert converted3[0].content.count("<invoke>") == 2
|
||||||
|
assert "北京" in converted3[0].content
|
||||||
|
assert "上海" in converted3[0].content
|
||||||
|
|
||||||
|
print("\n✓ 测试用例 3 通过!")
|
||||||
|
|
||||||
|
# 测试用例 4: 没有 tool_calls 的消息(应该保持不变)
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试用例 4: 没有 tool_calls 的消息")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
messages4 = [
|
||||||
|
ChatMessage(role="user", content="你好"),
|
||||||
|
ChatMessage(role="assistant", content="你好,有什么可以帮助你的吗?"),
|
||||||
|
ChatMessage(role="user", content="再见")
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n原始消息:")
|
||||||
|
for i, msg in enumerate(messages4):
|
||||||
|
print(f" 消息 {i+1}: {msg.role} - {msg.content}")
|
||||||
|
|
||||||
|
converted4 = convert_tool_calls_to_content(messages4)
|
||||||
|
|
||||||
|
print("\n转换后的消息:")
|
||||||
|
for i, msg in enumerate(converted4):
|
||||||
|
print(f" 消息 {i+1}: {msg.role} - {msg.content}")
|
||||||
|
|
||||||
|
# 验证消息保持不变
|
||||||
|
assert len(converted4) == len(messages4)
|
||||||
|
assert converted4[0].content == "你好"
|
||||||
|
assert converted4[1].content == "你好,有什么可以帮助你的吗?"
|
||||||
|
assert converted4[2].content == "再见"
|
||||||
|
|
||||||
|
print("\n✓ 测试用例 4 通过!")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("所有测试用例通过! ✓")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_convert_tool_calls_to_content()
|
||||||
Reference in New Issue
Block a user