from dataclasses import dataclass, field from functools import partial import jax import jax.numpy as jnp import numpy as np from braceexpand import braceexpand from datasets import Dataset, load_dataset from .text import TextNormalizer @dataclass class Dataset: dataset_repo_or_path: str train_file: str = None validation_file: str = None streaming: bool = True use_auth_token: bool = False text_column: str = "caption" encoding_column: str = "encoding" max_train_samples: int = None max_eval_samples: int = None preprocessing_num_workers: int = None overwrite_cache: bool = False do_train: bool = False do_eval: bool = True seed_dataset: int = None shard_by_host: bool = False train_dataset: Dataset = field(init=False) eval_dataset: Dataset = field(init=False) rng_dataset: jnp.ndarray = field(init=False) multi_hosts: bool = field(init=False) def __post_init__(self): self.multi_hosts = jax.process_count() > 1 # define data_files if self.train_file is not None or self.validation_file is not None: # accept braceexpand notation for k in ["train_file", "validation_file"]: f = getattr(self, k) if isinstance(f, str): setattr(self, k, list(braceexpand(f))) # for list of files, split training data shards by host if ( isinstance(self.train_file, list) and self.multi_hosts and self.shard_by_host ): self.train_file = self.train_file[ jax.process_index() :: jax.process_count() ] data_files = { "train": self.train_file, "validation": self.validation_file, } else: data_files = None # load dataset dataset = load_dataset( self.dataset_repo_or_path, data_files=data_files, streaming=self.streaming, use_auth_token=self.use_auth_token, ) if self.do_train: if "train" not in dataset: raise ValueError("Training requires a training dataset") self.train_dataset = dataset["train"] if self.max_train_samples is not None: self.train_dataset = ( self.train_dataset.take(self.max_train_samples) if self.streaming else self.train_dataset.select(range(self.max_train_samples)) ) if self.do_eval: if "validation" not in dataset: raise ValueError("Evaluating requires a validation dataset") self.eval_dataset = dataset["validation"] if self.max_eval_samples is not None: self.eval_dataset = ( self.eval_dataset.take(self.max_eval_samples) if self.streaming else self.eval_dataset.select(range(self.max_eval_samples)) ) def preprocess(self, tokenizer, decoder_start_token_id, normalize_text, max_length): if self.streaming: # we need to shuffle early in streaming mode if hasattr(self, "train_dataset"): self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset) else: # prepare rng for later shuffling if self.seed_dataset is None: self.seed_dataset = np.random.get_state()[1][0] self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) # normalize text if normalize_text: text_normalizer = TextNormalizer() partial_normalize_function = partial( normalize_function, text_column=self.text_column, text_normalizer=text_normalizer, ) for ds in ["train_dataset", "eval_dataset"]: if hasattr(self, ds): setattr( self, ds, ( getattr(self, ds).map(partial_normalize_function) if self.streaming else getattr(self, ds).map( partial_normalize_function, num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Normalizing datasets", ) ), ) # preprocess partial_preprocess_function = partial( preprocess_function, tokenizer=tokenizer, text_column=self.text_column, encoding_column=self.encoding_column, max_length=max_length, decoder_start_token_id=decoder_start_token_id, ) for ds in ["train_dataset", "eval_dataset"]: if hasattr(self, ds): setattr( self, ds, ( getattr(self, ds).map( partial_preprocess_function, batched=True, ) if self.streaming else getattr(self, ds).map( partial_preprocess_function, batched=True, remove_columns=getattr(ds, "column_names"), num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Preprocessing datasets", ) ), ) def dataloader( self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None ): num_devices = jax.local_device_count() def _dataloader_datasets_non_streaming( dataset: Dataset, per_device_batch_size: int, gradient_accumulation_steps: int, rng: jax.random.PRNGKey = None, ): """ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Shuffle batches if rng is set. """ batch_size = ( per_device_batch_size * num_devices * gradient_accumulation_steps ) steps_per_epoch = len(dataset) // batch_size if rng is not None: batch_idx = jax.random.permutation(rng, len(dataset)) else: batch_idx = jnp.arange(len(dataset)) batch_idx = batch_idx[ : steps_per_epoch * batch_size ] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) for idx in batch_idx: batch = dataset[idx] batch = {k: jnp.array(v) for k, v in batch.items()} if gradient_accumulation_steps is not None: batch = jax.tree_map( lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]), batch, ) yield batch def _dataloader_datasets_streaming( dataset: Dataset, split: str, per_device_batch_size: int, gradient_accumulation_steps: int, epoch: int, ): keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] batch = {k: [] for k in keys} first_loop = True # stop after one loop in some cases while (self.multi_hosts and split == "train") or first_loop: # in multi-host, we run forever (no epoch) as hosts need to stop # at the same time and training data may not be split equally # For validation data we put the entire set on each host as we could lose # too many samples on pods if epoch is not None: # reshuffle training data at each epoch (not applicable with validation set) dataset.set_epoch(epoch) epoch += 1 for item in dataset: for k, v in item.items(): batch[k].append(v) # batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8 # (40, 3, 3) -> shard 8 x (5, 3, 3) # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3) if len(batch[keys[0]]) == per_device_batch_size * num_devices * ( gradient_accumulation_steps if gradient_accumulation_steps is not None else 1 ): batch = {k: jnp.array(v) for k, v in batch.items()} if gradient_accumulation_steps is not None: batch = jax.tree_map( lambda x: x.reshape( (-1, per_device_batch_size) + x.shape[1:] ), batch, ) yield batch batch = {k: [] for k in keys} first_loop = False if split == "train": ds = self.train_dataset elif split == "eval": ds = self.eval_dataset else: raise ValueError(f'split must be "train" or "eval", got {split}') if self.streaming: return _dataloader_datasets_streaming( ds, split, per_device_batch_size, gradient_accumulation_steps, epoch ) else: if split == "train": self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) return _dataloader_datasets_non_streaming( ds, per_device_batch_size, gradient_accumulation_steps, input_rng ) @property def length(self): len_train_dataset, len_eval_dataset = None, None if self.streaming: # we don't know the length, let's just assume max_samples if defined if self.max_train_samples is not None: len_train_dataset = self.max_train_samples if self.max_eval_samples is not None: len_eval_dataset = self.max_eval_samples else: len_train_dataset = ( len(self.train_dataset) if hasattr(self, "train_dataset") else None ) len_eval_dataset = ( len(self.eval_dataset) if hasattr(self, "eval_dataset") else None ) return len_train_dataset, len_eval_dataset def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = np.zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1] shifted_input_ids[:, 0] = decoder_start_token_id return shifted_input_ids def normalize_function(example, text_column, text_normalizer): example[text_column] = text_normalizer(example[text_column]) return example def preprocess_function( examples, tokenizer, text_column, encoding_column, max_length, decoder_start_token_id, ): inputs = examples[text_column] # Setting padding="max_length" as we need fixed length inputs for jitted functions model_inputs = tokenizer( inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np", ) # set up targets # Note: labels correspond to our target indices # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token) labels = examples[encoding_column] labels = np.asarray(labels) # We need the labels, in addition to the decoder_input_ids, for the compute_loss function model_inputs["labels"] = labels # In our case, this prepends the bos token and removes the last one decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) model_inputs["decoder_input_ids"] = decoder_input_ids return model_inputs