Ganesh43 commited on
Commit
4265c8b
1 Parent(s): e32de15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -1,32 +1,35 @@
1
  import torch
2
  import streamlit as st
3
- from transformers import BertTokenizer, BertModel
4
 
5
- # Load the pre-trained model and tokenizer
6
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
7
- model = BertModel.from_pretrained("bert-base-uncased")
8
 
9
  def answer_query(question, context):
10
- # Preprocess the question and context using the tokenizer
11
  inputs = tokenizer(question, context, return_tensors="pt")
12
 
13
- # Use the model to get the answer
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
 
17
- # Access the logits from the model's output structure
18
- start_logits = outputs.hidden_states[-1][:, 0, :] # Access from hidden states
19
- end_logits = outputs.hidden_states[-1][:, 1, :]
20
 
21
- # Find the most likely answer span
22
  answer_start = torch.argmax(start_logits)
23
  answer_end = torch.argmax(end_logits) + 1
24
 
25
- # Extract the answer from the context
26
- answer = tokenizer.convert_tokens_to_string(context)[answer_start:answer_end]
 
 
27
 
28
  return answer
29
 
 
30
  # Streamlit app
31
  st.title("Question Answering App")
32
 
 
1
  import torch
2
  import streamlit as st
3
+ from transformers import BertTokenizer, BertForQuestionAnswering
4
 
5
+ # Utilize BertForQuestionAnswering model for direct start/end logits
6
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
7
+ model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
8
 
9
  def answer_query(question, context):
10
+ # Preprocess using tokenizer
11
  inputs = tokenizer(question, context, return_tensors="pt")
12
 
13
+ # Use model for question answering
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
 
17
+ # Retrieve logits directly
18
+ start_logits = outputs.start_logits
19
+ end_logits = outputs.end_logits
20
 
21
+ # Find answer span
22
  answer_start = torch.argmax(start_logits)
23
  answer_end = torch.argmax(end_logits) + 1
24
 
25
+ # Extract answer from context
26
+ answer = tokenizer.convert_tokens_to_string(
27
+ tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Access original tokens
28
+ )[answer_start:answer_end]
29
 
30
  return answer
31
 
32
+
33
  # Streamlit app
34
  st.title("Question Answering App")
35