kotstantinovskii commited on
Commit
ac107e6
·
1 Parent(s): a82fc19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
 
3
- from torch import nn
4
 
5
  from model import ArxivModel, load_model
6
  from tokenizer import get_tokenizer
@@ -13,7 +13,7 @@ model = load_model()
13
  tokenizer = get_tokenizer()
14
 
15
  arxiv_model = ArxivModel(model, tokenizer)
16
- softmax = nn.Softmax(dim=1)
17
 
18
  st.markdown("### Classification of article topics")
19
 
@@ -27,9 +27,13 @@ with col2:
27
  summary_text = st.text_area("Write summary of article (optional)", key='arxiv_sum_input')
28
  click_button_text = st.button('Submit title and summary', key=1)
29
 
30
- if click_button_text and summary_text.strip() != "":
 
 
 
 
31
  text = title_text.strip() + '\t' + summary_text.strip()
32
- else:
33
  text = title_text.strip()
34
  text = text.strip()
35
 
@@ -48,10 +52,15 @@ if click_button_url and id_url != "":
48
  print(text)
49
 
50
  if text != "":
51
- idxs = arxiv_model.get_idx_class(text, thr=0.95)[:10]
52
-
53
- for idx, prob in idxs:
54
- if taxonomy.get(num_to_classes[idx], -1) != -1:
55
- st.markdown("{} \t {}%".format(taxonomy.get(num_to_classes[idx], -1), round(prob * 100, 1)))
 
 
 
 
 
56
  else:
57
  st.markdown("")
 
1
  import streamlit as st
2
 
3
+ from torch.nn import Softmax
4
 
5
  from model import ArxivModel, load_model
6
  from tokenizer import get_tokenizer
 
13
  tokenizer = get_tokenizer()
14
 
15
  arxiv_model = ArxivModel(model, tokenizer)
16
+ softmax = Softmax(dim=1)
17
 
18
  st.markdown("### Classification of article topics")
19
 
 
27
  summary_text = st.text_area("Write summary of article (optional)", key='arxiv_sum_input')
28
  click_button_text = st.button('Submit title and summary', key=1)
29
 
30
+ if click_button_text and title_text.strip() == "":
31
+ text = ""
32
+ if summary_text.strip() != "":
33
+ st.markdown(f'<p style="color:#FF2D00;font-size:18px">Please, input title</p>', unsafe_allow_html=True)
34
+ elif click_button_text and title_text.strip() != "" and summary_text.strip() != "":
35
  text = title_text.strip() + '\t' + summary_text.strip()
36
+ elif click_button_text and title_text.strip() != "":
37
  text = title_text.strip()
38
  text = text.strip()
39
 
 
52
  print(text)
53
 
54
  if text != "":
55
+ idxs = arxiv_model.get_idx_class(text, thr=0.95)
56
+ print(len(idxs))
57
+ if len(idxs) > 85:
58
+ st.markdown("#### Sorry, model can't classify the article with high confidence")
59
+ else:
60
+ idxs = idxs[:10]
61
+ st.markdown("#### The model have defined:")
62
+ for idx, prob in idxs:
63
+ if taxonomy.get(num_to_classes[idx], -1) != -1:
64
+ st.markdown("{} \t {}%".format(taxonomy.get(num_to_classes[idx], -1), round(prob * 100, 1)))
65
  else:
66
  st.markdown("")