Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
from string import punctuation | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
# Initialize or retrieve the session state variable | |
if 'plot_visible' not in st.session_state: | |
st.session_state.plot_visible = False # Initially, the plot is not visible | |
def strip_input_str(x): | |
characters_to_remove = "ββ‘()γγ:\"γΒ·, ?γ" + punctuation | |
translating = str.maketrans('', '', characters_to_remove) | |
x = x.translate(translating) | |
return x.strip() | |
# Load the pipeline with the HanmunRoBERTa model | |
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa") | |
# Streamlit app layout | |
title = "HanmunRoBERTa Century Classifier" | |
st.set_page_config(page_title=title, page_icon="π") | |
st.title(title) | |
# Checkbox to remove punctuation | |
remove_punct = st.checkbox(label="Remove punctuation", value=True) | |
# Text area for user input | |
input_str = st.text_area( | |
"Input text", | |
height=150, | |
value="ζ¬η₯ ι«ιΊ εδΊθ£ζθ¨γ δΌζε°ι¦, θͺ ζζη η‘ε£θ¨ιδΉεΎ, θΎζ½ ε η¦ εε§η«δ½θ .", | |
max_chars=500 | |
) | |
if remove_punct and input_str: | |
input_str = strip_input_str(input_str) | |
st.write("Processed input:", input_str) | |
# Button to classify the text and toggle the visibility of the plot | |
if st.button("Classify"): | |
st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility | |
if input_str: | |
with st.spinner("Classifying..."): | |
predictions = model_pipeline(input_str, top_k=None) | |
data = pd.DataFrame(predictions) | |
data = data.sort_values(by='score', ascending=True) | |
data.label = data.label.astype(str) | |
# Ensure the plot is only displayed when `plot_visible` is True | |
if st.session_state.plot_visible: | |
colors = px.colors.qualitative.Plotly | |
fig = go.Figure( | |
go.Bar( | |
x=data.score.values, | |
y=[f'{i}th Century' for i in data.label.values], | |
orientation='h', | |
text=[f'{score:.3f}' for score in data['score'].values], | |
textposition='outside', | |
hoverinfo='text', | |
hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], | |
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), | |
)) | |
fig.update_layout( | |
height=300, | |
xaxis_title='Score', | |
yaxis_title='', | |
title='Model predictions and scores', | |
uniformtext_minsize=8, | |
uniformtext_mode='hide', | |
) | |
st.plotly_chart(figure_or_data=fig, use_container_width=True) | |
st.session_state.plot_visible = False # Reset to False after displaying | |