kotstantinovskii commited on
Commit
904bb36
·
1 Parent(s): 91fd099

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import time
2
-
3
  import streamlit as st
4
 
5
  from torch.nn import Softmax
@@ -7,34 +5,30 @@ from torch.nn import Softmax
7
  from model import ArxivModel, load_model
8
  from tokenizer import get_tokenizer
9
 
10
- from lables import num_to_classes
11
 
12
 
13
- start_time = time.time()
14
  model = load_model()
15
- end_time = time.time()
16
-
17
- print("Model:", (end_time - start_time))
18
-
19
- start_time = time.time()
20
  tokenizer = get_tokenizer()
21
- end_time = time.time()
22
-
23
- print("Tokenizer:", (end_time - start_time))
24
 
25
  arxiv_model = ArxivModel(model, tokenizer)
26
  softmax = Softmax(dim=1)
27
 
28
  st.markdown("### Classification of article topics")
29
- # st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
30
 
31
- text = st.text_area("Write title and (optional) summary of article")
 
 
32
  text = text.strip()
33
 
34
  if text != "":
35
  idxs = arxiv_model.get_idx_class(text, thr=0.95)
 
36
 
37
- for idx in idxs:
38
- st.markdown(num_to_classes[idx])
 
 
 
39
  else:
40
  st.markdown("")
 
 
 
1
  import streamlit as st
2
 
3
  from torch.nn import Softmax
 
5
  from model import ArxivModel, load_model
6
  from tokenizer import get_tokenizer
7
 
8
+ from lables import num_to_classes, taxonomy
9
 
10
 
 
11
  model = load_model()
 
 
 
 
 
12
  tokenizer = get_tokenizer()
 
 
 
13
 
14
  arxiv_model = ArxivModel(model, tokenizer)
15
  softmax = Softmax(dim=1)
16
 
17
  st.markdown("### Classification of article topics")
 
18
 
19
+ title_text = st.text_area("Write title of article")
20
+ summary_text = st.text_area("Write summary of article (optional)")
21
+ text = title_text.strip() + " " + summary_text.strip()
22
  text = text.strip()
23
 
24
  if text != "":
25
  idxs = arxiv_model.get_idx_class(text, thr=0.95)
26
+ idxs = idxs[:10]
27
 
28
+ for idx, prob in idxs:
29
+ for tax in taxonomy:
30
+ if num_to_classes[idx] in tax[0]:
31
+ st.markdown("{} \t {}%".format(tax[1], round(prob*100, 1)))
32
+ break
33
  else:
34
  st.markdown("")