Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -25,6 +25,9 @@ from fastapi.staticfiles import StaticFiles
|
|
25 |
|
26 |
from bearer_token import BearerTokenGenerator
|
27 |
|
|
|
|
|
|
|
28 |
# 模型列表
|
29 |
MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
|
30 |
|
@@ -121,6 +124,18 @@ def is_base64_image(url: str) -> bool:
|
|
121 |
"""
|
122 |
return url.startswith("data:image/")
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
# 根路径GET请求处理
|
125 |
@app.get("/", response_class=HTMLResponse)
|
126 |
async def read_root():
|
@@ -140,7 +155,7 @@ async def read_root():
|
|
140 |
|
141 |
# 聊天完成处理
|
142 |
@app.post("/ai/v1/chat/completions")
|
143 |
-
async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
144 |
"""
|
145 |
处理聊天完成请求
|
146 |
"""
|
@@ -380,7 +395,7 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
|
|
380 |
|
381 |
# 图像生成处理
|
382 |
@app.post("/ai/v1/images/generations")
|
383 |
-
async def images_generations(request: Request):
|
384 |
"""
|
385 |
处理图像生成请求
|
386 |
"""
|
@@ -544,10 +559,13 @@ def main():
|
|
544 |
parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
|
545 |
parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
|
546 |
args = parser.parse_args()
|
547 |
-
|
548 |
base_url = args.base_url
|
549 |
port = args.port
|
550 |
|
|
|
|
|
|
|
|
|
551 |
# 确保 images 目录存在
|
552 |
if not os.path.exists("images"):
|
553 |
os.makedirs("images")
|
|
|
25 |
|
26 |
from bearer_token import BearerTokenGenerator
|
27 |
|
28 |
+
from fastapi import Depends, HTTPException, Security
|
29 |
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
30 |
+
|
31 |
# 模型列表
|
32 |
MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
|
33 |
|
|
|
124 |
"""
|
125 |
return url.startswith("data:image/")
|
126 |
|
127 |
+
# 添加 HTTPBearer 实例
|
128 |
+
security = HTTPBearer()
|
129 |
+
|
130 |
+
# 添加 API_KEY 验证函数
|
131 |
+
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
|
132 |
+
api_key = os.environ.get("API_KEY")
|
133 |
+
if api_key is None:
|
134 |
+
raise HTTPException(status_code=500, detail="API_KEY not set in environment variables")
|
135 |
+
if credentials.credentials != api_key:
|
136 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
137 |
+
return credentials.credentials
|
138 |
+
|
139 |
# 根路径GET请求处理
|
140 |
@app.get("/", response_class=HTMLResponse)
|
141 |
async def read_root():
|
|
|
155 |
|
156 |
# 聊天完成处理
|
157 |
@app.post("/ai/v1/chat/completions")
|
158 |
+
async def chat_completions(request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
|
159 |
"""
|
160 |
处理聊天完成请求
|
161 |
"""
|
|
|
395 |
|
396 |
# 图像生成处理
|
397 |
@app.post("/ai/v1/images/generations")
|
398 |
+
async def images_generations(request: Request, api_key: str = Depends(verify_api_key)):
|
399 |
"""
|
400 |
处理图像生成请求
|
401 |
"""
|
|
|
559 |
parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
|
560 |
parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
|
561 |
args = parser.parse_args()
|
|
|
562 |
base_url = args.base_url
|
563 |
port = args.port
|
564 |
|
565 |
+
# 检查 API_KEY 是否设置
|
566 |
+
if not os.environ.get("API_KEY"):
|
567 |
+
print("警告: API_KEY 环境变量未设置。客户端验证将无法正常工作。")
|
568 |
+
|
569 |
# 确保 images 目录存在
|
570 |
if not os.path.exists("images"):
|
571 |
os.makedirs("images")
|