File size: 1,049 Bytes
36da459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import torch.nn.functional as F
import pickle
import re
model = torch.load("models/model", map_location='cpu')
tokenizer = torch.load("models/tokenizer")
with open("models/label_dict", 'rb') as file:
label_dict = pickle.load(file)
def preprocess_string(tweet: str) -> str:
tweet = tweet.lower().strip()
tweet = re.sub(r'[^\w\s]', '', tweet)
return tweet
def predict_single(tweet: str) -> str:
clean_tweet = preprocess_string(tweet)
input = tokenizer(clean_tweet, return_tensors='pt', truncation=True)
output = model(**input)
pred = torch.max(F.softmax(output.logits, dim=-1), dim=-1)[1]
pred = pred.data.item()
return label_dict[pred]
def predict_batch(tweets):
clean_tweets = [preprocess_string(tweet) for tweet in tweets]
inputs = tokenizer(clean_tweets, return_tensors='pt', padding=True, truncation=True)
outputs = model(**inputs)
preds = torch.max(F.softmax(outputs.logits, dim=-1), dim=-1)[1]
preds = preds.tolist()
return [label_dict[pred] for pred in preds]
|