|
import torch |
|
from transformers import AutoTokenizer, BertForSequenceClassification |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity") |
|
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity") |
|
|
|
inputs = tokenizer("GEMINI AI Just got updated", return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
predicted_class_id = logits.argmax().item() |
|
model.config.id2label[predicted_class_id] |
|
|
|
|
|
num_labels = len(model.config.id2label) |
|
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity", num_labels=num_labels) |
|
|
|
labels = torch.tensor([1]) |
|
loss = model(**inputs, labels=labels).loss |
|
round(loss.item(), 2) |