Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from peft import PeftModel | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
class ResearchSummarizer: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.base_model_name = "pendar02/biobart-finetune" | |
self.finetuned_model_name = "pendar02/biobart-finetune" | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.base_model_name, | |
use_auth_token=False, | |
local_files_only=False | |
) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
self.base_model_name, | |
use_auth_token=False, | |
local_files_only=False, | |
torch_dtype=torch.float32 | |
).to(self.device) | |
# Load PEFT model | |
self.model = PeftModel.from_pretrained( | |
self.model, | |
self.finetuned_model_name, | |
use_auth_token=False, | |
local_files_only=False | |
).to(self.device) | |
self.model = self.model.base_model | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
raise | |
self.vectorizer = TfidfVectorizer(stop_words='english') | |
def summarize_text(self, text, max_length=150): | |
"""Generate summary using the fine-tuned model""" | |
try: | |
# Skip summarization for very short texts | |
text = text.strip() | |
if len(text) < 50: | |
return text | |
# Tokenize with proper truncation | |
inputs = self.tokenizer( | |
text, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
).to(self.device) | |
# Generate summary | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
min_length=25, | |
num_beams=4, | |
length_penalty=1.0, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
except Exception as e: | |
st.error(f"Summarization error: {str(e)}") | |
return f"Error generating summary: {str(e)}" | |
def calculate_relevance_scores(self, papers, research_question): | |
"""Calculate relevance scores for papers""" | |
all_texts = [research_question] + [paper['Abstract'] for paper in papers] | |
tfidf_matrix = self.vectorizer.fit_transform(all_texts) | |
cosine_similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]) | |
return cosine_similarities[0] | |
def process_papers(self, df, research_question): | |
"""Process papers and generate summaries""" | |
try: | |
# Convert DataFrame rows to paper dictionaries | |
papers = [] | |
for _, row in df.iterrows(): | |
paper = { | |
"Abstract": str(row.get('Abstract', '')), | |
"title": str(row.get('Article Title', '')), | |
"authors": str(row.get('Authors', '')), | |
"source": str(row.get('Source Title', '')), | |
"year": str(row.get('Publication Year', '')), | |
"doi": str(row.get('DOI', '')), | |
} | |
papers.append(paper) | |
# Filter out empty abstracts | |
papers = [paper for paper in papers if paper['Abstract'].strip() != ""] | |
if not papers: | |
return { | |
"research_question": research_question, | |
"papers": [], | |
"overall_summary": "No valid papers found for analysis", | |
"paper_count": 0 | |
} | |
# Generate individual summaries | |
with st.spinner("Summarizing abstracts..."): | |
for paper in papers: | |
paper['summary'] = self.summarize_text(paper['Abstract']) | |
# Calculate relevance scores | |
relevance_scores = self.calculate_relevance_scores(papers, research_question) | |
for i, paper in enumerate(papers): | |
paper['relevance_score'] = relevance_scores[i] | |
# Sort papers by relevance | |
papers = sorted(papers, key=lambda x: x['relevance_score'], reverse=True) | |
# Generate overall summary from top papers | |
overall_summary = self.generate_overall_summary(papers[:5], research_question) | |
return { | |
"research_question": research_question, | |
"papers": papers, | |
"overall_summary": overall_summary, | |
"paper_count": len(papers) | |
} | |
except Exception as e: | |
st.error(f"Error processing papers: {str(e)}") | |
return None | |
def generate_overall_summary(self, top_papers, research_question): | |
"""Generate an overall summary from top papers""" | |
if not top_papers: | |
return "No relevant papers found for summary generation" | |
# Prepare context for the summary | |
context = f"Research Question: {research_question}\n\nKey Findings:\n" | |
for paper in top_papers: | |
context += f"- From {paper['authors']} ({paper['year']}): {paper['summary']}\n" | |
# Generate final summary | |
try: | |
final_summary = self.summarize_text(context, max_length=200) | |
return final_summary | |
except Exception as e: | |
st.error(f"Error generating overall summary: {str(e)}") | |
return "Error generating overall summary" | |
# Streamlit App Configuration | |
st.set_page_config( | |
page_title="Biomedical Research Summarizer", | |
page_icon="π§¬", | |
layout="wide" | |
) | |
# App Title and Description | |
st.title("𧬠Biomedical Research Summarizer") | |
st.markdown(""" | |
This tool analyzes biomedical research papers to: | |
* Generate AI-powered summaries of each paper | |
* Find papers most relevant to your research question | |
* Create an overall summary of key findings | |
Required Excel columns: | |
* Abstract | |
* Article Title | |
* Authors | |
* Source Title | |
* Publication Year | |
* DOI | |
""") | |
# Initialize summarizer | |
if 'summarizer' not in st.session_state: | |
with st.spinner('Loading AI models... This may take a few minutes.'): | |
try: | |
st.session_state.summarizer = ResearchSummarizer() | |
st.success('Models loaded successfully!') | |
except Exception as e: | |
st.error(f"Failed to initialize models: {str(e)}") | |
st.stop() | |
# File upload and research question | |
uploaded_file = st.file_uploader("Upload Excel File", type=["xlsx", "xls"]) | |
research_question = st.text_area( | |
"Research Question", | |
placeholder="Enter your research question here..." | |
) | |
if st.button("Analyze Papers"): | |
if not uploaded_file or not research_question.strip(): | |
st.error("Please provide both a file and a research question.") | |
else: | |
try: | |
with st.spinner("Processing papers... This may take a few minutes."): | |
df = pd.read_excel(uploaded_file) | |
results = st.session_state.summarizer.process_papers(df, research_question) | |
if results: | |
st.subheader("π Research Question") | |
st.write(results['research_question']) | |
st.subheader("π Overall Summary") | |
st.write(results['overall_summary']) | |
st.subheader(f"π Top Papers (Total: {results['paper_count']})") | |
for paper in results['papers'][:5]: | |
with st.expander( | |
f"π {paper['title']} (Relevance: {paper['relevance_score']:.2f})" | |
): | |
st.markdown(f"**Authors:** {paper['authors']}") | |
st.markdown(f"**Source:** {paper['source']}") | |
st.markdown(f"**Year:** {paper['year']}") | |
st.markdown(f"**DOI:** {paper['doi']}") | |
st.markdown("**Summary:**") | |
st.markdown(paper['summary']) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.error("Please check your Excel file format and try again.") |