File size: 5,508 Bytes
5cc9c06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from datasets import load_dataset
# Not ideal to import this type here but it's needed for the transform function
from torchtune.modules import Tokenizer
CROSS_ENTROPY_IGNORE_IDX = -100
DEFAULT = 0
INSTRUCTION = 1
INPUT = 2
RESPONSE = 3
class ColoringAlpacaDataset(Dataset):
"""
See torchtune.datasets.alpaca.AlpacaDataset for the original implementation.
Constructor now takes in a dataset path directly.
This implementation returns 3 lists representing the tokens, labels, and token colors
(as opposed to just the tokens & labels from the original).
"""
def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str = "yahma/alpaca-cleaned",
train_on_input: bool = True,
**kwargs
) -> None:
self._data = load_dataset(dataset_path, split="train")
self._tokenizer = tokenizer
self.train_on_input = train_on_input
self.num_colors = 4 # matches the above usage of DEFAULT, INSTRUCTION, INPUT, RESPONSE
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Tuple[List[int], List[int], List[int]]:
sample = self._data[index]
return self._transform(
instruction=sample["instruction"],
input=sample["input"],
output=sample["output"],
)
def _transform(
self, instruction: str, input: str, output: str
) -> Tuple[List[int], List[int], List[int]]:
"""
Split a sample on ``response`` tag to create input and labels.
Args:
instruction (str): Instruction text.
input (str): Input text. Can be an empty string. Determines the prompt generation template
used.
output (str): Response text.
Returns:
Tuple of encoded inputs, labels, token colors.
"""
prompt = self._generate_prompt(instruction, input)
# First handle the prompt
colors = []
tokenized = []
labels = []
is_first = True
for token_type, text in prompt:
tokenized_part = self._tokenizer.encode(
text=text, add_bos=is_first, add_eos=False
)
is_first = False
tokenized += tokenized_part
colors += [token_type] * len(tokenized_part)
if not self.train_on_input:
labels += [CROSS_ENTROPY_IGNORE_IDX] * len(tokenized_part)
else:
labels += tokenized_part
# Now add the response tokens
tokenized_part = self._tokenizer.encode(
text=output, add_bos=False, add_eos=True
)
tokenized += tokenized_part
colors += [RESPONSE] * len(tokenized_part)
labels += tokenized_part
assert len(tokenized) == len(labels)
assert len(tokenized) == len(colors)
return tokenized, labels, colors
def _generate_prompt(self, instruction: str, input: str) -> List[Tuple[(int, str)]]:
"""
Generate prompt from instruction and input.
Args:
instruction (str): Instruction text.
input (str): Input text.
Returns:
List of (int, templated text)
"""
if input:
return [
(DEFAULT, (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n"
)),
(INSTRUCTION, instruction),
(DEFAULT, "\n\n### Input:\n"),
(INPUT, input),
(DEFAULT, "\n\n### Response:\n"),
]
else:
return [
(DEFAULT, (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n"
)),
(INSTRUCTION, instruction),
(DEFAULT, "\n\n### Response:\n"),
]
# TokenPair is a pair (tuple) of three lists: tokenized text inputs, labels, colors.
TokenPair = Tuple[List[int], List[int], List[int]]
def padded_collate(
batch: List[TokenPair],
padding_idx: int = 0,
ignore_idx: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_ids = pad_sequence(
[torch.tensor(x[0]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
labels = pad_sequence(
[torch.tensor(x[1]) for x in batch],
batch_first=True,
padding_value=ignore_idx,
)
colors = pad_sequence(
[torch.tensor(x[2]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
input_ids_seq_len = input_ids.shape[-1]
labels_seq_len = labels.shape[-1]
colors_seq_len = colors.shape[-1]
assert input_ids_seq_len == labels_seq_len
assert input_ids_seq_len == colors_seq_len
return input_ids, labels, colors
|