|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import StreamingResponse, HTMLResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import httpx |
|
import json |
|
import os |
|
import random |
|
from datetime import datetime |
|
import pytz |
|
from typing import Dict, List, Optional, Any, Callable |
|
|
|
|
|
|
|
SI_KEYS = os.environ.get("SI_KEY", "").split(",") |
|
MODEL_MAP_JSON = os.environ.get("MODEL_MAP") |
|
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.siliconflow.cn/v1") |
|
REQUEST_TIMEOUT=120 |
|
|
|
ModelMap = Dict[str, str] |
|
|
|
KeyBalance = Dict[str, float] |
|
|
|
|
|
|
|
DEFAULT_MODEL_MAP: ModelMap = { |
|
"qwen-72b": "Qwen/Qwen2.5-72B-Instruct", |
|
"qwen-32b": "Qwen/Qwen2.5-32B-Instruct", |
|
"qwen-14b": "Qwen/Qwen2.5-14B-Instruct", |
|
"qwen-7b": "Qwen/Qwen2.5-7B-Instruct", |
|
"qwen-vl": "Qwen/Qwen2-VL-72B-Instruct", |
|
"qwen-coder": "Qwen/Qwen2.5-Coder-32B-Instruct", |
|
"qwq": "Qwen/QwQ-32B-Preview", |
|
"o1": "AIDC-AI/Marco-o1", |
|
"deepseek": "deepseek-ai/DeepSeek-V2.5", |
|
"deepseek-vl": "deepseek-ai/deepseek-vl2", |
|
"glm-9b": "THUDM/glm-4-9b-chat", |
|
"bce": "netease-youdao/bce-embedding-base_v1", |
|
"bge-m3": "BAAI/bge-m3", |
|
"bge-zh": "BAAI/bge-large-zh-v1.5", |
|
"sd": "stabilityai/stable-diffusion-3-5-large", |
|
"sd-turbo": "stabilityai/stable-diffusion-3-5-large-turbo", |
|
"flux-s": "black-forest-labs/FLUX.1-schnell", |
|
"flux-d": "black-forest-labs/FLUX.1-dev", |
|
} |
|
|
|
model_map: ModelMap = json.loads(MODEL_MAP_JSON) if MODEL_MAP_JSON else DEFAULT_MODEL_MAP |
|
|
|
|
|
keys: List[str] = [key.strip() for key in SI_KEYS if key.strip()] |
|
key_balance: KeyBalance = {} |
|
key_balance_notes: str = "" |
|
tz = pytz.timezone("Asia/Shanghai") |
|
last_updated_time: str = "" |
|
|
|
|
|
def get_api_key() -> str: |
|
"""随机返回一个API密钥.""" |
|
random.shuffle(keys) |
|
return keys[0] if keys else "" |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
def format_key_balance_note(key: str, balance: float) -> str: |
|
"""将 key 和 balance 信息格式化为 HTML 片段.""" |
|
safe_key = f"{key[0:4]}****{key[-4:]}" |
|
return f"<h2>{safe_key}————{balance}</h2>" |
|
|
|
async def check_key(client: httpx.AsyncClient, key:str) -> Optional[float]: |
|
"""检查单个密钥是否有效,并返回余额.""" |
|
url = f"{API_BASE_URL}/user/info" |
|
headers = {"Authorization": f"Bearer {key}"} |
|
try: |
|
res = await client.get(url, headers=headers) |
|
res.raise_for_status() |
|
balance = res.json()["data"]["balance"] |
|
return float(balance) |
|
except httpx.HTTPError as e: |
|
print(f"Error checking key {key}: {e}") |
|
return None |
|
|
|
async def forward_request( |
|
request: Request, |
|
url_path: str, |
|
is_stream: bool = False, |
|
) -> Any: |
|
"""通用的请求转发函数.""" |
|
body = await request.json() |
|
key = get_api_key() |
|
if not key: |
|
raise HTTPException(status_code=400, detail="No valid API key available.") |
|
headers = {"Authorization": f"Bearer {key}"} |
|
|
|
if "model" in body and body["model"] in model_map: |
|
body["model"] = model_map[body["model"]] |
|
if not "stream" in body or not body['stream']: |
|
is_stream=False |
|
if is_stream: |
|
async def generate_response(): |
|
async with httpx.AsyncClient() as client: |
|
async with client.stream( |
|
"POST", f"{API_BASE_URL}{url_path}", headers=headers, json=body |
|
) as response: |
|
response.raise_for_status() |
|
async for chunk in response.aiter_bytes(): |
|
if chunk: |
|
yield chunk |
|
return StreamingResponse(generate_response(), media_type="text/event-stream") |
|
else: |
|
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client: |
|
response = await client.post( |
|
f"{API_BASE_URL}{url_path}", headers=headers, json=body |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
"""返回 HTML 格式的页面,显示模型和密钥信息.""" |
|
models_info = "" |
|
for key, value in model_map.items(): |
|
models_info += f"<h2>{key}————{value}</h2>" |
|
|
|
return f""" |
|
<html> |
|
<head> |
|
<title>API 状态</title> |
|
</head> |
|
<body> |
|
<h1>有效Key数量: {len(keys)}</h1> |
|
{models_info} |
|
<h1>最后更新时间:{last_updated_time}</h1> |
|
{key_balance_notes} |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
@app.get("/check") |
|
async def check(): |
|
"""检查 API 密钥的余额,并更新 key_balance、key_balance_notes 和 last_updated_time.""" |
|
global key_balance, key_balance_notes, last_updated_time, keys |
|
key_balance_notes = "" |
|
new_keys = [] |
|
key_balance = {} |
|
async with httpx.AsyncClient() as client: |
|
for key in keys: |
|
balance = await check_key(client, key) |
|
if balance is not None and balance >= 0.1: |
|
key_balance[key] = balance |
|
key_balance_notes += format_key_balance_note(key, balance) |
|
new_keys.append(key) |
|
keys=new_keys |
|
last_updated_time = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") |
|
return f"更新成功:{last_updated_time}" |
|
|
|
|
|
@app.post("/hf/v1/chat/completions") |
|
async def chat_completions(request: Request): |
|
"""转发聊天补全请求,处理流式和非流式响应.""" |
|
return await forward_request(request, "/chat/completions", is_stream=True) |
|
|
|
@app.post("/hf/v1/embeddings") |
|
async def embeddings(request: Request): |
|
"""转发 embedding 请求.""" |
|
return await forward_request(request, "/embeddings") |
|
|
|
|
|
@app.post("/hf/v1/images/generations") |
|
async def images_generations(request: Request): |
|
"""转发图像生成请求.""" |
|
return await forward_request(request,"/images/generations",) |
|
|