samyak152002's picture
Update script
5a3da59
raw
history blame
1.25 kB
import torch
from transformers import DistilBertModel, DistilBertTokenizer
# Load the tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers
# Move the model to CPU
device = torch.device("cpu")
model.to(device)
# Load the saved model state dictionary
model.load_state_dict(torch.load("model.pt", map_location=device))
# Set the model to evaluation mode
model.eval()
# Define a function to predict the class of a given tweet
def classify_tweet(tweet):
inputs = tokenizer.encode_plus(
tweet,
add_special_tokens=True,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs[0]
predicted_class = torch.argmax(logits).item()
return predicted_class
# Example usage
tweet = "This is a sample tweet."
predicted_class = classify_tweet(tweet)
print(f"Predicted Class: {predicted_class}")