File size: 5,508 Bytes
5cc9c06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple

import torch

import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from datasets import load_dataset

# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer


CROSS_ENTROPY_IGNORE_IDX = -100


DEFAULT = 0
INSTRUCTION = 1
INPUT = 2
RESPONSE = 3


class ColoringAlpacaDataset(Dataset):
    """
    See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.

    Constructor now takes in a dataset path directly.
    
    This implementation returns 3 lists representing the tokens, labels, and token colors
    (as opposed to just the tokens & labels from the original).
    """

    def __init__(
        self,
        tokenizer: Tokenizer,
        dataset_path: str = "yahma/alpaca-cleaned",
        train_on_input: bool = True,
        **kwargs
    ) -> None:
        self._data = load_dataset(dataset_path, split="train")
        self._tokenizer = tokenizer
        self.train_on_input = train_on_input
        self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
        sample = self._data[index]

        return self._transform(
            instruction=sample["instruction"],
            input=sample["input"],
            output=sample["output"],
        )

    def _transform(
        self, instruction: str, input: str, output: str
    ) -> Tuple[List[int], List[int], List[int]]:
        """
        Split a sample on ``response`` tag to create input and labels.

        Args:
            instruction (str): Instruction text.
            input (str): Input text. Can be an empty string. Determines the prompt generation template
                used.
            output (str): Response text.

        Returns:
            Tuple of encoded inputs, labels, token colors.
        """
        prompt = self._generate_prompt(instruction, input)

        # First handle the prompt
        colors = []
        tokenized = []
        labels = []
        is_first = True
        for token_type, text in prompt:
            tokenized_part = self._tokenizer.encode(
                text=text, add_bos=is_first, add_eos=False
            )
            is_first = False

            tokenized += tokenized_part
            colors += [token_type] * len(tokenized_part)
            if not self.train_on_input:
                labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
            else:
                labels += tokenized_part

        # Now add the response tokens
        tokenized_part = self._tokenizer.encode(
            text=output, add_bos=False, add_eos=True
        )
        tokenized += tokenized_part
        colors += [RESPONSE] * len(tokenized_part)
        labels += tokenized_part

        assert len(tokenized) == len(labels)
        assert len(tokenized) == len(colors)

        return tokenized, labels, colors

    def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
        """
        Generate prompt from instruction and input.

        Args:
            instruction (str): Instruction text.
            input (str): Input text.

        Returns:
            List of (int, templated text)
        """
        if input:
            return [
                (DEFAULT, (
                    "Below is an instruction that describes a task, paired with an input that provides further context. "
                    "Write a response that appropriately completes the request.\n\n"
                    "### Instruction:\n"
                )),
                (INSTRUCTION, instruction),
                (DEFAULT, "\n\n### Input:\n"),
                (INPUT, input),
                (DEFAULT, "\n\n### Response:\n"),
            ]
        else:
            return [
                (DEFAULT, (
                    "Below is an instruction that describes a task. "
                    "Write a response that appropriately completes the request.\n\n"
                    "### Instruction:\n"
                )),
                (INSTRUCTION, instruction),
                (DEFAULT, "\n\n### Response:\n"),
            ]


# TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
TokenPair = Tuple[List[int], List[int], List[int]]


def padded_collate(
    batch: List[TokenPair],
    padding_idx: int = 0,
    ignore_idx: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    input_ids = pad_sequence(
        [torch.tensor(x[0]) for x in batch],
        batch_first=True,
        padding_value=padding_idx,
    )
    labels = pad_sequence(
        [torch.tensor(x[1]) for x in batch],
        batch_first=True,
        padding_value=ignore_idx,
    )
    colors = pad_sequence(
        [torch.tensor(x[2]) for x in batch],
        batch_first=True,
        padding_value=padding_idx,
    )

    input_ids_seq_len = input_ids.shape[-1]
    labels_seq_len = labels.shape[-1]
    colors_seq_len = colors.shape[-1]

    assert input_ids_seq_len == labels_seq_len
    assert input_ids_seq_len == colors_seq_len
    
    return input_ids, labels, colors