|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import streamlit as st |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
with open('articles_list.pkl', 'rb') as file: |
|
articles_list = pickle.load(file) |
|
|
|
label_names = [] |
|
for i in articles_list[0:20]: |
|
label_names.append(i[0:15]) |
|
|
|
def classify(text): |
|
input = tokenizer(text, truncation=True, return_tensors="pt") |
|
output = model(input["input_ids"].to(device)) |
|
prediction = torch.softmax(output["logits"][0], -1).tolist() |
|
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)} |
|
return prediction |
|
|
|
|
|
text = st.text_input('Enter some text:') |
|
if text: |
|
st.text(classify(text)) |