File size: 1,006 Bytes
52f43c9 d9d80a0 77d60b1 d9d80a0 bedfda3 52f43c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import streamlit as st
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
with open('articles_list.pkl', 'rb') as file:
articles_list = pickle.load(file)
label_names = []
for i in articles_list[0:20]:
label_names.append(i[0:15])
def classify(text):
input = tokenizer(text, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
return prediction
text = st.text_input('Enter some text:') # Input field for new text
if text:
st.text(classify(text)) |