Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import BertTokenizer, BertModel | |
import pdfplumber | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load the pre-trained BERT model and tokenizer once | |
model_name = "bert-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertModel.from_pretrained(model_name) | |
# Function to get BERT embeddings | |
def get_embeddings(text): | |
# Check if input text is empty | |
if not text.strip(): | |
raise ValueError("Input text is empty.") | |
# Ensure that text length does not exceed BERT's maximum input length | |
inputs = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=512, | |
truncation=True, # This will truncate the text to the maximum length | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): # Disable gradient calculation for inference | |
outputs = model(**inputs) | |
# Extract the embeddings from the last hidden state | |
if hasattr(outputs, 'last_hidden_state'): | |
return outputs.last_hidden_state[:, 0, :].detach().cpu().numpy() # Move to CPU before converting to numpy | |
else: | |
raise ValueError("Model output does not contain 'last_hidden_state'.") | |
# Extract text from PDF | |
def extract_text_from_pdf(pdf_file): | |
with pdfplumber.open(pdf_file) as pdf: | |
text = "" | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: # Check if page text is not empty | |
text += page_text + "\n" # Add newline for better separation | |
else: | |
st.warning("No extractable text found on a page.") | |
return text | |
# Split text into sentences for better matching | |
def split_text_into_sentences(text): | |
return text.split('\n') # Split by newlines; adjust as needed | |
# Streamlit app | |
st.title("PDF Chatbot using BERT") | |
# PDF file upload | |
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
# Store the PDF text and embeddings | |
pdf_text = "" | |
pdf_embeddings = None | |
if pdf_file: | |
pdf_text = extract_text_from_pdf(pdf_file) | |
# Check if the extracted text is empty | |
if not pdf_text.strip(): | |
st.error("The extracted PDF text is empty. Please upload a PDF with extractable text.") | |
else: | |
try: | |
pdf_sentences = split_text_into_sentences(pdf_text) # Split PDF text into sentences | |
pdf_embeddings = np.array([get_embeddings(sentence) for sentence in pdf_sentences]) # Get embeddings for each sentence | |
st.success("PDF loaded successfully!") | |
except Exception as e: | |
st.error(f"Error while processing PDF: {e}") | |
# User input for chatbot | |
user_input = st.text_input("Ask a question about the PDF:") | |
if st.button("Get Response"): | |
if not pdf_sentences: | |
st.warning("Please upload a PDF file first.") | |
elif not user_input.strip(): | |
st.warning("Please enter a question.") | |
else: | |
try: | |
user_embeddings = get_embeddings(user_input) | |
user_embeddings = user_embeddings.reshape(1, -1) # Reshape for cosine similarity calculation | |
# Calculate cosine similarity between user input and PDF sentence embeddings | |
similarities = cosine_similarity(user_embeddings, pdf_embeddings) | |
best_match_index = np.argmax(similarities) # Get the index of the best match | |
# Display the most relevant sentence | |
st.write("### Response:") | |
st.write(pdf_sentences[best_match_index]) # Return the most relevant sentence | |
except Exception as e: | |
st.error(f"Error while processing user input: {e}") | |