Update main.py
Browse files
main.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
-
from typing import Any, Dict, Optional
|
3 |
from pydantic import BaseModel
|
4 |
from os import getenv
|
5 |
from huggingface_hub import InferenceClient
|
@@ -18,11 +18,11 @@ HF_TOKEN = getenv("HF_TOKEN")
|
|
18 |
|
19 |
class InputData(BaseModel):
|
20 |
model: str
|
21 |
-
system_prompt_template: str
|
22 |
-
prompt_template: str
|
23 |
end_token: str
|
24 |
-
system_prompt: str
|
25 |
-
user_input: str
|
26 |
history: str = ""
|
27 |
segment: bool = False
|
28 |
max_sentences: Optional[int] = None
|
@@ -36,24 +36,29 @@ async def generate_response(data: InputData) -> Dict[str, Any]:
|
|
36 |
if data.max_sentences is not None and data.max_sentences != 0:
|
37 |
data.segment = True
|
38 |
elif data.max_sentences == 0:
|
39 |
-
|
|
|
|
|
40 |
return {
|
41 |
"response": "",
|
42 |
"history": data.history + data.end_token
|
43 |
}
|
44 |
|
|
|
45 |
if data.segment:
|
46 |
-
|
47 |
-
|
|
|
48 |
else:
|
49 |
-
user_input_str = data.user_input
|
50 |
|
51 |
-
|
|
|
52 |
|
53 |
-
inputs =
|
54 |
-
|
55 |
-
data.
|
56 |
-
|
57 |
|
58 |
seed = random.randint(0, 2**32 - 1)
|
59 |
|
@@ -111,4 +116,4 @@ async def check_word(data: WordCheckData) -> Dict[str, Any]:
|
|
111 |
"found": found
|
112 |
}
|
113 |
|
114 |
-
return result
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
from pydantic import BaseModel
|
4 |
from os import getenv
|
5 |
from huggingface_hub import InferenceClient
|
|
|
18 |
|
19 |
class InputData(BaseModel):
|
20 |
model: str
|
21 |
+
system_prompt_template: List[str]
|
22 |
+
prompt_template: List[str]
|
23 |
end_token: str
|
24 |
+
system_prompt: List[str]
|
25 |
+
user_input: List[str]
|
26 |
history: str = ""
|
27 |
segment: bool = False
|
28 |
max_sentences: Optional[int] = None
|
|
|
36 |
if data.max_sentences is not None and data.max_sentences != 0:
|
37 |
data.segment = True
|
38 |
elif data.max_sentences == 0:
|
39 |
+
for prompt in data.prompt_template:
|
40 |
+
for user_input in data.user_input:
|
41 |
+
data.history += prompt.replace("{Prompt}", user_input) + "\n"
|
42 |
return {
|
43 |
"response": "",
|
44 |
"history": data.history + data.end_token
|
45 |
}
|
46 |
|
47 |
+
user_input_str = ""
|
48 |
if data.segment:
|
49 |
+
for user_input in data.user_input:
|
50 |
+
user_sentences = tokenizer.tokenize(user_input)
|
51 |
+
user_input_str += "\n".join(user_sentences) + "\n"
|
52 |
else:
|
53 |
+
user_input_str = "\n".join(data.user_input)
|
54 |
|
55 |
+
for prompt in data.prompt_template:
|
56 |
+
data.history += prompt.replace("{Prompt}", user_input_str) + "\n"
|
57 |
|
58 |
+
inputs = ""
|
59 |
+
for system_prompt in data.system_prompt_template:
|
60 |
+
inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n"
|
61 |
+
inputs += data.history
|
62 |
|
63 |
seed = random.randint(0, 2**32 - 1)
|
64 |
|
|
|
116 |
"found": found
|
117 |
}
|
118 |
|
119 |
+
return result
|