import streamlit as st from transformers import pipeline from string import punctuation import pandas as pd import plotly.express as px import plotly.graph_objects as go # Initialize or retrieve the session state variable if 'plot_visible' not in st.session_state: st.session_state.plot_visible = False # Initially, the plot is not visible def strip_input_str(x): characters_to_remove = "○□()〔〕:\"。·, ?ㆍ" + punctuation translating = str.maketrans('', '', characters_to_remove) x = x.translate(translating) return x.strip() # Load the pipeline with the HanmunRoBERTa model model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa") # Streamlit app layout title = "HanmunRoBERTa Century Classifier" st.set_page_config(page_title=title, page_icon="📚") st.title(title) # Checkbox to remove punctuation remove_punct = st.checkbox(label="Remove punctuation", value=True) # Text area for user input input_str = st.text_area( "Input text", height=150, value="權知 高麗 國事臣某言。 伏惟小邦, 自 恭愍王 無嗣薨逝之後, 辛旽 子 禑 冒姓竊位者.", max_chars=500 ) if remove_punct and input_str: input_str = strip_input_str(input_str) st.write("Processed input:", input_str) # Button to classify the text and toggle the visibility of the plot if st.button("Classify"): st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility if input_str: with st.spinner("Classifying..."): predictions = model_pipeline(input_str, top_k=None) data = pd.DataFrame(predictions) data = data.sort_values(by='score', ascending=True) data.label = data.label.astype(str) # Ensure the plot is only displayed when `plot_visible` is True if st.session_state.plot_visible: colors = px.colors.qualitative.Plotly fig = go.Figure( go.Bar( x=data.score.values, y=[f'{i}th Century' for i in data.label.values], orientation='h', text=[f'{score:.3f}' for score in data['score'].values], textposition='outside', hoverinfo='text', hovertext=[f'{i}th Century
Score: {score:.3f}' for i, score in zip(data['label'], data['score'])], marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), )) fig.update_layout( height=300, xaxis_title='Score', yaxis_title='', title='Model predictions and scores', uniformtext_minsize=8, uniformtext_mode='hide', ) st.plotly_chart(figure_or_data=fig, use_container_width=True) st.session_state.plot_visible = False # Reset to False after displaying