import torch from transformers import DistilBertModel, DistilBertTokenizer # Load the tokenizer and model tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers # Move the model to CPU device = torch.device("cpu") model.to(device) # Load the saved model state dictionary model.load_state_dict(torch.load("model.pt", map_location=device)) # Set the model to evaluation mode model.eval() # Define a function to predict the class of a given tweet def classify_tweet(tweet): inputs = tokenizer.encode_plus( tweet, add_special_tokens=True, max_length=128, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs[0] predicted_class = torch.argmax(logits).item() return predicted_class # Example usage tweet = "This is a sample tweet." predicted_class = classify_tweet(tweet) print(f"Predicted Class: {predicted_class}")