Update main.py
Browse files
main.py
CHANGED
@@ -7,7 +7,6 @@ import random
|
|
7 |
from json_repair import repair_json
|
8 |
import nltk
|
9 |
import sys
|
10 |
-
sys.setrecursionlimit(10000) # Set a higher recursion limit
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
@@ -22,6 +21,8 @@ FALLBACK_MODELS = [
|
|
22 |
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
|
23 |
]
|
24 |
|
|
|
|
|
25 |
class InputData(BaseModel):
|
26 |
model: str
|
27 |
system_prompt_template: str
|
@@ -52,45 +53,48 @@ async def generate_response(data: InputData) -> Any:
|
|
52 |
seed = random.randint(0, 2**32 - 1)
|
53 |
|
54 |
models_to_try = [data.model] + FALLBACK_MODELS
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
for key, value in repaired_response.items():
|
73 |
-
cleaned_key = key.replace("###", "")
|
74 |
-
cleaned_response[cleaned_key] = value
|
75 |
-
|
76 |
-
for i, text in enumerate(cleaned_response["New response"]):
|
77 |
-
if i <= 2:
|
78 |
-
sentences = tokenizer.tokenize(text)
|
79 |
-
if sentences:
|
80 |
-
cleaned_response["New response"][i] = sentences[0]
|
81 |
-
else:
|
82 |
-
del cleaned_response["New response"][i]
|
83 |
-
if cleaned_response.get("Sentence count"):
|
84 |
-
if cleaned_response["Sentence count"] > 3:
|
85 |
-
cleaned_response["Sentence count"] = 3
|
86 |
else:
|
87 |
-
cleaned_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
|
|
95 |
|
96 |
-
raise HTTPException(status_code=500, detail="All models failed to generate response")
|
|
|
7 |
from json_repair import repair_json
|
8 |
import nltk
|
9 |
import sys
|
|
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
|
|
21 |
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
|
22 |
]
|
23 |
|
24 |
+
MAX_RETRIES = 3 # Maximum number of retries
|
25 |
+
|
26 |
class InputData(BaseModel):
|
27 |
model: str
|
28 |
system_prompt_template: str
|
|
|
53 |
seed = random.randint(0, 2**32 - 1)
|
54 |
|
55 |
models_to_try = [data.model] + FALLBACK_MODELS
|
56 |
+
retries = 0
|
57 |
+
|
58 |
+
while retries < MAX_RETRIES:
|
59 |
+
for model in models_to_try:
|
60 |
+
try:
|
61 |
+
response = client.text_generation(inputs,
|
62 |
+
temperature=1.0,
|
63 |
+
max_new_tokens=1000,
|
64 |
+
seed=seed)
|
65 |
+
|
66 |
+
strict_response = str(response)
|
67 |
+
|
68 |
+
repaired_response = repair_json(strict_response,
|
69 |
+
return_objects=True)
|
70 |
+
|
71 |
+
if isinstance(repaired_response, str):
|
72 |
+
raise HTTPException(status_code=500, detail="Invalid response from model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
+
cleaned_response = {}
|
75 |
+
for key, value in repaired_response.items():
|
76 |
+
cleaned_key = key.replace("###", "")
|
77 |
+
cleaned_response[cleaned_key] = value
|
78 |
+
|
79 |
+
for i, text in enumerate(cleaned_response["New response"]):
|
80 |
+
if i <= 2:
|
81 |
+
sentences = tokenizer.tokenize(text)
|
82 |
+
if sentences:
|
83 |
+
cleaned_response["New response"][i] = sentences[0]
|
84 |
+
else:
|
85 |
+
del cleaned_response["New response"][i]
|
86 |
+
if cleaned_response.get("Sentence count"):
|
87 |
+
if cleaned_response["Sentence count"] > 3:
|
88 |
+
cleaned_response["Sentence count"] = 3
|
89 |
+
else:
|
90 |
+
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
|
91 |
|
92 |
+
data.history += str(cleaned_response)
|
93 |
|
94 |
+
return cleaned_response
|
95 |
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Model {model} failed with error: {e}")
|
98 |
+
retries += 1
|
99 |
|
100 |
+
raise HTTPException(status_code=500, detail="All models failed to generate response after maximum retries")
|