diff --git a/docs/multi_backend_final_design.md b/docs/multi_backend_final_design.md new file mode 100644 index 0000000..3c30486 --- /dev/null +++ b/docs/multi_backend_final_design.md @@ -0,0 +1,1139 @@ +# 多后端大模型支持实现方案(最终版) + +> **版本:v2.0** +> **更新日期:2025-12-31** +> **状态:待确认** + +## 📋 方案概述 + +### 核心设计理念 +1. ✅ **Provider-Based 架构**:单个适配器实例支持同一提供商的多个模型 +2. ✅ **OpenAI 兼容接口**:适配器参数和返回值完全对标 OpenAI API +3. ✅ **简洁配置 + 代码实现**:配置文件只存数据,复杂逻辑用代码实现 +4. ✅ **统一标准化接口**:所有适配器遵循相同的接口规范 + +### 适用场景 +- 标准 OpenAI 兼容 API +- 不同认证方式的 API(API-Key、签名等) +- 不同请求/响应格式的 API +- 需要特殊处理的 API + +--- + +## 一、架构设计 + +### 1.1 整体架构图 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Client Request │ +│ {"model": "gpt-4", "messages": [...]} │ +└─────────────────────┬───────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────┐ +│ Chat Completions API │ +│ /v1/chat/completions │ +└─────────────────────┬───────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────┐ +│ Model Router (模型路由器) │ +│ - 根据 model 字段查找对应的 provider │ +│ - 缓存 adapter 实例(按 provider 缓存) │ +│ - 动态加载适配器(generic 或 custom) │ +│ │ +│ 例: "gpt-4" → OpenAI Provider → GenericAdapter │ +│ "gpt-3.5-turbo" → OpenAI Provider → GenericAdapter │ +│ (同一个 adapter 实例) │ +└─────────────────────┬───────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────┐ +│ Model Adapter (适配器) │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ BaseAdapter (基类) │ │ +│ │ - chat() / stream_chat() │ │ +│ │ - get_headers() → 可重写(认证) │ │ +│ │ - format_request() → 可重写(请求转换) │ │ +│ │ - parse_response() → 可重写(响应转换) │ │ +│ └─────────────────────────────────────────────────┘ │ +│ ↓ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ GenericAdapter (通用适配器) │ │ +│ │ 适用于标准 OpenAI 兼容 API │ │ +│ └─────────────────────────────────────────────────┘ │ +│ ↓ │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ CustomAdapter (自定义适配器) │ │ +│ │ 重写方法处理非标准接口 │ │ +│ │ - get_headers() → 自定义认证 │ │ +│ │ - format_request() → 请求格式转换 │ │ +│ │ - parse_response() → 响应格式转换 │ │ +│ └─────────────────────────────────────────────────┘ │ +└─────────────────────┬───────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────┐ +│ Backend LLM API (各大厂商) │ +└─────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────┐ +│ Standardized Response │ +│ (OpenAI compatible format) │ +└─────────────────────────────────────────────────────────┘ +``` + +### 1.2 核心组件 + +| 组件 | 文件路径 | 作用 | +|------|---------|------| +| **配置管理器** | `app/core/model_config.py` | 加载配置,提供 provider/model 查询 | +| **数据模型** | `app/models.py` | OpenAI 格式的请求/响应数据结构 | +| **适配器基类** | `app/model_adapters/base.py` | 定义统一的适配器接口 | +| **通用适配器** | `app/model_adapters/generic_adapter.py` | 处理标准 OpenAI 兼容 API | +| **自定义适配器** | `app/model_adapters/custom_*.py` | 处理各种非标准接口 | +| **模型路由器** | `app/services.py` | 动态加载适配器,管理实例缓存 | +| **API 接口** | `app/main.py` | FastAPI 路由,处理 HTTP 请求 | + +### 1.3 设计决策 + +| 决策点 | 选择 | 理由 | +|--------|------|------| +| 适配器粒度 | Provider-based(按提供商) | 单个适配器支持多个模型,复用连接池 | +| 接口格式 | OpenAI 兼容 | 行业标准,易于理解和迁移 | +| 配置方式 | 简洁 YAML | 只存数据,不存逻辑 | +| 复杂逻辑 | 代码实现 | 易调试、易维护、更灵活 | +| 缓存策略 | 按 provider 缓存 | 复用 HTTP 连接,减少实例数 | + +--- + +## 二、配置文件设计(极简版) + +### 2.1 配置结构 (`config/models.yaml`) + +```yaml +# ======================================== +# 版本信息 +# ======================================== +version: "2.0" + +# ======================================== +# 默认模型 +# ======================================== +default_model: gpt-3.5-turbo + +# ======================================== +# 提供商配置 +# ======================================== +providers: + # -------------------- 标准 OpenAI 兼容 API -------------------- + openai: + adapter_type: generic + api_url: https://api.openai.com/v1/chat/completions + api_key_env: OPENAI_API_KEY + description: "OpenAI API" + models: + - name: gpt-4 + display_name: "GPT-4" + max_tokens: 8192 + - name: gpt-4-turbo + display_name: "GPT-4 Turbo" + max_tokens: 128000 + - name: gpt-3.5-turbo + display_name: "GPT-3.5 Turbo" + max_tokens: 4096 + + # -------------------- Anthropic (需要自定义适配器)-------------------- + anthropic: + adapter_type: custom + adapter_class: app.model_adapters.anthropic.AnthropicAdapter + api_url: https://api.anthropic.com/v1/messages + api_key_env: ANTHROPIC_API_KEY + description: "Anthropic Claude API" + models: + - name: claude-3-opus + display_name: "Claude 3 Opus" + max_tokens: 200000 + - name: claude-3-sonnet + display_name: "Claude 3 Sonnet" + max_tokens: 200000 + + # -------------------- DashScope (阿里云) -------------------- + dashscope: + adapter_type: generic + api_url: https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + api_key_env: DASHSCOPE_API_KEY + description: "阿里云 DashScope" + models: + - name: qwen-max + display_name: "Qwen Max" + max_tokens: 6000 + - name: qwen-plus + display_name: "Qwen Plus" + max_tokens: 30000 + + # -------------------- 自定义认证 API -------------------- + custom_auth_api: + adapter_type: custom + adapter_class: app.model_adapters.custom_auth.CustomAuthAdapter + api_url: https://custom-api.com/v1/chat + api_key_env: CUSTOM_API_KEY + description: "使用 X-API-Key 认证的 API" + models: + - name: custom-model-v1 + display_name: "Custom Model V1" + + # -------------------- 需要签名的 API -------------------- + signed_api: + adapter_type: custom + adapter_class: app.model_adapters.signed.SignedApiAdapter + api_url: https://signed-api.com/v1/chat + api_key_env: SIGNED_API_KEY + api_secret_env: SIGNED_API_SECRET + description: "需要 HMAC 签名的 API" + models: + - name: signed-model + display_name: "Signed Model" + + # -------------------- 本地 Ollama -------------------- + local: + adapter_type: generic + api_url: http://localhost:11434/v1/chat/completions + api_key_env: LOCAL_LLM_API_KEY + description: "本地 Ollama" + models: + - name: llama2 + display_name: "Llama 2" + - name: mistral + display_name: "Mistral" +``` + +### 2.2 配置项说明 + +| 字段 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `adapter_type` | string | ✅ | `generic`(通用)或 `custom`(自定义) | +| `adapter_class` | string | ⚠️ | 自定义适配器类路径(adapter_type=custom 时必需) | +| `api_url` | string | ✅ | API 端点 URL | +| `api_key_env` | string | ✅ | API Key 环境变量名 | +| `api_secret_env` | string | ❌ | API Secret 环境变量名(签名 API 需要) | +| `description` | string | ❌ | 描述信息 | +| `models` | array | ✅ | 模型列表 | + +### 2.3 模型配置 + +```yaml +models: + - name: str # 模型名称(必需,唯一标识) + display_name: str # 显示名称(必需) + max_tokens: int # 最大 token 数(可选) + default_params: # 默认参数(可选) + temperature: 0.7 + top_p: 0.9 +``` + +--- + +## 三、数据模型设计(OpenAI 兼容) + +### 3.1 请求模型 (`app/models.py`) + +```python +from typing import List, Dict, Any, Optional, Union +from pydantic import BaseModel, Field +from enum import Enum + +class MessageRole(str, Enum): + """消息角色枚举""" + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + +class ChatMessage(BaseModel): + """聊天消息(与 OpenAI API 格式一致)""" + role: MessageRole + content: Optional[str] = None + name: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """序列化为字典,过滤 None 值""" + data = super().model_dump(exclude_none=True, **kwargs) + if "role" in data and isinstance(data["role"], MessageRole): + data["role"] = data["role"].value + return data + +class ToolFunction(BaseModel): + """工具函数定义""" + name: str + description: str + parameters: Dict[str, Any] + +class Tool(BaseModel): + """工具定义(OpenAI 格式)""" + type: str = "function" + function: ToolFunction + +class AdapterRequest(BaseModel): + """ + 标准化的适配器请求参数 + 完全对标 OpenAI Chat Completions API + """ + # ========== 必需参数 ========== + messages: List[ChatMessage] + model: str + + # ========== 可选参数 ========== + temperature: Optional[float] = Field(None, ge=0, le=2) + max_tokens: Optional[int] = Field(None, gt=0) + top_p: Optional[float] = Field(None, ge=0, le=1) + n: Optional[int] = Field(1, ge=1) + stream: bool = False + + # ========== 工具相关 ========== + tools: Optional[List[Tool]] = None + tool_choice: Optional[Union[str, Dict]] = None + + # ========== 停止条件 ========== + stop: Optional[Union[str, List[str]]] = None + + # ========== 其他参数 ========== + presence_penalty: Optional[float] = Field(None, ge=-2, le=2) + frequency_penalty: Optional[float] = Field(None, ge=-2, le=2) + + # ========== 扩展参数(厂商特定) ========== + extra: Dict[str, Any] = Field(default_factory=dict) +``` + +### 3.2 响应模型 (`app/models.py`) + +```python +class ToolCall(BaseModel): + """工具调用(OpenAI 格式)""" + id: str + type: str = "function" + function: Dict[str, Any] # {name, arguments} + +class ResponseMessage(BaseModel): + """响应消息""" + role: str = "assistant" + content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None + +class Choice(BaseModel): + """选择项(非流式)""" + index: int = 0 + message: ResponseMessage + finish_reason: Optional[str] = None + +class Usage(BaseModel): + """Token 使用统计""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + +class AdapterResponse(BaseModel): + """ + 标准化的适配器响应 + 完全对标 OpenAI Chat Completions API 响应格式 + """ + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[Choice] + usage: Optional[Usage] = None +``` + +--- + +## 四、适配器实现 + +### 4.1 适配器基类 (`app/model_adapters/base.py`) + +```python +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Optional, Dict, Any +import httpx +import logging +import time +import uuid + +from ..models import AdapterRequest, AdapterResponse +from ..core.model_config import ProviderConfig + +logger = logging.getLogger(__name__) + +class BaseAdapter(ABC): + """ + 适配器基类 + + 所有适配器必须实现此类,提供统一的接口 + """ + + def __init__(self, config: ProviderConfig): + """ + 初始化适配器 + + Args: + config: 提供商配置 + """ + self.config = config + self.client: Optional[httpx.AsyncClient] = None + self._is_initialized = False + + # ========== 核心接口(必须实现)========== + + @abstractmethod + async def chat(self, request: AdapterRequest) -> AdapterResponse: + """ + 非流式聊天 + + Args: + request: 标准化的请求(OpenAI 格式) + + Returns: + AdapterResponse: 标准化的响应(OpenAI 格式) + """ + pass + + @abstractmethod + async def stream_chat(self, request: AdapterRequest) -> AsyncGenerator[bytes, None]: + """ + 流式聊天 + + Args: + request: 标准化的请求(OpenAI 格式) + + Yields: + bytes: SSE 格式数据块 "data: {json}\n\n" + """ + pass + + # ========== 辅助方法(可重写以处理非标准接口)========== + + def get_headers(self) -> Dict[str, str]: + """ + 获取请求头(包含认证信息) + + 可重写此方法以实现自定义认证逻辑 + + Returns: + Dict[str, str]: 请求头字典 + """ + return { + "Authorization": f"Bearer {self.config.api_key}", + "Content-Type": "application/json" + } + + def format_request(self, request: AdapterRequest) -> Dict[str, Any]: + """ + 将 OpenAI 格式请求转换为厂商 API 格式 + + 可重写此方法以处理非标准请求格式 + + Args: + request: OpenAI 格式请求 + + Returns: + Dict: 厂商 API 需要的请求体 + """ + from ..models import Tool + + payload = { + "model": request.model, + "messages": [msg.model_dump() for msg in request.messages], + "stream": request.stream + } + + # 添加可选参数 + if request.temperature is not None: + payload["temperature"] = request.temperature + if request.max_tokens is not None: + payload["max_tokens"] = request.max_tokens + if request.top_p is not None: + payload["top_p"] = request.top_p + if request.stop is not None: + payload["stop"] = request.stop + if request.tools: + payload["tools"] = [tool.model_dump() for tool in request.tools] + if request.tool_choice is not None: + payload["tool_choice"] = request.tool_choice + + # 合并额外参数 + payload.update(request.extra) + + return payload + + def parse_response(self, response_data: Dict, request: AdapterRequest) -> AdapterResponse: + """ + 将厂商 API 响应转换为 OpenAI 格式 + + 可重写此方法以处理非标准响应格式 + + Args: + response_data: 厂商 API 返回的原始数据 + request: 原始请求 + + Returns: + AdapterResponse: OpenAI 格式响应 + """ + # 默认实现(假设已经是 OpenAI 格式) + return AdapterResponse(**response_data) + + # ========== 工具方法 ========== + + async def __aenter__(self): + """异步上下文管理器入口""" + if not self._is_initialized: + self.client = httpx.AsyncClient( + timeout=httpx.Timeout(60.0, connect=10.0), + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20) + ) + self._is_initialized = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + if self.client: + await self.client.aclose() + self._is_initialized = False + + @staticmethod + def generate_response_id() -> str: + """生成响应 ID""" + return f"chatcmpl-{uuid.uuid4().hex[:24]}" + + @staticmethod + def get_current_timestamp() -> int: + """获取当前时间戳""" + return int(time.time()) +``` + +### 4.2 通用适配器 (`app/model_adapters/generic_adapter.py`) + +```python +from typing import AsyncGenerator +import httpx +import json +import logging + +from .base import BaseAdapter +from ..models import AdapterRequest, AdapterResponse, Choice, ResponseMessage, Usage +from ..core.model_config import ProviderConfig + +logger = logging.getLogger(__name__) + +class GenericAdapter(BaseAdapter): + """ + 通用适配器 + + 适用于标准 OpenAI 兼容的 API + """ + + async def chat(self, request: AdapterRequest) -> AdapterResponse: + """非流式聊天(通过流式实现)""" + full_content = "" + tool_calls = [] + finish_reason = None + + request.stream = True + + async with self: + async for chunk in self.stream_chat(request): + if not chunk.startswith(b"data: "): + continue + + data_str = chunk[6:].decode().strip() + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + choices = data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + + if "content" in delta: + full_content += delta["content"] + + if "tool_calls" in delta: + self._merge_tool_calls(tool_calls, delta["tool_calls"]) + + finish_reason = choices[0].get("finish_reason") + + except json.JSONDecodeError: + pass + + return AdapterResponse( + id=self.generate_response_id(), + object="chat.completion", + created=self.get_current_timestamp(), + model=request.model, + choices=[Choice( + index=0, + message=ResponseMessage( + role="assistant", + content=full_content or None, + tool_calls=tool_calls or None + ), + finish_reason=finish_reason + )], + usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + ) + + async def stream_chat(self, request: AdapterRequest) -> AsyncGenerator[bytes, None]: + """流式聊天""" + if not self._is_initialized: + raise RuntimeError("Adapter not initialized") + + headers = self.get_headers() + payload = self.format_request(request) + + try: + async with self.client.stream( + "POST", + self.config.api_url, + headers=headers, + json=payload + ) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + except httpx.HTTPStatusError as e: + logger.error(f"API error: {e.response.status_code}") + yield self._format_error(f"API error: {e.response.status_code}", "api_error") + + except Exception as e: + logger.error(f"Request failed: {e}") + yield self._format_error(str(e), "request_error") + + def _merge_tool_calls(self, tool_calls: list, new_calls: list): + """合并增量工具调用""" + for call in new_calls: + index = call.get("index", 0) + while len(tool_calls) <= index: + tool_calls.append({ + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""} + }) + + tc = tool_calls[index] + if "id" in call: + tc["id"] = call["id"] + if "function" in call: + func = call["function"] + if "name" in func: + tc["function"]["name"] += func["name"] + if "arguments" in func: + tc["function"]["arguments"] += func["arguments"] + + def _format_error(self, message: str, error_type: str) -> bytes: + """格式化错误为 SSE""" + error_data = {"error": {"message": message, "type": error_type}} + return f"data: {json.dumps(error_data)}\n\n".encode() +``` + +### 4.3 自定义适配器示例 + +#### 示例 1: 自定义认证 + +```python +# app/model_adapters/custom_auth.py +from typing import Dict +from .base import BaseAdapter + +class CustomAuthAdapter(BaseAdapter): + """ + 使用 X-API-Key 认证的适配器 + + 配置: + adapter_type: custom + adapter_class: app.model_adapters.custom_auth.CustomAuthAdapter + """ + + def get_headers(self) -> Dict[str, str]: + """重写认证逻辑""" + return { + "X-API-Key": self.config.api_key, # 自定义认证 header + "Content-Type": "application/json", + "X-Custom-Header": "my-proxy" # 可以添加其他固定 header + } + + # 其他方法使用基类实现 + # chat() 和 stream_chat() 继承自 GenericAdapter +``` + +#### 示例 2: 请求格式不同 + +```python +# app/model_adapters/different_request.py +from typing import Dict +from .base import BaseAdapter +from ..models import AdapterRequest, AdapterResponse + +class DifferentRequestAdapter(BaseAdapter): + """ + 请求字段名不同的适配器 + + API 格式差异: + - messages -> msgs + - model -> model_name + - temperature -> temp + """ + + def format_request(self, request: AdapterRequest) -> Dict: + """重写请求格式转换""" + # 转换为非标准格式 + custom_payload = { + "model_name": request.model, + "msgs": [msg.model_dump() for msg in request.messages], + "stream": request.stream + } + + # 映射可选字段 + if request.temperature is not None: + custom_payload["temp"] = request.temperature + if request.max_tokens is not None: + custom_payload["max_tokens"] = request.max_tokens + + # 工具相关字段映射 + if request.tools: + custom_payload["functions"] = [tool.model_dump() for tool in request.tools] + if request.tool_choice is not None: + choice = request.tool_choice + if choice == "auto": + custom_payload["function_call"] = "auto" + elif choice == "none": + custom_payload["function_call"] = "none" + + return custom_payload +``` + +#### 示例 3: 响应格式不同 + +```python +# app/model_adapters/different_response.py +from typing import Dict +from .base import BaseAdapter +from ..models import AdapterRequest, AdapterResponse, Choice, ResponseMessage + +class DifferentResponseAdapter(BaseAdapter): + """ + 响应结构不同的适配器 + + API 返回: {"results": [{"text": "...", "stop_reason": "..."}]} + 需要转换为: {"choices": [{"message": {"content": "..."}}]} + """ + + def parse_response(self, response_data: Dict, request: AdapterRequest) -> AdapterResponse: + """重写响应解析""" + # 从非标准格式提取数据 + results = response_data.get("results", []) + if not results: + raise ValueError("Empty response") + + first_result = results[0] + + # 构建标准 OpenAI 格式响应 + return AdapterResponse( + id=response_data.get("id", self.generate_response_id()), + object="chat.completion", + created=response_data.get("created", self.get_current_timestamp()), + model=request.model, + choices=[Choice( + index=0, + message=ResponseMessage( + role="assistant", + content=first_result.get("text") + ), + finish_reason=first_result.get("stop_reason") + )], + usage=None + ) +``` + +#### 示例 4: 需要签名 + +```python +# app/model_adapters/signed_api.py +import hmac +import hashlib +import time +from typing import Dict +from .base import BaseAdapter + +class SignedApiAdapter(BaseAdapter): + """ + 需要 HMAC-SHA256 签名的 API 适配器 + + 配置需要: + - api_key_env + - api_secret_env + """ + + def get_headers(self) -> Dict[str, str]: + """重写认证逻辑,添加签名""" + timestamp = str(int(time.time())) + + # 构建签名字符串 + sign_str = f"api_key={self.config.api_key}×tamp={timestamp}" + + # 计算签名 + signature = hmac.new( + self.config.api_secret.encode(), + sign_str.encode(), + hashlib.sha256 + ).hexdigest() + + return { + "Content-Type": "application/json", + "X-API-Key": self.config.api_key, + "X-Timestamp": timestamp, + "X-Signature": signature + } +``` + +--- + +## 五、核心组件实现 + +### 5.1 配置管理器 (`app/core/model_config.py`) + +```python +from pydantic import BaseModel +from typing import Dict, Optional, List +import yaml +import os +import logging + +logger = logging.getLogger(__name__) + +class ModelInfo(BaseModel): + """模型信息""" + name: str + display_name: str + max_tokens: Optional[int] = None + default_params: Dict = {} + +class ProviderConfig(BaseModel): + """提供商配置""" + name: str + adapter_type: str # generic 或 custom + adapter_class: Optional[str] = None # 自定义适配器类路径 + api_url: str + api_key_env: str + api_secret_env: Optional[str] = None + api_key: Optional[str] = None + api_secret: Optional[str] = None + description: Optional[str] = None + models: List[ModelInfo] = [] + raw_config: Dict = {} + + @property + def model_names(self) -> List[str]: + return [model.name for model in self.models] + +class ModelConfigManager: + """模型配置管理器""" + + def __init__(self, config_path: str = "config/models.yaml"): + self.config_path = config_path + self.providers: Dict[str, ProviderConfig] = {} + self.model_to_provider: Dict[str, str] = {} + self.default_model: Optional[str] = None + self.load_config() + + def load_config(self): + """加载配置文件""" + try: + with open(self.config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + for provider_name, provider_config in config.get('providers', {}).items(): + api_key = os.getenv(provider_config.get('api_key_env', '')) + api_secret = os.getenv(provider_config.get('api_secret_env', '')) + + models_data = provider_config.get('models', []) + models = [ModelInfo(**model_data) for model_data in models_data] + + self.providers[provider_name] = ProviderConfig( + name=provider_name, + api_key=api_key, + api_secret=api_secret, + models=models, + raw_config=provider_config, + **{k: v for k, v in provider_config.items() + if k not in ['models', 'api_key_env', 'api_secret_env']} + ) + + for model in models: + if model.name in self.model_to_provider: + logger.warning(f"Duplicate model '{model.name}'") + self.model_to_provider[model.name] = provider_name + + self.default_model = config.get('default_model') + + logger.info(f"Loaded {len(self.providers)} providers, {len(self.model_to_provider)} models") + + except Exception as e: + logger.error(f"Failed to load config: {e}") + raise + + def get_provider_for_model(self, model_name: str) -> Optional[ProviderConfig]: + """根据模型名称获取提供商配置""" + provider_name = self.model_to_provider.get(model_name) + return self.providers.get(provider_name) if provider_name else None + + def list_models(self) -> List[str]: + """列出所有模型""" + return list(self.model_to_provider.keys()) + + def get_default_model(self) -> Optional[str]: + """获取默认模型""" + return self.default_model + + def model_exists(self, model_name: str) -> bool: + """检查模型是否存在""" + return model_name in self.model_to_provider +``` + +### 5.2 模型路由器 (`app/services.py`) + +```python +from typing import Dict, Optional, Tuple +from importlib import import_module +import logging + +from .core.model_config import ModelConfigManager, ProviderConfig +from .model_adapters.base import BaseAdapter +from .model_adapters.generic_adapter import GenericAdapter + +logger = logging.getLogger(__name__) + +class ModelRouter: + """ + 模型路由器 + + 功能: + 1. 根据 model 字段查找对应的 provider + 2. 缓存 adapter 实例(按 provider 缓存) + 3. 动态加载适配器(generic 或 custom) + """ + + def __init__(self, config_manager: ModelConfigManager): + self.config_manager = config_manager + self._adapter_cache: Dict[str, BaseAdapter] = {} + + def get_adapter(self, model_name: str) -> Optional[Tuple[BaseAdapter, str]]: + """ + 根据模型名称获取适配器 + + Args: + model_name: 模型名称 + + Returns: + Tuple[BaseAdapter, str]: (适配器实例, 实际模型名称) + """ + actual_model = self._resolve_model_name(model_name) + provider_config = self.config_manager.get_provider_for_model(actual_model) + + if not provider_config: + logger.error(f"Model '{actual_model}' not found") + return None + + # 检查缓存 + if provider_config.name in self._adapter_cache: + return self._adapter_cache[provider_config.name], actual_model + + # 创建适配器 + adapter = self._create_adapter(provider_config) + if not adapter: + return None + + # 缓存 + self._adapter_cache[provider_config.name] = adapter + + logger.info(f"Created adapter for provider '{provider_config.name}' " + f"(type: {provider_config.adapter_type}), " + f"supports models: {provider_config.model_names}") + + return adapter, actual_model + + def _create_adapter(self, config: ProviderConfig) -> Optional[BaseAdapter]: + """根据 adapter_type 创建适配器实例""" + adapter_type = config.adapter_type + + if adapter_type == "generic": + return GenericAdapter(config) + + elif adapter_type == "custom": + if not config.adapter_class: + logger.error(f"Provider '{config.name}': adapter_class not specified") + return None + + try: + module_path, class_name = config.adapter_class.rsplit('.', 1) + module = import_module(module_path) + adapter_class = getattr(module, class_name) + return adapter_class(config) + + except Exception as e: + logger.error(f"Failed to load adapter '{config.adapter_class}': {e}") + return None + + else: + logger.error(f"Unknown adapter_type: {adapter_type}") + return None + + def _resolve_model_name(self, model_name: Optional[str]) -> str: + """解析模型名称""" + if not model_name: + default_model = self.config_manager.get_default_model() + if not default_model: + raise ValueError("No model specified and no default model") + return default_model + return model_name + +# 全局实例 +_model_config_manager = ModelConfigManager() +model_router = ModelRouter(_model_config_manager) +``` + +--- + +## 六、实现步骤 + +### 阶段一:基础设施(1-2天) + +- [ ] 创建 `app/core/model_config.py`,实现配置管理器 +- [ ] 更新 `app/models.py`,添加标准化数据模型 +- [ ] 创建 `app/model_adapters/base.py`,定义适配器基类 +- [ ] 更新 `config/models.yaml`,采用简洁配置 +- [ ] 更新 `.env`,添加各厂商 API Key + +### 阶段二:适配器实现(2-3天) + +- [ ] 实现 `GenericAdapter`(通用 OpenAI 兼容适配器) +- [ ] 实现 2-3 个自定义适配器示例(custom_auth, signed_api 等) +- [ ] 添加单元测试 + +### 阶段三:路由和服务集成(1-2天) + +- [ ] 在 `app/services.py` 中实现 `ModelRouter` +- [ ] 更新服务函数,使用 ModelRouter +- [ ] 更新 `app/main.py` 中的 API 接口 +- [ ] 处理 model 参数 + +### 阶段四:测试和验证(1-2天) + +- [ ] 单元测试:配置加载、适配器、路由器 +- [ ] 集成测试:完整的请求/响应流程 +- [ ] 多模型测试:验证不同模型的调用 +- [ ] 自定义适配器测试:验证非标准接口处理 + +### 阶段五:文档和部署(1天) + +- [ ] 更新 README.md +- [ ] 编写配置指南 +- [ ] 编写适配器开发指南 +- [ ] 部署到生产环境 + +--- + +## 七、使用示例 + +### 7.1 标准配置示例 + +```yaml +# 最简单的配置(标准 OpenAI 兼容) +openai: + adapter_type: generic + api_url: https://api.openai.com/v1/chat/completions + api_key_env: OPENAI_API_KEY + models: + - name: gpt-3.5-turbo + display_name: "GPT-3.5 Turbo" +``` + +### 7.2 自定义适配器示例 + +```yaml +# 需要特殊处理的 API +custom_api: + adapter_type: custom + adapter_class: app.model_adapters.my_adapter.MyAdapter + api_url: https://custom-api.com/v1/chat + api_key_env: CUSTOM_API_KEY + models: + - name: custom-model + display_name: "Custom Model" +``` + +### 7.3 客户端请求示例 + +```json +// 使用不同的模型 +{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true +} + +{ + "model": "custom-model", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": false +} +``` + +--- + +## 八、总结 + +### 8.1 核心优势 + +- ✅ **Provider-Based**:单个适配器支持多个模型,资源利用率高 +- ✅ **OpenAI 兼容**:接口完全对标 OpenAI API,易于理解 +- ✅ **配置简洁**:只存数据,不存逻辑 +- ✅ **代码灵活**:复杂逻辑在代码中实现,易调试 +- ✅ **易于扩展**:添加新模型只需配置,添加新 API 只需写适配器 + +### 8.2 配置与代码职责 + +| 方面 | 配置文件 | 代码 | +|------|---------|------| +| API URL | ✅ | ❌ | +| 模型列表 | ✅ | ❌ | +| 认证方式 | ❌ | ✅(get_headers) | +| 请求转换 | ❌ | ✅(format_request) | +| 响应转换 | ❌ | ✅(parse_response) | +| 流式处理 | ❌ | ✅(stream_chat) | + +### 8.3 项目结构 + +``` +app/ +├── model_adapters/ +│ ├── __init__.py +│ ├── base.py # 适配器基类 +│ ├── generic_adapter.py # 通用适配器 +│ ├── custom_auth.py # 自定义认证示例 +│ ├── different_request.py # 请求格式不同 +│ ├── different_response.py # 响应格式不同 +│ └── signed_api.py # 签名 API +├── core/ +│ └── model_config.py # 配置管理器 +├── models.py # 数据模型 +└── services.py # 服务层(ModelRouter) + +config/ +└── models.yaml # 配置文件 +``` + +--- + +**状态**: ⏳ 待用户确认 + +**下一步**: 确认后开始实现代码