Spaces:
Running
Running
Merge pull request #98 from borisdayma/feat-seq2seq
Browse files- dev/seq2seq/run_seq2seq_flax.py +27 -64
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -23,21 +23,19 @@ import os
|
|
23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
24 |
import sys
|
25 |
from dataclasses import dataclass, field
|
26 |
-
from functools import partial
|
27 |
from pathlib import Path
|
28 |
from typing import Callable, Optional
|
29 |
import json
|
30 |
|
31 |
import datasets
|
32 |
import numpy as np
|
33 |
-
from datasets import Dataset, load_dataset
|
34 |
from tqdm import tqdm
|
35 |
|
36 |
import jax
|
37 |
import jax.numpy as jnp
|
38 |
import optax
|
39 |
import transformers
|
40 |
-
from filelock import FileLock
|
41 |
from flax import jax_utils, traverse_util
|
42 |
from flax.serialization import from_bytes, to_bytes
|
43 |
import flax.linen as nn
|
@@ -45,15 +43,12 @@ from flax.jax_utils import unreplicate
|
|
45 |
from flax.training import train_state
|
46 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
47 |
from transformers import (
|
48 |
-
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
49 |
AutoTokenizer,
|
50 |
-
FlaxAutoModelForSeq2SeqLM,
|
51 |
FlaxBartForConditionalGeneration,
|
52 |
HfArgumentParser,
|
53 |
TrainingArguments,
|
54 |
)
|
55 |
from transformers.models.bart.modeling_flax_bart import *
|
56 |
-
from transformers.file_utils import is_offline_mode
|
57 |
|
58 |
import wandb
|
59 |
|
@@ -62,10 +57,6 @@ from dalle_mini.text import TextNormalizer
|
|
62 |
logger = pylogging.getLogger(__name__)
|
63 |
|
64 |
|
65 |
-
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
66 |
-
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
67 |
-
|
68 |
-
|
69 |
# Model hyperparameters, for convenience
|
70 |
# TODO: the model has now it's own definition file and should be imported
|
71 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
@@ -87,25 +78,12 @@ class ModelArguments:
|
|
87 |
"Don't set if you want to train a model from scratch."
|
88 |
},
|
89 |
)
|
90 |
-
model_type: Optional[str] = field(
|
91 |
-
default=None,
|
92 |
-
metadata={
|
93 |
-
"help": "If training from scratch, pass a model type from the list: "
|
94 |
-
+ ", ".join(MODEL_TYPES)
|
95 |
-
},
|
96 |
-
)
|
97 |
config_name: Optional[str] = field(
|
98 |
default=None,
|
99 |
metadata={
|
100 |
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
},
|
102 |
)
|
103 |
-
cache_dir: Optional[str] = field(
|
104 |
-
default=None,
|
105 |
-
metadata={
|
106 |
-
"help": "Where do you want to store the pretrained models downloaded from s3"
|
107 |
-
},
|
108 |
-
)
|
109 |
use_fast_tokenizer: bool = field(
|
110 |
default=True,
|
111 |
metadata={
|
@@ -281,6 +259,19 @@ class TrainState(train_state.TrainState):
|
|
281 |
dropout_rng=shard_prng_key(self.dropout_rng)
|
282 |
)
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
class CustomFlaxBartModule(FlaxBartModule):
|
286 |
def setup(self):
|
@@ -480,22 +471,6 @@ def main():
|
|
480 |
streaming=data_args.streaming,
|
481 |
)
|
482 |
|
483 |
-
# Set up items to load or create
|
484 |
-
tokenizer = None
|
485 |
-
artifact_dir = None
|
486 |
-
|
487 |
-
def restore_state(state, artifact_dir):
|
488 |
-
# restore optimizer state
|
489 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
490 |
-
opt_state = from_bytes(state.opt_state, f.read())
|
491 |
-
|
492 |
-
# restore steps
|
493 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
494 |
-
training_state = json.load(f)
|
495 |
-
step = training_state["step"]
|
496 |
-
|
497 |
-
return step, opt_state
|
498 |
-
|
499 |
# Set up wandb run
|
500 |
wandb.init(
|
501 |
entity="dalle-mini",
|
@@ -510,22 +485,11 @@ def main():
|
|
510 |
artifact_dir = artifact.download()
|
511 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
512 |
|
513 |
-
#
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
# used in the preprocessing function
|
520 |
-
config = model.config
|
521 |
-
|
522 |
-
# load tokenizer if present
|
523 |
-
if (Path(artifact_dir) / "tokenizer_config.json").exists():
|
524 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
525 |
-
model_args.model_name_or_path,
|
526 |
-
cache_dir=model_args.cache_dir,
|
527 |
-
use_fast=model_args.use_fast_tokenizer,
|
528 |
-
)
|
529 |
|
530 |
else:
|
531 |
# Set up our new model config
|
@@ -552,11 +516,9 @@ def main():
|
|
552 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
553 |
)
|
554 |
|
555 |
-
|
556 |
-
if tokenizer is None:
|
557 |
tokenizer = AutoTokenizer.from_pretrained(
|
558 |
model_args.model_name_or_path,
|
559 |
-
cache_dir=model_args.cache_dir,
|
560 |
use_fast=model_args.use_fast_tokenizer,
|
561 |
)
|
562 |
|
@@ -609,7 +571,9 @@ def main():
|
|
609 |
model_inputs["labels"] = labels
|
610 |
|
611 |
# In our case, this prepends the bos token and removes the last one
|
612 |
-
decoder_input_ids = shift_tokens_right(
|
|
|
|
|
613 |
model_inputs["decoder_input_ids"] = decoder_input_ids
|
614 |
|
615 |
return model_inputs
|
@@ -787,9 +751,9 @@ def main():
|
|
787 |
)
|
788 |
if model_args.from_checkpoint is not None:
|
789 |
# restore optimizer state and step
|
790 |
-
|
791 |
-
state = state.replace(step=step, opt_state=opt_state)
|
792 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
|
|
793 |
|
794 |
# label smoothed cross entropy
|
795 |
def loss_fn(logits, labels):
|
@@ -837,6 +801,7 @@ def main():
|
|
837 |
p_eval_step = jax.pmap(eval_step, "batch")
|
838 |
|
839 |
# Replicate the train state on each device
|
|
|
840 |
state = state.replicate()
|
841 |
|
842 |
logger.info("***** Running training *****")
|
@@ -976,14 +941,12 @@ def main():
|
|
976 |
step = unreplicate(state.step)
|
977 |
wandb_log({"train/epoch": epoch}, step=step)
|
978 |
|
979 |
-
# Create sampling rng
|
980 |
-
rng, input_rng = jax.random.split(rng)
|
981 |
-
|
982 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
983 |
if data_args.streaming:
|
984 |
train_dataset.set_epoch(epoch)
|
985 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
986 |
else:
|
|
|
987 |
train_loader = data_loader(
|
988 |
input_rng, train_dataset, train_batch_size, shuffle=True
|
989 |
)
|
|
|
23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
24 |
import sys
|
25 |
from dataclasses import dataclass, field
|
|
|
26 |
from pathlib import Path
|
27 |
from typing import Callable, Optional
|
28 |
import json
|
29 |
|
30 |
import datasets
|
31 |
import numpy as np
|
32 |
+
from datasets import Dataset, load_dataset
|
33 |
from tqdm import tqdm
|
34 |
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
37 |
import optax
|
38 |
import transformers
|
|
|
39 |
from flax import jax_utils, traverse_util
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
import flax.linen as nn
|
|
|
43 |
from flax.training import train_state
|
44 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
45 |
from transformers import (
|
|
|
46 |
AutoTokenizer,
|
|
|
47 |
FlaxBartForConditionalGeneration,
|
48 |
HfArgumentParser,
|
49 |
TrainingArguments,
|
50 |
)
|
51 |
from transformers.models.bart.modeling_flax_bart import *
|
|
|
52 |
|
53 |
import wandb
|
54 |
|
|
|
57 |
logger = pylogging.getLogger(__name__)
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
60 |
# Model hyperparameters, for convenience
|
61 |
# TODO: the model has now it's own definition file and should be imported
|
62 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
|
|
78 |
"Don't set if you want to train a model from scratch."
|
79 |
},
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
config_name: Optional[str] = field(
|
82 |
default=None,
|
83 |
metadata={
|
84 |
"help": "Pretrained config name or path if not the same as model_name"
|
85 |
},
|
86 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
use_fast_tokenizer: bool = field(
|
88 |
default=True,
|
89 |
metadata={
|
|
|
259 |
dropout_rng=shard_prng_key(self.dropout_rng)
|
260 |
)
|
261 |
|
262 |
+
def restore_state(self, artifact_dir):
|
263 |
+
# restore optimizer state
|
264 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
265 |
+
new_opt_state = from_bytes(self.opt_state, f.read())
|
266 |
+
|
267 |
+
# restore steps
|
268 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
269 |
+
training_state = json.load(f)
|
270 |
+
new_step = training_state["step"]
|
271 |
+
|
272 |
+
# replace state
|
273 |
+
return self.replace(step=new_step, opt_state=new_opt_state)
|
274 |
+
|
275 |
|
276 |
class CustomFlaxBartModule(FlaxBartModule):
|
277 |
def setup(self):
|
|
|
471 |
streaming=data_args.streaming,
|
472 |
)
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
# Set up wandb run
|
475 |
wandb.init(
|
476 |
entity="dalle-mini",
|
|
|
485 |
artifact_dir = artifact.download()
|
486 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
487 |
|
488 |
+
# load tokenizer
|
489 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
490 |
+
artifact_dir,
|
491 |
+
use_fast=model_args.use_fast_tokenizer,
|
492 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
else:
|
495 |
# Set up our new model config
|
|
|
516 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
517 |
)
|
518 |
|
519 |
+
# Load tokenizer
|
|
|
520 |
tokenizer = AutoTokenizer.from_pretrained(
|
521 |
model_args.model_name_or_path,
|
|
|
522 |
use_fast=model_args.use_fast_tokenizer,
|
523 |
)
|
524 |
|
|
|
571 |
model_inputs["labels"] = labels
|
572 |
|
573 |
# In our case, this prepends the bos token and removes the last one
|
574 |
+
decoder_input_ids = shift_tokens_right(
|
575 |
+
labels, model.config.decoder_start_token_id
|
576 |
+
)
|
577 |
model_inputs["decoder_input_ids"] = decoder_input_ids
|
578 |
|
579 |
return model_inputs
|
|
|
751 |
)
|
752 |
if model_args.from_checkpoint is not None:
|
753 |
# restore optimizer state and step
|
754 |
+
state = state.restore_state(artifact_dir)
|
|
|
755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
756 |
+
# TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
|
757 |
|
758 |
# label smoothed cross entropy
|
759 |
def loss_fn(logits, labels):
|
|
|
801 |
p_eval_step = jax.pmap(eval_step, "batch")
|
802 |
|
803 |
# Replicate the train state on each device
|
804 |
+
del model._params
|
805 |
state = state.replicate()
|
806 |
|
807 |
logger.info("***** Running training *****")
|
|
|
941 |
step = unreplicate(state.step)
|
942 |
wandb_log({"train/epoch": epoch}, step=step)
|
943 |
|
|
|
|
|
|
|
944 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
945 |
if data_args.streaming:
|
946 |
train_dataset.set_epoch(epoch)
|
947 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
948 |
else:
|
949 |
+
rng, input_rng = jax.random.split(rng)
|
950 |
train_loader = data_loader(
|
951 |
input_rng, train_dataset, train_batch_size, shuffle=True
|
952 |
)
|