Spaces:
Runtime error
Runtime error
File size: 2,960 Bytes
d3dbd6c 7b34f91 55eef5c 7b34f91 ef4bad6 4e5b267 ef4bad6 d3dbd6c 7b34f91 d3dbd6c 7b34f91 c9a9ab8 adf186a c9a9ab8 ef4bad6 dfac00f adf186a c9a9ab8 ef4bad6 adf186a 55eef5c c9a9ab8 55eef5c c9a9ab8 55eef5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import streamlit as st
from transformers import pipeline
from string import punctuation
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
# Initialize or retrieve the session state variable
if 'plot_visible' not in st.session_state:
st.session_state.plot_visible = False # Initially, the plot is not visible
def strip_input_str(x):
characters_to_remove = "ββ‘()γγ:\"γΒ·, ?γ" + punctuation
translating = str.maketrans('', '', characters_to_remove)
x = x.translate(translating)
return x.strip()
# Load the pipeline with the HanmunRoBERTa model
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
# Streamlit app layout
title = "HanmunRoBERTa Century Classifier"
st.set_page_config(page_title=title, page_icon="π")
st.title(title)
# Checkbox to remove punctuation
remove_punct = st.checkbox(label="Remove punctuation", value=True)
# Text area for user input
input_str = st.text_area(
"Input text",
height=150,
value="ζ¬η₯ ι«ιΊ εδΊθ£ζθ¨γ δΌζε°ι¦, θͺ ζζη η‘ε£θ¨ιδΉεΎ, θΎζ½ ε η¦ εε§η«δ½θ
.",
max_chars=500
)
if remove_punct and input_str:
input_str = strip_input_str(input_str)
st.write("Processed input:", input_str)
# Button to classify the text and toggle the visibility of the plot
if st.button("Classify"):
st.session_state.plot_visible = not st.session_state.plot_visible # Toggle the plot visibility
if input_str:
with st.spinner("Classifying..."):
predictions = model_pipeline(input_str, top_k=None)
data = pd.DataFrame(predictions)
data = data.sort_values(by='score', ascending=True)
data.label = data.label.astype(str)
# Ensure the plot is only displayed when `plot_visible` is True
if st.session_state.plot_visible:
colors = px.colors.qualitative.Plotly
fig = go.Figure(
go.Bar(
x=data.score.values,
y=[f'{i}th Century' for i in data.label.values],
orientation='h',
text=[f'{score:.3f}' for score in data['score'].values],
textposition='outside',
hoverinfo='text',
hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
))
fig.update_layout(
height=300,
xaxis_title='Score',
yaxis_title='',
title='Model predictions and scores',
uniformtext_minsize=8,
uniformtext_mode='hide',
)
st.plotly_chart(figure_or_data=fig, use_container_width=True)
st.session_state.plot_visible = False # Reset to False after displaying
|