1140 lines
37 KiB
Markdown
1140 lines
37 KiB
Markdown
# 多后端大模型支持实现方案(最终版)
|
||
|
||
> **版本:v2.0**
|
||
> **更新日期:2025-12-31**
|
||
> **状态:待确认**
|
||
|
||
## 📋 方案概述
|
||
|
||
### 核心设计理念
|
||
1. ✅ **Provider-Based 架构**:单个适配器实例支持同一提供商的多个模型
|
||
2. ✅ **OpenAI 兼容接口**:适配器参数和返回值完全对标 OpenAI API
|
||
3. ✅ **简洁配置 + 代码实现**:配置文件只存数据,复杂逻辑用代码实现
|
||
4. ✅ **统一标准化接口**:所有适配器遵循相同的接口规范
|
||
|
||
### 适用场景
|
||
- 标准 OpenAI 兼容 API
|
||
- 不同认证方式的 API(API-Key、签名等)
|
||
- 不同请求/响应格式的 API
|
||
- 需要特殊处理的 API
|
||
|
||
---
|
||
|
||
## 一、架构设计
|
||
|
||
### 1.1 整体架构图
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Client Request │
|
||
│ {"model": "gpt-4", "messages": [...]} │
|
||
└─────────────────────┬───────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Chat Completions API │
|
||
│ /v1/chat/completions │
|
||
└─────────────────────┬───────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Model Router (模型路由器) │
|
||
│ - 根据 model 字段查找对应的 provider │
|
||
│ - 缓存 adapter 实例(按 provider 缓存) │
|
||
│ - 动态加载适配器(generic 或 custom) │
|
||
│ │
|
||
│ 例: "gpt-4" → OpenAI Provider → GenericAdapter │
|
||
│ "gpt-3.5-turbo" → OpenAI Provider → GenericAdapter │
|
||
│ (同一个 adapter 实例) │
|
||
└─────────────────────┬───────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Model Adapter (适配器) │
|
||
│ ┌─────────────────────────────────────────────────┐ │
|
||
│ │ BaseAdapter (基类) │ │
|
||
│ │ - chat() / stream_chat() │ │
|
||
│ │ - get_headers() → 可重写(认证) │ │
|
||
│ │ - format_request() → 可重写(请求转换) │ │
|
||
│ │ - parse_response() → 可重写(响应转换) │ │
|
||
│ └─────────────────────────────────────────────────┘ │
|
||
│ ↓ │
|
||
│ ┌─────────────────────────────────────────────────┐ │
|
||
│ │ GenericAdapter (通用适配器) │ │
|
||
│ │ 适用于标准 OpenAI 兼容 API │ │
|
||
│ └─────────────────────────────────────────────────┘ │
|
||
│ ↓ │
|
||
│ ┌─────────────────────────────────────────────────┐ │
|
||
│ │ CustomAdapter (自定义适配器) │ │
|
||
│ │ 重写方法处理非标准接口 │ │
|
||
│ │ - get_headers() → 自定义认证 │ │
|
||
│ │ - format_request() → 请求格式转换 │ │
|
||
│ │ - parse_response() → 响应格式转换 │ │
|
||
│ └─────────────────────────────────────────────────┘ │
|
||
└─────────────────────┬───────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Backend LLM API (各大厂商) │
|
||
└─────────────────────────────────────────────────────────┘
|
||
↓
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ Standardized Response │
|
||
│ (OpenAI compatible format) │
|
||
└─────────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
### 1.2 核心组件
|
||
|
||
| 组件 | 文件路径 | 作用 |
|
||
|------|---------|------|
|
||
| **配置管理器** | `app/core/model_config.py` | 加载配置,提供 provider/model 查询 |
|
||
| **数据模型** | `app/models.py` | OpenAI 格式的请求/响应数据结构 |
|
||
| **适配器基类** | `app/model_adapters/base.py` | 定义统一的适配器接口 |
|
||
| **通用适配器** | `app/model_adapters/generic_adapter.py` | 处理标准 OpenAI 兼容 API |
|
||
| **自定义适配器** | `app/model_adapters/custom_*.py` | 处理各种非标准接口 |
|
||
| **模型路由器** | `app/services.py` | 动态加载适配器,管理实例缓存 |
|
||
| **API 接口** | `app/main.py` | FastAPI 路由,处理 HTTP 请求 |
|
||
|
||
### 1.3 设计决策
|
||
|
||
| 决策点 | 选择 | 理由 |
|
||
|--------|------|------|
|
||
| 适配器粒度 | Provider-based(按提供商) | 单个适配器支持多个模型,复用连接池 |
|
||
| 接口格式 | OpenAI 兼容 | 行业标准,易于理解和迁移 |
|
||
| 配置方式 | 简洁 YAML | 只存数据,不存逻辑 |
|
||
| 复杂逻辑 | 代码实现 | 易调试、易维护、更灵活 |
|
||
| 缓存策略 | 按 provider 缓存 | 复用 HTTP 连接,减少实例数 |
|
||
|
||
---
|
||
|
||
## 二、配置文件设计(极简版)
|
||
|
||
### 2.1 配置结构 (`config/models.yaml`)
|
||
|
||
```yaml
|
||
# ========================================
|
||
# 版本信息
|
||
# ========================================
|
||
version: "2.0"
|
||
|
||
# ========================================
|
||
# 默认模型
|
||
# ========================================
|
||
default_model: gpt-3.5-turbo
|
||
|
||
# ========================================
|
||
# 提供商配置
|
||
# ========================================
|
||
providers:
|
||
# -------------------- 标准 OpenAI 兼容 API --------------------
|
||
openai:
|
||
adapter_type: generic
|
||
api_url: https://api.openai.com/v1/chat/completions
|
||
api_key_env: OPENAI_API_KEY
|
||
description: "OpenAI API"
|
||
models:
|
||
- name: gpt-4
|
||
display_name: "GPT-4"
|
||
max_tokens: 8192
|
||
- name: gpt-4-turbo
|
||
display_name: "GPT-4 Turbo"
|
||
max_tokens: 128000
|
||
- name: gpt-3.5-turbo
|
||
display_name: "GPT-3.5 Turbo"
|
||
max_tokens: 4096
|
||
|
||
# -------------------- Anthropic (需要自定义适配器)--------------------
|
||
anthropic:
|
||
adapter_type: custom
|
||
adapter_class: app.model_adapters.anthropic.AnthropicAdapter
|
||
api_url: https://api.anthropic.com/v1/messages
|
||
api_key_env: ANTHROPIC_API_KEY
|
||
description: "Anthropic Claude API"
|
||
models:
|
||
- name: claude-3-opus
|
||
display_name: "Claude 3 Opus"
|
||
max_tokens: 200000
|
||
- name: claude-3-sonnet
|
||
display_name: "Claude 3 Sonnet"
|
||
max_tokens: 200000
|
||
|
||
# -------------------- DashScope (阿里云) --------------------
|
||
dashscope:
|
||
adapter_type: generic
|
||
api_url: https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions
|
||
api_key_env: DASHSCOPE_API_KEY
|
||
description: "阿里云 DashScope"
|
||
models:
|
||
- name: qwen-max
|
||
display_name: "Qwen Max"
|
||
max_tokens: 6000
|
||
- name: qwen-plus
|
||
display_name: "Qwen Plus"
|
||
max_tokens: 30000
|
||
|
||
# -------------------- 自定义认证 API --------------------
|
||
custom_auth_api:
|
||
adapter_type: custom
|
||
adapter_class: app.model_adapters.custom_auth.CustomAuthAdapter
|
||
api_url: https://custom-api.com/v1/chat
|
||
api_key_env: CUSTOM_API_KEY
|
||
description: "使用 X-API-Key 认证的 API"
|
||
models:
|
||
- name: custom-model-v1
|
||
display_name: "Custom Model V1"
|
||
|
||
# -------------------- 需要签名的 API --------------------
|
||
signed_api:
|
||
adapter_type: custom
|
||
adapter_class: app.model_adapters.signed.SignedApiAdapter
|
||
api_url: https://signed-api.com/v1/chat
|
||
api_key_env: SIGNED_API_KEY
|
||
api_secret_env: SIGNED_API_SECRET
|
||
description: "需要 HMAC 签名的 API"
|
||
models:
|
||
- name: signed-model
|
||
display_name: "Signed Model"
|
||
|
||
# -------------------- 本地 Ollama --------------------
|
||
local:
|
||
adapter_type: generic
|
||
api_url: http://localhost:11434/v1/chat/completions
|
||
api_key_env: LOCAL_LLM_API_KEY
|
||
description: "本地 Ollama"
|
||
models:
|
||
- name: llama2
|
||
display_name: "Llama 2"
|
||
- name: mistral
|
||
display_name: "Mistral"
|
||
```
|
||
|
||
### 2.2 配置项说明
|
||
|
||
| 字段 | 类型 | 必需 | 说明 |
|
||
|------|------|------|------|
|
||
| `adapter_type` | string | ✅ | `generic`(通用)或 `custom`(自定义) |
|
||
| `adapter_class` | string | ⚠️ | 自定义适配器类路径(adapter_type=custom 时必需) |
|
||
| `api_url` | string | ✅ | API 端点 URL |
|
||
| `api_key_env` | string | ✅ | API Key 环境变量名 |
|
||
| `api_secret_env` | string | ❌ | API Secret 环境变量名(签名 API 需要) |
|
||
| `description` | string | ❌ | 描述信息 |
|
||
| `models` | array | ✅ | 模型列表 |
|
||
|
||
### 2.3 模型配置
|
||
|
||
```yaml
|
||
models:
|
||
- name: str # 模型名称(必需,唯一标识)
|
||
display_name: str # 显示名称(必需)
|
||
max_tokens: int # 最大 token 数(可选)
|
||
default_params: # 默认参数(可选)
|
||
temperature: 0.7
|
||
top_p: 0.9
|
||
```
|
||
|
||
---
|
||
|
||
## 三、数据模型设计(OpenAI 兼容)
|
||
|
||
### 3.1 请求模型 (`app/models.py`)
|
||
|
||
```python
|
||
from typing import List, Dict, Any, Optional, Union
|
||
from pydantic import BaseModel, Field
|
||
from enum import Enum
|
||
|
||
class MessageRole(str, Enum):
|
||
"""消息角色枚举"""
|
||
SYSTEM = "system"
|
||
USER = "user"
|
||
ASSISTANT = "assistant"
|
||
TOOL = "tool"
|
||
|
||
class ChatMessage(BaseModel):
|
||
"""聊天消息(与 OpenAI API 格式一致)"""
|
||
role: MessageRole
|
||
content: Optional[str] = None
|
||
name: Optional[str] = None
|
||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||
tool_call_id: Optional[str] = None
|
||
|
||
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
||
"""序列化为字典,过滤 None 值"""
|
||
data = super().model_dump(exclude_none=True, **kwargs)
|
||
if "role" in data and isinstance(data["role"], MessageRole):
|
||
data["role"] = data["role"].value
|
||
return data
|
||
|
||
class ToolFunction(BaseModel):
|
||
"""工具函数定义"""
|
||
name: str
|
||
description: str
|
||
parameters: Dict[str, Any]
|
||
|
||
class Tool(BaseModel):
|
||
"""工具定义(OpenAI 格式)"""
|
||
type: str = "function"
|
||
function: ToolFunction
|
||
|
||
class AdapterRequest(BaseModel):
|
||
"""
|
||
标准化的适配器请求参数
|
||
完全对标 OpenAI Chat Completions API
|
||
"""
|
||
# ========== 必需参数 ==========
|
||
messages: List[ChatMessage]
|
||
model: str
|
||
|
||
# ========== 可选参数 ==========
|
||
temperature: Optional[float] = Field(None, ge=0, le=2)
|
||
max_tokens: Optional[int] = Field(None, gt=0)
|
||
top_p: Optional[float] = Field(None, ge=0, le=1)
|
||
n: Optional[int] = Field(1, ge=1)
|
||
stream: bool = False
|
||
|
||
# ========== 工具相关 ==========
|
||
tools: Optional[List[Tool]] = None
|
||
tool_choice: Optional[Union[str, Dict]] = None
|
||
|
||
# ========== 停止条件 ==========
|
||
stop: Optional[Union[str, List[str]]] = None
|
||
|
||
# ========== 其他参数 ==========
|
||
presence_penalty: Optional[float] = Field(None, ge=-2, le=2)
|
||
frequency_penalty: Optional[float] = Field(None, ge=-2, le=2)
|
||
|
||
# ========== 扩展参数(厂商特定) ==========
|
||
extra: Dict[str, Any] = Field(default_factory=dict)
|
||
```
|
||
|
||
### 3.2 响应模型 (`app/models.py`)
|
||
|
||
```python
|
||
class ToolCall(BaseModel):
|
||
"""工具调用(OpenAI 格式)"""
|
||
id: str
|
||
type: str = "function"
|
||
function: Dict[str, Any] # {name, arguments}
|
||
|
||
class ResponseMessage(BaseModel):
|
||
"""响应消息"""
|
||
role: str = "assistant"
|
||
content: Optional[str] = None
|
||
tool_calls: Optional[List[ToolCall]] = None
|
||
|
||
class Choice(BaseModel):
|
||
"""选择项(非流式)"""
|
||
index: int = 0
|
||
message: ResponseMessage
|
||
finish_reason: Optional[str] = None
|
||
|
||
class Usage(BaseModel):
|
||
"""Token 使用统计"""
|
||
prompt_tokens: int
|
||
completion_tokens: int
|
||
total_tokens: int
|
||
|
||
class AdapterResponse(BaseModel):
|
||
"""
|
||
标准化的适配器响应
|
||
完全对标 OpenAI Chat Completions API 响应格式
|
||
"""
|
||
id: str
|
||
object: str = "chat.completion"
|
||
created: int
|
||
model: str
|
||
choices: List[Choice]
|
||
usage: Optional[Usage] = None
|
||
```
|
||
|
||
---
|
||
|
||
## 四、适配器实现
|
||
|
||
### 4.1 适配器基类 (`app/model_adapters/base.py`)
|
||
|
||
```python
|
||
from abc import ABC, abstractmethod
|
||
from typing import AsyncGenerator, Optional, Dict, Any
|
||
import httpx
|
||
import logging
|
||
import time
|
||
import uuid
|
||
|
||
from ..models import AdapterRequest, AdapterResponse
|
||
from ..core.model_config import ProviderConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaseAdapter(ABC):
|
||
"""
|
||
适配器基类
|
||
|
||
所有适配器必须实现此类,提供统一的接口
|
||
"""
|
||
|
||
def __init__(self, config: ProviderConfig):
|
||
"""
|
||
初始化适配器
|
||
|
||
Args:
|
||
config: 提供商配置
|
||
"""
|
||
self.config = config
|
||
self.client: Optional[httpx.AsyncClient] = None
|
||
self._is_initialized = False
|
||
|
||
# ========== 核心接口(必须实现)==========
|
||
|
||
@abstractmethod
|
||
async def chat(self, request: AdapterRequest) -> AdapterResponse:
|
||
"""
|
||
非流式聊天
|
||
|
||
Args:
|
||
request: 标准化的请求(OpenAI 格式)
|
||
|
||
Returns:
|
||
AdapterResponse: 标准化的响应(OpenAI 格式)
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def stream_chat(self, request: AdapterRequest) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
流式聊天
|
||
|
||
Args:
|
||
request: 标准化的请求(OpenAI 格式)
|
||
|
||
Yields:
|
||
bytes: SSE 格式数据块 "data: {json}\n\n"
|
||
"""
|
||
pass
|
||
|
||
# ========== 辅助方法(可重写以处理非标准接口)==========
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""
|
||
获取请求头(包含认证信息)
|
||
|
||
可重写此方法以实现自定义认证逻辑
|
||
|
||
Returns:
|
||
Dict[str, str]: 请求头字典
|
||
"""
|
||
return {
|
||
"Authorization": f"Bearer {self.config.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
def format_request(self, request: AdapterRequest) -> Dict[str, Any]:
|
||
"""
|
||
将 OpenAI 格式请求转换为厂商 API 格式
|
||
|
||
可重写此方法以处理非标准请求格式
|
||
|
||
Args:
|
||
request: OpenAI 格式请求
|
||
|
||
Returns:
|
||
Dict: 厂商 API 需要的请求体
|
||
"""
|
||
from ..models import Tool
|
||
|
||
payload = {
|
||
"model": request.model,
|
||
"messages": [msg.model_dump() for msg in request.messages],
|
||
"stream": request.stream
|
||
}
|
||
|
||
# 添加可选参数
|
||
if request.temperature is not None:
|
||
payload["temperature"] = request.temperature
|
||
if request.max_tokens is not None:
|
||
payload["max_tokens"] = request.max_tokens
|
||
if request.top_p is not None:
|
||
payload["top_p"] = request.top_p
|
||
if request.stop is not None:
|
||
payload["stop"] = request.stop
|
||
if request.tools:
|
||
payload["tools"] = [tool.model_dump() for tool in request.tools]
|
||
if request.tool_choice is not None:
|
||
payload["tool_choice"] = request.tool_choice
|
||
|
||
# 合并额外参数
|
||
payload.update(request.extra)
|
||
|
||
return payload
|
||
|
||
def parse_response(self, response_data: Dict, request: AdapterRequest) -> AdapterResponse:
|
||
"""
|
||
将厂商 API 响应转换为 OpenAI 格式
|
||
|
||
可重写此方法以处理非标准响应格式
|
||
|
||
Args:
|
||
response_data: 厂商 API 返回的原始数据
|
||
request: 原始请求
|
||
|
||
Returns:
|
||
AdapterResponse: OpenAI 格式响应
|
||
"""
|
||
# 默认实现(假设已经是 OpenAI 格式)
|
||
return AdapterResponse(**response_data)
|
||
|
||
# ========== 工具方法 ==========
|
||
|
||
async def __aenter__(self):
|
||
"""异步上下文管理器入口"""
|
||
if not self._is_initialized:
|
||
self.client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
|
||
)
|
||
self._is_initialized = True
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
"""异步上下文管理器出口"""
|
||
if self.client:
|
||
await self.client.aclose()
|
||
self._is_initialized = False
|
||
|
||
@staticmethod
|
||
def generate_response_id() -> str:
|
||
"""生成响应 ID"""
|
||
return f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||
|
||
@staticmethod
|
||
def get_current_timestamp() -> int:
|
||
"""获取当前时间戳"""
|
||
return int(time.time())
|
||
```
|
||
|
||
### 4.2 通用适配器 (`app/model_adapters/generic_adapter.py`)
|
||
|
||
```python
|
||
from typing import AsyncGenerator
|
||
import httpx
|
||
import json
|
||
import logging
|
||
|
||
from .base import BaseAdapter
|
||
from ..models import AdapterRequest, AdapterResponse, Choice, ResponseMessage, Usage
|
||
from ..core.model_config import ProviderConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class GenericAdapter(BaseAdapter):
|
||
"""
|
||
通用适配器
|
||
|
||
适用于标准 OpenAI 兼容的 API
|
||
"""
|
||
|
||
async def chat(self, request: AdapterRequest) -> AdapterResponse:
|
||
"""非流式聊天(通过流式实现)"""
|
||
full_content = ""
|
||
tool_calls = []
|
||
finish_reason = None
|
||
|
||
request.stream = True
|
||
|
||
async with self:
|
||
async for chunk in self.stream_chat(request):
|
||
if not chunk.startswith(b"data: "):
|
||
continue
|
||
|
||
data_str = chunk[6:].decode().strip()
|
||
if data_str == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
|
||
if "content" in delta:
|
||
full_content += delta["content"]
|
||
|
||
if "tool_calls" in delta:
|
||
self._merge_tool_calls(tool_calls, delta["tool_calls"])
|
||
|
||
finish_reason = choices[0].get("finish_reason")
|
||
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return AdapterResponse(
|
||
id=self.generate_response_id(),
|
||
object="chat.completion",
|
||
created=self.get_current_timestamp(),
|
||
model=request.model,
|
||
choices=[Choice(
|
||
index=0,
|
||
message=ResponseMessage(
|
||
role="assistant",
|
||
content=full_content or None,
|
||
tool_calls=tool_calls or None
|
||
),
|
||
finish_reason=finish_reason
|
||
)],
|
||
usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||
)
|
||
|
||
async def stream_chat(self, request: AdapterRequest) -> AsyncGenerator[bytes, None]:
|
||
"""流式聊天"""
|
||
if not self._is_initialized:
|
||
raise RuntimeError("Adapter not initialized")
|
||
|
||
headers = self.get_headers()
|
||
payload = self.format_request(request)
|
||
|
||
try:
|
||
async with self.client.stream(
|
||
"POST",
|
||
self.config.api_url,
|
||
headers=headers,
|
||
json=payload
|
||
) as response:
|
||
response.raise_for_status()
|
||
async for chunk in response.aiter_bytes():
|
||
yield chunk
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(f"API error: {e.response.status_code}")
|
||
yield self._format_error(f"API error: {e.response.status_code}", "api_error")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Request failed: {e}")
|
||
yield self._format_error(str(e), "request_error")
|
||
|
||
def _merge_tool_calls(self, tool_calls: list, new_calls: list):
|
||
"""合并增量工具调用"""
|
||
for call in new_calls:
|
||
index = call.get("index", 0)
|
||
while len(tool_calls) <= index:
|
||
tool_calls.append({
|
||
"id": "",
|
||
"type": "function",
|
||
"function": {"name": "", "arguments": ""}
|
||
})
|
||
|
||
tc = tool_calls[index]
|
||
if "id" in call:
|
||
tc["id"] = call["id"]
|
||
if "function" in call:
|
||
func = call["function"]
|
||
if "name" in func:
|
||
tc["function"]["name"] += func["name"]
|
||
if "arguments" in func:
|
||
tc["function"]["arguments"] += func["arguments"]
|
||
|
||
def _format_error(self, message: str, error_type: str) -> bytes:
|
||
"""格式化错误为 SSE"""
|
||
error_data = {"error": {"message": message, "type": error_type}}
|
||
return f"data: {json.dumps(error_data)}\n\n".encode()
|
||
```
|
||
|
||
### 4.3 自定义适配器示例
|
||
|
||
#### 示例 1: 自定义认证
|
||
|
||
```python
|
||
# app/model_adapters/custom_auth.py
|
||
from typing import Dict
|
||
from .base import BaseAdapter
|
||
|
||
class CustomAuthAdapter(BaseAdapter):
|
||
"""
|
||
使用 X-API-Key 认证的适配器
|
||
|
||
配置:
|
||
adapter_type: custom
|
||
adapter_class: app.model_adapters.custom_auth.CustomAuthAdapter
|
||
"""
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""重写认证逻辑"""
|
||
return {
|
||
"X-API-Key": self.config.api_key, # 自定义认证 header
|
||
"Content-Type": "application/json",
|
||
"X-Custom-Header": "my-proxy" # 可以添加其他固定 header
|
||
}
|
||
|
||
# 其他方法使用基类实现
|
||
# chat() 和 stream_chat() 继承自 GenericAdapter
|
||
```
|
||
|
||
#### 示例 2: 请求格式不同
|
||
|
||
```python
|
||
# app/model_adapters/different_request.py
|
||
from typing import Dict
|
||
from .base import BaseAdapter
|
||
from ..models import AdapterRequest, AdapterResponse
|
||
|
||
class DifferentRequestAdapter(BaseAdapter):
|
||
"""
|
||
请求字段名不同的适配器
|
||
|
||
API 格式差异:
|
||
- messages -> msgs
|
||
- model -> model_name
|
||
- temperature -> temp
|
||
"""
|
||
|
||
def format_request(self, request: AdapterRequest) -> Dict:
|
||
"""重写请求格式转换"""
|
||
# 转换为非标准格式
|
||
custom_payload = {
|
||
"model_name": request.model,
|
||
"msgs": [msg.model_dump() for msg in request.messages],
|
||
"stream": request.stream
|
||
}
|
||
|
||
# 映射可选字段
|
||
if request.temperature is not None:
|
||
custom_payload["temp"] = request.temperature
|
||
if request.max_tokens is not None:
|
||
custom_payload["max_tokens"] = request.max_tokens
|
||
|
||
# 工具相关字段映射
|
||
if request.tools:
|
||
custom_payload["functions"] = [tool.model_dump() for tool in request.tools]
|
||
if request.tool_choice is not None:
|
||
choice = request.tool_choice
|
||
if choice == "auto":
|
||
custom_payload["function_call"] = "auto"
|
||
elif choice == "none":
|
||
custom_payload["function_call"] = "none"
|
||
|
||
return custom_payload
|
||
```
|
||
|
||
#### 示例 3: 响应格式不同
|
||
|
||
```python
|
||
# app/model_adapters/different_response.py
|
||
from typing import Dict
|
||
from .base import BaseAdapter
|
||
from ..models import AdapterRequest, AdapterResponse, Choice, ResponseMessage
|
||
|
||
class DifferentResponseAdapter(BaseAdapter):
|
||
"""
|
||
响应结构不同的适配器
|
||
|
||
API 返回: {"results": [{"text": "...", "stop_reason": "..."}]}
|
||
需要转换为: {"choices": [{"message": {"content": "..."}}]}
|
||
"""
|
||
|
||
def parse_response(self, response_data: Dict, request: AdapterRequest) -> AdapterResponse:
|
||
"""重写响应解析"""
|
||
# 从非标准格式提取数据
|
||
results = response_data.get("results", [])
|
||
if not results:
|
||
raise ValueError("Empty response")
|
||
|
||
first_result = results[0]
|
||
|
||
# 构建标准 OpenAI 格式响应
|
||
return AdapterResponse(
|
||
id=response_data.get("id", self.generate_response_id()),
|
||
object="chat.completion",
|
||
created=response_data.get("created", self.get_current_timestamp()),
|
||
model=request.model,
|
||
choices=[Choice(
|
||
index=0,
|
||
message=ResponseMessage(
|
||
role="assistant",
|
||
content=first_result.get("text")
|
||
),
|
||
finish_reason=first_result.get("stop_reason")
|
||
)],
|
||
usage=None
|
||
)
|
||
```
|
||
|
||
#### 示例 4: 需要签名
|
||
|
||
```python
|
||
# app/model_adapters/signed_api.py
|
||
import hmac
|
||
import hashlib
|
||
import time
|
||
from typing import Dict
|
||
from .base import BaseAdapter
|
||
|
||
class SignedApiAdapter(BaseAdapter):
|
||
"""
|
||
需要 HMAC-SHA256 签名的 API 适配器
|
||
|
||
配置需要:
|
||
- api_key_env
|
||
- api_secret_env
|
||
"""
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""重写认证逻辑,添加签名"""
|
||
timestamp = str(int(time.time()))
|
||
|
||
# 构建签名字符串
|
||
sign_str = f"api_key={self.config.api_key}×tamp={timestamp}"
|
||
|
||
# 计算签名
|
||
signature = hmac.new(
|
||
self.config.api_secret.encode(),
|
||
sign_str.encode(),
|
||
hashlib.sha256
|
||
).hexdigest()
|
||
|
||
return {
|
||
"Content-Type": "application/json",
|
||
"X-API-Key": self.config.api_key,
|
||
"X-Timestamp": timestamp,
|
||
"X-Signature": signature
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 五、核心组件实现
|
||
|
||
### 5.1 配置管理器 (`app/core/model_config.py`)
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
from typing import Dict, Optional, List
|
||
import yaml
|
||
import os
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""模型信息"""
|
||
name: str
|
||
display_name: str
|
||
max_tokens: Optional[int] = None
|
||
default_params: Dict = {}
|
||
|
||
class ProviderConfig(BaseModel):
|
||
"""提供商配置"""
|
||
name: str
|
||
adapter_type: str # generic 或 custom
|
||
adapter_class: Optional[str] = None # 自定义适配器类路径
|
||
api_url: str
|
||
api_key_env: str
|
||
api_secret_env: Optional[str] = None
|
||
api_key: Optional[str] = None
|
||
api_secret: Optional[str] = None
|
||
description: Optional[str] = None
|
||
models: List[ModelInfo] = []
|
||
raw_config: Dict = {}
|
||
|
||
@property
|
||
def model_names(self) -> List[str]:
|
||
return [model.name for model in self.models]
|
||
|
||
class ModelConfigManager:
|
||
"""模型配置管理器"""
|
||
|
||
def __init__(self, config_path: str = "config/models.yaml"):
|
||
self.config_path = config_path
|
||
self.providers: Dict[str, ProviderConfig] = {}
|
||
self.model_to_provider: Dict[str, str] = {}
|
||
self.default_model: Optional[str] = None
|
||
self.load_config()
|
||
|
||
def load_config(self):
|
||
"""加载配置文件"""
|
||
try:
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
for provider_name, provider_config in config.get('providers', {}).items():
|
||
api_key = os.getenv(provider_config.get('api_key_env', ''))
|
||
api_secret = os.getenv(provider_config.get('api_secret_env', ''))
|
||
|
||
models_data = provider_config.get('models', [])
|
||
models = [ModelInfo(**model_data) for model_data in models_data]
|
||
|
||
self.providers[provider_name] = ProviderConfig(
|
||
name=provider_name,
|
||
api_key=api_key,
|
||
api_secret=api_secret,
|
||
models=models,
|
||
raw_config=provider_config,
|
||
**{k: v for k, v in provider_config.items()
|
||
if k not in ['models', 'api_key_env', 'api_secret_env']}
|
||
)
|
||
|
||
for model in models:
|
||
if model.name in self.model_to_provider:
|
||
logger.warning(f"Duplicate model '{model.name}'")
|
||
self.model_to_provider[model.name] = provider_name
|
||
|
||
self.default_model = config.get('default_model')
|
||
|
||
logger.info(f"Loaded {len(self.providers)} providers, {len(self.model_to_provider)} models")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load config: {e}")
|
||
raise
|
||
|
||
def get_provider_for_model(self, model_name: str) -> Optional[ProviderConfig]:
|
||
"""根据模型名称获取提供商配置"""
|
||
provider_name = self.model_to_provider.get(model_name)
|
||
return self.providers.get(provider_name) if provider_name else None
|
||
|
||
def list_models(self) -> List[str]:
|
||
"""列出所有模型"""
|
||
return list(self.model_to_provider.keys())
|
||
|
||
def get_default_model(self) -> Optional[str]:
|
||
"""获取默认模型"""
|
||
return self.default_model
|
||
|
||
def model_exists(self, model_name: str) -> bool:
|
||
"""检查模型是否存在"""
|
||
return model_name in self.model_to_provider
|
||
```
|
||
|
||
### 5.2 模型路由器 (`app/services.py`)
|
||
|
||
```python
|
||
from typing import Dict, Optional, Tuple
|
||
from importlib import import_module
|
||
import logging
|
||
|
||
from .core.model_config import ModelConfigManager, ProviderConfig
|
||
from .model_adapters.base import BaseAdapter
|
||
from .model_adapters.generic_adapter import GenericAdapter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ModelRouter:
|
||
"""
|
||
模型路由器
|
||
|
||
功能:
|
||
1. 根据 model 字段查找对应的 provider
|
||
2. 缓存 adapter 实例(按 provider 缓存)
|
||
3. 动态加载适配器(generic 或 custom)
|
||
"""
|
||
|
||
def __init__(self, config_manager: ModelConfigManager):
|
||
self.config_manager = config_manager
|
||
self._adapter_cache: Dict[str, BaseAdapter] = {}
|
||
|
||
def get_adapter(self, model_name: str) -> Optional[Tuple[BaseAdapter, str]]:
|
||
"""
|
||
根据模型名称获取适配器
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
|
||
Returns:
|
||
Tuple[BaseAdapter, str]: (适配器实例, 实际模型名称)
|
||
"""
|
||
actual_model = self._resolve_model_name(model_name)
|
||
provider_config = self.config_manager.get_provider_for_model(actual_model)
|
||
|
||
if not provider_config:
|
||
logger.error(f"Model '{actual_model}' not found")
|
||
return None
|
||
|
||
# 检查缓存
|
||
if provider_config.name in self._adapter_cache:
|
||
return self._adapter_cache[provider_config.name], actual_model
|
||
|
||
# 创建适配器
|
||
adapter = self._create_adapter(provider_config)
|
||
if not adapter:
|
||
return None
|
||
|
||
# 缓存
|
||
self._adapter_cache[provider_config.name] = adapter
|
||
|
||
logger.info(f"Created adapter for provider '{provider_config.name}' "
|
||
f"(type: {provider_config.adapter_type}), "
|
||
f"supports models: {provider_config.model_names}")
|
||
|
||
return adapter, actual_model
|
||
|
||
def _create_adapter(self, config: ProviderConfig) -> Optional[BaseAdapter]:
|
||
"""根据 adapter_type 创建适配器实例"""
|
||
adapter_type = config.adapter_type
|
||
|
||
if adapter_type == "generic":
|
||
return GenericAdapter(config)
|
||
|
||
elif adapter_type == "custom":
|
||
if not config.adapter_class:
|
||
logger.error(f"Provider '{config.name}': adapter_class not specified")
|
||
return None
|
||
|
||
try:
|
||
module_path, class_name = config.adapter_class.rsplit('.', 1)
|
||
module = import_module(module_path)
|
||
adapter_class = getattr(module, class_name)
|
||
return adapter_class(config)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load adapter '{config.adapter_class}': {e}")
|
||
return None
|
||
|
||
else:
|
||
logger.error(f"Unknown adapter_type: {adapter_type}")
|
||
return None
|
||
|
||
def _resolve_model_name(self, model_name: Optional[str]) -> str:
|
||
"""解析模型名称"""
|
||
if not model_name:
|
||
default_model = self.config_manager.get_default_model()
|
||
if not default_model:
|
||
raise ValueError("No model specified and no default model")
|
||
return default_model
|
||
return model_name
|
||
|
||
# 全局实例
|
||
_model_config_manager = ModelConfigManager()
|
||
model_router = ModelRouter(_model_config_manager)
|
||
```
|
||
|
||
---
|
||
|
||
## 六、实现步骤
|
||
|
||
### 阶段一:基础设施(1-2天)
|
||
|
||
- [ ] 创建 `app/core/model_config.py`,实现配置管理器
|
||
- [ ] 更新 `app/models.py`,添加标准化数据模型
|
||
- [ ] 创建 `app/model_adapters/base.py`,定义适配器基类
|
||
- [ ] 更新 `config/models.yaml`,采用简洁配置
|
||
- [ ] 更新 `.env`,添加各厂商 API Key
|
||
|
||
### 阶段二:适配器实现(2-3天)
|
||
|
||
- [ ] 实现 `GenericAdapter`(通用 OpenAI 兼容适配器)
|
||
- [ ] 实现 2-3 个自定义适配器示例(custom_auth, signed_api 等)
|
||
- [ ] 添加单元测试
|
||
|
||
### 阶段三:路由和服务集成(1-2天)
|
||
|
||
- [ ] 在 `app/services.py` 中实现 `ModelRouter`
|
||
- [ ] 更新服务函数,使用 ModelRouter
|
||
- [ ] 更新 `app/main.py` 中的 API 接口
|
||
- [ ] 处理 model 参数
|
||
|
||
### 阶段四:测试和验证(1-2天)
|
||
|
||
- [ ] 单元测试:配置加载、适配器、路由器
|
||
- [ ] 集成测试:完整的请求/响应流程
|
||
- [ ] 多模型测试:验证不同模型的调用
|
||
- [ ] 自定义适配器测试:验证非标准接口处理
|
||
|
||
### 阶段五:文档和部署(1天)
|
||
|
||
- [ ] 更新 README.md
|
||
- [ ] 编写配置指南
|
||
- [ ] 编写适配器开发指南
|
||
- [ ] 部署到生产环境
|
||
|
||
---
|
||
|
||
## 七、使用示例
|
||
|
||
### 7.1 标准配置示例
|
||
|
||
```yaml
|
||
# 最简单的配置(标准 OpenAI 兼容)
|
||
openai:
|
||
adapter_type: generic
|
||
api_url: https://api.openai.com/v1/chat/completions
|
||
api_key_env: OPENAI_API_KEY
|
||
models:
|
||
- name: gpt-3.5-turbo
|
||
display_name: "GPT-3.5 Turbo"
|
||
```
|
||
|
||
### 7.2 自定义适配器示例
|
||
|
||
```yaml
|
||
# 需要特殊处理的 API
|
||
custom_api:
|
||
adapter_type: custom
|
||
adapter_class: app.model_adapters.my_adapter.MyAdapter
|
||
api_url: https://custom-api.com/v1/chat
|
||
api_key_env: CUSTOM_API_KEY
|
||
models:
|
||
- name: custom-model
|
||
display_name: "Custom Model"
|
||
```
|
||
|
||
### 7.3 客户端请求示例
|
||
|
||
```json
|
||
// 使用不同的模型
|
||
{
|
||
"model": "gpt-4",
|
||
"messages": [{"role": "user", "content": "Hello!"}],
|
||
"stream": true
|
||
}
|
||
|
||
{
|
||
"model": "custom-model",
|
||
"messages": [{"role": "user", "content": "Hello!"}],
|
||
"stream": false
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 八、总结
|
||
|
||
### 8.1 核心优势
|
||
|
||
- ✅ **Provider-Based**:单个适配器支持多个模型,资源利用率高
|
||
- ✅ **OpenAI 兼容**:接口完全对标 OpenAI API,易于理解
|
||
- ✅ **配置简洁**:只存数据,不存逻辑
|
||
- ✅ **代码灵活**:复杂逻辑在代码中实现,易调试
|
||
- ✅ **易于扩展**:添加新模型只需配置,添加新 API 只需写适配器
|
||
|
||
### 8.2 配置与代码职责
|
||
|
||
| 方面 | 配置文件 | 代码 |
|
||
|------|---------|------|
|
||
| API URL | ✅ | ❌ |
|
||
| 模型列表 | ✅ | ❌ |
|
||
| 认证方式 | ❌ | ✅(get_headers) |
|
||
| 请求转换 | ❌ | ✅(format_request) |
|
||
| 响应转换 | ❌ | ✅(parse_response) |
|
||
| 流式处理 | ❌ | ✅(stream_chat) |
|
||
|
||
### 8.3 项目结构
|
||
|
||
```
|
||
app/
|
||
├── model_adapters/
|
||
│ ├── __init__.py
|
||
│ ├── base.py # 适配器基类
|
||
│ ├── generic_adapter.py # 通用适配器
|
||
│ ├── custom_auth.py # 自定义认证示例
|
||
│ ├── different_request.py # 请求格式不同
|
||
│ ├── different_response.py # 响应格式不同
|
||
│ └── signed_api.py # 签名 API
|
||
├── core/
|
||
│ └── model_config.py # 配置管理器
|
||
├── models.py # 数据模型
|
||
└── services.py # 服务层(ModelRouter)
|
||
|
||
config/
|
||
└── models.yaml # 配置文件
|
||
```
|
||
|
||
---
|
||
|
||
**状态**: ⏳ 待用户确认
|
||
|
||
**下一步**: 确认后开始实现代码
|