anthony.galtier
Added light code files
06a851e
import torch
from bert.tokenize import extract_inputs_masks, tokenize_encode_corpus
from torch.utils.data import TensorDataset, DataLoader
def predict(samples, tokenizer, scaler, model, device, max_len, batch_size,
return_scaled=False):
model.eval()
encoded_corpus = tokenize_encode_corpus(tokenizer, samples, max_len)
input_ids, attention_mask = extract_inputs_masks(encoded_corpus)
input_ids = torch.tensor([input_ids]).to(device)[0]
attention_mask = torch.tensor([attention_mask]).to(device)[0]
dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size)
output = []
for batch in dataloader:
batch_inputs, batch_masks = tuple(b.to(device) for b in batch)
with torch.no_grad():
output += model(batch_inputs, batch_masks).view(1,-1).tolist()[0]
if return_scaled:
return output
output = scaler.inverse_transform([output])
return output.reshape(1,-1).tolist()[0]