Spaces:
Runtime error
Runtime error
Rodolfo Torres
commited on
Commit
·
11f8b48
1
Parent(s):
97c4064
Fixing boolean answers
Browse files- main.py +29 -7
- static/js/app.js +9 -1
main.py
CHANGED
@@ -8,12 +8,14 @@ from fastapi.responses import JSONResponse
|
|
8 |
from io import BytesIO
|
9 |
import PyPDF2
|
10 |
from newspaper import Article
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
)
|
17 |
|
18 |
app = FastAPI()
|
19 |
|
@@ -27,6 +29,19 @@ class TextInfo(BaseModel):
|
|
27 |
|
28 |
class QuestionInfo(BaseModel):
|
29 |
question: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
@app.post("/store_text/{uuid}")
|
32 |
async def store_text(uuid: str, text_info: TextInfo):
|
@@ -97,14 +112,21 @@ async def upload_file(uuid: str, file: UploadFile = File(...)):
|
|
97 |
|
98 |
@app.post("/answer_question/{uuid}")
|
99 |
async def answer_question(uuid: str, question_info: QuestionInfo):
|
|
|
|
|
100 |
question = question_info.question
|
101 |
|
102 |
# Verifica si el texto con el ID existe en el diccionario
|
103 |
if uuid not in text_storage:
|
104 |
return {'error': 'Text not found'}
|
105 |
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
8 |
from io import BytesIO
|
9 |
import PyPDF2
|
10 |
from newspaper import Article
|
11 |
+
import torch
|
12 |
+
from transformers import AutoModelForMultipleChoice, AutoTokenizer
|
13 |
|
14 |
+
qa_pipeline = pipeline("question-answering", model="roaltopo/scan-u-doc_question-answer")
|
15 |
+
bool_q_pipeline = pipeline("text-classification", model="roaltopo/scan-u-doc_bool-question")
|
16 |
+
model_path = "roaltopo/scan-u-doc_bool-answer"
|
17 |
+
bool_a_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
18 |
+
bool_a_model = AutoModelForMultipleChoice.from_pretrained(model_path)
|
19 |
|
20 |
app = FastAPI()
|
21 |
|
|
|
29 |
|
30 |
class QuestionInfo(BaseModel):
|
31 |
question: str
|
32 |
+
allow_bool: Optional[bool] = False
|
33 |
+
|
34 |
+
def predict_boolean_answer(text, question):
|
35 |
+
id2label = {0: "NO", 1: "YES"}
|
36 |
+
text += '\n'
|
37 |
+
question += '\n'
|
38 |
+
inputs = bool_a_tokenizer([[text, question+'no'], [text, question+'yes']], return_tensors="pt", padding=True)
|
39 |
+
labels = torch.tensor(0).unsqueeze(0)
|
40 |
+
|
41 |
+
outputs = bool_a_model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)
|
42 |
+
logits = outputs.logits
|
43 |
+
|
44 |
+
return {'answer': id2label[int(logits.argmax().item())]}
|
45 |
|
46 |
@app.post("/store_text/{uuid}")
|
47 |
async def store_text(uuid: str, text_info: TextInfo):
|
|
|
112 |
|
113 |
@app.post("/answer_question/{uuid}")
|
114 |
async def answer_question(uuid: str, question_info: QuestionInfo):
|
115 |
+
bool_activate = question_info.allow_bool
|
116 |
+
|
117 |
question = question_info.question
|
118 |
|
119 |
# Verifica si el texto con el ID existe en el diccionario
|
120 |
if uuid not in text_storage:
|
121 |
return {'error': 'Text not found'}
|
122 |
|
123 |
+
answer = qa_pipeline(question=question, context=text_storage[uuid]['text'])
|
124 |
+
if bool_activate :
|
125 |
+
is_bool_inference = bool_q_pipeline(question)
|
126 |
+
if is_bool_inference[0]['label'] == 'YES' :
|
127 |
+
answer = predict_boolean_answer(answer['answer'], question)
|
128 |
+
|
129 |
+
return answer
|
130 |
|
131 |
|
132 |
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
static/js/app.js
CHANGED
@@ -202,9 +202,17 @@ async function getResponse(prompt) {
|
|
202 |
|
203 |
try {
|
204 |
let question = array_messages[array_messages.length - 1].content;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
// Datos a enviar al servidor
|
206 |
var questionData = {
|
207 |
-
question: question
|
|
|
208 |
};
|
209 |
|
210 |
//console.log(message);
|
|
|
202 |
|
203 |
try {
|
204 |
let question = array_messages[array_messages.length - 1].content;
|
205 |
+
let curr_settings = getSettings();
|
206 |
+
|
207 |
+
allow_bool = false;
|
208 |
+
if(curr_settings['answersToggle']){
|
209 |
+
allow_bool = true;
|
210 |
+
}
|
211 |
+
|
212 |
// Datos a enviar al servidor
|
213 |
var questionData = {
|
214 |
+
question: question,
|
215 |
+
allow_bool: allow_bool,
|
216 |
};
|
217 |
|
218 |
//console.log(message);
|