import streamlit as st | |
from utils.util_classifier import TextClassificationPipeline | |
import time | |
import requests | |
import io | |
import pdfplumber | |
from urllib.parse import urlparse | |
import plotly.graph_objects as go | |
import as px | |
def validate_url(url): | |
try: | |
result = urlparse(url) | |
return all([result.scheme, result.netloc]) | |
except: | |
return False | |
def download_pdf(url): | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
'Accept': 'application/pdf,*/*', | |
'Referer': '' | |
} | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
# Verify content type is PDF | |
content_type = response.headers.get('content-type', '') | |
if 'application/pdf' not in content_type.lower(): | |
raise ValueError(f"URL does not point to a PDF file. Content-Type: {content_type}") | |
return io.BytesIO(response.content) | |
except Exception as e: | |
st.error(f"Download error: {str(e)}") | |
return None | |
def extract_text(pdf_file): | |
try: | |
# Reset file pointer | | | |
with as pdf: | |
text = "" | |
for page in pdf.pages: | |
extracted = page.extract_text() | |
if extracted: | |
text += extracted + "\n" | |
if not text.strip(): | |
raise ValueError("No text could be extracted from the PDF") | |
return text.strip() | |
except Exception as e: | |
st.error(f"Text extraction error: {str(e)}") | |
return None | |
def main(): | |
st.title("π― Document Classifier") | |
# Model selection | |
method = "bertbased" | |
# Initialize classifier | |
classifier = TextClassificationPipeline(method=method) | |
# File input tabs | |
tab1, tab2 = st.tabs(["π URL Input", "π File Upload"]) | |
with tab1: | |
url = st.text_input("Enter PDF URL") | |
process_btn = st.button("Classify Document", key="url_classify") | |
if process_btn and url: | |
if not validate_url(url): | |
st.error("Please enter a valid URL") | |
return | |
progress_container = st.container() | |
with progress_container: | |
# Step 1: Downloading | |
with st.spinner("Downloading PDF..."): | |
pdf_file = download_pdf(url) | |
if pdf_file is None: | |
return | |
st.success("PDF downloaded successfully!") | |
# Step 2: Extracting Text | |
with st.spinner("Extracting text from PDF..."): | |
text = extract_text(pdf_file) | |
if text is None or len(text.strip()) == 0: | |
return | |
st.success("Text extracted successfully!") | |
with st.expander("View Extracted Text"): | |
st.text(text[:500] + "..." if len(text) > 500 else text) | |
# Step 3: Classification | |
with st.spinner("Classifying document..."): | |
result = classifier.predict(text, return_probability=True) | |
if isinstance(result, list): | |
result = result[0] | |
# Display results | |
def create_gauge_chart(confidence): | |
"""Create a gauge chart for confidence score""" | |
fig = go.Figure(go.Indicator( | |
mode = "gauge+number+delta", | |
value = confidence * 100, | |
domain = {'x': [0, 1], 'y': [0, 1]}, | |
gauge = { | |
'axis': {'range': [None, 100], 'tickwidth': 1, 'tickcolor': "darkblue"}, | |
'bar': {'color': "darkblue"}, | |
'bgcolor': "white", | |
'borderwidth': 2, | |
'bordercolor': "gray", | |
'steps': [ | |
{'range': [0, 50], 'color': '#FF9999'}, | |
{'range': [50, 75], 'color': '#FFCC99'}, | |
{'range': [75, 100], 'color': '#99FF99'} | |
], | |
}, | |
title = {'text': "Confidence Score"} | |
)) | |
fig.update_layout( | |
height=300, | |
margin=dict(l=10, r=10, t=50, b=10), | |
paper_bgcolor='rgba(0,0,0,0)', | |
font={'color': "darkblue", 'family': "Arial"} | |
) | |
return fig | |
def create_probability_chart(probabilities): | |
"""Create a horizontal bar chart for probability distribution""" | |
labels = list(probabilities.keys()) | |
values = list(probabilities.values()) | |
fig = go.Figure() | |
# Add bars | |
fig.add_trace(go.Bar( | |
y=labels, | |
x=[v * 100 for v in values], | |
orientation='h', | |
marker=dict( | |
color=[px.colors.sequential.Blues[i] for i in range(2, len(labels) + 2)], | |
line=dict(color='rgba(0,0,0,0.8)', width=2) | |
), | |
text=[f'{v:.1f}%' for v in [v * 100 for v in values]], | |
textposition='auto', | |
)) | |
# Update layout | |
fig.update_layout( | |
title=dict( | |
text='Probability Distribution', | |
y=0.95, | |
x=0.5, | |
xanchor='center', | |
yanchor='top', | |
font=dict(size=20, color='darkblue') | |
), | |
xaxis_title="Probability (%)", | |
yaxis_title="Categories", | |
height=400, | |
margin=dict(l=20, r=20, t=70, b=20), | |
paper_bgcolor='rgba(0,0,0,0)', | |
plot_bgcolor='rgba(0,0,0,0)', | |
font=dict(family="Arial", size=14), | |
showlegend=False | |
) | |
# Update axes | |
fig.update_xaxes( | |
range=[0, 100], | |
gridcolor='rgba(0,0,0,0.1)', | |
zerolinecolor='rgba(0,0,0,0.2)' | |
) | |
fig.update_yaxes( | |
gridcolor='rgba(0,0,0,0.1)', | |
zerolinecolor='rgba(0,0,0,0.2)' | |
) | |
return fig | |
# Update the results display section | |
def display_results(result): | |
"""Display classification results with modern visualizations""" | |
# Create three columns for the results | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
# Predicted Category Card | |
st.markdown(""" | |
<div style=' | |
background-color: white; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
text-align: center; | |
margin-bottom: 20px; | |
'> | |
<h4 style='color: #1f77b4; margin-bottom: 10px;'>Predicted Category</h4> | |
<p style=' | |
font-size: 24px; | |
font-weight: bold; | |
color: #2c3e50; | |
margin: 0; | |
padding: 10px; | |
background-color: #f8f9fa; | |
border-radius: 5px; | |
'>{}</p> | |
</div> | |
""".format(result['predicted_label']), unsafe_allow_html=True) | |
# Confidence Gauge | |
st.plotly_chart(create_gauge_chart(result['confidence']), use_container_width=True) | |
with col2: | |
# Probability Distribution | |
st.plotly_chart(create_probability_chart(result['probabilities']), use_container_width=True) | |
# Add metadata section | |
with st.expander("π Classification Details"): | |
st.markdown(f""" | |
- **Model Type**: {result['model_type'].title()} | |
- **Document Length**: {len(result['text'])} characters | |
""") | |
# Update the main classification results section | |
# Replace the existing results display with: | |
st.markdown("### π Classification Results") | |
display_results(result) | |
with tab2: | |
uploaded_file = st.file_uploader("Upload PDF file", type="pdf") | |
process_btn = st.button("Classify Document", key="file_classify") | |
if process_btn and uploaded_file: | |
with st.spinner("Processing uploaded PDF..."): | |
text = extract_text(uploaded_file) | |
if text is None: | |
return | |
result = classifier.predict(text, return_probability=True) | |
if isinstance(result, list): | |
result = result[0] | |
# Display results (same as URL tab) | |
st.markdown("### π Classification Results") | |
confidence = result['confidence'] | |
st.markdown(f""" | |
<div class="confidence-meter"> | |
<div class="meter-fill" style="width: {confidence*100}%"></div> | |
<span class="meter-text">{confidence:.1%} Confident</span> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="result-card"> | |
<h4>Predicted Category</h4> | |
<p class="prediction">{result['predicted_label']}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown("#### Probability Distribution") | |
for label, prob in result['probabilities'].items(): | |
st.markdown(f""" | |
<div class="prob-bar"> | |
<span class="label">{label}</span> | |
<div class="bar"> | |
<div class="fill" style="width: {prob*100}%"></div> | |
</div> | |
<span class="value">{prob:.1%}</span> | |
</div> | |
""", unsafe_allow_html=True) | |
main() |