Spaces:
Paused
Paused
File size: 1,022 Bytes
c120869 95964a8 c120869 95964a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
tokenizer = AutoTokenizer.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection")
model = AutoModelForSequenceClassification.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection")
#Following the same truncation & padding strategy used while training
encoded_input = tokenizer("Enter any news article to be classified. Can be a list of articles too.", truncation = True, padding = "max_length", max_length = 512, return_tensors='pt')
output = model(**encoded_input)["logits"]
#detaching the output from the computation graph
detached_output = output.detach()
#Applying softmax here for single label classification
softmax = nn.Softmax(dim = 1)
prediction_probabilities = list(softmax(detached_output).detach().numpy())
predictions = []
for x,y in prediction_probabilities:
predictions.append("not_fake_news") if x > y else predictions.append("fake_news")
print(predictions)
|