|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Perplexity Metric.""" |
|
|
|
import datasets |
|
import numpy as np |
|
import torch |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import evaluate |
|
from evaluate import logging |
|
|
|
|
|
_CITATION = """\ |
|
|
|
""" |
|
|
|
_DESCRIPTION = """ |
|
Perplexity (PPL) can be used for evaluating to what extent a dataset is similar to the distribution of text that a given model was trained on. |
|
It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. |
|
|
|
For more information, see https://huggingface.co/docs/transformers/perplexity |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
model_id (str): model used for calculating Perplexity |
|
NOTE: Perplexity can only be calculated for causal language models. |
|
This includes models such as gpt2, causal variations of bert, |
|
causal versions of t5, and more (the full list can be found |
|
in the AutoModelForCausalLM documentation here: |
|
https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) |
|
|
|
data (list of str): input data, each separate text snippet |
|
is one list entry. |
|
batch_size (int): the batch size to run texts through the model. Defaults to 16. |
|
add_start_token (bool): whether to add the start token to the texts, |
|
so the perplexity can include the probability of the first word. Defaults to True. |
|
device (str): device to run on, defaults to 'cuda' when available |
|
Returns: |
|
perplexity: dictionary containing the perplexity scores for the texts |
|
in the input list, as well as the mean perplexity. If one of the input texts is |
|
longer than the max input length of the model, then it is truncated to the |
|
max length for the perplexity computation. |
|
Examples: |
|
Example 1: |
|
>>> perplexity = evaluate.load("perplexity", module_type="measurement") |
|
>>> data = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] |
|
>>> results = perplexity.compute(model_id='gpt2', |
|
... add_start_token=False, |
|
... data=data) # doctest:+ELLIPSIS |
|
>>> print(list(results.keys())) |
|
['perplexities', 'mean_perplexity'] |
|
>>> print(round(results["mean_perplexity"], 0)) |
|
647.0 |
|
>>> print(round(results["perplexities"][0], 0)) |
|
32.0 |
|
|
|
Example 2: |
|
>>> from datasets import load_dataset |
|
>>> perplexity = evaluate.load("perplexity", module_type="measurement") |
|
>>> data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP |
|
>>> data = [s for s in data if s!=''] |
|
>>> results = perplexity.compute(model_id='gpt2', |
|
... data=data) |
|
>>> print(list(results.keys())) |
|
['perplexities', 'mean_perplexity'] |
|
>>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP |
|
576.76 |
|
>>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP |
|
889.28 |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class Perplexity(evaluate.Measurement): |
|
def _info(self): |
|
return evaluate.MeasurementInfo( |
|
module_type="measurement", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"data": datasets.Value("string"), |
|
} |
|
), |
|
reference_urls=["https://huggingface.co/docs/transformers/perplexity"], |
|
) |
|
|
|
def _compute( |
|
self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None |
|
): |
|
|
|
if device is not None: |
|
assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." |
|
if device == "gpu": |
|
device = "cuda" |
|
else: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
model = model.to(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None and batch_size > 1: |
|
existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) |
|
|
|
assert ( |
|
len(existing_special_tokens) > 0 |
|
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." |
|
|
|
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) |
|
|
|
if add_start_token and max_length: |
|
|
|
assert ( |
|
tokenizer.bos_token is not None |
|
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" |
|
max_tokenized_len = max_length - 1 |
|
else: |
|
max_tokenized_len = max_length |
|
|
|
encodings = tokenizer( |
|
data, |
|
add_special_tokens=False, |
|
padding=True, |
|
truncation=True if max_tokenized_len else False, |
|
max_length=max_tokenized_len, |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
).to(device) |
|
|
|
encoded_texts = encodings["input_ids"] |
|
attn_masks = encodings["attention_mask"] |
|
|
|
|
|
if add_start_token: |
|
assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." |
|
else: |
|
assert torch.all( |
|
torch.ge(attn_masks.sum(1), 2) |
|
), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." |
|
|
|
ppls = [] |
|
loss_fct = CrossEntropyLoss(reduction="none") |
|
|
|
for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): |
|
end_index = min(start_index + batch_size, len(encoded_texts)) |
|
encoded_batch = encoded_texts[start_index:end_index] |
|
attn_mask = attn_masks[start_index:end_index] |
|
|
|
if add_start_token: |
|
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) |
|
encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) |
|
attn_mask = torch.cat( |
|
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 |
|
) |
|
|
|
labels = encoded_batch |
|
|
|
with torch.no_grad(): |
|
out_logits = model(encoded_batch, attention_mask=attn_mask).logits |
|
|
|
shift_logits = out_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
shift_attention_mask_batch = attn_mask[..., 1:].contiguous() |
|
|
|
perplexity_batch = torch.exp( |
|
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) |
|
/ shift_attention_mask_batch.sum(1) |
|
) |
|
|
|
ppls += perplexity_batch.tolist() |
|
|
|
return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} |
|
|