Update main.py
Browse files
main.py
CHANGED
@@ -21,8 +21,6 @@ FALLBACK_MODELS = [
|
|
21 |
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
|
22 |
]
|
23 |
|
24 |
-
MAX_RETRIES = 3 # Maximum number of retries
|
25 |
-
|
26 |
class InputData(BaseModel):
|
27 |
model: str
|
28 |
system_prompt_template: str
|
@@ -53,48 +51,45 @@ async def generate_response(data: InputData) -> Any:
|
|
53 |
seed = random.randint(0, 2**32 - 1)
|
54 |
|
55 |
models_to_try = [data.model] + FALLBACK_MODELS
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
if
|
81 |
-
|
82 |
-
if sentences:
|
83 |
-
cleaned_response["New response"][i] = sentences[0]
|
84 |
-
else:
|
85 |
-
del cleaned_response["New response"][i]
|
86 |
-
if cleaned_response.get("Sentence count"):
|
87 |
-
if cleaned_response["Sentence count"] > 3:
|
88 |
-
cleaned_response["Sentence count"] = 3
|
89 |
else:
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
|
94 |
-
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
retries += 1
|
99 |
|
100 |
-
raise HTTPException(status_code=500, detail="All models failed to generate response
|
|
|
21 |
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
|
22 |
]
|
23 |
|
|
|
|
|
24 |
class InputData(BaseModel):
|
25 |
model: str
|
26 |
system_prompt_template: str
|
|
|
51 |
seed = random.randint(0, 2**32 - 1)
|
52 |
|
53 |
models_to_try = [data.model] + FALLBACK_MODELS
|
54 |
+
|
55 |
+
for model in models_to_try:
|
56 |
+
try:
|
57 |
+
response = client.text_generation(inputs,
|
58 |
+
temperature=1.0,
|
59 |
+
max_new_tokens=1000,
|
60 |
+
seed=seed)
|
61 |
+
|
62 |
+
strict_response = str(response)
|
63 |
+
|
64 |
+
repaired_response = repair_json(strict_response,
|
65 |
+
return_objects=True)
|
66 |
+
|
67 |
+
if isinstance(repaired_response, str):
|
68 |
+
raise HTTPException(status_code=500, detail="Invalid response from model")
|
69 |
+
else:
|
70 |
+
cleaned_response = {}
|
71 |
+
for key, value in repaired_response.items():
|
72 |
+
cleaned_key = key.replace("###", "")
|
73 |
+
cleaned_response[cleaned_key] = value
|
74 |
+
|
75 |
+
for i, text in enumerate(cleaned_response["New response"]):
|
76 |
+
if i <= 2:
|
77 |
+
sentences = tokenizer.tokenize(text)
|
78 |
+
if sentences:
|
79 |
+
cleaned_response["New response"][i] = sentences[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
+
del cleaned_response["New response"][i]
|
82 |
+
if cleaned_response.get("Sentence count"):
|
83 |
+
if cleaned_response["Sentence count"] > 3:
|
84 |
+
cleaned_response["Sentence count"] = 3
|
85 |
+
else:
|
86 |
+
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
|
87 |
|
88 |
+
data.history += str(cleaned_response)
|
89 |
|
90 |
+
return cleaned_response
|
91 |
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Model {model} failed with error: {e}")
|
|
|
94 |
|
95 |
+
raise HTTPException(status_code=500, detail="All models failed to generate response")
|