Rodolfo Torres commited on
Commit
11f8b48
·
1 Parent(s): 97c4064

Fixing boolean answers

Browse files
Files changed (2) hide show
  1. main.py +29 -7
  2. static/js/app.js +9 -1
main.py CHANGED
@@ -8,12 +8,14 @@ from fastapi.responses import JSONResponse
8
  from io import BytesIO
9
  import PyPDF2
10
  from newspaper import Article
 
 
11
 
12
- model_name = "roaltopo/scan-u-doc_question-answer"
13
- qa_pipeline = pipeline(
14
- "question-answering",
15
- model=model_name,
16
- )
17
 
18
  app = FastAPI()
19
 
@@ -27,6 +29,19 @@ class TextInfo(BaseModel):
27
 
28
  class QuestionInfo(BaseModel):
29
  question: str
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @app.post("/store_text/{uuid}")
32
  async def store_text(uuid: str, text_info: TextInfo):
@@ -97,14 +112,21 @@ async def upload_file(uuid: str, file: UploadFile = File(...)):
97
 
98
  @app.post("/answer_question/{uuid}")
99
  async def answer_question(uuid: str, question_info: QuestionInfo):
 
 
100
  question = question_info.question
101
 
102
  # Verifica si el texto con el ID existe en el diccionario
103
  if uuid not in text_storage:
104
  return {'error': 'Text not found'}
105
 
106
- r = qa_pipeline(question=question, context=text_storage[uuid]['text'], top_k=10)
107
- return r[0]
 
 
 
 
 
108
 
109
 
110
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
8
  from io import BytesIO
9
  import PyPDF2
10
  from newspaper import Article
11
+ import torch
12
+ from transformers import AutoModelForMultipleChoice, AutoTokenizer
13
 
14
+ qa_pipeline = pipeline("question-answering", model="roaltopo/scan-u-doc_question-answer")
15
+ bool_q_pipeline = pipeline("text-classification", model="roaltopo/scan-u-doc_bool-question")
16
+ model_path = "roaltopo/scan-u-doc_bool-answer"
17
+ bool_a_tokenizer = AutoTokenizer.from_pretrained(model_path)
18
+ bool_a_model = AutoModelForMultipleChoice.from_pretrained(model_path)
19
 
20
  app = FastAPI()
21
 
 
29
 
30
  class QuestionInfo(BaseModel):
31
  question: str
32
+ allow_bool: Optional[bool] = False
33
+
34
+ def predict_boolean_answer(text, question):
35
+ id2label = {0: "NO", 1: "YES"}
36
+ text += '\n'
37
+ question += '\n'
38
+ inputs = bool_a_tokenizer([[text, question+'no'], [text, question+'yes']], return_tensors="pt", padding=True)
39
+ labels = torch.tensor(0).unsqueeze(0)
40
+
41
+ outputs = bool_a_model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)
42
+ logits = outputs.logits
43
+
44
+ return {'answer': id2label[int(logits.argmax().item())]}
45
 
46
  @app.post("/store_text/{uuid}")
47
  async def store_text(uuid: str, text_info: TextInfo):
 
112
 
113
  @app.post("/answer_question/{uuid}")
114
  async def answer_question(uuid: str, question_info: QuestionInfo):
115
+ bool_activate = question_info.allow_bool
116
+
117
  question = question_info.question
118
 
119
  # Verifica si el texto con el ID existe en el diccionario
120
  if uuid not in text_storage:
121
  return {'error': 'Text not found'}
122
 
123
+ answer = qa_pipeline(question=question, context=text_storage[uuid]['text'])
124
+ if bool_activate :
125
+ is_bool_inference = bool_q_pipeline(question)
126
+ if is_bool_inference[0]['label'] == 'YES' :
127
+ answer = predict_boolean_answer(answer['answer'], question)
128
+
129
+ return answer
130
 
131
 
132
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
static/js/app.js CHANGED
@@ -202,9 +202,17 @@ async function getResponse(prompt) {
202
 
203
  try {
204
  let question = array_messages[array_messages.length - 1].content;
 
 
 
 
 
 
 
205
  // Datos a enviar al servidor
206
  var questionData = {
207
- question: question
 
208
  };
209
 
210
  //console.log(message);
 
202
 
203
  try {
204
  let question = array_messages[array_messages.length - 1].content;
205
+ let curr_settings = getSettings();
206
+
207
+ allow_bool = false;
208
+ if(curr_settings['answersToggle']){
209
+ allow_bool = true;
210
+ }
211
+
212
  // Datos a enviar al servidor
213
  var questionData = {
214
+ question: question,
215
+ allow_bool: allow_bool,
216
  };
217
 
218
  //console.log(message);