286 lines
10 KiB
Python
286 lines
10 KiB
Python
import os
|
||
import json
|
||
import random
|
||
import time
|
||
from typing import Optional, Dict, Any
|
||
from datetime import datetime
|
||
|
||
import httpx
|
||
from fastapi import FastAPI, Request, HTTPException
|
||
from fastapi.responses import StreamingResponse, JSONResponse
|
||
import uvicorn
|
||
|
||
app = FastAPI()
|
||
|
||
TOKEN_EXPIRY_THRESHOLD = 60
|
||
GITHUB_TOKEN = "ghu_kpJkheogXW18PMY0Eu6D0sL4r5bDsD3aS3EA" # 注意:硬编码令牌存在安全风险
|
||
GITHUB_API_URL = "https://api.github.com/copilot_internal/v2/token"
|
||
|
||
cached_token: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
def generate_uuid() -> str:
|
||
template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
|
||
return ''.join(
|
||
random.choice('0123456789abcdef') if c == 'x' else random.choice('89ab')
|
||
for c in template
|
||
)
|
||
|
||
|
||
def is_token_valid(token_data: Optional[Dict[str, Any]]) -> bool:
|
||
if not token_data or 'token' not in token_data or 'expires_at' not in token_data:
|
||
return False
|
||
|
||
current_time = int(time.time())
|
||
if current_time + TOKEN_EXPIRY_THRESHOLD >= token_data['expires_at']:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def get_copilot_token() -> Dict[str, Any]:
|
||
global cached_token
|
||
|
||
if cached_token and is_token_valid(cached_token):
|
||
return cached_token
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {GITHUB_TOKEN}",
|
||
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||
"Copilot-Language-Server-Version": "1.382.0",
|
||
"X-Github-Api-Version": "2024-12-15",
|
||
"User-Agent": "GithubCopilot/1.382.0",
|
||
"Accept": "*/*",
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
try:
|
||
response = await client.get(GITHUB_API_URL, headers=headers, timeout=10.0)
|
||
|
||
if response.status_code != 200:
|
||
if cached_token:
|
||
return cached_token
|
||
raise HTTPException(status_code=response.status_code, detail=f"Failed to get token: {response.text}")
|
||
|
||
data = response.json()
|
||
cached_token = data
|
||
|
||
expiry_ttl = data['expires_at'] - int(time.time()) - TOKEN_EXPIRY_THRESHOLD
|
||
if expiry_ttl > 0:
|
||
print(f"Token cached, will expire in {expiry_ttl} seconds")
|
||
else:
|
||
print("Warning: New token has short validity period")
|
||
|
||
return data
|
||
|
||
except httpx.RequestError as e:
|
||
if cached_token:
|
||
return cached_token
|
||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||
|
||
|
||
def get_headers_for_path(path: str) -> Dict[str, str]:
|
||
headers = {
|
||
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||
"Copilot-Language-Server-Version": "1.382.0",
|
||
"X-Github-Api-Version": "2025-05-01",
|
||
"Copilot-Integration-Id": "jetbrains-chat",
|
||
"User-Agent": "GithubCopilot/1.382.0",
|
||
}
|
||
|
||
if path == "/agents" or path == "/models":
|
||
return headers
|
||
|
||
elif path == "/chat/completions":
|
||
interaction_id = generate_uuid()
|
||
request_id = generate_uuid()
|
||
|
||
headers.update({
|
||
"X-Initiator": "user",
|
||
"X-Interaction-Id": interaction_id,
|
||
"X-Interaction-Type": "conversation-panel",
|
||
"Openai-Organization": "github-copilot",
|
||
"X-Request-Id": request_id,
|
||
"Vscode-Sessionid": "427689f2-5dad-4b50-95d9-7cca977450061761839746260",
|
||
"Vscode-Machineid": "c9421c6ac240db1c5bc5117218aa21a73f3762bda7db1702d003ec2df103b812",
|
||
"Openai-Intent": "conversation-panel",
|
||
"Copilot-Vision-Request": "true",
|
||
})
|
||
|
||
print(f"/chat/completions path matched, Interaction-Id: {interaction_id}, Request-Id: {request_id}")
|
||
|
||
return headers
|
||
|
||
|
||
def has_non_empty_content(msg):
|
||
"""检查消息的 content 是否非空"""
|
||
content = msg.get('content')
|
||
if content is None:
|
||
return False
|
||
if isinstance(content, str):
|
||
return bool(content.strip()) # 字符串需要去除空格后判断
|
||
if isinstance(content, (list, dict)):
|
||
return bool(content) # 列表或字典,非空则为 True
|
||
# 其他类型 (数字, 布尔值等) 通常视为非空
|
||
return True
|
||
|
||
|
||
def filter_messages_logic(messages):
|
||
"""
|
||
优化后的过滤逻辑:
|
||
找到一个 role 为 assistant 且有 tool_calls 的消息 A,
|
||
以及它后面紧接着的 role 为 tool 的消息 B。
|
||
删除 A 和 B 之间所有 content 非空的消息。
|
||
"""
|
||
if not messages or len(messages) < 3: # 至少需要 assistant, something, tool 才能操作
|
||
return
|
||
|
||
i = 0
|
||
while i < len(messages):
|
||
current_msg = messages[i]
|
||
|
||
# 检查当前消息是否为 assistant 且有 tool_calls (且 tool_calls 非空)
|
||
is_assistant_with_tool_calls = (
|
||
current_msg.get("role") == "assistant" and
|
||
isinstance(current_msg.get("tool_calls"), list) and
|
||
len(current_msg["tool_calls"]) > 0
|
||
)
|
||
|
||
if is_assistant_with_tool_calls:
|
||
# 从下一个消息开始查找第一个 role='tool' 的消息
|
||
j = i + 1
|
||
found_tool = False
|
||
indices_to_remove_between = []
|
||
|
||
while j < len(messages):
|
||
msg_to_check = messages[j]
|
||
|
||
if msg_to_check.get("role") == "tool":
|
||
found_tool = True
|
||
break # 找到第一个 tool 就停止,准备删除中间的
|
||
# 检查 j 位置的消息 (在 assistant 和 tool 之间) 是否有非空 content
|
||
if has_non_empty_content(msg_to_check):
|
||
indices_to_remove_between.append(j)
|
||
j += 1
|
||
|
||
if found_tool and indices_to_remove_between:
|
||
# 从后往前删除,避免因列表长度变化导致索引失效
|
||
for idx in sorted(indices_to_remove_between, reverse=True):
|
||
removed_msg = messages.pop(idx)
|
||
print(f"Removed intermediate message with non-empty content at index {idx}: {removed_msg}")
|
||
# 删除后,列表变短,下一次循环的 i 应该在当前位置,
|
||
# 因为原来的 i+1 位置的元素现在移动到了 i。
|
||
# 所以这里我们不增加 i,让外层循环来处理。
|
||
continue
|
||
else:
|
||
# 如果找到了 assistant 但没有找到配对的 tool,
|
||
# 或者找到了 tool 但中间没有需要删除的内容,
|
||
# 都正常检查下一条消息。
|
||
i += 1
|
||
else:
|
||
# 当前消息不符合条件,继续检查下一条
|
||
i += 1
|
||
|
||
|
||
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||
async def proxy_request(request: Request, path: str):
|
||
# 创建时间戳目录用于存放日志
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||
log_dir = os.path.join("logs", timestamp)
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# 记录原始请求数据
|
||
original_body = await request.body()
|
||
with open(os.path.join(log_dir, "original_request.txt"), "wb") as f:
|
||
f.write(original_body or b"")
|
||
|
||
token_data = await get_copilot_token()
|
||
print(token_data)
|
||
headers = get_headers_for_path(f"/{path}")
|
||
headers["Authorization"] = f"Bearer {token_data['token']}"
|
||
headers["hello"] = "world"
|
||
|
||
body = original_body
|
||
|
||
# 过滤 messages:优化后的逻辑
|
||
if body:
|
||
try:
|
||
body_data = json.loads(body.decode('utf-8') if isinstance(body, bytes) else body)
|
||
if "messages" in body_data and isinstance(body_data["messages"], list):
|
||
messages = body_data["messages"]
|
||
initial_len = len(messages)
|
||
print(f"Processing messages, initial count: {initial_len}")
|
||
filter_messages_logic(messages)
|
||
final_len = len(messages)
|
||
if initial_len != final_len:
|
||
body = json.dumps(body_data).encode('utf-8')
|
||
print(f"Messages filtered from {initial_len} to {final_len}.")
|
||
# 记录修改后的请求体
|
||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||
f.write(body or b"")
|
||
else:
|
||
# 如果没有修改,也记录原始内容作为modified_request
|
||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||
f.write(body or b"")
|
||
|
||
except json.JSONDecodeError:
|
||
# body 不是 JSON,保持原样
|
||
print("Request body is not valid JSON, skipping message filtering.")
|
||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||
f.write(body or b"")
|
||
|
||
# target_url = f"https://qwapi.oopsapi.com/v1/{path}"
|
||
target_url = "https://api.business.githubcopilot.com/" + path
|
||
|
||
print(target_url, " ", str(body))
|
||
|
||
# request_headers = {k: v for k, v in request.headers.items()
|
||
# if k.lower() not in ['host', 'content-length']}
|
||
# request_headers.update(headers)
|
||
async with httpx.AsyncClient() as client:
|
||
try:
|
||
response = await client.request(
|
||
method=request.method,
|
||
url=target_url,
|
||
headers=headers,
|
||
content=body if body else None,
|
||
timeout=120.0,
|
||
)
|
||
|
||
content = response.content
|
||
|
||
# 记录响应结果
|
||
with open(os.path.join(log_dir, "response.txt"), "wb") as f:
|
||
f.write(content or b"")
|
||
|
||
print("content: ", content)
|
||
if response.headers.get("content-type", "").startswith("text/event-stream"):
|
||
return StreamingResponse(
|
||
response.aiter_bytes(),
|
||
status_code=response.status_code,
|
||
headers=dict(response.headers),
|
||
)
|
||
|
||
return JSONResponse(
|
||
content=json.loads(content) if content else {},
|
||
status_code=response.status_code,
|
||
headers={k: v for k, v in response.headers.items()
|
||
if k.lower() not in ['content-length', 'transfer-encoding']}
|
||
)
|
||
|
||
except httpx.RequestError as e:
|
||
import backtrace
|
||
backtrace.print_exc()
|
||
raise HTTPException(status_code=500, detail=f"Proxy request failed: {str(e)}")
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "GitHub Copilot Proxy API"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|