|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import List, Literal, Optional |
|
|
|
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk |
|
from datasets.builder import DatasetGenerationError |
|
|
|
from .configs import DataArguments |
|
|
|
|
|
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" |
|
|
|
|
|
def maybe_insert_system_message(messages, tokenizer): |
|
if messages[0]["role"] == "system": |
|
return |
|
|
|
|
|
chat_template = tokenizer.chat_template |
|
if chat_template is None: |
|
chat_template = tokenizer.default_chat_template |
|
|
|
|
|
if "system" in chat_template: |
|
messages.insert(0, {"role": "system", "content": ""}) |
|
|
|
|
|
def apply_chat_template( |
|
example, |
|
tokenizer, |
|
task: Literal["sft", "generation", "rm", "dpo"], |
|
): |
|
if task in ["sft", "generation"]: |
|
messages = example["messages"] |
|
|
|
maybe_insert_system_message(messages, tokenizer) |
|
example["text"] = tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False |
|
) |
|
elif task == "rm": |
|
if all(k in example.keys() for k in ("chosen", "rejected")): |
|
chosen_messages = example["chosen"] |
|
rejected_messages = example["rejected"] |
|
|
|
maybe_insert_system_message(chosen_messages, tokenizer) |
|
maybe_insert_system_message(rejected_messages, tokenizer) |
|
|
|
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) |
|
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) |
|
else: |
|
raise ValueError( |
|
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" |
|
) |
|
elif task == "dpo": |
|
if all(k in example.keys() for k in ("chosen", "rejected")): |
|
|
|
|
|
prompt_messages = example["chosen"][:-1] |
|
|
|
if example["chosen"][0]["role"] != "system": |
|
prompt_messages.insert(0, {"role": "system", "content": ""}) |
|
|
|
chosen_messages = example["chosen"][-1:] |
|
rejected_messages = example["rejected"][-1:] |
|
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) |
|
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) |
|
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) |
|
else: |
|
raise ValueError( |
|
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" |
|
) |
|
else: |
|
raise ValueError( |
|
f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}" |
|
) |
|
return example |
|
|
|
|
|
def get_datasets( |
|
data_config: DataArguments | dict, |
|
splits: List[str] = ["train", "test"], |
|
shuffle: bool = True, |
|
) -> DatasetDict: |
|
""" |
|
Loads one or more datasets with varying training set proportions. |
|
|
|
Args: |
|
data_config (`DataArguments` or `dict`): |
|
Dataset configuration and split proportions. |
|
splits (`List[str]`, *optional*, defaults to `['train', 'test']`): |
|
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. |
|
shuffle (`bool`, *optional*, defaults to `True`): |
|
Whether to shuffle the training and testing/validation data. |
|
|
|
Returns |
|
[`DatasetDict`]: The dataset dictionary containing the loaded datasets. |
|
""" |
|
|
|
if type(data_config) is DataArguments: |
|
|
|
|
|
|
|
|
|
|
|
dataset_mixer = data_config.dataset_mixer |
|
elif isinstance(data_config, dict): |
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_mixer = data_config |
|
else: |
|
raise ValueError(f"Data config {data_config} not recognized.") |
|
|
|
raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) |
|
return raw_datasets |
|
|
|
|
|
def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: |
|
""" |
|
Loads and mixes datasets according to proportions specified in `dataset_mixer`. |
|
|
|
Args: |
|
dataset_mixer (`dict`): |
|
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. |
|
splits (Optional[List[str]], *optional*, defaults to `None`): |
|
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. |
|
shuffle (`bool`, *optional*, defaults to `True`): |
|
Whether to shuffle the training and testing/validation data. |
|
""" |
|
raw_datasets = DatasetDict() |
|
raw_train_datasets = [] |
|
raw_val_datasets = [] |
|
fracs = [] |
|
for ds, frac in dataset_mixer.items(): |
|
fracs.append(frac) |
|
for idx, split in enumerate(splits): |
|
try: |
|
|
|
dataset = load_dataset(ds, split=split) |
|
except DatasetGenerationError: |
|
|
|
dataset = load_from_disk(os.path.join(ds, split)) |
|
|
|
if idx == 0: |
|
raw_train_datasets.append(dataset) |
|
else: |
|
raw_val_datasets.append(dataset) |
|
|
|
if any(frac < 0 for frac in fracs): |
|
raise ValueError("Dataset fractions cannot be negative.") |
|
|
|
if len(raw_train_datasets) > 0: |
|
train_subsets = [] |
|
for dataset, frac in zip(raw_train_datasets, fracs): |
|
train_subset = dataset.select(range(int(frac * len(dataset)))) |
|
train_subsets.append(train_subset) |
|
if shuffle: |
|
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) |
|
else: |
|
raw_datasets["train"] = concatenate_datasets(train_subsets) |
|
|
|
if len(raw_val_datasets) > 0: |
|
if shuffle: |
|
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) |
|
else: |
|
raw_datasets["test"] = concatenate_datasets(raw_val_datasets) |
|
|
|
if len(raw_datasets) == 0: |
|
raise ValueError( |
|
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." |
|
) |
|
|
|
return raw_datasets |
|
|
|
|