snailyp commited on
Commit
9126c93
·
verified ·
1 Parent(s): 3e34048

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +243 -54
main.py CHANGED
@@ -1,9 +1,9 @@
1
- from fastapi import FastAPI, HTTPException, Header, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel
5
  import openai
6
- from typing import List, Optional,Union
7
  import logging
8
  from itertools import cycle
9
  import asyncio
@@ -13,6 +13,10 @@ import uvicorn
13
  from app import config
14
  import requests
15
  from datetime import datetime, timezone
 
 
 
 
16
 
17
  # 配置日志
18
  logging.basicConfig(
@@ -36,12 +40,61 @@ API_KEYS = config.settings.API_KEYS
36
 
37
  # 创建一个循环迭代器
38
  key_cycle = cycle(API_KEYS)
39
- key_lock = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  class ChatRequest(BaseModel):
43
  messages: List[dict]
44
- model: str = "llama-3.2-90b-text-preview"
45
  temperature: Optional[float] = 0.7
46
  stream: Optional[bool] = False
47
  tools: Optional[List[dict]] = []
@@ -73,49 +126,104 @@ async def verify_authorization(authorization: str = Header(None)):
73
  def get_gemini_models(api_key):
74
  base_url = "https://generativelanguage.googleapis.com/v1beta"
75
  url = f"{base_url}/models?key={api_key}"
76
-
77
  try:
78
  response = requests.get(url)
79
  if response.status_code == 200:
80
  gemini_models = response.json()
81
- return convert_to_openai_format(gemini_models)
82
  else:
83
  print(f"Error: {response.status_code}")
84
  print(response.text)
85
  return None
86
-
87
  except requests.RequestException as e:
88
  print(f"Request failed: {e}")
89
  return None
90
 
91
- def convert_to_openai_format(gemini_models):
92
- openai_format = {
93
- "object": "list",
94
- "data": []
95
- }
96
-
97
- for model in gemini_models.get('models', []):
98
  openai_model = {
99
- "id": model['name'].split('/')[-1], # 取最后一部分作为ID
100
  "object": "model",
101
  "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
102
  "owned_by": "google", # 假设所有Gemini模型都由Google拥有
103
  "permission": [], # Gemini API可能没有直接对应的权限信息
104
- "root": model['name'],
105
  "parent": None, # Gemini API可能没有直接对应的父模型信息
106
  }
107
  openai_format["data"].append(openai_model)
108
-
109
  return openai_format
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  @app.get("/v1/models")
113
  @app.get("/hf/v1/models")
114
  async def list_models(authorization: str = Header(None)):
115
  await verify_authorization(authorization)
116
- async with key_lock:
117
- api_key = next(key_cycle)
118
- logger.info(f"Using API key: {api_key}")
119
  try:
120
  response = get_gemini_models(api_key)
121
  logger.info("Successfully retrieved models list")
@@ -129,44 +237,125 @@ async def list_models(authorization: str = Header(None)):
129
  @app.post("/hf/v1/chat/completions")
130
  async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
131
  await verify_authorization(authorization)
132
- async with key_lock:
133
- api_key = next(key_cycle)
134
- logger.info(f"Using API key: {api_key}")
135
-
136
- try:
137
- logger.info(f"Chat completion request - Model: {request.model}")
138
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
139
- response = client.chat.completions.create(
140
- model=request.model,
141
- messages=request.messages,
142
- temperature=request.temperature,
143
- stream=request.stream if hasattr(request, "stream") else False,
144
- )
145
-
146
- if hasattr(request, "stream") and request.stream:
147
- logger.info("Streaming response enabled")
148
-
149
- async def generate():
150
- for chunk in response:
151
- yield f"data: {chunk.model_dump_json()}\n\n"
152
-
153
- return StreamingResponse(content=generate(), media_type="text/event-stream")
154
-
155
- logger.info("Chat completion successful")
156
- return response
157
-
158
- except Exception as e:
159
- logger.error(f"Error in chat completion: {str(e)}")
160
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  @app.post("/v1/embeddings")
164
  @app.post("/hf/v1/embeddings")
165
  async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
166
  await verify_authorization(authorization)
167
- async with key_lock:
168
- api_key = next(key_cycle)
169
- logger.info(f"Using API key: {api_key}")
170
 
171
  try:
172
  client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
@@ -186,4 +375,4 @@ async def health_check():
186
 
187
 
188
  if __name__ == "__main__":
189
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Header
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel
5
  import openai
6
+ from typing import List, Optional, Union
7
  import logging
8
  from itertools import cycle
9
  import asyncio
 
13
  from app import config
14
  import requests
15
  from datetime import datetime, timezone
16
+ import json
17
+ import httpx
18
+ import uuid
19
+ import time
20
 
21
  # 配置日志
22
  logging.basicConfig(
 
40
 
41
  # 创建一个循环迭代器
42
  key_cycle = cycle(API_KEYS)
43
+
44
+ # 创建两个独立的锁
45
+ key_cycle_lock = asyncio.Lock()
46
+ failure_count_lock = asyncio.Lock()
47
+
48
+ # 添加key失败计数记录
49
+ key_failure_counts = {key: 0 for key in API_KEYS}
50
+ MAX_FAILURES = 10 # 最大失败次数阈值
51
+ MAX_RETRIES = 3 # 最大重试次数
52
+
53
+
54
+ async def get_next_key():
55
+ """仅获取下一个key,不检查失败次数"""
56
+ async with key_cycle_lock:
57
+ return next(key_cycle)
58
+
59
+ async def is_key_valid(key):
60
+ """检查key是否有效"""
61
+ async with failure_count_lock:
62
+ return key_failure_counts[key] < MAX_FAILURES
63
+
64
+ async def reset_failure_counts():
65
+ """重置所有key的失败计数"""
66
+ async with failure_count_lock:
67
+ for key in key_failure_counts:
68
+ key_failure_counts[key] = 0
69
+
70
+ async def get_next_working_key():
71
+ """获取下一个可用的API key"""
72
+ initial_key = await get_next_key()
73
+ current_key = initial_key
74
+
75
+ while True:
76
+ if await is_key_valid(current_key):
77
+ return current_key
78
+
79
+ current_key = await get_next_key()
80
+ if current_key == initial_key: # 已经循环了一圈
81
+ await reset_failure_counts()
82
+ return current_key
83
+
84
+ async def handle_api_failure(api_key):
85
+ """处理API调用失败"""
86
+ async with failure_count_lock:
87
+ key_failure_counts[api_key] += 1
88
+ if key_failure_counts[api_key] >= MAX_FAILURES:
89
+ logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
90
+
91
+ # 在锁外获取新的key
92
+ return await get_next_working_key()
93
 
94
 
95
  class ChatRequest(BaseModel):
96
  messages: List[dict]
97
+ model: str = "gemini-1.5-flash-002"
98
  temperature: Optional[float] = 0.7
99
  stream: Optional[bool] = False
100
  tools: Optional[List[dict]] = []
 
126
  def get_gemini_models(api_key):
127
  base_url = "https://generativelanguage.googleapis.com/v1beta"
128
  url = f"{base_url}/models?key={api_key}"
129
+
130
  try:
131
  response = requests.get(url)
132
  if response.status_code == 200:
133
  gemini_models = response.json()
134
+ return convert_to_openai_models_format(gemini_models)
135
  else:
136
  print(f"Error: {response.status_code}")
137
  print(response.text)
138
  return None
139
+
140
  except requests.RequestException as e:
141
  print(f"Request failed: {e}")
142
  return None
143
 
144
+
145
+ def convert_to_openai_models_format(gemini_models):
146
+ openai_format = {"object": "list", "data": []}
147
+
148
+ for model in gemini_models.get("models", []):
 
 
149
  openai_model = {
150
+ "id": model["name"].split("/")[-1], # 取最后一部分作为ID
151
  "object": "model",
152
  "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
153
  "owned_by": "google", # 假设所有Gemini模型都由Google拥有
154
  "permission": [], # Gemini API可能没有直接对应的权限信息
155
+ "root": model["name"],
156
  "parent": None, # Gemini API可能没有直接对应的父模型信息
157
  }
158
  openai_format["data"].append(openai_model)
159
+
160
  return openai_format
161
+
162
+
163
+ def convert_messages_to_gemini_format(messages):
164
+ """Convert OpenAI message format to Gemini format"""
165
+ gemini_messages = []
166
+ for message in messages:
167
+ gemini_message = {
168
+ "role": "user" if message["role"] == "user" else "model",
169
+ "parts": [{"text": message["content"]}],
170
+ }
171
+ gemini_messages.append(gemini_message)
172
+ return gemini_messages
173
+
174
+
175
+ def convert_gemini_response_to_openai(response, model, stream=False):
176
+ """Convert Gemini response to OpenAI format"""
177
+ if stream:
178
+ # 处理流式响应
179
+ chunk = response
180
+ if not chunk["candidates"]:
181
+ return None
182
+
183
+ return {
184
+ "id": "chatcmpl-" + str(uuid.uuid4()),
185
+ "object": "chat.completion.chunk",
186
+ "created": int(time.time()),
187
+ "model": model,
188
+ "choices": [
189
+ {
190
+ "index": 0,
191
+ "delta": {
192
+ "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
193
+ },
194
+ "finish_reason": None,
195
+ }
196
+ ],
197
+ }
198
+ else:
199
+ # 处理普通响应
200
+ return {
201
+ "id": "chatcmpl-" + str(uuid.uuid4()),
202
+ "object": "chat.completion",
203
+ "created": int(time.time()),
204
+ "model": model,
205
+ "choices": [
206
+ {
207
+ "index": 0,
208
+ "message": {
209
+ "role": "assistant",
210
+ "content": response["candidates"][0]["content"]["parts"][0][
211
+ "text"
212
+ ],
213
+ },
214
+ "finish_reason": "stop",
215
+ }
216
+ ],
217
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
218
+ }
219
+
220
 
221
  @app.get("/v1/models")
222
  @app.get("/hf/v1/models")
223
  async def list_models(authorization: str = Header(None)):
224
  await verify_authorization(authorization)
225
+ api_key = await get_next_working_key()
226
+ logger.info(f"Using API key: {api_key}")
 
227
  try:
228
  response = get_gemini_models(api_key)
229
  logger.info("Successfully retrieved models list")
 
237
  @app.post("/hf/v1/chat/completions")
238
  async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
239
  await verify_authorization(authorization)
240
+ api_key = await get_next_working_key()
241
+ logger.info(f"Chat completion request - Model: {request.model}")
242
+ retries = 0
243
+
244
+ while retries < MAX_RETRIES:
245
+ try:
246
+ logger.info(f"Attempt {retries + 1} with API key: {api_key}")
247
+
248
+ if request.model in config.settings.MODEL_SEARCH:
249
+ # Gemini API调用部分
250
+ gemini_messages = convert_messages_to_gemini_format(request.messages)
251
+ # 调用Gemini API
252
+ payload = {
253
+ "contents": gemini_messages,
254
+ "generationConfig": {
255
+ "temperature": request.temperature,
256
+ },
257
+ "tools": [{"googleSearch": {}}],
258
+ }
259
+
260
+ if request.stream:
261
+ logger.info("Streaming response enabled")
262
+
263
+ async def generate():
264
+ nonlocal api_key, retries
265
+ while retries < MAX_RETRIES:
266
+ try:
267
+ async with httpx.AsyncClient() as client:
268
+ stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
269
+ async with client.stream("POST", stream_url, json=payload) as response:
270
+ if response.status_code == 429:
271
+ logger.warning(f"Rate limit reached for key: {api_key}")
272
+ api_key = await handle_api_failure(api_key)
273
+ logger.info(f"Retrying with new API key: {api_key}")
274
+ retries += 1
275
+ if retries >= MAX_RETRIES:
276
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
277
+ break
278
+ continue
279
+
280
+ if response.status_code != 200:
281
+ logger.error(f"Error in streaming response: {response.status_code}")
282
+ yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
283
+ break
284
+
285
+ async for line in response.aiter_lines():
286
+ if line.startswith("data: "):
287
+ try:
288
+ chunk = json.loads(line[6:])
289
+ openai_chunk = convert_gemini_response_to_openai(
290
+ chunk, request.model, stream=True
291
+ )
292
+ if openai_chunk:
293
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
294
+ except json.JSONDecodeError:
295
+ continue
296
+ yield "data: [DONE]\n\n"
297
+ return
298
+ except Exception as e:
299
+ logger.error(f"Stream error: {str(e)}")
300
+ api_key = await handle_api_failure(api_key)
301
+ retries += 1
302
+ if retries >= MAX_RETRIES:
303
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
304
+ break
305
+ continue
306
+
307
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
308
+ else:
309
+ # 非流式响应
310
+ async with httpx.AsyncClient() as client:
311
+ non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
312
+ response = await client.post(non_stream_url, json=payload)
313
+ gemini_response = response.json()
314
+ logger.info("Chat completion successful")
315
+ return convert_gemini_response_to_openai(gemini_response, request.model)
316
+
317
+ # OpenAI API调用部分
318
+ client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
319
+ response = client.chat.completions.create(
320
+ model=request.model,
321
+ messages=request.messages,
322
+ temperature=request.temperature,
323
+ stream=request.stream if hasattr(request, "stream") else False,
324
+ )
325
+
326
+ if hasattr(request, "stream") and request.stream:
327
+ logger.info("Streaming response enabled")
328
+
329
+ async def generate():
330
+ for chunk in response:
331
+ yield f"data: {chunk.model_dump_json()}\n\n"
332
+ logger.info("Chat completion successful")
333
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
334
+
335
+ logger.info("Chat completion successful")
336
+ return response
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error in chat completion: {str(e)}")
340
+ api_key = await handle_api_failure(api_key)
341
+ retries += 1
342
+
343
+ if retries >= MAX_RETRIES:
344
+ logger.error("Max retries reached, giving up")
345
+ raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
346
+
347
+ logger.info(f"Retrying with new API key: {api_key}")
348
+ continue
349
+
350
+ raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
351
 
352
 
353
  @app.post("/v1/embeddings")
354
  @app.post("/hf/v1/embeddings")
355
  async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
356
  await verify_authorization(authorization)
357
+ api_key = await get_next_working_key()
358
+ logger.info(f"Using API key: {api_key}")
 
359
 
360
  try:
361
  client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
 
375
 
376
 
377
  if __name__ == "__main__":
378
+ uvicorn.run(app, host="0.0.0.0", port=8000)