Spaces:
Runtime error
Runtime error
Adding plotly plot
Browse files
app.py
CHANGED
@@ -21,6 +21,14 @@ colors = px.colors.qualitative.Plotly
|
|
21 |
# from huggingface_hub import InferenceClient
|
22 |
# client = InferenceClient(model="bdsl/HanmunRoBERTa")
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Load the pipeline with the HanmunRoBERTa model
|
25 |
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
|
26 |
|
@@ -33,22 +41,22 @@ st.title(title)
|
|
33 |
remove_punct = st.checkbox(label="Remove punctuation", value=True)
|
34 |
|
35 |
# Text area for user input
|
36 |
-
input_str = st.text_area(
|
|
|
|
|
|
|
37 |
|
38 |
# Remove punctuation if checkbox is selected
|
39 |
if remove_punct and input_str:
|
40 |
-
|
41 |
-
characters_to_remove = "ββ‘()γγ:\"γΒ·, ?γ" + punctuation
|
42 |
-
translating = str.maketrans('', '', characters_to_remove)
|
43 |
-
input_str = input_str.translate(translating)
|
44 |
|
45 |
-
# Display the input text after processing
|
46 |
-
st.write("Processed input:", input_str)
|
47 |
|
48 |
# Predict and display the classification scores if input is provided
|
49 |
if st.button("Classify"):
|
50 |
if input_str:
|
51 |
-
predictions = model_pipeline(input_str)
|
52 |
data = pd.DataFrame(predictions)
|
53 |
data=data.sort_values(by='score', ascending=True)
|
54 |
data.label = data.label.astype(str)
|
@@ -67,17 +75,13 @@ if st.button("Classify"):
|
|
67 |
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
|
68 |
|
69 |
))
|
70 |
-
fig.update_traces(width=0.4)
|
71 |
|
72 |
fig.update_layout(
|
73 |
height=300, # Custom height
|
74 |
xaxis_title='Score',
|
75 |
yaxis_title='',
|
76 |
title='Model predictions and scores',
|
77 |
-
margin=dict(l=100, r=200, t=50, b=50),
|
78 |
uniformtext_minsize=8,
|
79 |
uniformtext_mode='hide',
|
80 |
)
|
81 |
st.plotly_chart(figure_or_data=fig, use_container_width=True)
|
82 |
-
else:
|
83 |
-
st.write("Please enter some text to classify.")
|
|
|
21 |
# from huggingface_hub import InferenceClient
|
22 |
# client = InferenceClient(model="bdsl/HanmunRoBERTa")
|
23 |
|
24 |
+
def strip_input_str(x):
|
25 |
+
# Specify the characters to remove
|
26 |
+
characters_to_remove = "ββ‘()γγ:\"γΒ·, ?γ" + punctuation
|
27 |
+
translating = str.maketrans('', '', characters_to_remove)
|
28 |
+
input_str = input_str.translate(translating)
|
29 |
+
|
30 |
+
return input_str.strip()
|
31 |
+
|
32 |
# Load the pipeline with the HanmunRoBERTa model
|
33 |
model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
|
34 |
|
|
|
41 |
remove_punct = st.checkbox(label="Remove punctuation", value=True)
|
42 |
|
43 |
# Text area for user input
|
44 |
+
input_str = st.text_area(
|
45 |
+
"Input text",
|
46 |
+
height=150,
|
47 |
+
value=" ζ¬η₯ ι«ιΊ εδΊθ£ζθ¨γ δΌζε°ι¦, θͺ ζζη η‘ε£θ¨ιδΉεΎ, θΎζ½ ε η¦ εε§η«δ½θ
.")
|
48 |
|
49 |
# Remove punctuation if checkbox is selected
|
50 |
if remove_punct and input_str:
|
51 |
+
input_str = strip_input_str(input_str)
|
|
|
|
|
|
|
52 |
|
53 |
+
# Display the input text after processing
|
54 |
+
st.write("Processed input:", input_str)
|
55 |
|
56 |
# Predict and display the classification scores if input is provided
|
57 |
if st.button("Classify"):
|
58 |
if input_str:
|
59 |
+
predictions = model_pipeline(input_str, top_k=None)
|
60 |
data = pd.DataFrame(predictions)
|
61 |
data=data.sort_values(by='score', ascending=True)
|
62 |
data.label = data.label.astype(str)
|
|
|
75 |
marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
|
76 |
|
77 |
))
|
|
|
78 |
|
79 |
fig.update_layout(
|
80 |
height=300, # Custom height
|
81 |
xaxis_title='Score',
|
82 |
yaxis_title='',
|
83 |
title='Model predictions and scores',
|
|
|
84 |
uniformtext_minsize=8,
|
85 |
uniformtext_mode='hide',
|
86 |
)
|
87 |
st.plotly_chart(figure_or_data=fig, use_container_width=True)
|
|
|
|