Pavel Malov commited on
Commit
28f6ce1
Β·
1 Parent(s): d2af509
Files changed (5) hide show
  1. app.py +15 -12
  2. inference.py +81 -0
  3. requirements.txt +2 -0
  4. resources/model.ckpt +3 -0
  5. resources/tag_mapping.json +172 -0
app.py CHANGED
@@ -1,24 +1,27 @@
1
  import streamlit as st
 
2
 
3
 
4
  st.set_page_config(layout="wide")
5
 
6
- st.markdown("""
7
- <style>
8
- .big-font {
9
- font-size:300px !important;
10
- }
11
- </style>
12
- """, unsafe_allow_html=True)
13
-
14
- st.title("ArxivTitlePicker")
15
  st.write("This app helps define category of your scientific paper based on its name and abstract.")
16
  name = st.text_input("Paste here name of your paper")
17
  abstract = st.text_area("Paste here abstract of your paper")
18
 
19
- if name != '':
20
- st.text("Your paper:\nName: " + name + '.\nAbstract: ' + abstract)
 
 
 
21
 
22
  if st.button("Start processing"):
23
  if name == '':
24
- st.write('<p style="font-family:sans-serif; color:Red; font-size: 21px;">Please, provide name of the paper!πŸ™‡β€β™‚οΈ</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from inference import InferenceModel
3
 
4
 
5
  st.set_page_config(layout="wide")
6
 
7
+ st.title("ArxivTopicPicker")
 
 
 
 
 
 
 
 
8
  st.write("This app helps define category of your scientific paper based on its name and abstract.")
9
  name = st.text_input("Paste here name of your paper")
10
  abstract = st.text_area("Paste here abstract of your paper")
11
 
12
+ model = InferenceModel()
13
+ model.inference('load')
14
+
15
+ # if name != '':
16
+ # st.text("Your paper:\n\tName: " + name + '.\n\tAbstract: ' + abstract)
17
 
18
  if st.button("Start processing"):
19
  if name == '':
20
+ st.write('<p style="font-family:sans-serif; color:Red; font-size: 21px;">Please, provide name of the paper!πŸ™‡β€β™‚οΈ</p>', unsafe_allow_html=True)
21
+ else:
22
+ input_text = name + '. ' + abstract if abstract != '' else name + '.'
23
+ top_topics = model.inference(input_text)
24
+ if len(top_topics) == 0:
25
+ st.text("We don't know yet😰")
26
+ else:
27
+ st.text(top_topics)
inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from torch import nn
4
+ from typing import List, Dict, Set
5
+ from pathlib import Path
6
+ from transformers import DistilBertTokenizer, DistilBertModel
7
+
8
+
9
+ class Nnet(nn.Module):
10
+ def __init__(self) -> None:
11
+ super().__init__()
12
+
13
+ self.nnet = nn.Sequential(
14
+ nn.Linear(768, 256),
15
+ nn.ReLU(),
16
+ nn.BatchNorm1d(256),
17
+ nn.Linear(256, 85)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.nnet(x)
22
+
23
+
24
+ class ClassificationHead(nn.Module):
25
+ def __init__(self) -> None:
26
+ super().__init__()
27
+
28
+ self.nnet = Nnet()
29
+
30
+ ckpt = torch.load("resources/model.ckpt")
31
+ self.nnet.load_state_dict(ckpt['state_dict'], strict=False)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.nnet(x.unsqueeze(0))
35
+
36
+ class InferenceModel:
37
+ def __init__(self) -> None:
38
+ self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
39
+ self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
40
+ self.head: nn.Module = ClassificationHead()
41
+
42
+ values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
43
+ values.remove('')
44
+ self.mapping: Dict = dict()
45
+ for i, val in enumerate(values):
46
+ self.mapping[i] = val
47
+
48
+ def topp(self, probs: torch.Tensor):
49
+ # sort probs
50
+ sorted_probs, sorted_inds = torch.sort(probs, descending=True)
51
+ # accumulate probs
52
+ accum = torch.cumsum(sorted_probs, dim=0)
53
+ # get index of the first element where cumsum reached 0.95
54
+ ind = torch.nonzero(accum > 0.95)[0]
55
+ return sorted_inds[:ind]
56
+
57
+ def get_lables(self, classes: torch.Tensor) -> List[str]:
58
+ output = ""
59
+ for cls in classes.numpy():
60
+ output += self.mapping[cls] + '\n'
61
+
62
+ return output
63
+
64
+ def inference(self, x: str) -> List[str]:
65
+ self.bert.eval()
66
+ self.head.eval()
67
+ with torch.no_grad():
68
+ # tokenize: str -> Tokens
69
+ encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
70
+ # get embedding: Tokens -> Embeddings -> MeanEmbedding
71
+ embeddings = self.bert(**encoded_input)
72
+ mean_embedding = embeddings[0].mean(dim=1)[0]
73
+ # get probs: MeanEmbedding -> Probs
74
+ probs = self.head(mean_embedding).softmax(-1)[0]
75
+
76
+ # get top_p classes: Probs -> 95% classes
77
+ topp_calsses = self.topp(probs)
78
+ print(probs)
79
+ # map classes to lables
80
+ return self.get_lables(topp_calsses)
81
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.13
2
+ transformers
resources/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d581cc499259712e58a5cf251c7c2d8054d8d67cad61bde6c0e936ff4e285ca
3
+ size 2643089
resources/tag_mapping.json ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "60g15": "Probability",
3
+ "62-07": "Statistics Theory",
4
+ "62f15": "Parametric inference",
5
+ "62g08": "Nonparametric inference",
6
+ "62h30": "Multivariate analysis",
7
+ "62m45": "Inference from stochastic processes",
8
+ "65k10": "Mathematical programming, optimization and variational techniques",
9
+ "68q32": "Theory of computing",
10
+ "68t01": "Artificial intelligence",
11
+ "68t05": "Artificial intelligence",
12
+ "68t10": "Artificial intelligence",
13
+ "68t20": "Artificial intelligence",
14
+ "68t27": "Artificial intelligence",
15
+ "68t30": "Artificial intelligence",
16
+ "68t37": "Artificial intelligence",
17
+ "68t40": "Artificial intelligence",
18
+ "68t45": "Artificial intelligence",
19
+ "68t50": "Artificial intelligence",
20
+ "68txx": "Artificial intelligence",
21
+ "68u10": "Computing methodologies and applications",
22
+ "90c25": "Mathematical programming",
23
+ "90c26": "Mathematical programming",
24
+ "90c90": "Mathematical programming",
25
+ "91f20": "Other social and behavioral sciences (mathematical treatment)",
26
+ "92b20": "Mathematical biology in general",
27
+ "94a08": "Communication, information",
28
+ "97r40": "Mathematics education",
29
+ "astro-ph.im": "Instrumentation and Methods for Astrophysics",
30
+ "c.1.3": "Distributed, Parallel, and Cluster Computing",
31
+ "c.2.4": "Distributed, Parallel, and Cluster Computing",
32
+ "cmp-lg": "Computation and Language",
33
+ "cond-mat.dis-nn": "Disordered Systems and Neural Networks",
34
+ "cond-mat.stat-mech": "",
35
+ "cs.ai": "Artificial intelligence",
36
+ "cs.ar": "Hardware Architecture",
37
+ "cs.cc": "Computational Complexity",
38
+ "cs.ce": "Computational Engineering, Finance, and Science",
39
+ "cs.cg": "Computational Geometry",
40
+ "cs.cl": "Computation and Language",
41
+ "cs.cr": "Cryptography and Security",
42
+ "cs.cv": "Computer Vision and Pattern Recognition",
43
+ "cs.cy": "Computers and Society",
44
+ "cs.db": "Databases",
45
+ "cs.dc": "Distributed, Parallel, and Cluster Computing",
46
+ "cs.dl": "Digital Libraries",
47
+ "cs.dm": "Discrete Mathematics",
48
+ "cs.ds": "Data Structures and Algorithms",
49
+ "cs.et": "Emerging Technologies",
50
+ "cs.fl": "Formal Languages and Automata Theory",
51
+ "cs.gr": "Graphics",
52
+ "cs.gt": "Computer Science and Game Theory",
53
+ "cs.hc": "Human-Computer Interaction",
54
+ "cs.ir": "Information Retrieval",
55
+ "cs.it": "Information Theory",
56
+ "cs.lg": "Machine Learning",
57
+ "cs.lo": "Logic in Computer Science",
58
+ "cs.ma": "Multiagent Systems",
59
+ "cs.mm": "Multimedia",
60
+ "cs.ms": "Mathematical Software",
61
+ "cs.na": "Numerical Analysis",
62
+ "cs.ne": "Neural and Evolutionary Computing",
63
+ "cs.ni": "Networking and Internet Architecture",
64
+ "cs.pf": "Performance",
65
+ "cs.pl": "Programming Languages",
66
+ "cs.ro": "Robotics",
67
+ "cs.sc": "Symbolic Computation",
68
+ "cs.sd": "Sound",
69
+ "cs.se": "Software Engineering",
70
+ "cs.si": "Social and Information Networks",
71
+ "cs.sy": "Systems and Control",
72
+ "d.1.3": "Distributed, Parallel, and Cluster Computing",
73
+ "d.1.6": "Programming Languages",
74
+ "d.2.2": "Software Engineering",
75
+ "d.3.1": "Programming Languages",
76
+ "d.3.2": "Programming Languages",
77
+ "d.3.3": "Programming Languages",
78
+ "e.2": "Databases",
79
+ "e.4": "Information Theory",
80
+ "eess.as": "Sound",
81
+ "eess.iv": "Computer Vision and Pattern Recognition",
82
+ "eess.sp": "Signal Processing",
83
+ "f.1.1": "Formal Languages and Automata Theory",
84
+ "f.1.3": "Computational Complexity",
85
+ "f.2": "Data Structures and Algorithms",
86
+ "f.2.2": "Data Structures and Algorithms",
87
+ "f.4.1": "Logic in Computer Science",
88
+ "f.4.2": "Logic in Computer Science",
89
+ "g.1.2": "Numerical Analysis",
90
+ "g.1.3": "Numerical Analysis",
91
+ "g.1.6": "Numerical Analysi",
92
+ "g.2.2": "Discrete Mathematics",
93
+ "g.3": "Discrete Mathematics",
94
+ "h.1.1": "Information Theory",
95
+ "h.1.2": "Human-Computer Interaction",
96
+ "h.2.4": "Databases",
97
+ "h.2.8": "Databases",
98
+ "h.3.1": "Information Retrieval",
99
+ "h.3.3": "Information Retrieval",
100
+ "h.3.4": "Information Retrieval",
101
+ "h.3.5": "Information Retrieval",
102
+ "h.5.1": "Sound",
103
+ "h.5.2": "Sound",
104
+ "h.5.3": "Sound",
105
+ "i.2": "Artificial intelligence",
106
+ "i.2.0": "Artificial intelligence",
107
+ "i.2.1": "Artificial intelligence",
108
+ "i.2.10": "Artificial intelligence",
109
+ "i.2.11": "Artificial intelligence",
110
+ "i.2.2": "Artificial intelligence",
111
+ "i.2.3": "Artificial intelligence",
112
+ "i.2.4": "Artificial intelligence",
113
+ "i.2.6": "Artificial intelligence",
114
+ "i.2.7": "Artificial intelligence",
115
+ "i.2.8": "Artificial intelligence",
116
+ "i.2.9": "Artificial intelligence",
117
+ "i.4": "Computer Vision and Pattern Recognition",
118
+ "i.4.1": "Computer Vision and Pattern Recognition",
119
+ "i.4.10": "Computer Vision and Pattern Recognition",
120
+ "i.4.3": "Computer Vision and Pattern Recognition",
121
+ "i.4.5": "Computer Vision and Pattern Recognition",
122
+ "i.4.6": "Computer Vision and Pattern Recognition",
123
+ "i.4.7": "Computer Vision and Pattern Recognition",
124
+ "i.4.8": "Computer Vision and Pattern Recognition",
125
+ "i.4.9": "Computer Vision and Pattern Recognition",
126
+ "i.5": "Computer Vision and Pattern Recognition",
127
+ "i.5.1": "Computer Vision and Pattern Recognition",
128
+ "i.5.2": "Computer Vision and Pattern Recognition",
129
+ "i.5.3": "Computer Vision and Pattern Recognition",
130
+ "i.5.4": "Computer Vision and Pattern Recognition",
131
+ "i.5.5": "Computer Vision and Pattern Recognition",
132
+ "j.2": "Computer Applications",
133
+ "j.3": "Computer Applications",
134
+ "j.4": "Computer Applications",
135
+ "j.5": "Computer Applications",
136
+ "k.3.2": "Computers and Society",
137
+ "math.ag": "Algebraic Geometry",
138
+ "math.co": "Combinatorics",
139
+ "math.ct": "Category Theory",
140
+ "math.dg": "Differential Geometry",
141
+ "math.ds": "Dynamical Systems",
142
+ "math.fa": "Functional Analysis",
143
+ "math.it": "Information Theory",
144
+ "math.lo": "Logic",
145
+ "math.na": "Numerical Analysis",
146
+ "math.oc": "Optimization and Control",
147
+ "math.pr": "Probability",
148
+ "math.st": "Statistics Theory",
149
+ "nlin.ao": "Adaptation and Self-Organizing Systems",
150
+ "nlin.cd": "Chaotic Dynamics",
151
+ "nlin.cg": "Cellular Automata and Lattice Gases",
152
+ "physics.ao-ph": "Astrophysics",
153
+ "physics.bio-ph": "Biological Physics",
154
+ "physics.chem-ph": "Chemical Physics",
155
+ "physics.comp-ph": "Computational Physics",
156
+ "physics.data-an": "Data Analysis, Statistics and Probability",
157
+ "physics.med-ph": "Medical Physics",
158
+ "physics.optics": "Optics",
159
+ "physics.soc-ph": "Physics and Society",
160
+ "q-bio.bm": "Biomolecules",
161
+ "q-bio.gn": "Genomics",
162
+ "q-bio.mn": "Molecular Networks",
163
+ "q-bio.nc": "Neurons and Cognition",
164
+ "q-bio.pe": "Populations and Evolution",
165
+ "q-bio.qm": "Quantitative Methods",
166
+ "quant-ph": "Quantum Physics",
167
+ "stat.ap": "Applications",
168
+ "stat.co": "Computation",
169
+ "stat.me": "Methodology",
170
+ "stat.ml": "Machine Learning",
171
+ "stat.th": "Statistics Theory"
172
+ }