mpt-7b-8k-chat / collator.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
fdb2891 verified
raw
history blame
15.5 kB
import logging
import warnings
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
log = logging.getLogger(__name__)
_HF_IGNORE_INDEX = -100
TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
def ensure_list(x: Union[List, torch.Tensor]) -> List:
if isinstance(x, torch.Tensor):
x = list(x.flatten())
assert isinstance(x, list)
return x
def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool):
"""Raises an error if target settings are invalid."""
if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'):
raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".')
if target_responses not in {'all', 'last'}:
raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}')
if target_prompts.startswith('length>='):
cutoff = target_prompts[8:]
if not cutoff.isdigit():
raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.')
cutoff = int(cutoff)
if cutoff <= 0:
raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.')
elif target_prompts not in {'all', 'none'}:
raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}')
def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn, cutoff
return sequence
def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn, cutoff
return [_HF_IGNORE_INDEX] * len(sequence)
def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del cutoff
if is_last_turn:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
del is_last_turn
if cutoff is None:
raise ValueError('input ``cutoff`` must be provided')
if len(sequence) >= cutoff:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
_TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff}
def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]:
target_prompts = target_prompts.lower()
target_responses = target_responses.lower()
if validate:
validate_target_settings(target_prompts, target_responses, decoder_only_format=True)
if target_prompts.startswith('length'):
prompt_cutoff = int(target_prompts.split('>=')[-1])
prompt_to_target = _TARGET_POLICY_LOOKUP['length']
else:
prompt_cutoff = None
prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts]
response_to_target = _TARGET_POLICY_LOOKUP[target_responses]
input_ids = []
labels = []
for idx, turn in enumerate(example_turns):
is_last_turn = idx + 1 == len(example_turns)
context = ensure_list(turn['input_ids'])
target = ensure_list(turn['labels'])
if is_last_turn and eos_token_id is not None:
if target[-1] != eos_token_id:
target = target + [eos_token_id]
input_ids += context
input_ids += target
labels += prompt_to_target(context, is_last_turn, prompt_cutoff)
labels += response_to_target(target, is_last_turn)
if len(input_ids) != len(labels):
raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}')
return (input_ids, labels)
def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]:
context = []
target = None
for idx, turn in enumerate(example_turns):
is_last_turn = idx + 1 == len(example_turns)
turn_context = ensure_list(turn['input_ids'])
turn_target = ensure_list(turn['labels'])
context += turn_context
if is_last_turn:
if eos_token_id is not None and turn_target[-1] != eos_token_id:
turn_target = turn_target + [eos_token_id]
target = turn_target
else:
context += turn_target
if target is None:
raise ValueError('target is still None but should be list[int]')
return (context, target)
class Seq2SeqFinetuningCollator:
"""A general-purpose collator for sequence-to-sequence training/evaluation.
Args:
tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
max_seq_len (int): The maximum sequence length of the combined
context/target sequence (decoder-only format) or of each the
context sequence and target sequence (encoder-decoder format).
decoder_only_format (bool): Whether to format the batches for a
decoder-only model (if True) or an encoder-decoder model (if False).
target_responses (str): For multi-turn examples, this controls which
responses are treated as training targets (i.e. generate loss).
Options are:
"last": (Default) Only the final response is used as the training
target; non-terminal responses are only part of the context.
"all": All of the responses are used as training targets.
target_prompts (str): This controls which prompts are treated as
training targets (i.e. generate loss).
Options are:
"none": (Default) Prompts are never used as training targets.
"all": Prompts are always used as training targets.
"length>=XX": Prompt sequences are used as training targets when
they have length of at least XX tokens. For instance,
setting "length>=512" instructs the collator to use a prompt
sequence as a training target when it is at least 512 tokens long.
allow_pad_trimming (bool, optional): Whether to allow the collator
to trim padding, which may result in smaller but inconsistent batch
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
batch_metadata (dict, optional): A dictionary of metadata which will be added
to the batch.
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.decoder_only_format = decoder_only_format
self.target_responses = target_responses.lower()
self.target_prompts = target_prompts.lower()
self.batch_metadata = batch_metadata or {}
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False
illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask']
found_keys = []
for illegal_key in illegal_keys:
if illegal_key in self.batch_metadata:
found_keys.append(illegal_key)
if found_keys:
raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}")
if max_seq_len % 8 != 0:
log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.')
if self.tokenizer.pad_token_id is None:
raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None')
validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format)
if self.target_prompts.startswith('length'):
self.prompt_cutoff = int(self.target_prompts.split('>=')[-1])
self.prompt_to_target = _TARGET_POLICY_LOOKUP['length']
else:
self.prompt_cutoff = None
self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts]
self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses]
self._warned_truncated = False
self._warned_context = False
self._warned_target = False
def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels']:
if check_key not in examples[0]['turns'][0]:
raise KeyError(f'Examples returned by dataset do not include required key: {check_key}')
if self.decoder_only_format:
batch = self._process_and_batch_decoder_only(examples)
else:
batch = self._process_and_batch_encoder_decoder(examples)
batch_size = batch['input_ids'].shape[0]
batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()})
return batch
def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
processed_examples = []
for example in examples:
input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id)
orig_size = len(input_ids)
if orig_size > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
labels = labels[:self.max_seq_len]
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.')
if not self._warned_truncated:
warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.')
self._warned_truncated = True
attention_mask = [1] * len(input_ids)
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
labels = labels + i_pad
processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
processed_examples.append(processed_example)
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
batch['sequence_id'] = batch['attention_mask'] - 1
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
multiple_of = 8
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k, v in batch.items():
if len(v.shape) < 2:
continue
if self.tokenizer.padding_side == 'left':
batch[k] = v[:, -keep_tokens:].contiguous()
else:
batch[k] = v[:, :keep_tokens].contiguous()
return batch
def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
processed_examples = []
for example in examples:
context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id)
if len(target) < self.max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
self._warned_target = True
target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
if len(context) > self.max_seq_len:
if not self._warned_context:
warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
self._warned_context = True
context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)}
processed_examples.append(processed_example)
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1)
batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id)
batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX)
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True
multiple_of = 8
n_non_padding = batch['attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['input_ids', 'attention_mask']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
batch[k] = batch[k][:, :keep_tokens].contiguous()
return batch