File size: 15,452 Bytes
fdb2891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
import warnings
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
log = logging.getLogger(__name__)
_HF_IGNORE_INDEX = -100
TokenizedExample = Dict[str, List[Dict[str, List[int]]]]

def ensure_list(x: Union[List, torch.Tensor]) -> List:
    if isinstance(x, torch.Tensor):
        x = list(x.flatten())
    assert isinstance(x, list)
    return x

def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool):
    """Raises an error if target settings are invalid."""
    if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'):
        raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".')
    if target_responses not in {'all', 'last'}:
        raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}')
    if target_prompts.startswith('length>='):
        cutoff = target_prompts[8:]
        if not cutoff.isdigit():
            raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.')
        cutoff = int(cutoff)
        if cutoff <= 0:
            raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.')
    elif target_prompts not in {'all', 'none'}:
        raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}')

def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
    del is_last_turn, cutoff
    return sequence

def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
    del is_last_turn, cutoff
    return [_HF_IGNORE_INDEX] * len(sequence)

def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
    del cutoff
    if is_last_turn:
        return sequence
    else:
        return [_HF_IGNORE_INDEX] * len(sequence)

def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
    del is_last_turn
    if cutoff is None:
        raise ValueError('input ``cutoff`` must be provided')
    if len(sequence) >= cutoff:
        return sequence
    else:
        return [_HF_IGNORE_INDEX] * len(sequence)
_TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff}

def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]:
    target_prompts = target_prompts.lower()
    target_responses = target_responses.lower()
    if validate:
        validate_target_settings(target_prompts, target_responses, decoder_only_format=True)
    if target_prompts.startswith('length'):
        prompt_cutoff = int(target_prompts.split('>=')[-1])
        prompt_to_target = _TARGET_POLICY_LOOKUP['length']
    else:
        prompt_cutoff = None
        prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts]
    response_to_target = _TARGET_POLICY_LOOKUP[target_responses]
    input_ids = []
    labels = []
    for idx, turn in enumerate(example_turns):
        is_last_turn = idx + 1 == len(example_turns)
        context = ensure_list(turn['input_ids'])
        target = ensure_list(turn['labels'])
        if is_last_turn and eos_token_id is not None:
            if target[-1] != eos_token_id:
                target = target + [eos_token_id]
        input_ids += context
        input_ids += target
        labels += prompt_to_target(context, is_last_turn, prompt_cutoff)
        labels += response_to_target(target, is_last_turn)
    if len(input_ids) != len(labels):
        raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}')
    return (input_ids, labels)

def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]:
    context = []
    target = None
    for idx, turn in enumerate(example_turns):
        is_last_turn = idx + 1 == len(example_turns)
        turn_context = ensure_list(turn['input_ids'])
        turn_target = ensure_list(turn['labels'])
        context += turn_context
        if is_last_turn:
            if eos_token_id is not None and turn_target[-1] != eos_token_id:
                turn_target = turn_target + [eos_token_id]
            target = turn_target
        else:
            context += turn_target
    if target is None:
        raise ValueError('target is still None but should be list[int]')
    return (context, target)

class Seq2SeqFinetuningCollator:
    """A general-purpose collator for sequence-to-sequence training/evaluation.

    Args:
        tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
        max_seq_len (int): The maximum sequence length of the combined
            context/target sequence (decoder-only format) or of each the
            context sequence and target sequence (encoder-decoder format).
        decoder_only_format (bool): Whether to format the batches for a
            decoder-only model (if True) or an encoder-decoder model (if False).
        target_responses (str): For multi-turn examples, this controls which
            responses are treated as training targets (i.e. generate loss).
            Options are:
                "last": (Default) Only the final response is used as the training
                    target; non-terminal responses are only part of the context.
                "all": All of the responses are used as training targets.
        target_prompts (str): This controls which prompts are treated as
            training targets (i.e. generate loss).
            Options are:
                "none": (Default) Prompts are never used as training targets.
                "all": Prompts are always used as training targets.
                "length>=XX": Prompt sequences are used as training targets when
                    they have length of at least XX tokens. For instance,
                    setting "length>=512" instructs the collator to use a prompt
                    sequence as a training target when it is at least 512 tokens long.
        allow_pad_trimming (bool, optional): Whether to allow the collator
            to trim padding, which may result in smaller but inconsistent batch
            sizes. Default: ``False`` ensures that all sequences are max_seq_len.
        batch_metadata (dict, optional): A dictionary of metadata which will be added
            to the batch.
    """

    def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.decoder_only_format = decoder_only_format
        self.target_responses = target_responses.lower()
        self.target_prompts = target_prompts.lower()
        self.batch_metadata = batch_metadata or {}
        self._allow_pad_trimming = allow_pad_trimming
        self._seen_first_batch = False
        illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask']
        found_keys = []
        for illegal_key in illegal_keys:
            if illegal_key in self.batch_metadata:
                found_keys.append(illegal_key)
        if found_keys:
            raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}")
        if max_seq_len % 8 != 0:
            log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.')
        if self.tokenizer.pad_token_id is None:
            raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None')
        validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format)
        if self.target_prompts.startswith('length'):
            self.prompt_cutoff = int(self.target_prompts.split('>=')[-1])
            self.prompt_to_target = _TARGET_POLICY_LOOKUP['length']
        else:
            self.prompt_cutoff = None
            self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts]
        self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses]
        self._warned_truncated = False
        self._warned_context = False
        self._warned_target = False

    def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
        for check_key in ['input_ids', 'labels']:
            if check_key not in examples[0]['turns'][0]:
                raise KeyError(f'Examples returned by dataset do not include required key: {check_key}')
        if self.decoder_only_format:
            batch = self._process_and_batch_decoder_only(examples)
        else:
            batch = self._process_and_batch_encoder_decoder(examples)
        batch_size = batch['input_ids'].shape[0]
        batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()})
        return batch

    def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
        processed_examples = []
        for example in examples:
            input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id)
            orig_size = len(input_ids)
            if orig_size > self.max_seq_len:
                input_ids = input_ids[:self.max_seq_len]
                labels = labels[:self.max_seq_len]
                if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
                    raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.')
                if not self._warned_truncated:
                    warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.')
                    self._warned_truncated = True
            attention_mask = [1] * len(input_ids)
            n_total = len(input_ids)
            i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
            if self.tokenizer.padding_side == 'left':
                labels = i_pad + labels
            else:
                labels = labels + i_pad
            processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
            processed_examples.append(processed_example)
        batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
        batch['sequence_id'] = batch['attention_mask'] - 1
        if not (self._allow_pad_trimming and self._seen_first_batch):
            self._seen_first_batch = True
            return batch
        self._seen_first_batch = True
        multiple_of = 8
        n_non_padding = batch['attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k, v in batch.items():
            if len(v.shape) < 2:
                continue
            if self.tokenizer.padding_side == 'left':
                batch[k] = v[:, -keep_tokens:].contiguous()
            else:
                batch[k] = v[:, :keep_tokens].contiguous()
        return batch

    def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
        processed_examples = []
        for example in examples:
            context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id)
            if len(target) < self.max_seq_len:
                i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
                target = target + i_pad
            else:
                if not self._warned_target:
                    warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
                    self._warned_target = True
                target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
            if len(context) > self.max_seq_len:
                if not self._warned_context:
                    warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
                    self._warned_context = True
                context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
            processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)}
            processed_examples.append(processed_example)
        batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
        batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1)
        batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id)
        batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX)
        if not (self._allow_pad_trimming and self._seen_first_batch):
            self._seen_first_batch = True
            return batch
        self._seen_first_batch = True
        multiple_of = 8
        n_non_padding = batch['attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k in ['input_ids', 'attention_mask']:
            batch[k] = batch[k][:, :keep_tokens].contiguous()
        n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
        keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
        for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
            batch[k] = batch[k][:, :keep_tokens].contiguous()
        return batch