import streamlit as st import torch from torch import nn from transformers import BertModel, AutoTokenizer, AutoModel, pipeline from time import time import matplotlib.pyplot as plt # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = 'cpu' # dict for decoding / enclding labels labels = {'cs.NE': 0, 'cs.CL': 1, 'cs.AI': 2, 'stat.ML': 3, 'cs.CV': 4, 'cs.LG': 5} labels_decoder = {'cs.NE': 'Neural and Evolutionary Computing', 'cs.CL': 'Computation and Language', 'cs.AI': 'Artificial Intelligence', 'stat.ML': 'Machine Learning (stat)', 'cs.CV': 'Computer Vision', 'cs.LG': 'Machine Learning'} model_name = 'bert-base-uncased' tokenizer = AutoTokenizer.from_pretrained(model_name) class BertClassifier(nn.Module): def __init__(self, n_classes, dropout=0.5, model_name='bert-base-uncased'): super(BertClassifier, self).__init__() self.bert = BertModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(768, n_classes) self.relu = nn.ReLU() def forward(self, input_id, mask): _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask,return_dict=False) dropout_output = self.dropout(pooled_output) linear_output = self.linear(dropout_output) final_layer = self.relu(linear_output) return final_layer @st.cache(suppress_st_warning=True) def build_model(): model = BertClassifier(n_classes=len(labels)) st.markdown("Model created") model.load_state_dict(torch.load('model_weights_1.pt', map_location=torch.device('cpu'))) model.eval() st.markdown("Model weights loaded") return model def inference(txt, mode=None): # infers classes for text topic based on the trained model from above # has separate mode 'print' for just output t2 = tokenizer(txt.lower().replace('\n', ''), padding='max_length', max_length = 512, truncation=True, return_tensors="pt") inp2 = t2['input_ids'].to(device) mask2 = t2['attention_mask'].unsqueeze(0).to(device) out = model(inp2, mask2) out = out.cpu().detach().numpy().reshape(-1) out = out/out.sum() * 100 res = [(l, o) for l, o in zip (list(labels.keys()), out.tolist())] return res model = build_model() st.markdown("### Privet, mir!") st.markdown("", unsafe_allow_html=True) text = st.text_area("ENTER TEXT HERE") start_time = time() st.markdown("INFERENCE STARTS ...") res = inference(text, mode=None) res.sort(key = lambda x : - x[1]) st.markdown("INFERENCE RESULT:") for lbl, score in res: if score >=1: st.markdown(f"[ {lbl:<7}] {labels_decoder[lbl]:<35} {score:.1f}%") fig, ax = plt.subplots() total=0 for r in res: if total < 95: ax.barh(r[0], r[1]) total += r[1] else: break st.pyplot() st.markdown(f"cycle time = {time() - start_time:.2f} s.")