ajaynagotha commited on
Commit
cd4ab5c
·
verified ·
1 Parent(s): 385bfcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
  import torch
5
  import logging
6
  from fastapi import FastAPI, HTTPException
@@ -18,9 +18,9 @@ logger.info("Dataset loaded successfully")
18
 
19
  # Load model and tokenizer
20
  logger.info("Loading the model and tokenizer")
21
- model_name = "deepset/roberta-large-squad2"
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
24
  logger.info("Model and tokenizer loaded successfully")
25
 
26
  def clean_answer(answer):
@@ -34,23 +34,57 @@ def answer_question(question):
34
  logger.info("Combining text from dataset")
35
  context = " ".join([item.get('Text', '') for item in ds['train']])
36
  logger.info(f"Combined context length: {len(context)} characters")
37
- logger.info("Tokenizing input")
38
- inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=2048, truncation=True)
39
- logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
40
- logger.info("Getting model output")
41
- outputs = model(**inputs)
42
- logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
43
- logger.info("Processing output to get answer")
44
- answer_start = torch.argmax(outputs.start_logits)
45
- answer_end = torch.argmax(outputs.end_logits) + 1
46
- raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
47
- answer = clean_answer(raw_answer)
48
- logger.info(f"Generated answer: {answer}")
49
- if not answer:
50
- logger.warning("Generated answer was empty after cleaning")
51
- answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  logger.info("Answer generated successfully")
53
- return answer
54
  except Exception as e:
55
  logger.error(f"Error in answer_question function: {str(e)}")
56
  return "I'm sorry, but an error occurred while processing your question. Please try again later."
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
+ from transformers import BartForQuestionAnswering, BartTokenizer
4
  import torch
5
  import logging
6
  from fastapi import FastAPI, HTTPException
 
18
 
19
  # Load model and tokenizer
20
  logger.info("Loading the model and tokenizer")
21
+ model_name = "facebook/bart-large-cnn"
22
+ tokenizer = BartTokenizer.from_pretrained(model_name)
23
+ model = BartForQuestionAnswering.from_pretrained(model_name)
24
  logger.info("Model and tokenizer loaded successfully")
25
 
26
  def clean_answer(answer):
 
34
  logger.info("Combining text from dataset")
35
  context = " ".join([item.get('Text', '') for item in ds['train']])
36
  logger.info(f"Combined context length: {len(context)} characters")
37
+
38
+ # Implement sliding window approach
39
+ max_length = 1024
40
+ stride = 512
41
+ answers = []
42
+ for i in range(0, len(context), stride):
43
+ chunk = context[i:i+max_length]
44
+
45
+ inputs = tokenizer.encode_plus(
46
+ question,
47
+ chunk,
48
+ return_tensors="pt",
49
+ max_length=max_length,
50
+ truncation=True,
51
+ padding='max_length'
52
+ )
53
+
54
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
55
+
56
+ logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
57
+
58
+ logger.info("Getting model output")
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+
62
+ logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
63
+
64
+ answer_start = torch.argmax(outputs.start_logits)
65
+ answer_end = torch.argmax(outputs.end_logits) + 1
66
+
67
+ ans = tokenizer.convert_tokens_to_string(
68
+ tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
69
+ )
70
+
71
+ score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
72
+ answers.append((ans, score.item()))
73
+
74
+ # Select best answer
75
+ best_answer = max(answers, key=lambda x: x[1])[0]
76
+
77
+ # Post-processing
78
+ best_answer = clean_answer(best_answer)
79
+ best_answer = best_answer.capitalize()
80
+
81
+ logger.info(f"Generated answer: {best_answer}")
82
+ if not best_answer or len(best_answer) < 5:
83
+ logger.warning("Generated answer was empty or too short after cleaning")
84
+ best_answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
85
+
86
  logger.info("Answer generated successfully")
87
+ return best_answer
88
  except Exception as e:
89
  logger.error(f"Error in answer_question function: {str(e)}")
90
  return "I'm sorry, but an error occurred while processing your question. Please try again later."