Spaces:
Running
Running
File size: 5,138 Bytes
960a587 0a61a36 960a587 fd4bd23 960a587 fd4bd23 960a587 4d8ee18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from typing import List, Optional
import logging
from itertools import cycle
import asyncio
import uvicorn
from app import config
import requests
from datetime import datetime, timezone
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API密钥配置
API_KEYS = config.settings.API_KEYS
# 创建一个循环迭代器
key_cycle = cycle(API_KEYS)
key_lock = asyncio.Lock()
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "llama-3.2-90b-text-preview"
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
async def verify_authorization(authorization: str = Header(None)):
if not authorization:
logger.error("Missing Authorization header")
raise HTTPException(status_code=401, detail="Missing Authorization header")
if not authorization.startswith("Bearer "):
logger.error("Invalid Authorization header format")
raise HTTPException(
status_code=401, detail="Invalid Authorization header format"
)
token = authorization.replace("Bearer ", "")
if token not in config.settings.ALLOWED_TOKENS:
logger.error("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
return token
def get_gemini_models(api_key):
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return convert_to_openai_format(gemini_models)
else:
print(f"Error: {response.status_code}")
print(response.text)
return None
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def convert_to_openai_format(gemini_models):
openai_format = {
"object": "list",
"data": []
}
for model in gemini_models.get('models', []):
openai_model = {
"id": model['name'].split('/')[-1], # 取最后一部分作为ID
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
"permission": [], # Gemini API可能没有直接对应的权限信息
"root": model['name'],
"parent": None, # Gemini API可能没有直接对应的父模型信息
}
openai_format["data"].append(openai_model)
return openai_format
@app.get("/v1/models")
@app.get("/hf/v1/models")
async def list_models(authorization: str = Header(None)):
await verify_authorization(authorization)
async with key_lock:
api_key = next(key_cycle)
logger.info(f"Using API key: {api_key[:8]}...")
try:
response = get_gemini_models(api_key)
logger.info("Successfully retrieved models list")
return response
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions")
@app.post("/hf/v1/chat/completions")
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
async with key_lock:
api_key = next(key_cycle)
logger.info(f"Using API key: {api_key[:8]}...")
try:
logger.info(f"Chat completion request - Model: {request.model}")
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.chat.completions.create(
model=request.model,
messages=request.messages,
temperature=request.temperature,
stream=request.stream if hasattr(request, "stream") else False,
)
if hasattr(request, "stream") and request.stream:
logger.info("Streaming response enabled")
async def generate():
for chunk in response:
yield f"data: {chunk.model_dump_json()}\n\n"
return StreamingResponse(content=generate(), media_type="text/event-stream")
logger.info("Chat completion successful")
return response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |