Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# author: xusong | |
# time: 2022/8/22 12:06 | |
import numpy as np | |
import torch | |
from transformers import FillMaskPipeline | |
class PerplexityPipeline(FillMaskPipeline): | |
def create_sequential_mask(self, input_data, mask_count=1): | |
_, seq_length = input_data["input_ids"].shape | |
mask_count = seq_length - 2 | |
input_ids = input_data["input_ids"] | |
new_input_ids = torch.repeat_interleave(input_data["input_ids"], repeats=mask_count, dim=0) | |
token_type_ids = torch.repeat_interleave(input_data["token_type_ids"], repeats=mask_count, dim=0) | |
attention_mask = torch.repeat_interleave(input_data["attention_mask"], repeats=mask_count, dim=0) | |
masked_lm_labels = [] | |
masked_lm_positions = list(range(1, mask_count + 1)) | |
for i in masked_lm_positions: | |
new_input_ids[i - 1][i] = self.tokenizer.mask_token_id | |
masked_lm_labels.append(input_ids[0][i].item()) | |
new_data = {"input_ids": new_input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} | |
return new_data, masked_lm_positions, masked_lm_labels | |
def __call__(self, input_text, *args, **kwargs): | |
""" | |
Compute perplexity for given sentence. | |
""" | |
if not isinstance(input_text, str): | |
return None | |
# 1. create sequential mask | |
model_inputs = self.tokenizer(input_text, return_tensors='pt') | |
new_data, masked_lm_positions, masked_lm_labels = self.create_sequential_mask(model_inputs.data) | |
model_inputs.data = new_data | |
labels = torch.tensor(masked_lm_labels) | |
# 2. predict | |
model_outputs = self.model(**model_inputs) | |
# 3. compute perplexity | |
sentence = {} | |
tokens = [] | |
for i in range(len(labels)): | |
model_outputs_i = {} | |
model_outputs_i["input_ids"] = model_inputs["input_ids"][i:i + 1] | |
model_outputs_i["logits"] = model_outputs["logits"][i:i + 1] | |
outputs = self.postprocess(model_outputs_i, target_ids=labels[i:i + 1]) | |
print(outputs) | |
tokens.append({"token": outputs[0]["token_str"], | |
"prob": outputs[0]["score"]}) | |
sentence["tokens"] = tokens | |
sentence["ppl"] = float(np.exp(- sum(np.log(token["prob"]) for token in tokens) / len(tokens))) | |
return sentence | |