boris commited on
Commit
db7d521
·
unverified ·
2 Parent(s): e6c2573 0fe3e72

Merge pull request #111 from borisdayma/feat-data

Browse files
Files changed (2) hide show
  1. dalle_mini/data.py +259 -0
  2. dev/seq2seq/run_seq2seq_flax.py +35 -220
dalle_mini/data.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from datasets import load_dataset, Dataset
3
+ from functools import partial
4
+ import numpy as np
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax.training.common_utils import shard
8
+ from .text import TextNormalizer
9
+
10
+
11
+ @dataclass
12
+ class Dataset:
13
+ dataset_repo_or_path: str
14
+ train_file: str = None
15
+ validation_file: str = None
16
+ dataset_type: str = "dataset"
17
+ streaming: bool = True
18
+ use_auth_token: bool = False
19
+ text_column: str = "caption"
20
+ encoding_column: str = "encoding"
21
+ max_source_length: int = 128
22
+ max_train_samples: int = None
23
+ max_eval_samples: int = None
24
+ preprocessing_num_workers: int = None
25
+ overwrite_cache: bool = False
26
+ do_train: bool = False
27
+ do_eval: bool = True
28
+ seed_dataset: int = None
29
+ train_dataset: Dataset = field(init=False)
30
+ eval_dataset: Dataset = field(init=False)
31
+ rng_dataset: jnp.ndarray = field(init=False)
32
+
33
+ def __post_init__(self):
34
+ # define data_files
35
+ if self.train_file is not None or self.validation_file is not None:
36
+ data_files = {
37
+ "train": self.train_file,
38
+ "validation": self.validation_file,
39
+ }
40
+ else:
41
+ data_files = None
42
+
43
+ # load dataset
44
+ dataset = load_dataset(
45
+ self.dataset_repo_or_path,
46
+ data_files=data_files,
47
+ streaming=self.streaming,
48
+ use_auth_token=self.use_auth_token,
49
+ )
50
+ if self.do_train:
51
+ if "train" not in dataset:
52
+ raise ValueError("Training requires a training dataset")
53
+ self.train_dataset = dataset["train"]
54
+ if self.max_train_samples is not None:
55
+ self.train_dataset = (
56
+ self.train_dataset.take(self.max_train_samples)
57
+ if self.streaming
58
+ else self.train_dataset.select(range(self.max_train_samples))
59
+ )
60
+ if self.do_eval:
61
+ if "validation" not in dataset:
62
+ raise ValueError("Evaluating requires a validation dataset")
63
+ self.eval_dataset = dataset["validation"]
64
+ if self.max_eval_samples is not None:
65
+ self.eval_dataset = (
66
+ self.eval_dataset.take(self.max_eval_samples)
67
+ if self.streaming
68
+ else self.eval_dataset.select(range(self.max_eval_samples))
69
+ )
70
+
71
+ def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
72
+ if self.streaming:
73
+ # we need to shuffle early in streaming mode
74
+ if hasattr(self, "train_dataset"):
75
+ self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
76
+ else:
77
+ # prepare rng for later shuffling
78
+ if self.seed_dataset is None:
79
+ self.seed_dataset = np.random.get_state()[1][0]
80
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
81
+
82
+ # normalize text
83
+ if normalize_text:
84
+ text_normalizer = TextNormalizer()
85
+ partial_normalize_function = partial(
86
+ normalize_function,
87
+ text_column=self.text_column,
88
+ text_normalizer=text_normalizer,
89
+ )
90
+ for ds in ["train_dataset", "eval_dataset"]:
91
+ if hasattr(self, ds):
92
+ setattr(
93
+ self,
94
+ ds,
95
+ (
96
+ getattr(self, ds).map(partial_normalize_function)
97
+ if self.streaming
98
+ else getattr(self, ds).map(
99
+ partial_normalize_function,
100
+ num_proc=self.preprocessing_num_workers,
101
+ load_from_cache_file=not self.overwrite_cache,
102
+ desc="Normalizing datasets",
103
+ )
104
+ ),
105
+ )
106
+
107
+ # preprocess
108
+ partial_preprocess_function = partial(
109
+ preprocess_function,
110
+ tokenizer=tokenizer,
111
+ text_column=self.text_column,
112
+ encoding_column=self.encoding_column,
113
+ max_source_length=self.max_source_length,
114
+ decoder_start_token_id=decoder_start_token_id,
115
+ )
116
+ for ds in ["train_dataset", "eval_dataset"]:
117
+ if hasattr(self, ds):
118
+ setattr(
119
+ self,
120
+ ds,
121
+ (
122
+ getattr(self, ds).map(
123
+ partial_preprocess_function,
124
+ batched=True,
125
+ )
126
+ if self.streaming
127
+ else getattr(self, ds).map(
128
+ partial_preprocess_function,
129
+ batched=True,
130
+ remove_columns=getattr(ds, "column_names"),
131
+ num_proc=self.preprocessing_num_workers,
132
+ load_from_cache_file=not self.overwrite_cache,
133
+ desc="Preprocessing datasets",
134
+ )
135
+ ),
136
+ )
137
+
138
+ def dataloader(self, split, batch_size, epoch=None):
139
+ def _dataloader_datasets_non_streaming(
140
+ dataset: Dataset,
141
+ batch_size: int,
142
+ rng: jax.random.PRNGKey = None,
143
+ ):
144
+ """
145
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
146
+ Shuffle batches if `shuffle` is `True`.
147
+ """
148
+ steps_per_epoch = len(dataset) // batch_size
149
+
150
+ if rng is not None:
151
+ batch_idx = jax.random.permutation(rng, len(dataset))
152
+ else:
153
+ batch_idx = jnp.arange(len(dataset))
154
+
155
+ batch_idx = batch_idx[
156
+ : steps_per_epoch * batch_size
157
+ ] # Skip incomplete batch.
158
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
159
+
160
+ for idx in batch_idx:
161
+ batch = dataset[idx]
162
+ batch = {k: jnp.array(v) for k, v in batch.items()}
163
+ batch = shard(batch)
164
+ yield batch
165
+
166
+ def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
167
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
168
+ batch = {k: [] for k in keys}
169
+ for item in dataset:
170
+ for k, v in item.items():
171
+ batch[k].append(v)
172
+ if len(batch[keys[0]]) == batch_size:
173
+ batch = {k: jnp.array(v) for k, v in batch.items()}
174
+ batch = shard(batch)
175
+ yield batch
176
+ batch = {k: [] for k in keys}
177
+
178
+ if split == "train":
179
+ ds = self.train_dataset
180
+ elif split == "eval":
181
+ ds = self.eval_dataset
182
+ else:
183
+ raise ValueError(f'split must be "train" or "eval", got {split}')
184
+
185
+ if self.streaming:
186
+ if split == "train":
187
+ ds.set_epoch(epoch)
188
+ return _dataloader_datasets_streaming(ds, batch_size)
189
+ else:
190
+ if split == "train":
191
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
192
+ return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
193
+
194
+ @property
195
+ def length(self):
196
+ len_train_dataset, len_eval_dataset = None, None
197
+ if self.streaming:
198
+ # we don't know the length, let's just assume max_samples if defined
199
+ if self.max_train_samples is not None:
200
+ len_train_dataset = self.max_train_samples
201
+ if self.max_eval_samples is not None:
202
+ len_eval_dataset = self.max_eval_samples
203
+ else:
204
+ len_train_dataset = (
205
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
206
+ )
207
+ len_eval_dataset = (
208
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
209
+ )
210
+ return len_train_dataset, len_eval_dataset
211
+
212
+
213
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
214
+ """
215
+ Shift input ids one token to the right.
216
+ """
217
+ shifted_input_ids = np.zeros(input_ids.shape)
218
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
219
+ shifted_input_ids[:, 0] = decoder_start_token_id
220
+ return shifted_input_ids
221
+
222
+
223
+ def normalize_function(example, text_column, text_normalizer):
224
+ example[text_column] = text_normalizer(example[text_column])
225
+ return example
226
+
227
+
228
+ def preprocess_function(
229
+ examples,
230
+ tokenizer,
231
+ text_column,
232
+ encoding_column,
233
+ max_source_length,
234
+ decoder_start_token_id,
235
+ ):
236
+ inputs = examples[text_column]
237
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
238
+ model_inputs = tokenizer(
239
+ inputs,
240
+ max_length=max_source_length,
241
+ padding="max_length",
242
+ truncation=True,
243
+ return_tensors="np",
244
+ )
245
+
246
+ # set up targets
247
+ # Note: labels correspond to our target indices
248
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
249
+ labels = examples[encoding_column]
250
+ labels = np.asarray(labels)
251
+
252
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
253
+ model_inputs["labels"] = labels
254
+
255
+ # In our case, this prepends the bos token and removes the last one
256
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
257
+ model_inputs["decoder_input_ids"] = decoder_input_ids
258
+
259
+ return model_inputs
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -28,9 +28,9 @@ from typing import Callable, Optional
28
  import json
29
 
30
  import datasets
31
- import numpy as np
32
- from datasets import Dataset, load_dataset
33
  from tqdm import tqdm
 
34
 
35
  import jax
36
  import jax.numpy as jnp
@@ -40,7 +40,7 @@ from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
  from transformers import (
45
  AutoTokenizer,
46
  HfArgumentParser,
@@ -49,7 +49,7 @@ from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
50
  import wandb
51
 
52
- from dalle_mini.text import TextNormalizer
53
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
55
  logger = logging.getLogger(__name__)
@@ -120,18 +120,21 @@ class DataTrainingArguments:
120
  "help": "The name of the column in the datasets containing the image encodings."
121
  },
122
  )
123
- dataset_repo_or_path: Optional[str] = field(
124
  default=None,
125
  metadata={"help": "The dataset repository containing encoded files."},
126
  )
127
  train_file: Optional[str] = field(
128
- default=None, metadata={"help": "The input training data file (a text file)."}
 
129
  )
130
  validation_file: Optional[str] = field(
131
  default=None,
132
- metadata={
133
- "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
134
- },
 
 
135
  )
136
  # data loading should not be a bottleneck so we use "streaming" mode by default
137
  streaming: bool = field(
@@ -177,6 +180,13 @@ class DataTrainingArguments:
177
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
178
  },
179
  )
 
 
 
 
 
 
 
180
 
181
  def __post_init__(self):
182
  if self.dataset_repo_or_path is None:
@@ -277,13 +287,6 @@ class TrainingArguments:
277
  "help": "Random seed for the model that will be set at the beginning of training."
278
  },
279
  )
280
- # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
281
- seed_dataset: int = field(
282
- default=None,
283
- metadata={
284
- "help": "Random seed for the dataset that will be set at the beginning of training."
285
- },
286
- )
287
 
288
  push_to_hub: bool = field(
289
  default=False,
@@ -327,45 +330,6 @@ class TrainState(train_state.TrainState):
327
  )
328
 
329
 
330
- def data_loader(
331
- dataset: Dataset,
332
- batch_size: int,
333
- rng: jax.random.PRNGKey = None,
334
- ):
335
- """
336
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
337
- Shuffle batches if `shuffle` is `True`.
338
- """
339
- steps_per_epoch = len(dataset) // batch_size
340
-
341
- if rng is not None:
342
- batch_idx = jax.random.permutation(rng, len(dataset))
343
- else:
344
- batch_idx = jnp.arange(len(dataset))
345
-
346
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
347
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
348
-
349
- for idx in batch_idx:
350
- batch = dataset[idx]
351
- batch = {k: jnp.array(v) for k, v in batch.items()}
352
- batch = shard(batch)
353
- yield batch
354
-
355
-
356
- def data_loader_streaming(dataset: Dataset, batch_size: int):
357
- keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
358
- batch = {k: [] for k in keys}
359
- for item in dataset:
360
- for k, v in item.items():
361
- batch[k].append(v)
362
- if len(batch[keys[0]]) == batch_size:
363
- batch = {k: jnp.array(v) for k, v in batch.items()}
364
- batch = shard(batch)
365
- yield batch
366
- batch = {k: [] for k in keys}
367
-
368
-
369
  def create_learning_rate_fn(
370
  num_warmup_steps: int,
371
  learning_rate: float,
@@ -447,18 +411,10 @@ def main():
447
  logger.info(f"Training/evaluation parameters {training_args}")
448
 
449
  # Load dataset
450
- if data_args.train_file is not None or data_args.validation_file is not None:
451
- data_files = {
452
- "train": data_args.train_file,
453
- "validation": data_args.validation_file,
454
- }
455
- else:
456
- data_files = None
457
- dataset = load_dataset(
458
- data_args.dataset_repo_or_path,
459
- data_files=data_files,
460
- streaming=data_args.streaming,
461
- use_auth_token=data_args.use_auth_token,
462
  )
463
 
464
  # Set up wandb run
@@ -552,141 +508,17 @@ def main():
552
  use_fast=True,
553
  )
554
 
555
- print(f"TPUs: {jax.device_count()}")
556
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
557
 
558
  # Preprocessing the datasets.
559
- # We need to tokenize inputs and targets.
560
-
561
- # Get the column names for input/target.
562
- text_column = data_args.text_column
563
- encoding_column = data_args.encoding_column
564
-
565
- def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
566
- """
567
- Shift input ids one token to the right.
568
- """
569
- shifted_input_ids = np.zeros(input_ids.shape)
570
- shifted_input_ids[:, 1:] = input_ids[:, :-1]
571
- shifted_input_ids[:, 0] = decoder_start_token_id
572
- return shifted_input_ids
573
-
574
- text_normalizer = TextNormalizer() if model.config.normalize_text else None
575
-
576
- def normalize_text(example):
577
- example[text_column] = text_normalizer(example[text_column])
578
- return example
579
-
580
- def preprocess_function(examples):
581
- inputs = examples[text_column]
582
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
583
- model_inputs = tokenizer(
584
- inputs,
585
- max_length=data_args.max_source_length,
586
- padding="max_length",
587
- truncation=True,
588
- return_tensors="np",
589
- )
590
-
591
- # set up targets
592
- # Note: labels correspond to our target indices
593
- # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
594
- labels = examples[encoding_column]
595
- labels = np.asarray(labels)
596
-
597
- # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
598
- model_inputs["labels"] = labels
599
-
600
- # In our case, this prepends the bos token and removes the last one
601
- decoder_input_ids = shift_tokens_right(
602
- labels, model.config.decoder_start_token_id
603
- )
604
- model_inputs["decoder_input_ids"] = decoder_input_ids
605
-
606
- return model_inputs
607
-
608
- if training_args.do_train:
609
- if "train" not in dataset:
610
- raise ValueError("--do_train requires a train dataset")
611
- train_dataset = dataset["train"]
612
- if data_args.max_train_samples is not None:
613
- train_dataset = (
614
- train_dataset.take(data_args.max_train_samples)
615
- if data_args.streaming
616
- else train_dataset.select(range(data_args.max_train_samples))
617
- )
618
- if data_args.streaming:
619
- train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
620
- else:
621
- seed_dataset = (
622
- training_args.seed_dataset
623
- if training_args.seed_dataset is not None
624
- else np.random.get_state()[1][0]
625
- )
626
- rng_dataset = jax.random.PRNGKey(seed_dataset)
627
- if model.config.normalize_text:
628
- train_dataset = (
629
- train_dataset.map(normalize_text)
630
- if data_args.streaming
631
- else train_dataset.map(
632
- normalize_text,
633
- num_proc=data_args.preprocessing_num_workers,
634
- load_from_cache_file=not data_args.overwrite_cache,
635
- desc="Normalizing the validation dataset",
636
- )
637
- )
638
- train_dataset = (
639
- train_dataset.map(
640
- preprocess_function,
641
- batched=True,
642
- )
643
- if data_args.streaming
644
- else train_dataset.map(
645
- preprocess_function,
646
- batched=True,
647
- num_proc=data_args.preprocessing_num_workers,
648
- remove_columns=train_dataset.column_names,
649
- load_from_cache_file=not data_args.overwrite_cache,
650
- desc="Running tokenizer on validation dataset",
651
- )
652
- )
653
 
654
- if training_args.do_eval:
655
- if "validation" not in dataset:
656
- raise ValueError("--do_eval requires a validation dataset")
657
- eval_dataset = dataset["validation"]
658
- if data_args.max_eval_samples is not None:
659
- eval_dataset = (
660
- eval_dataset.take(data_args.max_train_samples)
661
- if data_args.streaming
662
- else eval_dataset.select(range(data_args.max_train_samples))
663
- )
664
- if model.config.normalize_text:
665
- eval_dataset = (
666
- eval_dataset.map(normalize_text)
667
- if data_args.streaming
668
- else eval_dataset.map(
669
- normalize_text,
670
- num_proc=data_args.preprocessing_num_workers,
671
- load_from_cache_file=not data_args.overwrite_cache,
672
- desc="Normalizing the validation dataset",
673
- )
674
- )
675
- eval_dataset = (
676
- eval_dataset.map(
677
- preprocess_function,
678
- batched=True,
679
- )
680
- if data_args.streaming
681
- else eval_dataset.map(
682
- preprocess_function,
683
- batched=True,
684
- num_proc=data_args.preprocessing_num_workers,
685
- remove_columns=eval_dataset.column_names,
686
- load_from_cache_file=not data_args.overwrite_cache,
687
- desc="Running tokenizer on validation dataset",
688
- )
689
- )
690
 
691
  # Initialize our training
692
  rng = jax.random.PRNGKey(training_args.seed_model)
@@ -699,16 +531,7 @@ def main():
699
  )
700
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
701
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
702
- len_train_dataset, len_eval_dataset = None, None
703
- if data_args.streaming:
704
- # we don't know the length, let's just assume max_samples if defined
705
- if data_args.max_train_samples is not None:
706
- len_train_dataset = data_args.max_train_samples
707
- if data_args.max_eval_samples is not None:
708
- len_eval_dataset = data_args.max_eval_samples
709
- else:
710
- len_train_dataset = len(train_dataset)
711
- len_eval_dataset = len(eval_dataset)
712
  steps_per_epoch = (
713
  len_train_dataset // train_batch_size if len_train_dataset is not None else None
714
  )
@@ -854,8 +677,8 @@ def main():
854
  # add interesting config parameters
855
  wandb.config.update(
856
  {
857
- "len_train": len_train_dataset,
858
- "len_eval": len_eval_dataset,
859
  "batch_size_per_update": batch_size_per_update,
860
  }
861
  )
@@ -867,10 +690,7 @@ def main():
867
  # ======================== Evaluating ==============================
868
  eval_metrics = []
869
  if training_args.do_eval:
870
- if data_args.streaming:
871
- eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
872
- else:
873
- eval_loader = data_loader(eval_dataset, eval_batch_size)
874
  eval_steps = (
875
  len_eval_dataset // eval_batch_size
876
  if len_eval_dataset is not None
@@ -985,12 +805,7 @@ def main():
985
  wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
986
 
987
  # Generate an epoch by shuffling sampling indices from the train dataset
988
- if data_args.streaming:
989
- train_dataset.set_epoch(epoch) # shuffle dataset
990
- train_loader = data_loader_streaming(train_dataset, train_batch_size)
991
- else:
992
- rng_dataset, input_rng = jax.random.split(rng_dataset)
993
- train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
994
  # train
995
  for batch in tqdm(
996
  train_loader,
 
28
  import json
29
 
30
  import datasets
31
+ from datasets import Dataset
 
32
  from tqdm import tqdm
33
+ from dataclasses import asdict
34
 
35
  import jax
36
  import jax.numpy as jnp
 
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard_prng_key
44
  from transformers import (
45
  AutoTokenizer,
46
  HfArgumentParser,
 
49
 
50
  import wandb
51
 
52
+ from dalle_mini.data import Dataset
53
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
55
  logger = logging.getLogger(__name__)
 
120
  "help": "The name of the column in the datasets containing the image encodings."
121
  },
122
  )
123
+ dataset_repo_or_path: str = field(
124
  default=None,
125
  metadata={"help": "The dataset repository containing encoded files."},
126
  )
127
  train_file: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "The input training data file (glob acceptable)."},
130
  )
131
  validation_file: Optional[str] = field(
132
  default=None,
133
+ metadata={"help": "An optional input evaluation data file (glob acceptable)."},
134
+ )
135
+ dataset_type: str = field(
136
+ default="datasets",
137
+ metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
138
  )
139
  # data loading should not be a bottleneck so we use "streaming" mode by default
140
  streaming: bool = field(
 
180
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
181
  },
182
  )
183
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
184
+ seed_dataset: int = field(
185
+ default=None,
186
+ metadata={
187
+ "help": "Random seed for the dataset that will be set at the beginning of training."
188
+ },
189
+ )
190
 
191
  def __post_init__(self):
192
  if self.dataset_repo_or_path is None:
 
287
  "help": "Random seed for the model that will be set at the beginning of training."
288
  },
289
  )
 
 
 
 
 
 
 
290
 
291
  push_to_hub: bool = field(
292
  default=False,
 
330
  )
331
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  def create_learning_rate_fn(
334
  num_warmup_steps: int,
335
  learning_rate: float,
 
411
  logger.info(f"Training/evaluation parameters {training_args}")
412
 
413
  # Load dataset
414
+ dataset = Dataset(
415
+ **asdict(data_args),
416
+ do_train=training_args.do_train,
417
+ do_eval=training_args.do_eval,
 
 
 
 
 
 
 
 
418
  )
419
 
420
  # Set up wandb run
 
508
  use_fast=True,
509
  )
510
 
511
+ logger.info(f"TPUs: {jax.device_count()}")
512
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
513
 
514
  # Preprocessing the datasets.
515
+ # We need to normalize and tokenize inputs and targets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
+ dataset.preprocess(
518
+ tokenizer=tokenizer,
519
+ decoder_start_token_id=model.config.decoder_start_token_id,
520
+ normalize_text=model.config.normalize_text,
521
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  # Initialize our training
524
  rng = jax.random.PRNGKey(training_args.seed_model)
 
531
  )
532
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
533
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
534
+ len_train_dataset, len_eval_dataset = dataset.length
 
 
 
 
 
 
 
 
 
535
  steps_per_epoch = (
536
  len_train_dataset // train_batch_size if len_train_dataset is not None else None
537
  )
 
677
  # add interesting config parameters
678
  wandb.config.update(
679
  {
680
+ "len_train_dataset": len_train_dataset,
681
+ "len_eval_dataset": len_eval_dataset,
682
  "batch_size_per_update": batch_size_per_update,
683
  }
684
  )
 
690
  # ======================== Evaluating ==============================
691
  eval_metrics = []
692
  if training_args.do_eval:
693
+ eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
 
694
  eval_steps = (
695
  len_eval_dataset // eval_batch_size
696
  if len_eval_dataset is not None
 
805
  wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
806
 
807
  # Generate an epoch by shuffling sampling indices from the train dataset
808
+ train_loader = dataset.dataloader("train", train_batch_size)
 
 
 
 
 
809
  # train
810
  for batch in tqdm(
811
  train_loader,