File size: 1,001 Bytes
06a851e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from bert.tokenize import extract_inputs_masks, tokenize_encode_corpus
from torch.utils.data import TensorDataset, DataLoader


def predict(samples, tokenizer, scaler, model, device, max_len, batch_size, 
            return_scaled=False):
    
    model.eval()
    encoded_corpus = tokenize_encode_corpus(tokenizer, samples, max_len)
    input_ids, attention_mask = extract_inputs_masks(encoded_corpus)
    input_ids = torch.tensor([input_ids]).to(device)[0]
    attention_mask = torch.tensor([attention_mask]).to(device)[0]
    dataset = TensorDataset(input_ids, attention_mask)
    dataloader = DataLoader(dataset, batch_size)
    output = []
    for batch in dataloader:
        batch_inputs, batch_masks = tuple(b.to(device) for b in batch)
        with torch.no_grad():
            output += model(batch_inputs, batch_masks).view(1,-1).tolist()[0]
    if return_scaled:
        return output
    output = scaler.inverse_transform([output])
    return output.reshape(1,-1).tolist()[0]