|
|
|
from typing import List, Tuple |
|
|
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset |
|
|
|
|
|
from torchtune.modules import Tokenizer |
|
|
|
|
|
CROSS_ENTROPY_IGNORE_IDX = -100 |
|
|
|
_PROMPT_TEMPLATE = { |
|
"prompt_input": ( |
|
"Below is an instruction that describes a task, paired with an input that provides further context. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" |
|
), |
|
"prompt_no_input": ( |
|
"Below is an instruction that describes a task. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n{instruction}\n\n### Response:\n" |
|
), |
|
} |
|
|
|
|
|
class AlpacaDataset(Dataset): |
|
""" |
|
See torchtune.datasets.AlpacaDataset for the original implementation. |
|
This version supports custom dataset paths. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset_path: str, |
|
tokenizer: Tokenizer, |
|
train_on_input: bool = True, |
|
**kwargs |
|
) -> None: |
|
self._data = load_dataset(dataset_path, split="train") |
|
self._tokenizer = tokenizer |
|
self.train_on_input = train_on_input |
|
|
|
def __len__(self): |
|
return len(self._data) |
|
|
|
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: |
|
sample = self._data[index] |
|
|
|
return self._transform( |
|
instruction=sample["instruction"], |
|
input=sample["input"], |
|
output=sample["output"], |
|
) |
|
|
|
def _transform( |
|
self, instruction: str, input: str, output: str |
|
) -> Tuple[List[int], List[int]]: |
|
""" |
|
Split a sample on ``response`` tag to create input and labels. |
|
|
|
Args: |
|
instruction (str): Instruction text. |
|
input (str): Input text. Can be an empty string. Determines the prompt generation template |
|
used. |
|
output (str): Response text. |
|
|
|
Returns: |
|
Tuple of encoded inputs and labels. |
|
""" |
|
prompt = self._generate_prompt(instruction, input) |
|
prompt_with_response = prompt + output |
|
|
|
|
|
|
|
encoded_prompt = self._tokenizer.encode( |
|
text=prompt, add_bos=True, add_eos=False |
|
) |
|
encoded_prompt_with_response = self._tokenizer.encode( |
|
text=prompt_with_response, add_bos=True, add_eos=True |
|
) |
|
labels = encoded_prompt_with_response.copy() |
|
|
|
if not self.train_on_input: |
|
labels[: len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len( |
|
encoded_prompt |
|
) |
|
|
|
assert len(encoded_prompt_with_response) == len(labels) |
|
|
|
return encoded_prompt_with_response, labels |
|
|
|
def _generate_prompt(self, instruction: str, input: str) -> str: |
|
""" |
|
Generate prompt from instruction and input. |
|
|
|
Args: |
|
instruction (str): Instruction text. |
|
input (str): Input text. |
|
|
|
Returns: |
|
Prompt text. |
|
""" |
|
if input: |
|
prompt = _PROMPT_TEMPLATE["prompt_input"].format( |
|
instruction=instruction, input=input |
|
) |
|
else: |
|
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction) |
|
return prompt |
|
|