snailyp commited on
Commit
fd4bd23
1 Parent(s): 379f9a0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -2
main.py CHANGED
@@ -60,6 +60,45 @@ async def verify_authorization(authorization: str = Header(None)):
60
  return token
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @app.get("/v1/models")
64
  @app.get("/hf/v1/models")
65
  async def list_models(authorization: str = Header(None)):
@@ -68,8 +107,7 @@ async def list_models(authorization: str = Header(None)):
68
  api_key = next(key_cycle)
69
  logger.info(f"Using API key: {api_key[:8]}...")
70
  try:
71
- client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
72
- response = client.models.list()
73
  logger.info("Successfully retrieved models list")
74
  return response
75
  except Exception as e:
 
60
  return token
61
 
62
 
63
+ def get_gemini_models(api_key):
64
+ base_url = "https://generativelanguage.googleapis.com/v1beta"
65
+ url = f"{base_url}/models?key={api_key}"
66
+
67
+ try:
68
+ response = requests.get(url)
69
+ if response.status_code == 200:
70
+ gemini_models = response.json()
71
+ return convert_to_openai_format(gemini_models)
72
+ else:
73
+ print(f"Error: {response.status_code}")
74
+ print(response.text)
75
+ return None
76
+
77
+ except requests.RequestException as e:
78
+ print(f"Request failed: {e}")
79
+ return None
80
+
81
+ def convert_to_openai_format(gemini_models):
82
+ openai_format = {
83
+ "object": "list",
84
+ "data": []
85
+ }
86
+
87
+ for model in gemini_models.get('models', []):
88
+ openai_model = {
89
+ "id": model['name'].split('/')[-1], # 取最后一部分作为ID
90
+ "object": "model",
91
+ "created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
92
+ "owned_by": "google", # 假设所有Gemini模型都由Google拥有
93
+ "permission": [], # Gemini API可能没有直接对应的权限信息
94
+ "root": model['name'],
95
+ "parent": None, # Gemini API可能没有直接对应的父模型信息
96
+ }
97
+ openai_format["data"].append(openai_model)
98
+
99
+ return openai_format
100
+
101
+
102
  @app.get("/v1/models")
103
  @app.get("/hf/v1/models")
104
  async def list_models(authorization: str = Header(None)):
 
107
  api_key = next(key_cycle)
108
  logger.info(f"Using API key: {api_key[:8]}...")
109
  try:
110
+ response = get_gemini_models(api_key)
 
111
  logger.info("Successfully retrieved models list")
112
  return response
113
  except Exception as e: