yenniejun commited on
Commit
ef4bad6
Β·
1 Parent(s): c9a9ab8

Adding plotly plot

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -21,6 +21,14 @@ colors = px.colors.qualitative.Plotly
21
  # from huggingface_hub import InferenceClient
22
  # client = InferenceClient(model="bdsl/HanmunRoBERTa")
23
 
 
 
 
 
 
 
 
 
24
  # Load the pipeline with the HanmunRoBERTa model
25
  model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
26
 
@@ -33,22 +41,22 @@ st.title(title)
33
  remove_punct = st.checkbox(label="Remove punctuation", value=True)
34
 
35
  # Text area for user input
36
- input_str = st.text_area("Input text", height=275)
 
 
 
37
 
38
  # Remove punctuation if checkbox is selected
39
  if remove_punct and input_str:
40
- # Specify the characters to remove
41
- characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
42
- translating = str.maketrans('', '', characters_to_remove)
43
- input_str = input_str.translate(translating)
44
 
45
- # Display the input text after processing
46
- st.write("Processed input:", input_str)
47
 
48
  # Predict and display the classification scores if input is provided
49
  if st.button("Classify"):
50
  if input_str:
51
- predictions = model_pipeline(input_str)
52
  data = pd.DataFrame(predictions)
53
  data=data.sort_values(by='score', ascending=True)
54
  data.label = data.label.astype(str)
@@ -67,17 +75,13 @@ if st.button("Classify"):
67
  marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
68
 
69
  ))
70
- fig.update_traces(width=0.4)
71
 
72
  fig.update_layout(
73
  height=300, # Custom height
74
  xaxis_title='Score',
75
  yaxis_title='',
76
  title='Model predictions and scores',
77
- margin=dict(l=100, r=200, t=50, b=50),
78
  uniformtext_minsize=8,
79
  uniformtext_mode='hide',
80
  )
81
  st.plotly_chart(figure_or_data=fig, use_container_width=True)
82
- else:
83
- st.write("Please enter some text to classify.")
 
21
  # from huggingface_hub import InferenceClient
22
  # client = InferenceClient(model="bdsl/HanmunRoBERTa")
23
 
24
+ def strip_input_str(x):
25
+ # Specify the characters to remove
26
+ characters_to_remove = "β—‹β–‘()〔〕:\"。·, ?ㆍ" + punctuation
27
+ translating = str.maketrans('', '', characters_to_remove)
28
+ input_str = input_str.translate(translating)
29
+
30
+ return input_str.strip()
31
+
32
  # Load the pipeline with the HanmunRoBERTa model
33
  model_pipeline = pipeline(task="text-classification", model="bdsl/HanmunRoBERTa")
34
 
 
41
  remove_punct = st.checkbox(label="Remove punctuation", value=True)
42
 
43
  # Text area for user input
44
+ input_str = st.text_area(
45
+ "Input text",
46
+ height=150,
47
+ value=" 權ηŸ₯ ι«˜ιΊ— εœ‹δΊ‹θ‡£ζŸθ¨€γ€‚ δΌζƒŸε°ι‚¦, θ‡ͺ ζ­ζ„ηŽ‹ η„‘ε—£θ–¨ι€δΉ‹εΎŒ, θΎ›ζ—½ 子 禑 ε†’ε§“η«Šδ½θ€….")
48
 
49
  # Remove punctuation if checkbox is selected
50
  if remove_punct and input_str:
51
+ input_str = strip_input_str(input_str)
 
 
 
52
 
53
+ # Display the input text after processing
54
+ st.write("Processed input:", input_str)
55
 
56
  # Predict and display the classification scores if input is provided
57
  if st.button("Classify"):
58
  if input_str:
59
+ predictions = model_pipeline(input_str, top_k=None)
60
  data = pd.DataFrame(predictions)
61
  data=data.sort_values(by='score', ascending=True)
62
  data.label = data.label.astype(str)
 
75
  marker=dict(color=[colors[i % len(colors)] for i in range(len(data))]), # Cycle through colors
76
 
77
  ))
 
78
 
79
  fig.update_layout(
80
  height=300, # Custom height
81
  xaxis_title='Score',
82
  yaxis_title='',
83
  title='Model predictions and scores',
 
84
  uniformtext_minsize=8,
85
  uniformtext_mode='hide',
86
  )
87
  st.plotly_chart(figure_or_data=fig, use_container_width=True)