File size: 2,960 Bytes
d3dbd6c
 
 
 
 
 
7b34f91
55eef5c
 
 
7b34f91
ef4bad6
 
 
4e5b267
 
ef4bad6
d3dbd6c
 
7b34f91
d3dbd6c
 
 
 
7b34f91
c9a9ab8
 
adf186a
c9a9ab8
ef4bad6
 
 
dfac00f
 
 
adf186a
c9a9ab8
ef4bad6
 
adf186a
55eef5c
c9a9ab8
55eef5c
c9a9ab8
55eef5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
from transformers import pipeline
from string import punctuation
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Initialize or retrieve the session state variable
if 'plot_visible' not in st.session_state:
    st.session_state.plot_visible = False  # Initially, the plot is not visible

def strip_input_str(x):
    characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
    translating = str.maketrans('', '', characters_to_remove)
    x = x.translate(translating)
    return x.strip()

# Load the pipeline with the HanmunRoBERTa model
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")

# Streamlit app layout
title = "HanmunRoBERTa Century Classifier"
st.set_page_config(page_title=title, page_icon="πŸ“š")
st.title(title)

# Checkbox to remove punctuation
remove_punct = st.checkbox(label="Remove punctuation", value=True)

# Text area for user input
input_str = st.text_area(
    "Input text", 
    height=150, 
    value="權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….",
    max_chars=500
    )

if remove_punct and input_str:
    input_str = strip_input_str(input_str)
    st.write("Processed input:", input_str)

# Button to classify the text and toggle the visibility of the plot
if st.button("Classify"):
    st.session_state.plot_visible = not st.session_state.plot_visible  # Toggle the plot visibility
    if input_str:
        with st.spinner("Classifying..."):
            predictions = model_pipeline(input_str, top_k=None)
            data = pd.DataFrame(predictions)
            data = data.sort_values(by='score', ascending=True)
            data.label = data.label.astype(str)

        # Ensure the plot is only displayed when `plot_visible` is True
        if st.session_state.plot_visible:
            colors = px.colors.qualitative.Plotly
            fig = go.Figure(
                go.Bar(
                    x=data.score.values,
                    y=[f'{i}th Century' for i in data.label.values],
                    orientation='h',
                    text=[f'{score:.3f}' for score in data['score'].values],
                    textposition='outside',
                    hoverinfo='text',
                    hovertext=[f'{i}th Century<br>Score: {score:.3f}' for i, score in zip(data['label'], data['score'])],
                    marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]),
                ))

            fig.update_layout(
                height=300,
                xaxis_title='Score',
                yaxis_title='',
                title='Model predictions and scores',
                uniformtext_minsize=8, 
                uniformtext_mode='hide',
            )
            st.plotly_chart(figure_or_data=fig, use_container_width=True)
            st.session_state.plot_visible = False  # Reset to False after displaying