Uhhy commited on
Commit
ebc22be
·
verified ·
1 Parent(s): c8e35b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from fastapi import FastAPI, HTTPException, Request
3
  from pydantic import BaseModel
4
  from llama_cpp import Llama
@@ -6,12 +5,19 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
6
  import uvicorn
7
  import re
8
  from dotenv import load_dotenv
9
- import spaces
10
 
11
  load_dotenv()
12
 
13
  app = FastAPI()
14
 
 
 
 
 
 
 
 
15
  global_data = {
16
  'models': {},
17
  'tokens': {
@@ -57,7 +63,7 @@ class ModelManager:
57
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
58
  except Exception as e:
59
  print(f"Error loading model {model_config['name']}: {e}")
60
- pass
61
 
62
  def load_all_models(self):
63
  if self.loaded:
@@ -77,7 +83,6 @@ class ModelManager:
77
  return global_data['models']
78
  except Exception as e:
79
  print(f"Error loading models: {e}")
80
- pass
81
  return {}
82
 
83
  model_manager = ModelManager()
@@ -112,28 +117,24 @@ def remove_repetitive_responses(responses):
112
  normalized_response = remove_duplicates(response['response'])
113
  if normalized_response not in seen:
114
  seen.add(normalized_response)
115
-
116
-
117
  unique_responses.append({'model': response['model'], 'response': normalized_response})
118
  return unique_responses
119
 
120
- @app.post("/chat/")
121
- @spaces.GPU(duration=0)
122
- async def chat(request: ChatRequest):
123
  try:
124
  normalized_message = normalize_input(request.message)
125
  with ThreadPoolExecutor() as executor:
126
  futures = [executor.submit(model.generate, f"<s>[INST]{normalized_message} [/INST]",
127
  top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
128
  for model in global_data['models'].values()]
129
- responses = []
130
- for future, model_name in zip(as_completed(futures), global_data['models'].keys()):
131
- response = future.result()
132
- responses.append({'model': model_name, 'response': response})
133
  unique_responses = remove_repetitive_responses(responses)
134
  return unique_responses
135
  except Exception as e:
136
- raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
137
 
138
  if __name__ == "__main__":
139
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
  from fastapi import FastAPI, HTTPException, Request
2
  from pydantic import BaseModel
3
  from llama_cpp import Llama
 
5
  import uvicorn
6
  import re
7
  from dotenv import load_dotenv
8
+ from spaces import GPU
9
 
10
  load_dotenv()
11
 
12
  app = FastAPI()
13
 
14
+ # Initialize ZeroGPU
15
+ try:
16
+ GPU.initialize()
17
+ except Exception as e:
18
+ print(f"ZeroGPU initialization failed: {e}")
19
+
20
+ # Global data dictionary
21
  global_data = {
22
  'models': {},
23
  'tokens': {
 
63
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
64
  except Exception as e:
65
  print(f"Error loading model {model_config['name']}: {e}")
66
+ return None
67
 
68
  def load_all_models(self):
69
  if self.loaded:
 
83
  return global_data['models']
84
  except Exception as e:
85
  print(f"Error loading models: {e}")
 
86
  return {}
87
 
88
  model_manager = ModelManager()
 
117
  normalized_response = remove_duplicates(response['response'])
118
  if normalized_response not in seen:
119
  seen.add(normalized_response)
 
 
120
  unique_responses.append({'model': response['model'], 'response': normalized_response})
121
  return unique_responses
122
 
123
+ @app.post("/generate/")
124
+ async def generate(request: ChatRequest):
 
125
  try:
126
  normalized_message = normalize_input(request.message)
127
  with ThreadPoolExecutor() as executor:
128
  futures = [executor.submit(model.generate, f"<s>[INST]{normalized_message} [/INST]",
129
  top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
130
  for model in global_data['models'].values()]
131
+ responses = [{'model': model, 'response': future.result()}
132
+ for model, future in zip(global_data['models'].keys(), as_completed(futures))]
133
+
 
134
  unique_responses = remove_repetitive_responses(responses)
135
  return unique_responses
136
  except Exception as e:
137
+ raise HTTPException(status_code=500, detail=f"Error generating responses: {e}")
138
 
139
  if __name__ == "__main__":
140
  uvicorn.run(app, host="0.0.0.0", port=8000)