File size: 1,022 Bytes
c120869
95964a8
c120869
 
95964a8
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn as nn
tokenizer = AutoTokenizer.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection")
model = AutoModelForSequenceClassification.from_pretrained("vikram71198/distilroberta-base-finetuned-fake-news-detection")
#Following the same truncation & padding strategy used while training
encoded_input = tokenizer("Enter any news article to be classified. Can be a list of articles too.", truncation = True, padding = "max_length", max_length = 512, return_tensors='pt')
output = model(**encoded_input)["logits"]
#detaching the output from the computation graph
detached_output = output.detach()
#Applying softmax here for single label classification
softmax = nn.Softmax(dim = 1)
prediction_probabilities = list(softmax(detached_output).detach().numpy())
predictions = []
for x,y in prediction_probabilities:
   predictions.append("not_fake_news") if x > y else predictions.append("fake_news")
print(predictions)