smgc commited on
Commit
4e08475
·
verified ·
1 Parent(s): 7dc6e35

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -3
main.py CHANGED
@@ -25,6 +25,9 @@ from fastapi.staticfiles import StaticFiles
25
 
26
  from bearer_token import BearerTokenGenerator
27
 
 
 
 
28
  # 模型列表
29
  MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
30
 
@@ -121,6 +124,18 @@ def is_base64_image(url: str) -> bool:
121
  """
122
  return url.startswith("data:image/")
123
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # 根路径GET请求处理
125
  @app.get("/", response_class=HTMLResponse)
126
  async def read_root():
@@ -140,7 +155,7 @@ async def read_root():
140
 
141
  # 聊天完成处理
142
  @app.post("/ai/v1/chat/completions")
143
- async def chat_completions(request: Request, background_tasks: BackgroundTasks):
144
  """
145
  处理聊天完成请求
146
  """
@@ -380,7 +395,7 @@ async def chat_completions(request: Request, background_tasks: BackgroundTasks):
380
 
381
  # 图像生成处理
382
  @app.post("/ai/v1/images/generations")
383
- async def images_generations(request: Request):
384
  """
385
  处理图像生成请求
386
  """
@@ -544,10 +559,13 @@ def main():
544
  parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
545
  parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
546
  args = parser.parse_args()
547
-
548
  base_url = args.base_url
549
  port = args.port
550
 
 
 
 
 
551
  # 确保 images 目录存在
552
  if not os.path.exists("images"):
553
  os.makedirs("images")
 
25
 
26
  from bearer_token import BearerTokenGenerator
27
 
28
+ from fastapi import Depends, HTTPException, Security
29
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
30
+
31
  # 模型列表
32
  MODELS = ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet", "claude"]
33
 
 
124
  """
125
  return url.startswith("data:image/")
126
 
127
+ # 添加 HTTPBearer 实例
128
+ security = HTTPBearer()
129
+
130
+ # 添加 API_KEY 验证函数
131
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
132
+ api_key = os.environ.get("API_KEY")
133
+ if api_key is None:
134
+ raise HTTPException(status_code=500, detail="API_KEY not set in environment variables")
135
+ if credentials.credentials != api_key:
136
+ raise HTTPException(status_code=401, detail="Invalid API key")
137
+ return credentials.credentials
138
+
139
  # 根路径GET请求处理
140
  @app.get("/", response_class=HTMLResponse)
141
  async def read_root():
 
155
 
156
  # 聊天完成处理
157
  @app.post("/ai/v1/chat/completions")
158
+ async def chat_completions(request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
159
  """
160
  处理聊天完成请求
161
  """
 
395
 
396
  # 图像生成处理
397
  @app.post("/ai/v1/images/generations")
398
+ async def images_generations(request: Request, api_key: str = Depends(verify_api_key)):
399
  """
400
  处理图像生成请求
401
  """
 
559
  parser.add_argument('--base_url', type=str, default='http://localhost', help='Base URL for accessing images')
560
  parser.add_argument('--port', type=int, default=INITIAL_PORT, help='服务器监听端口')
561
  args = parser.parse_args()
 
562
  base_url = args.base_url
563
  port = args.port
564
 
565
+ # 检查 API_KEY 是否设置
566
+ if not os.environ.get("API_KEY"):
567
+ print("警告: API_KEY 环境变量未设置。客户端验证将无法正常工作。")
568
+
569
  # 确保 images 目录存在
570
  if not os.path.exists("images"):
571
  os.makedirs("images")