Spaces:
Paused
Paused
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn as nn | |
tokenizer = AutoTokenizer.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection") | |
model = AutoModelForSequenceClassification.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection") | |
#Following the same truncation & padding strategy used while training | |
encoded_input = tokenizer("Enter any news article to be classified. Can be a list of articles too.", truncation = True, padding = "max_length", max_length = 512, return_tensors='pt') | |
output = model(**encoded_input)["logits"] | |
#detaching the output from the computation graph | |
detached_output = output.detach() | |
#Applying softmax here for single label classification | |
softmax = nn.Softmax(dim = 1) | |
prediction_probabilities = list(softmax(detached_output).detach().numpy()) | |
predictions = [] | |
for x,y in prediction_probabilities: | |
predictions.append("not_fake_news") if x > y else predictions.append("fake_news") | |
print(predictions) | |