HanmunRoBERTa / app.py
yenniejun's picture
adding max chars to input
dfac00f
import streamlit as st
from transformers import pipeline
from string import punctuation
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
# Initialize or retrieve the session state variable
if 'plot_visible' not in st.session_state:
st.session_state.plot_visible = False # Initially, the plot is not visible
def strip_input_str(x):
characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
translating = str.maketrans('', '', characters_to_remove)
x = x.translate(translating)
return x.strip()
# Load the pipeline with the HanmunRoBERTa model
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
# Streamlit app layout
title = "HanmunRoBERTa Century Classifier"
st.set_page_config(page_title=title, page_icon="πŸ“š")
st.title(title)
# Checkbox to remove punctuation
remove_punct = st.checkbox(label="Remove punctuation", value=True)
# Text area for user input
input_str = st.text_area(
"Input text",
height=150,
value="權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….",
max_chars=500
)
if remove_punct and input_str:
input_str = strip_input_str(input_str)
st.write("Processed input:", input_str)
# Button to classify the text and toggle the visibility of the plot
if st.button("Classify"):
st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility
if input_str:
with st.spinner("Classifying..."):
predictions = model_pipeline(input_str, top_k=None)
data = pd.DataFrame(predictions)
data = data.sort_values(by='score', ascending=True)
data.label = data.label.astype(str)
# Ensure the plot is only displayed when `plot_visible` is True
if st.session_state.plot_visible:
colors = px.colors.qualitative.Plotly
fig = go.Figure(
go.Bar(
x=data.score.values,
y=[f'{i}th Century' for i in data.label.values],
orientation='h',
text=[f'{score:.3f}' for score in data['score'].values],
textposition='outside',
hoverinfo='text',
hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
))
fig.update_layout(
height=300,
xaxis_title='Score',
yaxis_title='',
title='Model predictions and scores',
uniformtext_minsize=8,
uniformtext_mode='hide',
)
st.plotly_chart(figure_or_data=fig, use_container_width=True)
st.session_state.plot_visible = False # Reset to False after displaying