|
import random |
|
from dataclasses import dataclass, field |
|
from functools import partial |
|
from pathlib import Path |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from braceexpand import braceexpand |
|
from datasets import Dataset, load_dataset |
|
|
|
from .model.text import TextNormalizer |
|
|
|
|
|
@dataclass |
|
class Dataset: |
|
dataset_repo_or_path: str |
|
train_file: str = None |
|
validation_file: str = None |
|
streaming: bool = True |
|
use_auth_token: bool = False |
|
text_column: str = "caption" |
|
encoding_column: str = "encoding" |
|
max_train_samples: int = None |
|
max_eval_samples: int = None |
|
preprocessing_num_workers: int = None |
|
overwrite_cache: bool = False |
|
do_train: bool = False |
|
do_eval: bool = True |
|
seed_dataset: int = None |
|
shard_by_host: bool = False |
|
blank_caption_prob: float = 0.0 |
|
clip_score_column: str = "clip_score" |
|
min_clip_score: float = None |
|
max_clip_score: float = None |
|
filter_column: str = None |
|
filter_value: str = None |
|
multi_eval_ds: bool = False |
|
train_dataset: Dataset = field(init=False) |
|
eval_dataset: Dataset = field(init=False) |
|
other_eval_datasets: list = field(init=False) |
|
rng_dataset: jnp.ndarray = field(init=False) |
|
multi_hosts: bool = field(init=False) |
|
|
|
def __post_init__(self): |
|
if self.seed_dataset is None: |
|
|
|
self.seed_dataset = random.randint(0, 2**32 - 1) |
|
|
|
self.np_rng = np.random.default_rng(self.seed_dataset) |
|
self.multi_hosts = jax.process_count() > 1 |
|
|
|
|
|
if self.blank_caption_prob: |
|
assert ( |
|
self.streaming is True |
|
), "blank_caption_prob can only be used in streaming mode" |
|
|
|
if self.train_file is not None or self.validation_file is not None: |
|
|
|
for k in ["train_file", "validation_file"]: |
|
f = getattr(self, k) |
|
if isinstance(f, str): |
|
setattr(self, k, list(braceexpand(f))) |
|
|
|
if ( |
|
isinstance(self.train_file, list) |
|
and self.multi_hosts |
|
and self.shard_by_host |
|
): |
|
self.train_file = self.train_file[ |
|
jax.process_index() :: jax.process_count() |
|
] |
|
data_files = { |
|
"train": self.train_file, |
|
"validation": self.validation_file, |
|
} |
|
else: |
|
data_files = None |
|
|
|
|
|
if self.multi_eval_ds: |
|
assert Path( |
|
self.dataset_repo_or_path |
|
).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds" |
|
data_files = { |
|
split.name: [str(f) for f in split.glob("*.parquet")] |
|
for split in Path(self.dataset_repo_or_path).glob("*") |
|
} |
|
|
|
if "valid" in data_files: |
|
data_files["validation"] = data_files["valid"] |
|
del data_files["valid"] |
|
self.dataset_repo_or_path = "parquet" |
|
|
|
|
|
dataset = load_dataset( |
|
self.dataset_repo_or_path, |
|
data_files=data_files, |
|
streaming=self.streaming, |
|
use_auth_token=self.use_auth_token, |
|
) |
|
if self.do_train: |
|
if "train" not in dataset: |
|
raise ValueError("Training requires a training dataset") |
|
self.train_dataset = dataset["train"] |
|
if self.max_train_samples is not None: |
|
self.train_dataset = ( |
|
self.train_dataset.take(self.max_train_samples) |
|
if self.streaming |
|
else self.train_dataset.select(range(self.max_train_samples)) |
|
) |
|
if self.do_eval: |
|
if "validation" not in dataset: |
|
raise ValueError("Evaluating requires a validation dataset") |
|
self.eval_dataset = dataset["validation"] |
|
if self.max_eval_samples is not None: |
|
self.eval_dataset = ( |
|
self.eval_dataset.take(self.max_eval_samples) |
|
if self.streaming |
|
else self.eval_dataset.select(range(self.max_eval_samples)) |
|
) |
|
|
|
other_eval_splits = dataset.keys() - {"train", "validation"} |
|
self.other_eval_datasets = { |
|
split: dataset[split] for split in other_eval_splits |
|
} |
|
|
|
def preprocess(self, tokenizer, config): |
|
|
|
decoder_start_token_id = config.decoder_start_token_id |
|
normalize_text = config.normalize_text |
|
max_length = config.max_text_length |
|
|
|
if self.streaming: |
|
|
|
if hasattr(self, "train_dataset"): |
|
self.train_dataset = self.train_dataset.shuffle( |
|
buffer_size=5000, seed=self.seed_dataset |
|
) |
|
else: |
|
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) |
|
|
|
|
|
partial_filter_function = partial( |
|
filter_function, |
|
filter_column=self.filter_column, |
|
filter_value=self.filter_value, |
|
clip_score_column=self.clip_score_column, |
|
min_clip_score=self.min_clip_score, |
|
max_clip_score=self.max_clip_score, |
|
) |
|
for ds in ["train_dataset", "eval_dataset"]: |
|
if hasattr(self, ds): |
|
setattr( |
|
self, |
|
ds, |
|
( |
|
getattr(self, ds).filter(partial_filter_function) |
|
if self.streaming |
|
else getattr(self, ds).filter( |
|
partial_filter_function, |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Filtering datasets", |
|
) |
|
), |
|
) |
|
if hasattr(self, "other_eval_datasets"): |
|
self.other_eval_datasets = { |
|
split: ( |
|
ds.filter(partial_filter_function) |
|
if self.streaming |
|
else ds.filter( |
|
partial_filter_function, |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Filtering datasets", |
|
) |
|
) |
|
for split, ds in self.other_eval_datasets.items() |
|
} |
|
|
|
|
|
if normalize_text: |
|
text_normalizer = TextNormalizer() |
|
partial_normalize_function = partial( |
|
normalize_function, |
|
text_column=self.text_column, |
|
text_normalizer=text_normalizer, |
|
) |
|
for ds in ["train_dataset", "eval_dataset"]: |
|
if hasattr(self, ds): |
|
setattr( |
|
self, |
|
ds, |
|
( |
|
getattr(self, ds).map(partial_normalize_function) |
|
if self.streaming |
|
else getattr(self, ds).map( |
|
partial_normalize_function, |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Normalizing datasets", |
|
) |
|
), |
|
) |
|
if hasattr(self, "other_eval_datasets"): |
|
self.other_eval_datasets = { |
|
split: ( |
|
ds.map(partial_normalize_function) |
|
if self.streaming |
|
else ds.map( |
|
partial_normalize_function, |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Normalizing datasets", |
|
) |
|
) |
|
for split, ds in self.other_eval_datasets.items() |
|
} |
|
|
|
|
|
if self.blank_caption_prob: |
|
partial_blank_caption_function = partial( |
|
blank_caption_function, |
|
text_column=self.text_column, |
|
blank_caption_prob=self.blank_caption_prob, |
|
rng=self.np_rng, |
|
) |
|
if hasattr(self, "train_dataset"): |
|
self.train_dataset = ( |
|
self.train_dataset.map(partial_blank_caption_function) |
|
if self.streaming |
|
else self.train_dataset.map( |
|
partial_blank_caption_function, |
|
num_proc=None |
|
if self.seed_dataset |
|
else self.preprocessing_num_workers, |
|
load_from_cache_file=False, |
|
desc="Blanking some captions", |
|
) |
|
) |
|
|
|
|
|
partial_preprocess_function = partial( |
|
preprocess_function, |
|
tokenizer=tokenizer, |
|
text_column=self.text_column, |
|
encoding_column=self.encoding_column, |
|
max_length=max_length, |
|
decoder_start_token_id=decoder_start_token_id, |
|
) |
|
for ds in ["train_dataset", "eval_dataset"]: |
|
if hasattr(self, ds): |
|
setattr( |
|
self, |
|
ds, |
|
( |
|
getattr(self, ds).map( |
|
partial_preprocess_function, |
|
batched=True, |
|
remove_columns=[ |
|
self.text_column, |
|
self.encoding_column, |
|
], |
|
) |
|
if self.streaming |
|
else getattr(self, ds).map( |
|
partial_preprocess_function, |
|
batched=True, |
|
remove_columns=getattr(ds, "column_names"), |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Preprocessing datasets", |
|
) |
|
), |
|
) |
|
if hasattr(self, "other_eval_datasets"): |
|
self.other_eval_datasets = { |
|
split: ( |
|
ds.map( |
|
partial_preprocess_function, |
|
batched=True, |
|
remove_columns=[ |
|
self.text_column, |
|
self.encoding_column, |
|
], |
|
) |
|
if self.streaming |
|
else ds.map( |
|
partial_preprocess_function, |
|
batched=True, |
|
remove_columns=getattr(ds, "column_names"), |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Preprocessing datasets", |
|
) |
|
) |
|
for split, ds in self.other_eval_datasets.items() |
|
} |
|
|
|
def dataloader(self, split, batch_size, epoch=None): |
|
def _dataloader_datasets_non_streaming( |
|
dataset: Dataset, |
|
rng: jax.random.PRNGKey = None, |
|
): |
|
""" |
|
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
|
Shuffle batches if rng is set. |
|
""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
|
|
if rng is not None: |
|
batch_idx = jax.random.permutation(rng, len(dataset)) |
|
else: |
|
batch_idx = jnp.arange(len(dataset)) |
|
|
|
batch_idx = batch_idx[ |
|
: steps_per_epoch * batch_size |
|
] |
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
yield batch |
|
|
|
def _dataloader_datasets_streaming( |
|
dataset: Dataset, |
|
epoch: int, |
|
): |
|
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] |
|
batch = {k: [] for k in keys} |
|
first_loop = True |
|
while (self.multi_hosts and split == "train") or first_loop: |
|
|
|
|
|
|
|
|
|
if epoch is not None: |
|
assert split == "train" |
|
|
|
dataset.set_epoch(epoch) |
|
epoch += 1 |
|
for item in dataset: |
|
for k in keys: |
|
batch[k].append(item[k]) |
|
if len(batch[keys[0]]) == batch_size: |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
yield batch |
|
batch = {k: [] for k in keys} |
|
first_loop = False |
|
|
|
if split == "train": |
|
ds = self.train_dataset |
|
elif split == "eval": |
|
ds = self.eval_dataset |
|
else: |
|
ds = self.other_eval_datasets[split] |
|
|
|
if self.streaming: |
|
return _dataloader_datasets_streaming(ds, epoch) |
|
else: |
|
if split == "train": |
|
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) |
|
return _dataloader_datasets_non_streaming(ds, input_rng) |
|
|
|
@property |
|
def length(self): |
|
len_train_dataset, len_eval_dataset = None, None |
|
if self.streaming: |
|
|
|
if self.max_train_samples is not None: |
|
len_train_dataset = self.max_train_samples |
|
if self.max_eval_samples is not None: |
|
len_eval_dataset = self.max_eval_samples |
|
else: |
|
len_train_dataset = ( |
|
len(self.train_dataset) if hasattr(self, "train_dataset") else None |
|
) |
|
len_eval_dataset = ( |
|
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None |
|
) |
|
return len_train_dataset, len_eval_dataset |
|
|
|
|
|
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = np.zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1] |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
return shifted_input_ids |
|
|
|
|
|
def blank_caption_function(example, text_column, blank_caption_prob, rng=None): |
|
if ( |
|
blank_caption_prob |
|
and (rng.random() if rng is not None else np.random.random()) |
|
< blank_caption_prob |
|
): |
|
example[text_column] = "" |
|
return example |
|
|
|
|
|
def normalize_function(example, text_column, text_normalizer): |
|
example[text_column] = text_normalizer(example[text_column]) |
|
return example |
|
|
|
|
|
def filter_function( |
|
example, |
|
min_clip_score, |
|
max_clip_score, |
|
clip_score_column, |
|
filter_column, |
|
filter_value, |
|
): |
|
if min_clip_score is not None and example[clip_score_column] < min_clip_score: |
|
return False |
|
if max_clip_score is not None and example[clip_score_column] > max_clip_score: |
|
return False |
|
if filter_column is not None and example[filter_column] != filter_value: |
|
return False |
|
return True |
|
|
|
|
|
def preprocess_function( |
|
examples, |
|
tokenizer, |
|
text_column, |
|
encoding_column, |
|
max_length, |
|
decoder_start_token_id, |
|
): |
|
inputs = examples[text_column] |
|
|
|
model_inputs = tokenizer( |
|
inputs, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
|
|
|
|
|
|
|
|
labels = examples[encoding_column] |
|
labels = np.asarray(labels) |
|
|
|
|
|
model_inputs["labels"] = labels |
|
|
|
|
|
decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) |
|
model_inputs["decoder_input_ids"] = decoder_input_ids |
|
|
|
return model_inputs |
|
|