File size: 915 Bytes
c1c5bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import scipy
import torch


def _perplexity(logits, labels, pad_token=3):
    for i in range(len(labels)-1, -1, -1):
        if labels[i] != pad_token:
            last_not_pad_id = i
            break
    logits = logits[:last_not_pad_id + 1]
    labels = labels[:last_not_pad_id + 1]
    log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
    log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
    l = np.mean(log_probas)
    return 2 ** (-l)


def perplexity(logits, labels, pad_token=3):
    pp = []
    if isinstance(logits, torch.Tensor):
        logits = logits.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()
    for cur_logits, cur_labels in zip(logits, labels):
        pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
    return np.mean(pp)