arxiv_class / app.py
kotstantinovskii's picture
Update app.py
904bb36
raw
history blame
861 Bytes
import streamlit as st
from torch.nn import Softmax
from model import ArxivModel, load_model
from tokenizer import get_tokenizer
from lables import num_to_classes, taxonomy
model = load_model()
tokenizer = get_tokenizer()
arxiv_model = ArxivModel(model, tokenizer)
softmax = Softmax(dim=1)
st.markdown("### Classification of article topics")
title_text = st.text_area("Write title of article")
summary_text = st.text_area("Write summary of article (optional)")
text = title_text.strip() + " " + summary_text.strip()
text = text.strip()
if text != "":
idxs = arxiv_model.get_idx_class(text, thr=0.95)
idxs = idxs[:10]
for idx, prob in idxs:
for tax in taxonomy:
if num_to_classes[idx] in tax[0]:
st.markdown("{} \t {}%".format(tax[1], round(prob*100, 1)))
break
else:
st.markdown("")