Update app.py
Browse files
app.py
CHANGED
@@ -1,137 +1,192 @@
|
|
1 |
-
from fastapi import FastAPI,Request,HTTPException
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
-
from fastapi.responses import HTMLResponse
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
import httpx
|
6 |
-
import json
|
|
|
|
|
7 |
from datetime import datetime
|
8 |
import pytz
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
app = FastAPI()
|
|
|
13 |
app.add_middleware(
|
14 |
CORSMiddleware,
|
15 |
-
allow_origins=["*"],
|
16 |
-
allow_credentials=True,
|
17 |
-
allow_methods=["*"],
|
18 |
-
allow_headers=["*"],
|
19 |
)
|
20 |
-
base_url = "https://api.siliconflow.cn/v1/chat/completions"
|
21 |
-
if os.environ.get("MODEL_MAP"):
|
22 |
-
model_map=json.loads(os.environ.get("MODEL_MAP"))
|
23 |
-
else:
|
24 |
-
model_map={
|
25 |
-
"qwen-72b":"Qwen/Qwen2.5-72B-Instruct",
|
26 |
-
"qwen-32b":"Qwen/Qwen2.5-32B-Instruct",
|
27 |
-
"qwen-14b":"Qwen/Qwen2.5-14B-Instruct",
|
28 |
-
"qwen-7b":"Qwen/Qwen2.5-7B-Instruct",
|
29 |
-
"qwen-vl":"Qwen/Qwen2-VL-72B-Instruct",
|
30 |
-
"qwen-coder":"Qwen/Qwen2.5-Coder-32B-Instruct",
|
31 |
-
"qwq":"Qwen/QwQ-32B-Preview",
|
32 |
-
"o1":"AIDC-AI/Marco-o1",
|
33 |
-
"deepseek":"deepseek-ai/DeepSeek-V2.5",
|
34 |
-
"deepseek-vl":"deepseek-ai/deepseek-vl2",
|
35 |
-
"glm-9b":"THUDM/glm-4-9b-chat",
|
36 |
-
"bce":"netease-youdao/bce-embedding-base_v1",
|
37 |
-
"bge-m3":"BAAI/bge-m3",
|
38 |
-
"bge-zh":"BAAI/bge-large-zh-v1.5"
|
39 |
-
}
|
40 |
-
if os.environ.get("SI_KEY"):
|
41 |
-
keys=os.environ.get("SI_KEY").split(",")
|
42 |
-
else:
|
43 |
-
keys=["sk-jopguysfqvrlciqjvlydkbxtynqagxqdhmrprfjivupuutfk","sk-bvkqkygxqrusyhdfrqoyoctqfhnxylpuoajgfzbtkhiecffo","sk-eeoefacvytttokuwhslnvjazspigjhrdkhxuzdrxrizrpeep"]
|
44 |
-
key_balacnce={}
|
45 |
-
key_balacnce_notes=""
|
46 |
-
# 创建一个东八区的时区对象
|
47 |
-
tz = pytz.timezone('Asia/Shanghai')
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
async def root():
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
models+=f"<h2>{key}————{model_map[key]}</h2>"
|
56 |
-
global now
|
57 |
-
return f"""
|
58 |
<html>
|
59 |
<head>
|
60 |
<title>富文本示例</title>
|
61 |
</head>
|
62 |
<body>
|
63 |
-
<h1>有效key数量:{len(keys)}</h1>
|
64 |
-
{
|
65 |
-
<h1>最后更新时间:{
|
66 |
-
{
|
67 |
</body>
|
68 |
</html>
|
69 |
"""
|
|
|
70 |
@app.get("/check")
|
71 |
async def check():
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
}
|
79 |
-
async with httpx.AsyncClient() as client:
|
80 |
-
res=await client.get(url,headers=headers)
|
81 |
-
if res.status_code==200:
|
82 |
-
balance=res.json()['data']['balance']
|
83 |
-
if float(balance)<0.1:
|
84 |
-
keys.pop(i)
|
85 |
-
continue
|
86 |
-
va=f'''<h2>{key.strip()[0:4]}****{key.strip()[-4:]}————{balance}</h2>'''
|
87 |
-
key_balacnce[key.strip()]=balance
|
88 |
-
key_balacnce_notes+=va
|
89 |
-
|
90 |
-
now = datetime.now(tz)
|
91 |
-
|
92 |
-
return f"更新成功:{now}"
|
93 |
-
@app.post("/hf/v1/chat/completions")
|
94 |
-
async def reforword(request:Request):
|
95 |
body = await request.json()
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
if "model" in body
|
104 |
-
body[
|
105 |
-
|
106 |
-
|
107 |
async def generate_response():
|
108 |
async with httpx.AsyncClient() as client:
|
109 |
-
async with client.stream("POST",
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
return StreamingResponse(generate_response(), media_type="text/event-stream")
|
116 |
else:
|
117 |
-
# 发送 POST 请求
|
118 |
async with httpx.AsyncClient() as client:
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
@app.post("/hf/v1/embeddings")
|
123 |
-
async def
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
body['model']=body_map[body['model']]
|
128 |
-
# 获取 API 密钥
|
129 |
-
key = get_si_key()
|
130 |
-
print(key)
|
131 |
-
headers = {
|
132 |
-
"Authorization": f"Bearer {key}"
|
133 |
-
}
|
134 |
-
async with httpx.AsyncClient() as client:
|
135 |
-
response = await client.post(base_url, json=body, headers=headers)
|
136 |
-
response.raise_for_status() # 检查请求是否成功
|
137 |
-
return response.json()
|
|
|
1 |
+
from fastapi import FastAPI, Request, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
import httpx
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
from datetime import datetime
|
9 |
import pytz
|
10 |
+
import logging
|
11 |
+
|
12 |
+
# 配置日志
|
13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
+
|
15 |
+
# 从环境变量或者文件中加载配置
|
16 |
+
BASE_URL = os.environ.get("BASE_URL", "https://api.siliconflow.cn/v1/chat/completions")
|
17 |
+
MODEL_MAP = {
|
18 |
+
"qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
|
19 |
+
"qwen-32b": "Qwen/Qwen2.5-32B-Instruct",
|
20 |
+
"qwen-14b": "Qwen/Qwen2.5-14B-Instruct",
|
21 |
+
"qwen-7b": "Qwen/Qwen2.5-7B-Instruct",
|
22 |
+
"qwen-vl": "Qwen/Qwen2-VL-72B-Instruct",
|
23 |
+
"qwen-coder": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
24 |
+
"qwq": "Qwen/QwQ-32B-Preview",
|
25 |
+
"o1": "AIDC-AI/Marco-o1",
|
26 |
+
"deepseek": "deepseek-ai/DeepSeek-V2.5",
|
27 |
+
"deepseek-vl": "deepseek-ai/deepseek-vl2",
|
28 |
+
"glm-9b": "THUDM/glm-4-9b-chat",
|
29 |
+
"bce": "netease-youdao/bce-embedding-base_v1",
|
30 |
+
"bge-m3": "BAAI/bge-m3",
|
31 |
+
"bge-zh": "BAAI/bge-large-zh-v1.5",
|
32 |
+
}
|
33 |
+
if os.environ.get("MODEL_MAP"):
|
34 |
+
MODEL_MAP = json.loads(os.environ.get("MODEL_MAP"))
|
35 |
+
|
36 |
+
KEY_STR = os.environ.get("SI_KEY")
|
37 |
+
if KEY_STR is None:
|
38 |
+
logging.error("SI_KEY not found in env")
|
39 |
+
raise EnvironmentError("SI_KEY not found in env")
|
40 |
+
|
41 |
+
KEYS = KEY_STR.split(",")
|
42 |
+
KEY_BALANCE = {}
|
43 |
+
KEY_BALANCE_NOTES = ""
|
44 |
+
# 创建一个东八区的时区对象
|
45 |
+
TIMEZONE = pytz.timezone("Asia/Shanghai")
|
46 |
+
LAST_UPDATE_TIME = ""
|
47 |
+
|
48 |
app = FastAPI()
|
49 |
+
|
50 |
app.add_middleware(
|
51 |
CORSMiddleware,
|
52 |
+
allow_origins=["*"],
|
53 |
+
allow_credentials=True,
|
54 |
+
allow_methods=["*"],
|
55 |
+
allow_headers=["*"],
|
56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# API Key 管理类
|
59 |
+
class ApiKeyManager:
|
60 |
+
def __init__(self, keys):
|
61 |
+
self.keys = keys
|
62 |
+
self.key_balance = {}
|
63 |
+
self.key_balance_notes = ""
|
64 |
+
self.last_update_time = ""
|
65 |
+
|
66 |
+
def get_key(self) -> str:
|
67 |
+
"""随机获取一个可用 API Key"""
|
68 |
+
if not self.keys:
|
69 |
+
raise HTTPException(status_code=500, detail="No available API keys")
|
70 |
+
random.shuffle(self.keys)
|
71 |
+
return self.keys[0]
|
72 |
+
|
73 |
+
def remove_key(self,key:str):
|
74 |
+
"""移除不可用 API Key"""
|
75 |
+
if key in self.keys:
|
76 |
+
self.keys.remove(key)
|
77 |
+
else:
|
78 |
+
logging.warning("try remove a not exists key from key_pool")
|
79 |
+
|
80 |
+
async def check_keys_balance(self):
|
81 |
+
"""检查所有 API Key 的余额,并移除余额不足的 API Key"""
|
82 |
+
self.key_balance_notes = ""
|
83 |
+
key_to_remove = []
|
84 |
+
|
85 |
+
for key in self.keys:
|
86 |
+
try:
|
87 |
+
balance = await self._fetch_balance(key)
|
88 |
+
if balance < 0.1:
|
89 |
+
key_to_remove.append(key)
|
90 |
+
else:
|
91 |
+
balance_info = f"<h2>{key.strip()[0:4]}****{key.strip()[-4:]}————{balance}</h2>"
|
92 |
+
self.key_balance[key.strip()] = balance
|
93 |
+
self.key_balance_notes += balance_info
|
94 |
+
except HTTPException as e:
|
95 |
+
logging.error(f"Key {key} check balance failed, detail:{e.detail}")
|
96 |
+
key_to_remove.append(key)
|
97 |
+
except Exception as e:
|
98 |
+
logging.error(f"Key {key} check balance failed, unexcept error:{e}")
|
99 |
+
key_to_remove.append(key)
|
100 |
+
for remove_key in key_to_remove:
|
101 |
+
self.remove_key(remove_key)
|
102 |
+
self.last_update_time = datetime.now(TIMEZONE)
|
103 |
+
|
104 |
+
async def _fetch_balance(self, key: str) -> float:
|
105 |
+
"""发送 API 请求,获取 API Key 的余额"""
|
106 |
+
url = "https://api.siliconflow.cn/v1/user/info"
|
107 |
+
headers = {"Authorization": f"Bearer {key.strip()}"}
|
108 |
+
async with httpx.AsyncClient() as client:
|
109 |
+
try:
|
110 |
+
res = await client.get(url, headers=headers)
|
111 |
+
res.raise_for_status()
|
112 |
+
balance = res.json()["data"]["balance"]
|
113 |
+
return float(balance)
|
114 |
+
except httpx.HTTPError as exc:
|
115 |
+
logging.error("httpx request error, detail:" + str(exc))
|
116 |
+
raise HTTPException(status_code=500, detail=f"Check balance failed with status:{exc.response.status_code},url:{exc.request.url}")
|
117 |
+
key_manager = ApiKeyManager(KEYS)
|
118 |
+
async def get_model_info():
|
119 |
+
models = ""
|
120 |
+
for key, value in MODEL_MAP.items():
|
121 |
+
models += f"<h2>{key}————{value}</h2>"
|
122 |
+
return models
|
123 |
+
|
124 |
+
|
125 |
+
@app.get("/", response_class=HTMLResponse)
|
126 |
async def root():
|
127 |
+
"""根路由,返回 HTML 页面,展示模型信息和更新时间"""
|
128 |
+
models_info = await get_model_info()
|
129 |
+
return f"""
|
|
|
|
|
|
|
130 |
<html>
|
131 |
<head>
|
132 |
<title>富文本示例</title>
|
133 |
</head>
|
134 |
<body>
|
135 |
+
<h1>有效key数量:{len(key_manager.keys)}</h1>
|
136 |
+
{models_info}
|
137 |
+
<h1>最后更新时间:{key_manager.last_update_time}</h1>
|
138 |
+
{key_manager.key_balance_notes}
|
139 |
</body>
|
140 |
</html>
|
141 |
"""
|
142 |
+
|
143 |
@app.get("/check")
|
144 |
async def check():
|
145 |
+
"""手动触发检查 API Key 余额的路由"""
|
146 |
+
await key_manager.check_keys_balance()
|
147 |
+
return f"更新成功:{key_manager.last_update_time}"
|
148 |
+
|
149 |
+
async def _forward_request(request: Request, api_type: str):
|
150 |
+
"""转发请求到硅流 API"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
body = await request.json()
|
152 |
+
try:
|
153 |
+
key = key_manager.get_key()
|
154 |
+
except HTTPException as e:
|
155 |
+
return e
|
156 |
+
logging.info(f"using key {key[0:4]}***{key[-4:]} to {api_type}")
|
157 |
+
headers = {"Authorization": f"Bearer {key}"}
|
158 |
+
# 处理模型映射
|
159 |
+
if "model" in body and body["model"] in MODEL_MAP:
|
160 |
+
body["model"] = MODEL_MAP[body["model"]]
|
161 |
+
|
162 |
+
if api_type == "chat" and "stream" in body and body["stream"]:
|
163 |
async def generate_response():
|
164 |
async with httpx.AsyncClient() as client:
|
165 |
+
async with client.stream("POST", BASE_URL, headers=headers, json=body) as response:
|
166 |
+
response.raise_for_status() # 检查响应状态码
|
167 |
+
async for chunk in response.aiter_bytes():
|
168 |
+
if chunk:
|
169 |
+
yield chunk
|
|
|
170 |
return StreamingResponse(generate_response(), media_type="text/event-stream")
|
171 |
else:
|
|
|
172 |
async with httpx.AsyncClient() as client:
|
173 |
+
try:
|
174 |
+
response = await client.post(BASE_URL, headers=headers, json=body)
|
175 |
+
response.raise_for_status()
|
176 |
+
return response.json()
|
177 |
+
except httpx.HTTPError as exc :
|
178 |
+
logging.error("httpx request error:" + str(exc))
|
179 |
+
raise HTTPException(
|
180 |
+
status_code=500,
|
181 |
+
detail=f"Request failed with status: {exc.response.status_code},url:{exc.request.url},detail:{exc.response.text}"
|
182 |
+
)
|
183 |
+
|
184 |
+
@app.post("/hf/v1/chat/completions")
|
185 |
+
async def chat_completions(request: Request):
|
186 |
+
"""转发 chat 完成请求"""
|
187 |
+
return await _forward_request(request, "chat")
|
188 |
@app.post("/hf/v1/embeddings")
|
189 |
+
async def embeddings(request: Request):
|
190 |
+
"""转发 embeddings 请求"""
|
191 |
+
return await _forward_request(request, "embedding")
|
192 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|