diff --git a/.gitignore b/.gitignore index 768be7d..051a000 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ dmypy.json # Cython debug symbols cython_debug/ + +# logs +logs/ diff --git a/Dockerfile b/Dockerfile index 586b117..df66bf3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim +FROM hub.rat.dev/library/python:3.10-alpine WORKDIR /app @@ -8,4 +8,4 @@ RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r re COPY . . -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "app.ghcproxy:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/app/ghcproxy.py b/app/ghcproxy.py new file mode 100644 index 0000000..8a5b4a7 --- /dev/null +++ b/app/ghcproxy.py @@ -0,0 +1,285 @@ +import os +import json +import random +import time +from typing import Optional, Dict, Any +from datetime import datetime + +import httpx +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +app = FastAPI() + +TOKEN_EXPIRY_THRESHOLD = 60 +GITHUB_TOKEN = "ghu_kpJkheogXW18PMY0Eu6D0sL4r5bDsD3aS3EA" # 注意:硬编码令牌存在安全风险 +GITHUB_API_URL = "https://api.github.com/copilot_internal/v2/token" + +cached_token: Optional[Dict[str, Any]] = None + + +def generate_uuid() -> str: + template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return ''.join( + random.choice('0123456789abcdef') if c == 'x' else random.choice('89ab') + for c in template + ) + + +def is_token_valid(token_data: Optional[Dict[str, Any]]) -> bool: + if not token_data or 'token' not in token_data or 'expires_at' not in token_data: + return False + + current_time = int(time.time()) + if current_time + TOKEN_EXPIRY_THRESHOLD >= token_data['expires_at']: + return False + + return True + + +async def get_copilot_token() -> Dict[str, Any]: + global cached_token + + if cached_token and is_token_valid(cached_token): + return cached_token + + headers = { + "Authorization": f"Bearer {GITHUB_TOKEN}", + "Editor-Version": "JetBrains-IU/252.26830.84", + "Editor-Plugin-Version": "copilot-intellij/1.5.58-243", + "Copilot-Language-Server-Version": "1.382.0", + "X-Github-Api-Version": "2024-12-15", + "User-Agent": "GithubCopilot/1.382.0", + "Accept": "*/*", + } + + async with httpx.AsyncClient() as client: + try: + response = await client.get(GITHUB_API_URL, headers=headers, timeout=10.0) + + if response.status_code != 200: + if cached_token: + return cached_token + raise HTTPException(status_code=response.status_code, detail=f"Failed to get token: {response.text}") + + data = response.json() + cached_token = data + + expiry_ttl = data['expires_at'] - int(time.time()) - TOKEN_EXPIRY_THRESHOLD + if expiry_ttl > 0: + print(f"Token cached, will expire in {expiry_ttl} seconds") + else: + print("Warning: New token has short validity period") + + return data + + except httpx.RequestError as e: + if cached_token: + return cached_token + raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}") + + +def get_headers_for_path(path: str) -> Dict[str, str]: + headers = { + "Editor-Version": "JetBrains-IU/252.26830.84", + "Editor-Plugin-Version": "copilot-intellij/1.5.58-243", + "Copilot-Language-Server-Version": "1.382.0", + "X-Github-Api-Version": "2025-05-01", + "Copilot-Integration-Id": "jetbrains-chat", + "User-Agent": "GithubCopilot/1.382.0", + } + + if path == "/agents" or path == "/models": + return headers + + elif path == "/chat/completions": + interaction_id = generate_uuid() + request_id = generate_uuid() + + headers.update({ + "X-Initiator": "user", + "X-Interaction-Id": interaction_id, + "X-Interaction-Type": "conversation-panel", + "Openai-Organization": "github-copilot", + "X-Request-Id": request_id, + "Vscode-Sessionid": "427689f2-5dad-4b50-95d9-7cca977450061761839746260", + "Vscode-Machineid": "c9421c6ac240db1c5bc5117218aa21a73f3762bda7db1702d003ec2df103b812", + "Openai-Intent": "conversation-panel", + "Copilot-Vision-Request": "true", + }) + + print(f"/chat/completions path matched, Interaction-Id: {interaction_id}, Request-Id: {request_id}") + + return headers + + +def has_non_empty_content(msg): + """检查消息的 content 是否非空""" + content = msg.get('content') + if content is None: + return False + if isinstance(content, str): + return bool(content.strip()) # 字符串需要去除空格后判断 + if isinstance(content, (list, dict)): + return bool(content) # 列表或字典,非空则为 True + # 其他类型 (数字, 布尔值等) 通常视为非空 + return True + + +def filter_messages_logic(messages): + """ + 优化后的过滤逻辑: + 找到一个 role 为 assistant 且有 tool_calls 的消息 A, + 以及它后面紧接着的 role 为 tool 的消息 B。 + 删除 A 和 B 之间所有 content 非空的消息。 + """ + if not messages or len(messages) < 3: # 至少需要 assistant, something, tool 才能操作 + return + + i = 0 + while i < len(messages): + current_msg = messages[i] + + # 检查当前消息是否为 assistant 且有 tool_calls (且 tool_calls 非空) + is_assistant_with_tool_calls = ( + current_msg.get("role") == "assistant" and + isinstance(current_msg.get("tool_calls"), list) and + len(current_msg["tool_calls"]) > 0 + ) + + if is_assistant_with_tool_calls: + # 从下一个消息开始查找第一个 role='tool' 的消息 + j = i + 1 + found_tool = False + indices_to_remove_between = [] + + while j < len(messages): + msg_to_check = messages[j] + + if msg_to_check.get("role") == "tool": + found_tool = True + break # 找到第一个 tool 就停止,准备删除中间的 + # 检查 j 位置的消息 (在 assistant 和 tool 之间) 是否有非空 content + if has_non_empty_content(msg_to_check): + indices_to_remove_between.append(j) + j += 1 + + if found_tool and indices_to_remove_between: + # 从后往前删除,避免因列表长度变化导致索引失效 + for idx in sorted(indices_to_remove_between, reverse=True): + removed_msg = messages.pop(idx) + print(f"Removed intermediate message with non-empty content at index {idx}: {removed_msg}") + # 删除后,列表变短,下一次循环的 i 应该在当前位置, + # 因为原来的 i+1 位置的元素现在移动到了 i。 + # 所以这里我们不增加 i,让外层循环来处理。 + continue + else: + # 如果找到了 assistant 但没有找到配对的 tool, + # 或者找到了 tool 但中间没有需要删除的内容, + # 都正常检查下一条消息。 + i += 1 + else: + # 当前消息不符合条件,继续检查下一条 + i += 1 + + +@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) +async def proxy_request(request: Request, path: str): + # 创建时间戳目录用于存放日志 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + log_dir = os.path.join("logs", timestamp) + os.makedirs(log_dir, exist_ok=True) + + # 记录原始请求数据 + original_body = await request.body() + with open(os.path.join(log_dir, "original_request.txt"), "wb") as f: + f.write(original_body or b"") + + token_data = await get_copilot_token() + print(token_data) + headers = get_headers_for_path(f"/{path}") + headers["Authorization"] = f"Bearer {token_data['token']}" + headers["hello"] = "world" + + body = original_body + + # 过滤 messages:优化后的逻辑 + if body: + try: + body_data = json.loads(body.decode('utf-8') if isinstance(body, bytes) else body) + if "messages" in body_data and isinstance(body_data["messages"], list): + messages = body_data["messages"] + initial_len = len(messages) + print(f"Processing messages, initial count: {initial_len}") + filter_messages_logic(messages) + final_len = len(messages) + if initial_len != final_len: + body = json.dumps(body_data).encode('utf-8') + print(f"Messages filtered from {initial_len} to {final_len}.") + # 记录修改后的请求体 + with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f: + f.write(body or b"") + else: + # 如果没有修改,也记录原始内容作为modified_request + with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f: + f.write(body or b"") + + except json.JSONDecodeError: + # body 不是 JSON,保持原样 + print("Request body is not valid JSON, skipping message filtering.") + with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f: + f.write(body or b"") + + # target_url = f"https://qwapi.oopsapi.com/v1/{path}" + target_url = "https://api.business.githubcopilot.com/" + path + + print(target_url, " ", str(body)) + + # request_headers = {k: v for k, v in request.headers.items() + # if k.lower() not in ['host', 'content-length']} + # request_headers.update(headers) + async with httpx.AsyncClient() as client: + try: + response = await client.request( + method=request.method, + url=target_url, + headers=headers, + content=body if body else None, + timeout=120.0, + ) + + content = response.content + + # 记录响应结果 + with open(os.path.join(log_dir, "response.txt"), "wb") as f: + f.write(content or b"") + + print("content: ", content) + if response.headers.get("content-type", "").startswith("text/event-stream"): + return StreamingResponse( + response.aiter_bytes(), + status_code=response.status_code, + headers=dict(response.headers), + ) + + return JSONResponse( + content=json.loads(content) if content else {}, + status_code=response.status_code, + headers={k: v for k, v in response.headers.items() + if k.lower() not in ['content-length', 'transfer-encoding']} + ) + + except httpx.RequestError as e: + import backtrace + backtrace.print_exc() + raise HTTPException(status_code=500, detail=f"Proxy request failed: {str(e)}") + + +@app.get("/") +async def root(): + return {"message": "GitHub Copilot Proxy API"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/main.py b/app/main.py index e9ef5d4..41c5105 100644 --- a/app/main.py +++ b/app/main.py @@ -105,10 +105,12 @@ async def chat_completions( # First, collect all chunks to detect if there are tool calls async for chunk in stream_llm_api(messages_to_llm, settings, log_id): + logger.info(f"sse_result: {chunk}") raw_chunks.append(chunk) # Extract content from SSE chunks parsed = _parse_sse_data(chunk) - if parsed and parsed.get("type") != "done": + logger.info(f"sse_result_data: {parsed}") + if parsed and ( parsed.get("type") != "done" or parsed.get("choices").get("finish_reason") == "stop" ): choices = parsed.get("choices") if choices and len(choices) > 0: delta = choices[0].get("delta") diff --git a/app/services.py b/app/services.py index 0735ce1..2a5bd13 100644 --- a/app/services.py +++ b/app/services.py @@ -166,7 +166,7 @@ async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings, Yields raw byte chunks as received. """ headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" } - payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True } + payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True } # Log the request payload to the database update_request_log(log_id, llm_request=payload) diff --git a/docker-compose.yml b/docker-compose.yml index 4b5b586..7eda3b6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,14 @@ version: '3.8' services: + sqlite-web: + image: docker.1ms.run/coleifer/sqlite-web + volumes: + - .:/data + environment: + SQLITE_DATABASE: llm_proxy.db + ports: + - 8580:8080 llmproxy: build: . ports: