Spaces:
Running
Running
Merge pull request #111 from borisdayma/feat-data
Browse files- dalle_mini/data.py +259 -0
- dev/seq2seq/run_seq2seq_flax.py +35 -220
dalle_mini/data.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from datasets import load_dataset, Dataset
|
3 |
+
from functools import partial
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from flax.training.common_utils import shard
|
8 |
+
from .text import TextNormalizer
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Dataset:
|
13 |
+
dataset_repo_or_path: str
|
14 |
+
train_file: str = None
|
15 |
+
validation_file: str = None
|
16 |
+
dataset_type: str = "dataset"
|
17 |
+
streaming: bool = True
|
18 |
+
use_auth_token: bool = False
|
19 |
+
text_column: str = "caption"
|
20 |
+
encoding_column: str = "encoding"
|
21 |
+
max_source_length: int = 128
|
22 |
+
max_train_samples: int = None
|
23 |
+
max_eval_samples: int = None
|
24 |
+
preprocessing_num_workers: int = None
|
25 |
+
overwrite_cache: bool = False
|
26 |
+
do_train: bool = False
|
27 |
+
do_eval: bool = True
|
28 |
+
seed_dataset: int = None
|
29 |
+
train_dataset: Dataset = field(init=False)
|
30 |
+
eval_dataset: Dataset = field(init=False)
|
31 |
+
rng_dataset: jnp.ndarray = field(init=False)
|
32 |
+
|
33 |
+
def __post_init__(self):
|
34 |
+
# define data_files
|
35 |
+
if self.train_file is not None or self.validation_file is not None:
|
36 |
+
data_files = {
|
37 |
+
"train": self.train_file,
|
38 |
+
"validation": self.validation_file,
|
39 |
+
}
|
40 |
+
else:
|
41 |
+
data_files = None
|
42 |
+
|
43 |
+
# load dataset
|
44 |
+
dataset = load_dataset(
|
45 |
+
self.dataset_repo_or_path,
|
46 |
+
data_files=data_files,
|
47 |
+
streaming=self.streaming,
|
48 |
+
use_auth_token=self.use_auth_token,
|
49 |
+
)
|
50 |
+
if self.do_train:
|
51 |
+
if "train" not in dataset:
|
52 |
+
raise ValueError("Training requires a training dataset")
|
53 |
+
self.train_dataset = dataset["train"]
|
54 |
+
if self.max_train_samples is not None:
|
55 |
+
self.train_dataset = (
|
56 |
+
self.train_dataset.take(self.max_train_samples)
|
57 |
+
if self.streaming
|
58 |
+
else self.train_dataset.select(range(self.max_train_samples))
|
59 |
+
)
|
60 |
+
if self.do_eval:
|
61 |
+
if "validation" not in dataset:
|
62 |
+
raise ValueError("Evaluating requires a validation dataset")
|
63 |
+
self.eval_dataset = dataset["validation"]
|
64 |
+
if self.max_eval_samples is not None:
|
65 |
+
self.eval_dataset = (
|
66 |
+
self.eval_dataset.take(self.max_eval_samples)
|
67 |
+
if self.streaming
|
68 |
+
else self.eval_dataset.select(range(self.max_eval_samples))
|
69 |
+
)
|
70 |
+
|
71 |
+
def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
|
72 |
+
if self.streaming:
|
73 |
+
# we need to shuffle early in streaming mode
|
74 |
+
if hasattr(self, "train_dataset"):
|
75 |
+
self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
|
76 |
+
else:
|
77 |
+
# prepare rng for later shuffling
|
78 |
+
if self.seed_dataset is None:
|
79 |
+
self.seed_dataset = np.random.get_state()[1][0]
|
80 |
+
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
81 |
+
|
82 |
+
# normalize text
|
83 |
+
if normalize_text:
|
84 |
+
text_normalizer = TextNormalizer()
|
85 |
+
partial_normalize_function = partial(
|
86 |
+
normalize_function,
|
87 |
+
text_column=self.text_column,
|
88 |
+
text_normalizer=text_normalizer,
|
89 |
+
)
|
90 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
91 |
+
if hasattr(self, ds):
|
92 |
+
setattr(
|
93 |
+
self,
|
94 |
+
ds,
|
95 |
+
(
|
96 |
+
getattr(self, ds).map(partial_normalize_function)
|
97 |
+
if self.streaming
|
98 |
+
else getattr(self, ds).map(
|
99 |
+
partial_normalize_function,
|
100 |
+
num_proc=self.preprocessing_num_workers,
|
101 |
+
load_from_cache_file=not self.overwrite_cache,
|
102 |
+
desc="Normalizing datasets",
|
103 |
+
)
|
104 |
+
),
|
105 |
+
)
|
106 |
+
|
107 |
+
# preprocess
|
108 |
+
partial_preprocess_function = partial(
|
109 |
+
preprocess_function,
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
text_column=self.text_column,
|
112 |
+
encoding_column=self.encoding_column,
|
113 |
+
max_source_length=self.max_source_length,
|
114 |
+
decoder_start_token_id=decoder_start_token_id,
|
115 |
+
)
|
116 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
117 |
+
if hasattr(self, ds):
|
118 |
+
setattr(
|
119 |
+
self,
|
120 |
+
ds,
|
121 |
+
(
|
122 |
+
getattr(self, ds).map(
|
123 |
+
partial_preprocess_function,
|
124 |
+
batched=True,
|
125 |
+
)
|
126 |
+
if self.streaming
|
127 |
+
else getattr(self, ds).map(
|
128 |
+
partial_preprocess_function,
|
129 |
+
batched=True,
|
130 |
+
remove_columns=getattr(ds, "column_names"),
|
131 |
+
num_proc=self.preprocessing_num_workers,
|
132 |
+
load_from_cache_file=not self.overwrite_cache,
|
133 |
+
desc="Preprocessing datasets",
|
134 |
+
)
|
135 |
+
),
|
136 |
+
)
|
137 |
+
|
138 |
+
def dataloader(self, split, batch_size, epoch=None):
|
139 |
+
def _dataloader_datasets_non_streaming(
|
140 |
+
dataset: Dataset,
|
141 |
+
batch_size: int,
|
142 |
+
rng: jax.random.PRNGKey = None,
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
146 |
+
Shuffle batches if `shuffle` is `True`.
|
147 |
+
"""
|
148 |
+
steps_per_epoch = len(dataset) // batch_size
|
149 |
+
|
150 |
+
if rng is not None:
|
151 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
152 |
+
else:
|
153 |
+
batch_idx = jnp.arange(len(dataset))
|
154 |
+
|
155 |
+
batch_idx = batch_idx[
|
156 |
+
: steps_per_epoch * batch_size
|
157 |
+
] # Skip incomplete batch.
|
158 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
159 |
+
|
160 |
+
for idx in batch_idx:
|
161 |
+
batch = dataset[idx]
|
162 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
163 |
+
batch = shard(batch)
|
164 |
+
yield batch
|
165 |
+
|
166 |
+
def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
|
167 |
+
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
168 |
+
batch = {k: [] for k in keys}
|
169 |
+
for item in dataset:
|
170 |
+
for k, v in item.items():
|
171 |
+
batch[k].append(v)
|
172 |
+
if len(batch[keys[0]]) == batch_size:
|
173 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
174 |
+
batch = shard(batch)
|
175 |
+
yield batch
|
176 |
+
batch = {k: [] for k in keys}
|
177 |
+
|
178 |
+
if split == "train":
|
179 |
+
ds = self.train_dataset
|
180 |
+
elif split == "eval":
|
181 |
+
ds = self.eval_dataset
|
182 |
+
else:
|
183 |
+
raise ValueError(f'split must be "train" or "eval", got {split}')
|
184 |
+
|
185 |
+
if self.streaming:
|
186 |
+
if split == "train":
|
187 |
+
ds.set_epoch(epoch)
|
188 |
+
return _dataloader_datasets_streaming(ds, batch_size)
|
189 |
+
else:
|
190 |
+
if split == "train":
|
191 |
+
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
192 |
+
return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
|
193 |
+
|
194 |
+
@property
|
195 |
+
def length(self):
|
196 |
+
len_train_dataset, len_eval_dataset = None, None
|
197 |
+
if self.streaming:
|
198 |
+
# we don't know the length, let's just assume max_samples if defined
|
199 |
+
if self.max_train_samples is not None:
|
200 |
+
len_train_dataset = self.max_train_samples
|
201 |
+
if self.max_eval_samples is not None:
|
202 |
+
len_eval_dataset = self.max_eval_samples
|
203 |
+
else:
|
204 |
+
len_train_dataset = (
|
205 |
+
len(self.train_dataset) if hasattr(self, "train_dataset") else None
|
206 |
+
)
|
207 |
+
len_eval_dataset = (
|
208 |
+
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
|
209 |
+
)
|
210 |
+
return len_train_dataset, len_eval_dataset
|
211 |
+
|
212 |
+
|
213 |
+
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
214 |
+
"""
|
215 |
+
Shift input ids one token to the right.
|
216 |
+
"""
|
217 |
+
shifted_input_ids = np.zeros(input_ids.shape)
|
218 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
219 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
220 |
+
return shifted_input_ids
|
221 |
+
|
222 |
+
|
223 |
+
def normalize_function(example, text_column, text_normalizer):
|
224 |
+
example[text_column] = text_normalizer(example[text_column])
|
225 |
+
return example
|
226 |
+
|
227 |
+
|
228 |
+
def preprocess_function(
|
229 |
+
examples,
|
230 |
+
tokenizer,
|
231 |
+
text_column,
|
232 |
+
encoding_column,
|
233 |
+
max_source_length,
|
234 |
+
decoder_start_token_id,
|
235 |
+
):
|
236 |
+
inputs = examples[text_column]
|
237 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
238 |
+
model_inputs = tokenizer(
|
239 |
+
inputs,
|
240 |
+
max_length=max_source_length,
|
241 |
+
padding="max_length",
|
242 |
+
truncation=True,
|
243 |
+
return_tensors="np",
|
244 |
+
)
|
245 |
+
|
246 |
+
# set up targets
|
247 |
+
# Note: labels correspond to our target indices
|
248 |
+
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
249 |
+
labels = examples[encoding_column]
|
250 |
+
labels = np.asarray(labels)
|
251 |
+
|
252 |
+
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
253 |
+
model_inputs["labels"] = labels
|
254 |
+
|
255 |
+
# In our case, this prepends the bos token and removes the last one
|
256 |
+
decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
|
257 |
+
model_inputs["decoder_input_ids"] = decoder_input_ids
|
258 |
+
|
259 |
+
return model_inputs
|
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -28,9 +28,9 @@ from typing import Callable, Optional
|
|
28 |
import json
|
29 |
|
30 |
import datasets
|
31 |
-
|
32 |
-
from datasets import Dataset, load_dataset
|
33 |
from tqdm import tqdm
|
|
|
34 |
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
@@ -40,7 +40,7 @@ from flax import jax_utils, traverse_util
|
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.jax_utils import unreplicate
|
42 |
from flax.training import train_state
|
43 |
-
from flax.training.common_utils import get_metrics, onehot,
|
44 |
from transformers import (
|
45 |
AutoTokenizer,
|
46 |
HfArgumentParser,
|
@@ -49,7 +49,7 @@ from transformers.models.bart.modeling_flax_bart import BartConfig
|
|
49 |
|
50 |
import wandb
|
51 |
|
52 |
-
from dalle_mini.
|
53 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
@@ -120,18 +120,21 @@ class DataTrainingArguments:
|
|
120 |
"help": "The name of the column in the datasets containing the image encodings."
|
121 |
},
|
122 |
)
|
123 |
-
dataset_repo_or_path:
|
124 |
default=None,
|
125 |
metadata={"help": "The dataset repository containing encoded files."},
|
126 |
)
|
127 |
train_file: Optional[str] = field(
|
128 |
-
default=None,
|
|
|
129 |
)
|
130 |
validation_file: Optional[str] = field(
|
131 |
default=None,
|
132 |
-
metadata={
|
133 |
-
|
134 |
-
|
|
|
|
|
135 |
)
|
136 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
137 |
streaming: bool = field(
|
@@ -177,6 +180,13 @@ class DataTrainingArguments:
|
|
177 |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
178 |
},
|
179 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def __post_init__(self):
|
182 |
if self.dataset_repo_or_path is None:
|
@@ -277,13 +287,6 @@ class TrainingArguments:
|
|
277 |
"help": "Random seed for the model that will be set at the beginning of training."
|
278 |
},
|
279 |
)
|
280 |
-
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
281 |
-
seed_dataset: int = field(
|
282 |
-
default=None,
|
283 |
-
metadata={
|
284 |
-
"help": "Random seed for the dataset that will be set at the beginning of training."
|
285 |
-
},
|
286 |
-
)
|
287 |
|
288 |
push_to_hub: bool = field(
|
289 |
default=False,
|
@@ -327,45 +330,6 @@ class TrainState(train_state.TrainState):
|
|
327 |
)
|
328 |
|
329 |
|
330 |
-
def data_loader(
|
331 |
-
dataset: Dataset,
|
332 |
-
batch_size: int,
|
333 |
-
rng: jax.random.PRNGKey = None,
|
334 |
-
):
|
335 |
-
"""
|
336 |
-
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
337 |
-
Shuffle batches if `shuffle` is `True`.
|
338 |
-
"""
|
339 |
-
steps_per_epoch = len(dataset) // batch_size
|
340 |
-
|
341 |
-
if rng is not None:
|
342 |
-
batch_idx = jax.random.permutation(rng, len(dataset))
|
343 |
-
else:
|
344 |
-
batch_idx = jnp.arange(len(dataset))
|
345 |
-
|
346 |
-
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
347 |
-
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
348 |
-
|
349 |
-
for idx in batch_idx:
|
350 |
-
batch = dataset[idx]
|
351 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
352 |
-
batch = shard(batch)
|
353 |
-
yield batch
|
354 |
-
|
355 |
-
|
356 |
-
def data_loader_streaming(dataset: Dataset, batch_size: int):
|
357 |
-
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
358 |
-
batch = {k: [] for k in keys}
|
359 |
-
for item in dataset:
|
360 |
-
for k, v in item.items():
|
361 |
-
batch[k].append(v)
|
362 |
-
if len(batch[keys[0]]) == batch_size:
|
363 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
364 |
-
batch = shard(batch)
|
365 |
-
yield batch
|
366 |
-
batch = {k: [] for k in keys}
|
367 |
-
|
368 |
-
|
369 |
def create_learning_rate_fn(
|
370 |
num_warmup_steps: int,
|
371 |
learning_rate: float,
|
@@ -447,18 +411,10 @@ def main():
|
|
447 |
logger.info(f"Training/evaluation parameters {training_args}")
|
448 |
|
449 |
# Load dataset
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
}
|
455 |
-
else:
|
456 |
-
data_files = None
|
457 |
-
dataset = load_dataset(
|
458 |
-
data_args.dataset_repo_or_path,
|
459 |
-
data_files=data_files,
|
460 |
-
streaming=data_args.streaming,
|
461 |
-
use_auth_token=data_args.use_auth_token,
|
462 |
)
|
463 |
|
464 |
# Set up wandb run
|
@@ -552,141 +508,17 @@ def main():
|
|
552 |
use_fast=True,
|
553 |
)
|
554 |
|
555 |
-
|
556 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
557 |
|
558 |
# Preprocessing the datasets.
|
559 |
-
# We need to tokenize inputs and targets.
|
560 |
-
|
561 |
-
# Get the column names for input/target.
|
562 |
-
text_column = data_args.text_column
|
563 |
-
encoding_column = data_args.encoding_column
|
564 |
-
|
565 |
-
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
566 |
-
"""
|
567 |
-
Shift input ids one token to the right.
|
568 |
-
"""
|
569 |
-
shifted_input_ids = np.zeros(input_ids.shape)
|
570 |
-
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
571 |
-
shifted_input_ids[:, 0] = decoder_start_token_id
|
572 |
-
return shifted_input_ids
|
573 |
-
|
574 |
-
text_normalizer = TextNormalizer() if model.config.normalize_text else None
|
575 |
-
|
576 |
-
def normalize_text(example):
|
577 |
-
example[text_column] = text_normalizer(example[text_column])
|
578 |
-
return example
|
579 |
-
|
580 |
-
def preprocess_function(examples):
|
581 |
-
inputs = examples[text_column]
|
582 |
-
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
583 |
-
model_inputs = tokenizer(
|
584 |
-
inputs,
|
585 |
-
max_length=data_args.max_source_length,
|
586 |
-
padding="max_length",
|
587 |
-
truncation=True,
|
588 |
-
return_tensors="np",
|
589 |
-
)
|
590 |
-
|
591 |
-
# set up targets
|
592 |
-
# Note: labels correspond to our target indices
|
593 |
-
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
594 |
-
labels = examples[encoding_column]
|
595 |
-
labels = np.asarray(labels)
|
596 |
-
|
597 |
-
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
598 |
-
model_inputs["labels"] = labels
|
599 |
-
|
600 |
-
# In our case, this prepends the bos token and removes the last one
|
601 |
-
decoder_input_ids = shift_tokens_right(
|
602 |
-
labels, model.config.decoder_start_token_id
|
603 |
-
)
|
604 |
-
model_inputs["decoder_input_ids"] = decoder_input_ids
|
605 |
-
|
606 |
-
return model_inputs
|
607 |
-
|
608 |
-
if training_args.do_train:
|
609 |
-
if "train" not in dataset:
|
610 |
-
raise ValueError("--do_train requires a train dataset")
|
611 |
-
train_dataset = dataset["train"]
|
612 |
-
if data_args.max_train_samples is not None:
|
613 |
-
train_dataset = (
|
614 |
-
train_dataset.take(data_args.max_train_samples)
|
615 |
-
if data_args.streaming
|
616 |
-
else train_dataset.select(range(data_args.max_train_samples))
|
617 |
-
)
|
618 |
-
if data_args.streaming:
|
619 |
-
train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
|
620 |
-
else:
|
621 |
-
seed_dataset = (
|
622 |
-
training_args.seed_dataset
|
623 |
-
if training_args.seed_dataset is not None
|
624 |
-
else np.random.get_state()[1][0]
|
625 |
-
)
|
626 |
-
rng_dataset = jax.random.PRNGKey(seed_dataset)
|
627 |
-
if model.config.normalize_text:
|
628 |
-
train_dataset = (
|
629 |
-
train_dataset.map(normalize_text)
|
630 |
-
if data_args.streaming
|
631 |
-
else train_dataset.map(
|
632 |
-
normalize_text,
|
633 |
-
num_proc=data_args.preprocessing_num_workers,
|
634 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
635 |
-
desc="Normalizing the validation dataset",
|
636 |
-
)
|
637 |
-
)
|
638 |
-
train_dataset = (
|
639 |
-
train_dataset.map(
|
640 |
-
preprocess_function,
|
641 |
-
batched=True,
|
642 |
-
)
|
643 |
-
if data_args.streaming
|
644 |
-
else train_dataset.map(
|
645 |
-
preprocess_function,
|
646 |
-
batched=True,
|
647 |
-
num_proc=data_args.preprocessing_num_workers,
|
648 |
-
remove_columns=train_dataset.column_names,
|
649 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
650 |
-
desc="Running tokenizer on validation dataset",
|
651 |
-
)
|
652 |
-
)
|
653 |
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
eval_dataset = (
|
660 |
-
eval_dataset.take(data_args.max_train_samples)
|
661 |
-
if data_args.streaming
|
662 |
-
else eval_dataset.select(range(data_args.max_train_samples))
|
663 |
-
)
|
664 |
-
if model.config.normalize_text:
|
665 |
-
eval_dataset = (
|
666 |
-
eval_dataset.map(normalize_text)
|
667 |
-
if data_args.streaming
|
668 |
-
else eval_dataset.map(
|
669 |
-
normalize_text,
|
670 |
-
num_proc=data_args.preprocessing_num_workers,
|
671 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
672 |
-
desc="Normalizing the validation dataset",
|
673 |
-
)
|
674 |
-
)
|
675 |
-
eval_dataset = (
|
676 |
-
eval_dataset.map(
|
677 |
-
preprocess_function,
|
678 |
-
batched=True,
|
679 |
-
)
|
680 |
-
if data_args.streaming
|
681 |
-
else eval_dataset.map(
|
682 |
-
preprocess_function,
|
683 |
-
batched=True,
|
684 |
-
num_proc=data_args.preprocessing_num_workers,
|
685 |
-
remove_columns=eval_dataset.column_names,
|
686 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
687 |
-
desc="Running tokenizer on validation dataset",
|
688 |
-
)
|
689 |
-
)
|
690 |
|
691 |
# Initialize our training
|
692 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
@@ -699,16 +531,7 @@ def main():
|
|
699 |
)
|
700 |
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
701 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
702 |
-
len_train_dataset, len_eval_dataset =
|
703 |
-
if data_args.streaming:
|
704 |
-
# we don't know the length, let's just assume max_samples if defined
|
705 |
-
if data_args.max_train_samples is not None:
|
706 |
-
len_train_dataset = data_args.max_train_samples
|
707 |
-
if data_args.max_eval_samples is not None:
|
708 |
-
len_eval_dataset = data_args.max_eval_samples
|
709 |
-
else:
|
710 |
-
len_train_dataset = len(train_dataset)
|
711 |
-
len_eval_dataset = len(eval_dataset)
|
712 |
steps_per_epoch = (
|
713 |
len_train_dataset // train_batch_size if len_train_dataset is not None else None
|
714 |
)
|
@@ -854,8 +677,8 @@ def main():
|
|
854 |
# add interesting config parameters
|
855 |
wandb.config.update(
|
856 |
{
|
857 |
-
"
|
858 |
-
"
|
859 |
"batch_size_per_update": batch_size_per_update,
|
860 |
}
|
861 |
)
|
@@ -867,10 +690,7 @@ def main():
|
|
867 |
# ======================== Evaluating ==============================
|
868 |
eval_metrics = []
|
869 |
if training_args.do_eval:
|
870 |
-
|
871 |
-
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
872 |
-
else:
|
873 |
-
eval_loader = data_loader(eval_dataset, eval_batch_size)
|
874 |
eval_steps = (
|
875 |
len_eval_dataset // eval_batch_size
|
876 |
if len_eval_dataset is not None
|
@@ -985,12 +805,7 @@ def main():
|
|
985 |
wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
|
986 |
|
987 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
988 |
-
|
989 |
-
train_dataset.set_epoch(epoch) # shuffle dataset
|
990 |
-
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
991 |
-
else:
|
992 |
-
rng_dataset, input_rng = jax.random.split(rng_dataset)
|
993 |
-
train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
|
994 |
# train
|
995 |
for batch in tqdm(
|
996 |
train_loader,
|
|
|
28 |
import json
|
29 |
|
30 |
import datasets
|
31 |
+
from datasets import Dataset
|
|
|
32 |
from tqdm import tqdm
|
33 |
+
from dataclasses import asdict
|
34 |
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
|
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.jax_utils import unreplicate
|
42 |
from flax.training import train_state
|
43 |
+
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
44 |
from transformers import (
|
45 |
AutoTokenizer,
|
46 |
HfArgumentParser,
|
|
|
49 |
|
50 |
import wandb
|
51 |
|
52 |
+
from dalle_mini.data import Dataset
|
53 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
|
|
120 |
"help": "The name of the column in the datasets containing the image encodings."
|
121 |
},
|
122 |
)
|
123 |
+
dataset_repo_or_path: str = field(
|
124 |
default=None,
|
125 |
metadata={"help": "The dataset repository containing encoded files."},
|
126 |
)
|
127 |
train_file: Optional[str] = field(
|
128 |
+
default=None,
|
129 |
+
metadata={"help": "The input training data file (glob acceptable)."},
|
130 |
)
|
131 |
validation_file: Optional[str] = field(
|
132 |
default=None,
|
133 |
+
metadata={"help": "An optional input evaluation data file (glob acceptable)."},
|
134 |
+
)
|
135 |
+
dataset_type: str = field(
|
136 |
+
default="datasets",
|
137 |
+
metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
|
138 |
)
|
139 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
140 |
streaming: bool = field(
|
|
|
180 |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
181 |
},
|
182 |
)
|
183 |
+
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
184 |
+
seed_dataset: int = field(
|
185 |
+
default=None,
|
186 |
+
metadata={
|
187 |
+
"help": "Random seed for the dataset that will be set at the beginning of training."
|
188 |
+
},
|
189 |
+
)
|
190 |
|
191 |
def __post_init__(self):
|
192 |
if self.dataset_repo_or_path is None:
|
|
|
287 |
"help": "Random seed for the model that will be set at the beginning of training."
|
288 |
},
|
289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
push_to_hub: bool = field(
|
292 |
default=False,
|
|
|
330 |
)
|
331 |
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
def create_learning_rate_fn(
|
334 |
num_warmup_steps: int,
|
335 |
learning_rate: float,
|
|
|
411 |
logger.info(f"Training/evaluation parameters {training_args}")
|
412 |
|
413 |
# Load dataset
|
414 |
+
dataset = Dataset(
|
415 |
+
**asdict(data_args),
|
416 |
+
do_train=training_args.do_train,
|
417 |
+
do_eval=training_args.do_eval,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
)
|
419 |
|
420 |
# Set up wandb run
|
|
|
508 |
use_fast=True,
|
509 |
)
|
510 |
|
511 |
+
logger.info(f"TPUs: {jax.device_count()}")
|
512 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
513 |
|
514 |
# Preprocessing the datasets.
|
515 |
+
# We need to normalize and tokenize inputs and targets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
+
dataset.preprocess(
|
518 |
+
tokenizer=tokenizer,
|
519 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
520 |
+
normalize_text=model.config.normalize_text,
|
521 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
# Initialize our training
|
524 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
|
|
531 |
)
|
532 |
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
533 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
534 |
+
len_train_dataset, len_eval_dataset = dataset.length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
steps_per_epoch = (
|
536 |
len_train_dataset // train_batch_size if len_train_dataset is not None else None
|
537 |
)
|
|
|
677 |
# add interesting config parameters
|
678 |
wandb.config.update(
|
679 |
{
|
680 |
+
"len_train_dataset": len_train_dataset,
|
681 |
+
"len_eval_dataset": len_eval_dataset,
|
682 |
"batch_size_per_update": batch_size_per_update,
|
683 |
}
|
684 |
)
|
|
|
690 |
# ======================== Evaluating ==============================
|
691 |
eval_metrics = []
|
692 |
if training_args.do_eval:
|
693 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
|
|
|
|
|
|
694 |
eval_steps = (
|
695 |
len_eval_dataset // eval_batch_size
|
696 |
if len_eval_dataset is not None
|
|
|
805 |
wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
|
806 |
|
807 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
808 |
+
train_loader = dataset.dataloader("train", train_batch_size)
|
|
|
|
|
|
|
|
|
|
|
809 |
# train
|
810 |
for batch in tqdm(
|
811 |
train_loader,
|