支持ghcproxy
This commit is contained in:
285
app/ghcproxy.py
Normal file
285
app/ghcproxy.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
TOKEN_EXPIRY_THRESHOLD = 60
|
||||
GITHUB_TOKEN = "ghu_kpJkheogXW18PMY0Eu6D0sL4r5bDsD3aS3EA" # 注意:硬编码令牌存在安全风险
|
||||
GITHUB_API_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
|
||||
cached_token: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
|
||||
return ''.join(
|
||||
random.choice('0123456789abcdef') if c == 'x' else random.choice('89ab')
|
||||
for c in template
|
||||
)
|
||||
|
||||
|
||||
def is_token_valid(token_data: Optional[Dict[str, Any]]) -> bool:
|
||||
if not token_data or 'token' not in token_data or 'expires_at' not in token_data:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
if current_time + TOKEN_EXPIRY_THRESHOLD >= token_data['expires_at']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_copilot_token() -> Dict[str, Any]:
|
||||
global cached_token
|
||||
|
||||
if cached_token and is_token_valid(cached_token):
|
||||
return cached_token
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {GITHUB_TOKEN}",
|
||||
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||||
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||||
"Copilot-Language-Server-Version": "1.382.0",
|
||||
"X-Github-Api-Version": "2024-12-15",
|
||||
"User-Agent": "GithubCopilot/1.382.0",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(GITHUB_API_URL, headers=headers, timeout=10.0)
|
||||
|
||||
if response.status_code != 200:
|
||||
if cached_token:
|
||||
return cached_token
|
||||
raise HTTPException(status_code=response.status_code, detail=f"Failed to get token: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
cached_token = data
|
||||
|
||||
expiry_ttl = data['expires_at'] - int(time.time()) - TOKEN_EXPIRY_THRESHOLD
|
||||
if expiry_ttl > 0:
|
||||
print(f"Token cached, will expire in {expiry_ttl} seconds")
|
||||
else:
|
||||
print("Warning: New token has short validity period")
|
||||
|
||||
return data
|
||||
|
||||
except httpx.RequestError as e:
|
||||
if cached_token:
|
||||
return cached_token
|
||||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||
|
||||
|
||||
def get_headers_for_path(path: str) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Editor-Version": "JetBrains-IU/252.26830.84",
|
||||
"Editor-Plugin-Version": "copilot-intellij/1.5.58-243",
|
||||
"Copilot-Language-Server-Version": "1.382.0",
|
||||
"X-Github-Api-Version": "2025-05-01",
|
||||
"Copilot-Integration-Id": "jetbrains-chat",
|
||||
"User-Agent": "GithubCopilot/1.382.0",
|
||||
}
|
||||
|
||||
if path == "/agents" or path == "/models":
|
||||
return headers
|
||||
|
||||
elif path == "/chat/completions":
|
||||
interaction_id = generate_uuid()
|
||||
request_id = generate_uuid()
|
||||
|
||||
headers.update({
|
||||
"X-Initiator": "user",
|
||||
"X-Interaction-Id": interaction_id,
|
||||
"X-Interaction-Type": "conversation-panel",
|
||||
"Openai-Organization": "github-copilot",
|
||||
"X-Request-Id": request_id,
|
||||
"Vscode-Sessionid": "427689f2-5dad-4b50-95d9-7cca977450061761839746260",
|
||||
"Vscode-Machineid": "c9421c6ac240db1c5bc5117218aa21a73f3762bda7db1702d003ec2df103b812",
|
||||
"Openai-Intent": "conversation-panel",
|
||||
"Copilot-Vision-Request": "true",
|
||||
})
|
||||
|
||||
print(f"/chat/completions path matched, Interaction-Id: {interaction_id}, Request-Id: {request_id}")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def has_non_empty_content(msg):
|
||||
"""检查消息的 content 是否非空"""
|
||||
content = msg.get('content')
|
||||
if content is None:
|
||||
return False
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip()) # 字符串需要去除空格后判断
|
||||
if isinstance(content, (list, dict)):
|
||||
return bool(content) # 列表或字典,非空则为 True
|
||||
# 其他类型 (数字, 布尔值等) 通常视为非空
|
||||
return True
|
||||
|
||||
|
||||
def filter_messages_logic(messages):
|
||||
"""
|
||||
优化后的过滤逻辑:
|
||||
找到一个 role 为 assistant 且有 tool_calls 的消息 A,
|
||||
以及它后面紧接着的 role 为 tool 的消息 B。
|
||||
删除 A 和 B 之间所有 content 非空的消息。
|
||||
"""
|
||||
if not messages or len(messages) < 3: # 至少需要 assistant, something, tool 才能操作
|
||||
return
|
||||
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 检查当前消息是否为 assistant 且有 tool_calls (且 tool_calls 非空)
|
||||
is_assistant_with_tool_calls = (
|
||||
current_msg.get("role") == "assistant" and
|
||||
isinstance(current_msg.get("tool_calls"), list) and
|
||||
len(current_msg["tool_calls"]) > 0
|
||||
)
|
||||
|
||||
if is_assistant_with_tool_calls:
|
||||
# 从下一个消息开始查找第一个 role='tool' 的消息
|
||||
j = i + 1
|
||||
found_tool = False
|
||||
indices_to_remove_between = []
|
||||
|
||||
while j < len(messages):
|
||||
msg_to_check = messages[j]
|
||||
|
||||
if msg_to_check.get("role") == "tool":
|
||||
found_tool = True
|
||||
break # 找到第一个 tool 就停止,准备删除中间的
|
||||
# 检查 j 位置的消息 (在 assistant 和 tool 之间) 是否有非空 content
|
||||
if has_non_empty_content(msg_to_check):
|
||||
indices_to_remove_between.append(j)
|
||||
j += 1
|
||||
|
||||
if found_tool and indices_to_remove_between:
|
||||
# 从后往前删除,避免因列表长度变化导致索引失效
|
||||
for idx in sorted(indices_to_remove_between, reverse=True):
|
||||
removed_msg = messages.pop(idx)
|
||||
print(f"Removed intermediate message with non-empty content at index {idx}: {removed_msg}")
|
||||
# 删除后,列表变短,下一次循环的 i 应该在当前位置,
|
||||
# 因为原来的 i+1 位置的元素现在移动到了 i。
|
||||
# 所以这里我们不增加 i,让外层循环来处理。
|
||||
continue
|
||||
else:
|
||||
# 如果找到了 assistant 但没有找到配对的 tool,
|
||||
# 或者找到了 tool 但中间没有需要删除的内容,
|
||||
# 都正常检查下一条消息。
|
||||
i += 1
|
||||
else:
|
||||
# 当前消息不符合条件,继续检查下一条
|
||||
i += 1
|
||||
|
||||
|
||||
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_request(request: Request, path: str):
|
||||
# 创建时间戳目录用于存放日志
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
log_dir = os.path.join("logs", timestamp)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 记录原始请求数据
|
||||
original_body = await request.body()
|
||||
with open(os.path.join(log_dir, "original_request.txt"), "wb") as f:
|
||||
f.write(original_body or b"")
|
||||
|
||||
token_data = await get_copilot_token()
|
||||
print(token_data)
|
||||
headers = get_headers_for_path(f"/{path}")
|
||||
headers["Authorization"] = f"Bearer {token_data['token']}"
|
||||
headers["hello"] = "world"
|
||||
|
||||
body = original_body
|
||||
|
||||
# 过滤 messages:优化后的逻辑
|
||||
if body:
|
||||
try:
|
||||
body_data = json.loads(body.decode('utf-8') if isinstance(body, bytes) else body)
|
||||
if "messages" in body_data and isinstance(body_data["messages"], list):
|
||||
messages = body_data["messages"]
|
||||
initial_len = len(messages)
|
||||
print(f"Processing messages, initial count: {initial_len}")
|
||||
filter_messages_logic(messages)
|
||||
final_len = len(messages)
|
||||
if initial_len != final_len:
|
||||
body = json.dumps(body_data).encode('utf-8')
|
||||
print(f"Messages filtered from {initial_len} to {final_len}.")
|
||||
# 记录修改后的请求体
|
||||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||
f.write(body or b"")
|
||||
else:
|
||||
# 如果没有修改,也记录原始内容作为modified_request
|
||||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||
f.write(body or b"")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# body 不是 JSON,保持原样
|
||||
print("Request body is not valid JSON, skipping message filtering.")
|
||||
with open(os.path.join(log_dir, "modified_request.txt"), "wb") as f:
|
||||
f.write(body or b"")
|
||||
|
||||
# target_url = f"https://qwapi.oopsapi.com/v1/{path}"
|
||||
target_url = "https://api.business.githubcopilot.com/" + path
|
||||
|
||||
print(target_url, " ", str(body))
|
||||
|
||||
# request_headers = {k: v for k, v in request.headers.items()
|
||||
# if k.lower() not in ['host', 'content-length']}
|
||||
# request_headers.update(headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
content=body if body else None,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
content = response.content
|
||||
|
||||
# 记录响应结果
|
||||
with open(os.path.join(log_dir, "response.txt"), "wb") as f:
|
||||
f.write(content or b"")
|
||||
|
||||
print("content: ", content)
|
||||
if response.headers.get("content-type", "").startswith("text/event-stream"):
|
||||
return StreamingResponse(
|
||||
response.aiter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=json.loads(content) if content else {},
|
||||
status_code=response.status_code,
|
||||
headers={k: v for k, v in response.headers.items()
|
||||
if k.lower() not in ['content-length', 'transfer-encoding']}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
import backtrace
|
||||
backtrace.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"Proxy request failed: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "GitHub Copilot Proxy API"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -105,10 +105,12 @@ async def chat_completions(
|
||||
|
||||
# First, collect all chunks to detect if there are tool calls
|
||||
async for chunk in stream_llm_api(messages_to_llm, settings, log_id):
|
||||
logger.info(f"sse_result: {chunk}")
|
||||
raw_chunks.append(chunk)
|
||||
# Extract content from SSE chunks
|
||||
parsed = _parse_sse_data(chunk)
|
||||
if parsed and parsed.get("type") != "done":
|
||||
logger.info(f"sse_result_data: {parsed}")
|
||||
if parsed and ( parsed.get("type") != "done" or parsed.get("choices").get("finish_reason") == "stop" ):
|
||||
choices = parsed.get("choices")
|
||||
if choices and len(choices) > 0:
|
||||
delta = choices[0].get("delta")
|
||||
|
||||
@@ -166,7 +166,7 @@ async def _raw_stream_from_llm(messages: List[ChatMessage], settings: Settings,
|
||||
Yields raw byte chunks as received.
|
||||
"""
|
||||
headers = { "Authorization": f"Bearer {settings.REAL_LLM_API_KEY}", "Content-Type": "application/json" }
|
||||
payload = { "model": "default-model", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
||||
payload = { "model": "gpt-4.1", "messages": [msg.model_dump() for msg in messages], "stream": True }
|
||||
|
||||
# Log the request payload to the database
|
||||
update_request_log(log_id, llm_request=payload)
|
||||
|
||||
Reference in New Issue
Block a user