Update main.py
Browse files
main.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
-
from typing import Any
|
3 |
from pydantic import BaseModel
|
4 |
from os import getenv
|
5 |
from huggingface_hub import InferenceClient
|
6 |
import random
|
7 |
from json_repair import repair_json
|
8 |
import nltk
|
9 |
-
import sys
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
@@ -26,7 +25,7 @@ class InputData(BaseModel):
|
|
26 |
history: str = ""
|
27 |
|
28 |
@app.post("/generate-response/")
|
29 |
-
async def generate_response(data: InputData) -> Any:
|
30 |
client = InferenceClient(model=data.model, token=HF_TOKEN)
|
31 |
|
32 |
sentences = tokenizer.tokenize(data.user_input)
|
@@ -38,27 +37,27 @@ async def generate_response(data: InputData) -> Any:
|
|
38 |
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
|
39 |
|
40 |
inputs = (
|
41 |
-
data.system_prompt_template.replace("{SystemPrompt}",
|
42 |
-
data.system_prompt) +
|
43 |
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
|
44 |
-
data.history
|
|
|
45 |
|
46 |
seed = random.randint(0, 2**32 - 1)
|
47 |
|
48 |
try:
|
49 |
-
response = client.text_generation(
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
strict_response = str(response)
|
55 |
|
56 |
-
repaired_response = repair_json(strict_response,
|
57 |
-
return_objects=True)
|
58 |
|
59 |
if isinstance(repaired_response, str):
|
60 |
raise HTTPException(status_code=500, detail="Invalid response from model")
|
61 |
-
|
62 |
else:
|
63 |
cleaned_response = {}
|
64 |
for key, value in repaired_response.items():
|
@@ -72,6 +71,7 @@ async def generate_response(data: InputData) -> Any:
|
|
72 |
cleaned_response["New response"][i] = sentences[0]
|
73 |
else:
|
74 |
del cleaned_response["New response"][i]
|
|
|
75 |
if cleaned_response.get("Sentence count"):
|
76 |
if cleaned_response["Sentence count"] > 3:
|
77 |
cleaned_response["Sentence count"] = 3
|
@@ -80,7 +80,10 @@ async def generate_response(data: InputData) -> Any:
|
|
80 |
|
81 |
data.history += str(cleaned_response)
|
82 |
|
83 |
-
return
|
|
|
|
|
|
|
84 |
|
85 |
except Exception as e:
|
86 |
print(f"Model {data.model} failed with error: {e}")
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
+
from typing import Any, Dict
|
3 |
from pydantic import BaseModel
|
4 |
from os import getenv
|
5 |
from huggingface_hub import InferenceClient
|
6 |
import random
|
7 |
from json_repair import repair_json
|
8 |
import nltk
|
|
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
|
|
25 |
history: str = ""
|
26 |
|
27 |
@app.post("/generate-response/")
|
28 |
+
async def generate_response(data: InputData) -> Dict[str, Any]:
|
29 |
client = InferenceClient(model=data.model, token=HF_TOKEN)
|
30 |
|
31 |
sentences = tokenizer.tokenize(data.user_input)
|
|
|
37 |
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
|
38 |
|
39 |
inputs = (
|
40 |
+
data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) +
|
|
|
41 |
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
|
42 |
+
data.history
|
43 |
+
)
|
44 |
|
45 |
seed = random.randint(0, 2**32 - 1)
|
46 |
|
47 |
try:
|
48 |
+
response = client.text_generation(
|
49 |
+
inputs,
|
50 |
+
temperature=1.0,
|
51 |
+
max_new_tokens=1000,
|
52 |
+
seed=seed
|
53 |
+
)
|
54 |
|
55 |
strict_response = str(response)
|
56 |
|
57 |
+
repaired_response = repair_json(strict_response, return_objects=True)
|
|
|
58 |
|
59 |
if isinstance(repaired_response, str):
|
60 |
raise HTTPException(status_code=500, detail="Invalid response from model")
|
|
|
61 |
else:
|
62 |
cleaned_response = {}
|
63 |
for key, value in repaired_response.items():
|
|
|
71 |
cleaned_response["New response"][i] = sentences[0]
|
72 |
else:
|
73 |
del cleaned_response["New response"][i]
|
74 |
+
|
75 |
if cleaned_response.get("Sentence count"):
|
76 |
if cleaned_response["Sentence count"] > 3:
|
77 |
cleaned_response["Sentence count"] = 3
|
|
|
80 |
|
81 |
data.history += str(cleaned_response)
|
82 |
|
83 |
+
return {
|
84 |
+
"response": cleaned_response,
|
85 |
+
"history": data.history
|
86 |
+
}
|
87 |
|
88 |
except Exception as e:
|
89 |
print(f"Model {data.model} failed with error: {e}")
|