summarizer / app.py
pendar02's picture
Update app.py
7a19921 verified
import streamlit as st
import pandas as pd
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
class ResearchSummarizer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.base_model_name = "pendar02/biobart-finetune"
self.finetuned_model_name = "pendar02/biobart-finetune"
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
use_auth_token=False,
local_files_only=False
)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.base_model_name,
use_auth_token=False,
local_files_only=False,
torch_dtype=torch.float32
).to(self.device)
# Load PEFT model
self.model = PeftModel.from_pretrained(
self.model,
self.finetuned_model_name,
use_auth_token=False,
local_files_only=False
).to(self.device)
self.model = self.model.base_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
raise
self.vectorizer = TfidfVectorizer(stop_words='english')
def summarize_text(self, text, max_length=150):
"""Generate summary using the fine-tuned model"""
try:
# Skip summarization for very short texts
text = text.strip()
if len(text) < 50:
return text
# Tokenize with proper truncation
inputs = self.tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
).to(self.device)
# Generate summary
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=max_length,
min_length=25,
num_beams=4,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True
)
summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"Summarization error: {str(e)}")
return f"Error generating summary: {str(e)}"
def calculate_relevance_scores(self, papers, research_question):
"""Calculate relevance scores for papers"""
all_texts = [research_question] + [paper['Abstract'] for paper in papers]
tfidf_matrix = self.vectorizer.fit_transform(all_texts)
cosine_similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
return cosine_similarities[0]
def process_papers(self, df, research_question):
"""Process papers and generate summaries"""
try:
# Convert DataFrame rows to paper dictionaries
papers = []
for _, row in df.iterrows():
paper = {
"Abstract": str(row.get('Abstract', '')),
"title": str(row.get('Article Title', '')),
"authors": str(row.get('Authors', '')),
"source": str(row.get('Source Title', '')),
"year": str(row.get('Publication Year', '')),
"doi": str(row.get('DOI', '')),
}
papers.append(paper)
# Filter out empty abstracts
papers = [paper for paper in papers if paper['Abstract'].strip() != ""]
if not papers:
return {
"research_question": research_question,
"papers": [],
"overall_summary": "No valid papers found for analysis",
"paper_count": 0
}
# Generate individual summaries
with st.spinner("Summarizing abstracts..."):
for paper in papers:
paper['summary'] = self.summarize_text(paper['Abstract'])
# Calculate relevance scores
relevance_scores = self.calculate_relevance_scores(papers, research_question)
for i, paper in enumerate(papers):
paper['relevance_score'] = relevance_scores[i]
# Sort papers by relevance
papers = sorted(papers, key=lambda x: x['relevance_score'], reverse=True)
# Generate overall summary from top papers
overall_summary = self.generate_overall_summary(papers[:5], research_question)
return {
"research_question": research_question,
"papers": papers,
"overall_summary": overall_summary,
"paper_count": len(papers)
}
except Exception as e:
st.error(f"Error processing papers: {str(e)}")
return None
def generate_overall_summary(self, top_papers, research_question):
"""Generate an overall summary from top papers"""
if not top_papers:
return "No relevant papers found for summary generation"
# Prepare context for the summary
context = f"Research Question: {research_question}\n\nKey Findings:\n"
for paper in top_papers:
context += f"- From {paper['authors']} ({paper['year']}): {paper['summary']}\n"
# Generate final summary
try:
final_summary = self.summarize_text(context, max_length=200)
return final_summary
except Exception as e:
st.error(f"Error generating overall summary: {str(e)}")
return "Error generating overall summary"
# Streamlit App Configuration
st.set_page_config(
page_title="Biomedical Research Summarizer",
page_icon="🧬",
layout="wide"
)
# App Title and Description
st.title("🧬 Biomedical Research Summarizer")
st.markdown("""
This tool analyzes biomedical research papers to:
* Generate AI-powered summaries of each paper
* Find papers most relevant to your research question
* Create an overall summary of key findings
Required Excel columns:
* Abstract
* Article Title
* Authors
* Source Title
* Publication Year
* DOI
""")
# Initialize summarizer
if 'summarizer' not in st.session_state:
with st.spinner('Loading AI models... This may take a few minutes.'):
try:
st.session_state.summarizer = ResearchSummarizer()
st.success('Models loaded successfully!')
except Exception as e:
st.error(f"Failed to initialize models: {str(e)}")
st.stop()
# File upload and research question
uploaded_file = st.file_uploader("Upload Excel File", type=["xlsx", "xls"])
research_question = st.text_area(
"Research Question",
placeholder="Enter your research question here..."
)
if st.button("Analyze Papers"):
if not uploaded_file or not research_question.strip():
st.error("Please provide both a file and a research question.")
else:
try:
with st.spinner("Processing papers... This may take a few minutes."):
df = pd.read_excel(uploaded_file)
results = st.session_state.summarizer.process_papers(df, research_question)
if results:
st.subheader("πŸ“‹ Research Question")
st.write(results['research_question'])
st.subheader("πŸ“Š Overall Summary")
st.write(results['overall_summary'])
st.subheader(f"πŸ“š Top Papers (Total: {results['paper_count']})")
for paper in results['papers'][:5]:
with st.expander(
f"πŸ“„ {paper['title']} (Relevance: {paper['relevance_score']:.2f})"
):
st.markdown(f"**Authors:** {paper['authors']}")
st.markdown(f"**Source:** {paper['source']}")
st.markdown(f"**Year:** {paper['year']}")
st.markdown(f"**DOI:** {paper['doi']}")
st.markdown("**Summary:**")
st.markdown(paper['summary'])
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.error("Please check your Excel file format and try again.")