import math |
import torch |
from tqdm import tqdm |
from dataclasses import dataclass |
from contextlib import nullcontext |
from typing import Mapping, Optional, Tuple |
from accelerate import Accelerator |
from collections import defaultdict |
from transformers.modeling_outputs import BaseModelOutputWithPast |
def optional_grad_ctx(with_grad=False): |
if with_grad: |
return nullcontext() |
else: |
return torch.no_grad() |
def move_to_device(data, device): |
""" |
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. |
""" |
if isinstance(data, Mapping): |
return type(data)({k: move_to_device(v, device) for k, v in data.items()}) |
elif isinstance(data, (tuple, list)): |
return type(data)(move_to_device(v, device) for v in data) |
elif isinstance(data, torch.Tensor): |
kwargs = {"device": device} |
return data.to(**kwargs) |
else: |
return data |
def compute_loss(logits, labels, shift=False): |
""" |
Returns: |
token_loss: batch_size, seq_length |
""" |
if shift: |
logits = logits[:, :-1, :].contiguous() |
labels = labels[:, 1:].contiguous() |
labels = labels.to(logits.device) |
batch_size = logits.shape[0] |
token_loss = torch.nn.functional.cross_entropy( |
logits.flatten(0, 1), |
labels.reshape(-1), |
reduction="none" |
).reshape(batch_size, -1) |
valid_token_num = (labels != -100).sum(-1) |
all_valid_token_num = valid_token_num.sum() |
if all_valid_token_num > 0: |
loss = token_loss.sum() / valid_token_num.sum() |
else: |
loss = token_loss.sum() |
batch_loss = token_loss.sum(-1) / valid_token_num |
if (valid_token_num == 0).any(): |
batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.) |
return loss, batch_loss, valid_token_num |
@torch.no_grad() |
def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=None): |
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader: |
dataloader = accelerator.prepare(dataloader) |
all_loss = defaultdict(list) |
for i, x in enumerate(tqdm(dataloader, desc="Computing Perplexity")): |
if hasattr(model, "memory"): |
model.memory.reset() |
index = x.pop("index") |
length = x.pop("length", None) |
output = model(**x) |
if hasattr(output, "batch_loss"): |
batch_loss = output.batch_loss |
valid_token_num = output.valid_token_num |
else: |
loss, batch_loss, valid_token_num = compute_loss(output.logits, x["labels"], shift=True) |
index = index.tolist() |
batch_loss = batch_loss.tolist() |
valid_token_num = valid_token_num.tolist() |
if accelerator is not None and accelerator.num_processes > 1: |
index = accelerator.gather_for_metrics(index) |
batch_loss = accelerator.gather_for_metrics(batch_loss) |
valid_token_num = accelerator.gather_for_metrics(valid_token_num) |
for _id, _loss, _num in zip(index, batch_loss, valid_token_num): |
all_loss[_id].append((_loss * _num, _num)) |
all_loss = dict(all_loss) |
for _id, loss_and_num in all_loss.items(): |
all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num) |
perplexity = math.exp(sum(all_loss.values()) / len(all_loss)) |
return perplexity |
@torch.no_grad() |
def evaluate_generation(model, dataloader, accelerator:Optional[Accelerator]=None, tokenizer=None, return_new_tokens_only=True, **generation_config): |
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader: |
dataloader = accelerator.prepare(dataloader) |
all_indices = [] |
all_outputs = [] |
index = 0 |
for i, x in enumerate(tqdm(dataloader, desc="Computing Generation")): |
if hasattr(model, "memory"): |
model.memory.reset() |
length = x.pop("length", None) |
indices = x.pop("index", None) |
if indices is None: |
indices = list(range(index, index + x['input_ids'].shape[0])) |
index += x['input_ids'].shape[0] |
else: |
indices = indices.tolist() |
outputs = model.generate(**x, **generation_config) |
if return_new_tokens_only: |
start_idx = x["input_ids"].shape[1] |
outputs = outputs[:, start_idx:] |
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
if accelerator is not None and accelerator.num_processes > 1: |
outputs = accelerator.gather_for_metrics(outputs) |
indices = accelerator.gather_for_metrics(indices) |
outputs = outputs |
indices = indices |
all_indices.extend(indices) |
all_outputs.extend(outputs) |
return all_indices, all_outputs |
@torch.no_grad() |
def evaluate_nll(model, dataloader, accelerator:Optional[Accelerator]=None): |
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader: |
dataloader = accelerator.prepare(dataloader) |
all_loss = defaultdict(list) |
for i, x in enumerate(tqdm(dataloader, desc="Computing Perplexity")): |
if hasattr(model, "memory"): |
model.memory.reset() |
index = x.pop("index") |
length = x.pop("length", None) |
output = model(**x) |
if hasattr(output, "batch_loss"): |
batch_loss = output.batch_loss |
valid_token_num = output.valid_token_num |
else: |
loss, batch_loss, valid_token_num = compute_loss(output.logits, x["labels"], shift=True) |
if accelerator is not None and accelerator.num_processes > 1: |
index = accelerator.gather_for_metrics(index) |
batch_loss = accelerator.gather_for_metrics(batch_loss) |
valid_token_num = accelerator.gather_for_metrics(valid_token_num) |
for _id, _loss in zip(index.tolist(), batch_loss.tolist()): |
all_loss[_id].append(_loss) |
return all_loss |
@dataclass |
class BeaconModelOutput(BaseModelOutputWithPast): |
loss: Optional[torch.FloatTensor] = None |
batch_loss: Optional[torch.FloatTensor] = None |
valid_token_num: Optional[torch.LongTensor] = None |
logits: torch.FloatTensor = None |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
attentions: Optional[Tuple[torch.FloatTensor]] = None |