File size: 1,049 Bytes
36da459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn.functional as F
import pickle
import re


model = torch.load("models/model", map_location='cpu')
tokenizer = torch.load("models/tokenizer")

with open("models/label_dict", 'rb') as file:
    label_dict = pickle.load(file)

def preprocess_string(tweet: str) -> str:
    tweet = tweet.lower().strip()
    tweet = re.sub(r'[^\w\s]', '', tweet)
    return tweet

def predict_single(tweet: str) -> str:
    clean_tweet = preprocess_string(tweet)
    input = tokenizer(clean_tweet, return_tensors='pt', truncation=True)
    output = model(**input)
    pred = torch.max(F.softmax(output.logits, dim=-1), dim=-1)[1]
    pred = pred.data.item()
    return label_dict[pred]

def predict_batch(tweets):
    clean_tweets = [preprocess_string(tweet) for tweet in tweets]
    inputs = tokenizer(clean_tweets, return_tensors='pt', padding=True, truncation=True)
    outputs = model(**inputs)
    preds = torch.max(F.softmax(outputs.logits, dim=-1), dim=-1)[1]
    preds = preds.tolist()
    return [label_dict[pred] for pred in preds]