Spaces:
Sleeping
Sleeping
ajaynagotha
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
-
from transformers import
|
4 |
import torch
|
5 |
import logging
|
6 |
from fastapi import FastAPI, HTTPException
|
@@ -18,9 +18,9 @@ logger.info("Dataset loaded successfully")
|
|
18 |
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
-
model_name = "
|
22 |
-
tokenizer =
|
23 |
-
model =
|
24 |
logger.info("Model and tokenizer loaded successfully")
|
25 |
|
26 |
def clean_answer(answer):
|
@@ -34,23 +34,57 @@ def answer_question(question):
|
|
34 |
logger.info("Combining text from dataset")
|
35 |
context = " ".join([item.get('Text', '') for item in ds['train']])
|
36 |
logger.info(f"Combined context length: {len(context)} characters")
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
logger.info("Answer generated successfully")
|
53 |
-
return
|
54 |
except Exception as e:
|
55 |
logger.error(f"Error in answer_question function: {str(e)}")
|
56 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
+
from transformers import BartForQuestionAnswering, BartTokenizer
|
4 |
import torch
|
5 |
import logging
|
6 |
from fastapi import FastAPI, HTTPException
|
|
|
18 |
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
+
model_name = "facebook/bart-large-cnn"
|
22 |
+
tokenizer = BartTokenizer.from_pretrained(model_name)
|
23 |
+
model = BartForQuestionAnswering.from_pretrained(model_name)
|
24 |
logger.info("Model and tokenizer loaded successfully")
|
25 |
|
26 |
def clean_answer(answer):
|
|
|
34 |
logger.info("Combining text from dataset")
|
35 |
context = " ".join([item.get('Text', '') for item in ds['train']])
|
36 |
logger.info(f"Combined context length: {len(context)} characters")
|
37 |
+
|
38 |
+
# Implement sliding window approach
|
39 |
+
max_length = 1024
|
40 |
+
stride = 512
|
41 |
+
answers = []
|
42 |
+
for i in range(0, len(context), stride):
|
43 |
+
chunk = context[i:i+max_length]
|
44 |
+
|
45 |
+
inputs = tokenizer.encode_plus(
|
46 |
+
question,
|
47 |
+
chunk,
|
48 |
+
return_tensors="pt",
|
49 |
+
max_length=max_length,
|
50 |
+
truncation=True,
|
51 |
+
padding='max_length'
|
52 |
+
)
|
53 |
+
|
54 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
55 |
+
|
56 |
+
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
57 |
+
|
58 |
+
logger.info("Getting model output")
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model(**inputs)
|
61 |
+
|
62 |
+
logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
|
63 |
+
|
64 |
+
answer_start = torch.argmax(outputs.start_logits)
|
65 |
+
answer_end = torch.argmax(outputs.end_logits) + 1
|
66 |
+
|
67 |
+
ans = tokenizer.convert_tokens_to_string(
|
68 |
+
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
|
69 |
+
)
|
70 |
+
|
71 |
+
score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
|
72 |
+
answers.append((ans, score.item()))
|
73 |
+
|
74 |
+
# Select best answer
|
75 |
+
best_answer = max(answers, key=lambda x: x[1])[0]
|
76 |
+
|
77 |
+
# Post-processing
|
78 |
+
best_answer = clean_answer(best_answer)
|
79 |
+
best_answer = best_answer.capitalize()
|
80 |
+
|
81 |
+
logger.info(f"Generated answer: {best_answer}")
|
82 |
+
if not best_answer or len(best_answer) < 5:
|
83 |
+
logger.warning("Generated answer was empty or too short after cleaning")
|
84 |
+
best_answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
|
85 |
+
|
86 |
logger.info("Answer generated successfully")
|
87 |
+
return best_answer
|
88 |
except Exception as e:
|
89 |
logger.error(f"Error in answer_question function: {str(e)}")
|
90 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|