Try avoid hf hub git rate limits
Browse files- config.gin +2 -3
- config.json +1 -1
- small_nl24_pretrain.gin +2 -3
- start_train.sh +2 -1
- tasks.py +70 -37
- train.py +0 -689
- train/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.0.v2 +0 -3
- train/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.0.v2 +0 -3
- train/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.0.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.0.v2} +2 -2
- training_eval/pretrain_finnish/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.1.v2 +0 -3
- training_eval/pretrain_finnish/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.1.v2 +0 -3
- training_eval/pretrain_finnish/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.1.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.1.v2} +2 -2
config.gin
CHANGED
@@ -12,7 +12,7 @@ import tasks
|
|
12 |
|
13 |
# Macros:
|
14 |
# ==============================================================================
|
15 |
-
BATCH_SIZE =
|
16 |
DROPOUT_RATE = 0.0
|
17 |
LABEL_SMOOTHING = 0.0
|
18 |
LOSS_NORMALIZING_FACTOR = None
|
@@ -23,7 +23,7 @@ MODEL_DIR = '/researchdisk/t5x-small-nl24-finnish'
|
|
23 |
OPTIMIZER = @adafactor.Adafactor()
|
24 |
RANDOM_SEED = None
|
25 |
SHUFFLE_TRAIN_EXAMPLES = True
|
26 |
-
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets':
|
27 |
TRAIN_STEPS = 500000
|
28 |
USE_CACHED_TASKS = False
|
29 |
USE_HARDWARE_RNG = False
|
@@ -123,7 +123,6 @@ network.T5Config.vocab_size = 32128
|
|
123 |
train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
|
124 |
train_script.train.eval_period = 10000
|
125 |
train_script.train.eval_steps = 20
|
126 |
-
train_script.train.hub_model_id = 'Finnish-NLP/t5x-small-nl24-finnish'
|
127 |
train_script.train.infer_eval_dataset_cfg = None
|
128 |
train_script.train.model = %MODEL
|
129 |
train_script.train.model_dir = %MODEL_DIR
|
|
|
12 |
|
13 |
# Macros:
|
14 |
# ==============================================================================
|
15 |
+
BATCH_SIZE = 256
|
16 |
DROPOUT_RATE = 0.0
|
17 |
LABEL_SMOOTHING = 0.0
|
18 |
LOSS_NORMALIZING_FACTOR = None
|
|
|
23 |
OPTIMIZER = @adafactor.Adafactor()
|
24 |
RANDOM_SEED = None
|
25 |
SHUFFLE_TRAIN_EXAMPLES = True
|
26 |
+
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
|
27 |
TRAIN_STEPS = 500000
|
28 |
USE_CACHED_TASKS = False
|
29 |
USE_HARDWARE_RNG = False
|
|
|
123 |
train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
|
124 |
train_script.train.eval_period = 10000
|
125 |
train_script.train.eval_steps = 20
|
|
|
126 |
train_script.train.infer_eval_dataset_cfg = None
|
127 |
train_script.train.model = %MODEL
|
128 |
train_script.train.model_dir = %MODEL_DIR
|
config.json
CHANGED
@@ -7,7 +7,7 @@
|
|
7 |
"d_kv": 64,
|
8 |
"d_model": 512,
|
9 |
"decoder_start_token_id": 0,
|
10 |
-
"dropout_rate": 0.
|
11 |
"eos_token_id": 1,
|
12 |
"feed_forward_proj": "gated-gelu",
|
13 |
"initializer_factor": 1.0,
|
|
|
7 |
"d_kv": 64,
|
8 |
"d_model": 512,
|
9 |
"decoder_start_token_id": 0,
|
10 |
+
"dropout_rate": 0.1,
|
11 |
"eos_token_id": 1,
|
12 |
"feed_forward_proj": "gated-gelu",
|
13 |
"initializer_factor": 1.0,
|
small_nl24_pretrain.gin
CHANGED
@@ -11,7 +11,6 @@ include 't5x/configs/runs/pretrain.gin'
|
|
11 |
# ------------------- Training specification overrides --------------------------
|
12 |
train_script.train:
|
13 |
eval_period = 10000
|
14 |
-
hub_model_id = "Finnish-NLP/t5x-small-nl24-finnish"
|
15 |
|
16 |
utils.SaveCheckpointConfig:
|
17 |
period = 10000
|
@@ -19,7 +18,7 @@ utils.SaveCheckpointConfig:
|
|
19 |
|
20 |
MIXTURE_OR_TASK_NAME = "pretrain_finnish"
|
21 |
USE_CACHED_TASKS = False
|
22 |
-
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets":
|
23 |
TRAIN_STEPS = 500000
|
24 |
DROPOUT_RATE = 0.0
|
25 |
-
BATCH_SIZE =
|
|
|
11 |
# ------------------- Training specification overrides --------------------------
|
12 |
train_script.train:
|
13 |
eval_period = 10000
|
|
|
14 |
|
15 |
utils.SaveCheckpointConfig:
|
16 |
period = 10000
|
|
|
18 |
|
19 |
MIXTURE_OR_TASK_NAME = "pretrain_finnish"
|
20 |
USE_CACHED_TASKS = False
|
21 |
+
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
22 |
TRAIN_STEPS = 500000
|
23 |
DROPOUT_RATE = 0.0
|
24 |
+
BATCH_SIZE = 256
|
start_train.sh
CHANGED
@@ -2,10 +2,11 @@
|
|
2 |
unset LD_PRELOAD
|
3 |
|
4 |
PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
|
|
|
5 |
MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
|
6 |
export PYTHONPATH=${PROJECT_DIR}
|
7 |
|
8 |
-
python3 train.py \
|
9 |
--gin_search_paths=${PROJECT_DIR} \
|
10 |
--gin_file="small_nl24_pretrain.gin" \
|
11 |
--gin.MODEL_DIR=\"${MODEL_DIR}\"
|
|
|
2 |
unset LD_PRELOAD
|
3 |
|
4 |
PROJECT_DIR="/researchdisk/t5x-small-nl24-finnish"
|
5 |
+
T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
|
6 |
MODEL_DIR="/researchdisk/t5x-small-nl24-finnish"
|
7 |
export PYTHONPATH=${PROJECT_DIR}
|
8 |
|
9 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
10 |
--gin_search_paths=${PROJECT_DIR} \
|
11 |
--gin_file="small_nl24_pretrain.gin" \
|
12 |
--gin.MODEL_DIR=\"${MODEL_DIR}\"
|
tasks.py
CHANGED
@@ -1,49 +1,82 @@
|
|
|
|
|
|
1 |
import functools
|
|
|
2 |
import seqio
|
3 |
-
|
|
|
|
|
|
|
4 |
from t5.data import preprocessors
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
}
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
preprocessors=[
|
19 |
functools.partial(
|
20 |
-
|
21 |
-
field_names=["text"],
|
22 |
-
field_delim="\n"),
|
23 |
-
functools.partial(
|
24 |
-
preprocessors.rekey, key_map={
|
25 |
"inputs": None,
|
26 |
-
"targets":
|
27 |
-
}),
|
28 |
seqio.preprocessors.tokenize,
|
29 |
-
seqio.CacheDatasetPlaceholder(),
|
30 |
-
preprocessors.span_corruption,
|
31 |
seqio.preprocessors.append_eos_after_trim,
|
32 |
],
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# dataset = seqio.get_mixture_or_task("pretrain_finnish").get_dataset(
|
37 |
-
# sequence_length={"inputs": 512, "targets": 114},
|
38 |
-
# split="train",
|
39 |
-
# shuffle=True,
|
40 |
-
# num_epochs=1,
|
41 |
-
# #shard_info=seqio.ShardInfo(index=0, num_shards=10),
|
42 |
-
# use_cached=False,
|
43 |
-
# seed=42
|
44 |
-
# )
|
45 |
-
|
46 |
-
|
47 |
-
# # Print the first 5 examples.
|
48 |
-
# for _, ex in zip(range(5), dataset.as_numpy_iterator()):
|
49 |
-
# print(ex)
|
|
|
1 |
+
# adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py
|
2 |
+
|
3 |
import functools
|
4 |
+
|
5 |
import seqio
|
6 |
+
import tensorflow as tf
|
7 |
+
import t5.data
|
8 |
+
from datasets import load_dataset, load_from_disk
|
9 |
+
from t5.data import postprocessors
|
10 |
from t5.data import preprocessors
|
11 |
+
from t5.evaluation import metrics
|
12 |
+
from seqio import FunctionDataSource, utils
|
13 |
|
14 |
+
TaskRegistry = seqio.TaskRegistry
|
15 |
+
|
16 |
+
vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
|
17 |
+
|
18 |
+
DEFAULT_OUTPUT_FEATURES = {
|
19 |
+
"inputs": seqio.Feature(
|
20 |
+
vocabulary=vocabulary, add_eos=True,
|
21 |
+
required=False),
|
22 |
+
"targets": seqio.Feature(
|
23 |
+
vocabulary=vocabulary, add_eos=True)
|
24 |
}
|
25 |
|
26 |
+
|
27 |
+
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
|
28 |
+
if shuffle:
|
29 |
+
if seed:
|
30 |
+
dataset = dataset.shuffle(seed=seed)
|
31 |
+
else:
|
32 |
+
dataset = dataset.shuffle()
|
33 |
+
while True:
|
34 |
+
for item in dataset[str(split)]:
|
35 |
+
yield item[column]
|
36 |
+
|
37 |
+
|
38 |
+
def dataset_fn(split, shuffle_files, seed=None, dataset=None):
|
39 |
+
return tf.data.Dataset.from_generator(
|
40 |
+
functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
|
41 |
+
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
@utils.map_over_dataset
|
46 |
+
def target_to_key(x, key_map, target_key):
|
47 |
+
"""Assign the value from the dataset to target_key in key_map"""
|
48 |
+
return {**key_map, target_key: x}
|
49 |
+
|
50 |
+
|
51 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
52 |
+
dataset_name = "/researchdisk/lm_training_dataset_full"
|
53 |
+
dataset_params = {"from_disk_path": dataset_name}
|
54 |
+
|
55 |
+
if "from_disk_path" in dataset_params:
|
56 |
+
dataset = load_from_disk(dataset_params.get("from_disk_path"))
|
57 |
+
else:
|
58 |
+
dataset = load_dataset(**dataset_params)
|
59 |
+
|
60 |
+
dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
|
61 |
+
TaskRegistry.add(
|
62 |
+
"pretrain_finnish",
|
63 |
+
source=seqio.FunctionDataSource(
|
64 |
+
dataset_fn=functools.partial(dataset_fn, dataset=dataset),
|
65 |
+
splits=("train", "validation"),
|
66 |
+
caching_permitted=False,
|
67 |
+
num_input_examples=dataset_shapes,
|
68 |
+
),
|
69 |
preprocessors=[
|
70 |
functools.partial(
|
71 |
+
target_to_key, key_map={
|
|
|
|
|
|
|
|
|
72 |
"inputs": None,
|
73 |
+
"targets": None,
|
74 |
+
}, target_key="targets"),
|
75 |
seqio.preprocessors.tokenize,
|
76 |
+
# seqio.CacheDatasetPlaceholder(),
|
77 |
+
preprocessors.span_corruption,
|
78 |
seqio.preprocessors.append_eos_after_trim,
|
79 |
],
|
80 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
81 |
+
metric_fns=[metrics.accuracy]
|
82 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,689 +0,0 @@
|
|
1 |
-
# Copyright 2022 The T5X Authors.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
r"""Script to pretrain or finetune in JAX using a SeqIO pipeline.
|
16 |
-
|
17 |
-
"""
|
18 |
-
import functools
|
19 |
-
import itertools
|
20 |
-
import math
|
21 |
-
import os
|
22 |
-
import time
|
23 |
-
from typing import Callable, Iterator, Sequence, Mapping, Tuple, Type, Optional
|
24 |
-
import subprocess
|
25 |
-
|
26 |
-
# Set Linen to add profiling information when constructing Modules.
|
27 |
-
# Must be set before flax imports.
|
28 |
-
# pylint:disable=g-import-not-at-top
|
29 |
-
os.environ['FLAX_PROFILE'] = 'true'
|
30 |
-
# TODO(adarob): Re-enable once users are notified and tests are updated.
|
31 |
-
os.environ['FLAX_LAZY_RNG'] = 'no'
|
32 |
-
from absl import logging
|
33 |
-
from clu import metric_writers
|
34 |
-
import jax
|
35 |
-
from jax import random
|
36 |
-
from jax.experimental import multihost_utils
|
37 |
-
import jax.numpy as jnp
|
38 |
-
import numpy as np
|
39 |
-
import seqio
|
40 |
-
from t5x import models
|
41 |
-
from t5x import partitioning
|
42 |
-
from t5x import train_state as train_state_lib
|
43 |
-
from t5x import trainer as trainer_lib
|
44 |
-
from t5x import utils
|
45 |
-
from t5x import checkpoint_importer
|
46 |
-
LazyArray = checkpoint_importer.LazyArray
|
47 |
-
import tensorflow as tf
|
48 |
-
|
49 |
-
|
50 |
-
# Automatically search for gin files relative to the T5X package.
|
51 |
-
_DEFAULT_GIN_SEARCH_PATHS = [
|
52 |
-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
53 |
-
]
|
54 |
-
PyTreeDef = type(jax.tree_structure(None))
|
55 |
-
P = partitioning.PartitionSpec
|
56 |
-
# Special key that used to distinguish train metrics.
|
57 |
-
TRAIN_METRIC_KEY = 'train'
|
58 |
-
# String keys that is acceptable from config.
|
59 |
-
_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys())
|
60 |
-
|
61 |
-
|
62 |
-
def run_actions(
|
63 |
-
mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType,
|
64 |
-
train_state: train_state_lib.TrainState,
|
65 |
-
metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool:
|
66 |
-
"""Invokes all actions on the given mode on host 0, then broadcasts to all.
|
67 |
-
|
68 |
-
Args:
|
69 |
-
mode: The mode to run the actions. e.g., if mode is `train`, only actions
|
70 |
-
configured to run with `train` mode will be invoked.
|
71 |
-
actions: A mapping of actions that runs after train, eval or infer_eval, to
|
72 |
-
inspect the model and perform useful operations, e.g., early stopping.
|
73 |
-
train_state: The current train_state of the trainer.
|
74 |
-
metrics_by_task: A map of metrics keyed by task name.
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
A bool indicating whether training should be halted.
|
78 |
-
|
79 |
-
Raises:
|
80 |
-
RuntimeError: When the metrics processed on host 0 is None.
|
81 |
-
"""
|
82 |
-
stop_training = False
|
83 |
-
if jax.process_index() == 0:
|
84 |
-
if not metrics_by_task:
|
85 |
-
raise RuntimeError('Metric is unexpectedly empty on process 0')
|
86 |
-
for action in actions.get(mode, []):
|
87 |
-
stop_training |= action.run(train_state, metrics_by_task=metrics_by_task)
|
88 |
-
# Broadcast result from host 0 to others.
|
89 |
-
return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training)))
|
90 |
-
|
91 |
-
|
92 |
-
def train(
|
93 |
-
*,
|
94 |
-
model: models.BaseTransformerModel,
|
95 |
-
train_dataset_cfg: utils.DatasetConfig,
|
96 |
-
train_eval_dataset_cfg: Optional[utils.DatasetConfig],
|
97 |
-
infer_eval_dataset_cfg: Optional[utils.DatasetConfig],
|
98 |
-
checkpoint_cfg: utils.CheckpointConfig,
|
99 |
-
partitioner: partitioning.BasePartitioner,
|
100 |
-
trainer_cls: Type[trainer_lib.BaseTrainer],
|
101 |
-
model_dir: str,
|
102 |
-
total_steps: int,
|
103 |
-
eval_steps: int,
|
104 |
-
eval_period: int,
|
105 |
-
stats_period: Optional[int] = None,
|
106 |
-
random_seed: Optional[int],
|
107 |
-
use_hardware_rng: bool = False,
|
108 |
-
summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int],
|
109 |
-
None],
|
110 |
-
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
|
111 |
-
get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset,
|
112 |
-
concurrent_metrics: bool = True,
|
113 |
-
actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None,
|
114 |
-
train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None,
|
115 |
-
run_eval_before_training: bool = False,
|
116 |
-
hub_model_id: str = None,
|
117 |
-
) -> Tuple[int, train_state_lib.TrainState]:
|
118 |
-
"""Train function.
|
119 |
-
|
120 |
-
Args:
|
121 |
-
model: The model object to use for training.
|
122 |
-
train_dataset_cfg: Specification for the dataset to train with.
|
123 |
-
train_eval_dataset_cfg: Specification for the dataset to evaluate with using
|
124 |
-
the train metrics and no inference (e.g., uses teacher forcing). If None,
|
125 |
-
train eval is disabled.
|
126 |
-
infer_eval_dataset_cfg: Specification for the dataset to evaluate with using
|
127 |
-
the inference metrics (e.g., uses sampled decoding). If None, inference
|
128 |
-
eval is disabled.
|
129 |
-
checkpoint_cfg: Specification for saving and restoring model parameters and
|
130 |
-
dataset state to/from checkpoints.
|
131 |
-
partitioner: Partitioner for model parameters and data across devices.
|
132 |
-
trainer_cls: An implementation of BaseTrainer.
|
133 |
-
model_dir: Path of directory to store checkpoints and metric summaries.
|
134 |
-
total_steps: The step number to stop training after. The number of actual
|
135 |
-
steps trained in this run will be this number minus the starting step from
|
136 |
-
the checkpoint.
|
137 |
-
eval_steps: The number of batches to process for each train-eval loop.
|
138 |
-
eval_period: The number of train steps between each evaluation (both
|
139 |
-
train-eval and infer-eval).
|
140 |
-
stats_period: The number of train steps between writing scalar stats. If
|
141 |
-
None, defaults to eval_period.
|
142 |
-
random_seed: A random seed to use for dropout and initialization. If None, a
|
143 |
-
fast, non-deterministic hardware-based RNG is used.
|
144 |
-
use_hardware_rng: Whether to force using the RngBitGenerator based hardware
|
145 |
-
rng, which takes seeds and acts similarly to software PRNG in that it
|
146 |
-
should be seed-deterministic. The new RngBitGenerator custom PRNG system
|
147 |
-
should be reproducible for a given sharding, but the numbers will change
|
148 |
-
for different shardings of the same model.
|
149 |
-
summarize_config_fn: A function that takes in the model directory, a
|
150 |
-
SummaryWriter, and the step number, and writes a summary of the
|
151 |
-
inference_evaluator_cls: seqio.Evaluator class to use for inference
|
152 |
-
evaluation, potentially with bound configuration args.
|
153 |
-
get_dataset_fn: The callable use to get the train and train-eval datasets
|
154 |
-
based on the DatasetConfig and shard information.
|
155 |
-
concurrent_metrics: If True, allow metrics computation and logging to
|
156 |
-
overlap with training. Will likely result in additional TPU memory usage.
|
157 |
-
actions: A mapping of actions that runs after train, eval or infer_eval, to
|
158 |
-
inspect the model and perform useful operations, e.g., early stopping. The
|
159 |
-
key must have a 1:1 mapping to ActionMode enum. For EVAL actions to
|
160 |
-
actually work, this requires `concurrent_metrics` to be turned off,
|
161 |
-
since chaining futures and mutating states concurrently might be
|
162 |
-
error-prone.
|
163 |
-
train_eval_get_dataset_fn: Optional callable use to get the train-eval
|
164 |
-
datasets based on the DatasetConfig and shard information. If missing, it
|
165 |
-
defaults to `get_dataset_fn`.
|
166 |
-
run_eval_before_training: If True, calculate training eval and inference
|
167 |
-
eval metrics before training begins.
|
168 |
-
|
169 |
-
Returns:
|
170 |
-
The tuple of (last_step, last_train_state).
|
171 |
-
"""
|
172 |
-
logging.info('Process ID: %d', jax.process_index())
|
173 |
-
tf.io.gfile.makedirs(model_dir)
|
174 |
-
|
175 |
-
# Each "epoch" of the training loop should be the min of the eval period,
|
176 |
-
# checkpoint period or the full training.
|
177 |
-
# We compute here to ensure that the eval period and checkpoint period are
|
178 |
-
# divisible by this number, otherwise we fail.
|
179 |
-
eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg)
|
180 |
-
eval_period = eval_period if eval_enabled else 0
|
181 |
-
checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0
|
182 |
-
if eval_period or checkpoint_period:
|
183 |
-
steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf)
|
184 |
-
else:
|
185 |
-
steps_per_epoch = total_steps
|
186 |
-
stats_period = stats_period or steps_per_epoch
|
187 |
-
if (eval_period and eval_period % steps_per_epoch or
|
188 |
-
checkpoint_period and checkpoint_period % steps_per_epoch):
|
189 |
-
raise ValueError(
|
190 |
-
f'Checkpoint period ({checkpoint_period}) must evenly divide eval '
|
191 |
-
f'period ({eval_period}), or vice-versa.')
|
192 |
-
|
193 |
-
if use_hardware_rng or random_seed is None:
|
194 |
-
logging.info(
|
195 |
-
'Using fast RngBitGenerator PRNG for initialization and dropout.')
|
196 |
-
|
197 |
-
if random_seed is None:
|
198 |
-
random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
|
199 |
-
logging.info('Random seed not provided, using RNG seed %s', random_seed)
|
200 |
-
else:
|
201 |
-
logging.warning(
|
202 |
-
'When using hardware RNG with a fixed seed, repeatability is only '
|
203 |
-
'guaranteed for fixed hardware and partitioning schemes and for a '
|
204 |
-
'fixed version of this code and its dependencies.')
|
205 |
-
utils.set_hardware_rng_ops()
|
206 |
-
rng = random.PRNGKey(random_seed)
|
207 |
-
else:
|
208 |
-
logging.info('Using seed for initialization and dropout RNG: %d',
|
209 |
-
random_seed)
|
210 |
-
rng = random.PRNGKey(random_seed)
|
211 |
-
|
212 |
-
init_rng, trainer_rng = random.split(rng, 2)
|
213 |
-
|
214 |
-
# ---------------------------------------------------------------------------
|
215 |
-
# Initialize datasets
|
216 |
-
# ---------------------------------------------------------------------------
|
217 |
-
|
218 |
-
if (train_dataset_cfg.seed and
|
219 |
-
not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)):
|
220 |
-
logging.warning(
|
221 |
-
'Providing a random seed for the train dataset with '
|
222 |
-
'`checkpoint_train_ds=False` is dangerous since each '
|
223 |
-
'preemption/restart will cause the dataset to deterministically replay '
|
224 |
-
'from the beginning.')
|
225 |
-
|
226 |
-
data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size)
|
227 |
-
ds_shard_id = data_layout.shard_id
|
228 |
-
num_ds_shards = data_layout.num_shards
|
229 |
-
|
230 |
-
def _verify_matching_vocabs(cfg: utils.DatasetConfig):
|
231 |
-
ds_vocabs = utils.get_vocabulary(cfg)
|
232 |
-
if (ds_vocabs[0] != model.input_vocabulary or
|
233 |
-
ds_vocabs[1] != model.output_vocabulary):
|
234 |
-
raise ValueError(f'Model and Task vocabularies do not match:\n'
|
235 |
-
f' task={cfg.mixture_or_task_name}\n'
|
236 |
-
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
|
237 |
-
f' model.input_vocabulary={model.input_vocabulary}\n'
|
238 |
-
f' model.output_vocabulary={model.output_vocabulary}\n')
|
239 |
-
|
240 |
-
_verify_matching_vocabs(train_dataset_cfg)
|
241 |
-
|
242 |
-
train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
|
243 |
-
model.FEATURE_CONVERTER_CLS)
|
244 |
-
|
245 |
-
if train_eval_dataset_cfg:
|
246 |
-
_verify_matching_vocabs(train_eval_dataset_cfg)
|
247 |
-
train_eval_datasets = utils.get_training_eval_datasets(
|
248 |
-
train_eval_dataset_cfg,
|
249 |
-
ds_shard_id,
|
250 |
-
num_ds_shards,
|
251 |
-
eval_steps,
|
252 |
-
model.FEATURE_CONVERTER_CLS,
|
253 |
-
get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn
|
254 |
-
is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset]
|
255 |
-
if not train_eval_datasets:
|
256 |
-
logging.warning(
|
257 |
-
'No train_eval datasets loaded from config `train_eval_dataset_cfg`: '
|
258 |
-
'%s', train_eval_dataset_cfg)
|
259 |
-
else:
|
260 |
-
train_eval_datasets = {}
|
261 |
-
|
262 |
-
# Initialize optimizer, maybe from an existing checkpoint.
|
263 |
-
checkpointable_train_iter: tf.data.Iterator = iter(train_ds) # pytype:disable=annotation-type-mismatch
|
264 |
-
train_iter: Iterator[trainer_lib.BatchType] = map(
|
265 |
-
lambda x: jax.tree_map(np.array, x), checkpointable_train_iter)
|
266 |
-
|
267 |
-
# The manner in which parameters are initialized follows this order of
|
268 |
-
# preference:
|
269 |
-
# 1. From a T5X checkpoint in `model_dir`, if one exists.
|
270 |
-
# 2. From a T5X or TF checkpoint specified by `cfg.path`, if set.
|
271 |
-
# 3. From scratch using `init_fn`.
|
272 |
-
|
273 |
-
# 1. From a T5X checkpoint in `model_dir`, if one exists.
|
274 |
-
if checkpoint_cfg.restore is not None:
|
275 |
-
state_transforms_for_restore = [
|
276 |
-
functools.partial(fn, is_resuming=True)
|
277 |
-
for fn in checkpoint_cfg.restore.state_transformation_fns
|
278 |
-
]
|
279 |
-
else:
|
280 |
-
state_transforms_for_restore = []
|
281 |
-
restore_cfgs = [
|
282 |
-
utils.RestoreCheckpointConfig(
|
283 |
-
path=model_dir,
|
284 |
-
mode='latest',
|
285 |
-
dtype=checkpoint_cfg.save.dtype,
|
286 |
-
checkpointer_cls=checkpoint_cfg.save.checkpointer_cls,
|
287 |
-
# Restore dataset state if it is being saved.
|
288 |
-
restore_dataset=(checkpoint_cfg.save and
|
289 |
-
checkpoint_cfg.save.save_dataset),
|
290 |
-
state_transformation_fns=state_transforms_for_restore)
|
291 |
-
]
|
292 |
-
# 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set.
|
293 |
-
if checkpoint_cfg.restore:
|
294 |
-
if checkpoint_cfg.restore.mode == 'all':
|
295 |
-
raise ValueError(
|
296 |
-
"Restore checkpoint mode 'all' is not supported in training.")
|
297 |
-
|
298 |
-
# TODO(dhgarrette): Split "restore" behavior into separate configurations
|
299 |
-
# for the initial restoration for a new run, vs resuming a stopped run.
|
300 |
-
if isinstance(checkpoint_cfg.restore.path, str):
|
301 |
-
restore_cfgs.append(checkpoint_cfg.restore)
|
302 |
-
elif not checkpoint_cfg.restore.path:
|
303 |
-
# `path` is an empty (non-`str`) sequence, so there is nothing to restore.
|
304 |
-
pass
|
305 |
-
else:
|
306 |
-
raise ValueError(
|
307 |
-
'Restore checkpoint config may only have a single path in training.')
|
308 |
-
|
309 |
-
# Need to use full batch size.
|
310 |
-
input_shapes = {
|
311 |
-
k: (data_layout.batch_size, *v.shape[1:])
|
312 |
-
for k, v in train_ds.element_spec.items()
|
313 |
-
}
|
314 |
-
input_types = {
|
315 |
-
k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
|
316 |
-
}
|
317 |
-
init_or_restore_tick = time.time()
|
318 |
-
train_state_initializer = utils.TrainStateInitializer(
|
319 |
-
optimizer_def=model.optimizer_def,
|
320 |
-
init_fn=model.get_initial_variables,
|
321 |
-
input_shapes=input_shapes,
|
322 |
-
input_types=input_types,
|
323 |
-
partitioner=partitioner)
|
324 |
-
# 3. From scratch using `init_fn`.
|
325 |
-
train_state = train_state_initializer.from_checkpoint_or_scratch(
|
326 |
-
restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
|
327 |
-
train_state_axes = train_state_initializer.train_state_axes
|
328 |
-
init_or_restore_secs = time.time() - init_or_restore_tick
|
329 |
-
logging.info('Initialize/restore complete (%.2f seconds).',
|
330 |
-
init_or_restore_secs)
|
331 |
-
|
332 |
-
# Log the variable shapes information and write to a file.
|
333 |
-
log_file = os.path.join(model_dir, 'model-info.txt')
|
334 |
-
utils.log_model_info(log_file,
|
335 |
-
train_state_initializer.global_train_state_shape,
|
336 |
-
partitioner)
|
337 |
-
|
338 |
-
if checkpoint_period:
|
339 |
-
checkpointer = checkpoint_cfg.save.checkpointer_cls(
|
340 |
-
train_state=train_state_initializer.global_train_state_shape,
|
341 |
-
partitioner=partitioner,
|
342 |
-
checkpoints_dir=model_dir,
|
343 |
-
dataset_iterator=(checkpointable_train_iter
|
344 |
-
if checkpoint_cfg.save.save_dataset else None),
|
345 |
-
save_dtype=checkpoint_cfg.save.dtype,
|
346 |
-
keep=checkpoint_cfg.save.keep)
|
347 |
-
|
348 |
-
|
349 |
-
# Restore step from last checkpoint or set to 0 if training from scratch.
|
350 |
-
host_step = int(train_state.step)
|
351 |
-
|
352 |
-
# ---------------------------------------------------------------------------
|
353 |
-
# Trainer
|
354 |
-
# ---------------------------------------------------------------------------
|
355 |
-
|
356 |
-
trainer: trainer_lib.BaseTrainer = trainer_cls(
|
357 |
-
model=model,
|
358 |
-
train_state=train_state,
|
359 |
-
partitioner=partitioner,
|
360 |
-
train_state_axes=train_state_axes,
|
361 |
-
eval_names=train_eval_datasets.keys(),
|
362 |
-
summary_dir=model_dir,
|
363 |
-
rng=trainer_rng)
|
364 |
-
del train_state
|
365 |
-
|
366 |
-
train_metrics = trainer.train_metrics_manager
|
367 |
-
summarize_config_fn(model_dir, train_metrics.summary_writer, host_step)
|
368 |
-
|
369 |
-
train_metrics.write_scalar('timing/init_or_restore_seconds',
|
370 |
-
init_or_restore_secs, host_step)
|
371 |
-
|
372 |
-
# ----------------------------------------------------------------------------
|
373 |
-
# SeqIO (inference-based) evaluation setup
|
374 |
-
# ----------------------------------------------------------------------------
|
375 |
-
# Init evaluator to set up cached datasets
|
376 |
-
evaluator = None
|
377 |
-
if infer_eval_dataset_cfg is not None:
|
378 |
-
_verify_matching_vocabs(infer_eval_dataset_cfg)
|
379 |
-
evaluator = inference_evaluator_cls(
|
380 |
-
log_dir=os.path.join(model_dir, 'inference_eval'),
|
381 |
-
mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name,
|
382 |
-
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
|
383 |
-
eval_split=infer_eval_dataset_cfg.split,
|
384 |
-
use_cached=infer_eval_dataset_cfg.use_cached,
|
385 |
-
seed=infer_eval_dataset_cfg.seed,
|
386 |
-
sequence_length=infer_eval_dataset_cfg.task_feature_lengths,
|
387 |
-
use_memory_cache=infer_eval_dataset_cfg.use_memory_cache)
|
388 |
-
if not evaluator.eval_tasks:
|
389 |
-
# Skip evaluaton.
|
390 |
-
evaluator = None
|
391 |
-
|
392 |
-
if evaluator is not None:
|
393 |
-
predict_fn = utils.get_infer_fn(
|
394 |
-
infer_step=model.predict_batch,
|
395 |
-
batch_size=infer_eval_dataset_cfg.batch_size,
|
396 |
-
train_state_axes=train_state_axes,
|
397 |
-
partitioner=partitioner)
|
398 |
-
|
399 |
-
score_fn = utils.get_infer_fn(
|
400 |
-
infer_step=model.score_batch,
|
401 |
-
batch_size=infer_eval_dataset_cfg.batch_size,
|
402 |
-
train_state_axes=train_state_axes,
|
403 |
-
partitioner=partitioner)
|
404 |
-
|
405 |
-
if actions is None:
|
406 |
-
actions = {}
|
407 |
-
|
408 |
-
if set(actions.keys()).difference(_ACTION_KEYS):
|
409 |
-
raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : '
|
410 |
-
f'{actions.keys()}')
|
411 |
-
|
412 |
-
# Transform the string key into proper ActionMode enum.
|
413 |
-
actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()}
|
414 |
-
|
415 |
-
if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL,
|
416 |
-
None) is not None:
|
417 |
-
logging.warning('Actions for INFER_EVAL will not be triggered when async '
|
418 |
-
'metrics computation is enabled')
|
419 |
-
if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN,
|
420 |
-
None) is not None:
|
421 |
-
logging.warning('Actions for TRAIN will not be triggered when async '
|
422 |
-
'metrics computation is enabled')
|
423 |
-
|
424 |
-
# ----------------------------------------------------------------------------
|
425 |
-
# Setup Eval Utility Functions
|
426 |
-
# ----------------------------------------------------------------------------
|
427 |
-
def _run_training_eval(first_run: bool = False):
|
428 |
-
if first_run:
|
429 |
-
logging.info('Compiling training eval loop.')
|
430 |
-
trainer.compile_eval({
|
431 |
-
task: utils.get_zeros_batch_like_dataset(ds)
|
432 |
-
for task, ds in train_eval_datasets.items()
|
433 |
-
})
|
434 |
-
logging.info('Computing training evaluation metrics.')
|
435 |
-
eval_batch_iters = {
|
436 |
-
task: ds.as_numpy_iterator()
|
437 |
-
for task, ds in train_eval_datasets.items()
|
438 |
-
}
|
439 |
-
eval_summaries = trainer.eval(eval_batch_iters)
|
440 |
-
trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL,
|
441 |
-
actions, trainer.train_state,
|
442 |
-
eval_summaries)
|
443 |
-
|
444 |
-
def _run_inference_eval():
|
445 |
-
"""Run prediction based inference eval."""
|
446 |
-
if evaluator is None:
|
447 |
-
return
|
448 |
-
logging.info('Running inference evaluation.')
|
449 |
-
evaluate_tick = time.time()
|
450 |
-
all_metrics, _, _ = evaluator.evaluate(
|
451 |
-
compute_metrics=jax.process_index() == 0,
|
452 |
-
step=host_step,
|
453 |
-
predict_fn=functools.partial(
|
454 |
-
predict_fn,
|
455 |
-
train_state=trainer.train_state,
|
456 |
-
rng=jax.random.PRNGKey(0)),
|
457 |
-
score_fn=functools.partial(score_fn, train_state=trainer.train_state))
|
458 |
-
if not concurrent_metrics:
|
459 |
-
# Ensure metrics are finished being computed.
|
460 |
-
all_metrics_done = all_metrics.result() or {}
|
461 |
-
trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL,
|
462 |
-
actions, trainer.train_state,
|
463 |
-
all_metrics_done)
|
464 |
-
train_metrics.write_scalar('timing/evaluate_seconds',
|
465 |
-
time.time() - evaluate_tick, host_step)
|
466 |
-
|
467 |
-
# Optionally run teacher-forcing training eval and SeqIO inference-base eval
|
468 |
-
# before training. Useful for testing how much a model knows before any
|
469 |
-
# finetuning.
|
470 |
-
if run_eval_before_training:
|
471 |
-
if train_eval_datasets:
|
472 |
-
logging.info('Running training eval before training.')
|
473 |
-
_run_training_eval(first_run=True)
|
474 |
-
if evaluator is not None:
|
475 |
-
logging.info('Running inference eval before training.')
|
476 |
-
_run_inference_eval()
|
477 |
-
|
478 |
-
# ----------------------------------------------------------------------------
|
479 |
-
# Main training loop
|
480 |
-
# ----------------------------------------------------------------------------
|
481 |
-
logging.info('Starting training loop.')
|
482 |
-
|
483 |
-
first_step = host_step
|
484 |
-
|
485 |
-
if total_steps < first_step:
|
486 |
-
raise ValueError(
|
487 |
-
f'Unexpected total_steps ({total_steps}) < checkpoint step '
|
488 |
-
f' ({first_step}).')
|
489 |
-
|
490 |
-
logging.info('Starting main loop over steps %d-%d', first_step, total_steps)
|
491 |
-
|
492 |
-
steps_per_epoch = min(steps_per_epoch, total_steps)
|
493 |
-
first_epoch = first_step // steps_per_epoch
|
494 |
-
num_epochs = first_epoch + math.ceil(
|
495 |
-
(total_steps - first_step) / steps_per_epoch)
|
496 |
-
logging.info('Training with artificial "epochs" of %d steps.',
|
497 |
-
steps_per_epoch)
|
498 |
-
|
499 |
-
# Kickstart training dataset and compile train loop.
|
500 |
-
logging.info('Kickstarting train dataset prefetch.')
|
501 |
-
logging.flush()
|
502 |
-
|
503 |
-
ds_tick = time.time()
|
504 |
-
# Get first batch to warm up the dataset pipeline.
|
505 |
-
first_batch = next(train_iter)
|
506 |
-
# Prepend first batch back to iterator to be used by trainer.
|
507 |
-
train_iter = itertools.chain([first_batch], train_iter)
|
508 |
-
train_metrics.write_scalar('timing/dataset_warmup_seconds',
|
509 |
-
time.time() - ds_tick, host_step)
|
510 |
-
logging.info('Compiling train loop.')
|
511 |
-
logging.flush()
|
512 |
-
trainer.compile_train(first_batch)
|
513 |
-
|
514 |
-
# Main Loop over "epochs".
|
515 |
-
for epoch in range(first_epoch, num_epochs):
|
516 |
-
final_epoch = epoch == num_epochs - 1
|
517 |
-
logging.info('Epoch %d of %d', epoch, num_epochs)
|
518 |
-
|
519 |
-
# `stop_training` is requested, break out the main loop immediately.
|
520 |
-
if trainer.stop_training:
|
521 |
-
break
|
522 |
-
|
523 |
-
logging.info('BEGIN Train loop.')
|
524 |
-
try:
|
525 |
-
# Until the last epoch, `num_steps = steps_per_epoch`
|
526 |
-
num_steps = min(total_steps - host_step, steps_per_epoch)
|
527 |
-
epoch_end_step = host_step + num_steps
|
528 |
-
logging.info('Training for %d steps.', num_steps)
|
529 |
-
while host_step < epoch_end_step:
|
530 |
-
if trainer.stop_training:
|
531 |
-
logging.info('Saving a checkpoint before early stopping...')
|
532 |
-
checkpointer.save(trainer.train_state,
|
533 |
-
checkpoint_cfg.save.state_transformation_fns)
|
534 |
-
|
535 |
-
if hub_model_id:
|
536 |
-
# convert checkpoint to HF Flax model and push to hub
|
537 |
-
checkpoint_step = trainer.train_state.step
|
538 |
-
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
|
539 |
-
checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
|
540 |
-
config_path = os.path.join(model_dir, 'config.json')
|
541 |
-
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
|
542 |
-
subprocess.run("git lfs prune --verify-remote", shell=True)
|
543 |
-
subprocess.run("git add .", shell=True)
|
544 |
-
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
|
545 |
-
subprocess.Popen("git push", shell=True)
|
546 |
-
|
547 |
-
logging.info('Stopping training loop early since `stop_training` is '
|
548 |
-
'requested.')
|
549 |
-
break
|
550 |
-
|
551 |
-
inner_num_steps = min(epoch_end_step - host_step, stats_period)
|
552 |
-
train_summary = trainer.train(
|
553 |
-
train_iter, inner_num_steps, start_step=host_step)
|
554 |
-
if not concurrent_metrics:
|
555 |
-
# Note that we always pass the dictionary of `tasks` -> summary so
|
556 |
-
# that the actions can be performed without special casing. The only
|
557 |
-
# caveat is that train would need its own special `key` given no
|
558 |
-
# `task` will be applied.
|
559 |
-
trainer.stop_training = run_actions(
|
560 |
-
trainer_lib.ActionMode.TRAIN, actions, trainer.train_state,
|
561 |
-
{TRAIN_METRIC_KEY: train_summary.result()})
|
562 |
-
|
563 |
-
host_step += inner_num_steps
|
564 |
-
logging.info('END Train loop.')
|
565 |
-
except trainer_lib.PreemptionError as e:
|
566 |
-
logging.info('Saving emergency checkpoint.')
|
567 |
-
checkpointer.save(trainer.train_state,
|
568 |
-
checkpoint_cfg.save.state_transformation_fns)
|
569 |
-
logging.info('Saving emergency checkpoint done.')
|
570 |
-
raise e
|
571 |
-
|
572 |
-
step_offset = host_step - first_step
|
573 |
-
|
574 |
-
is_eval_epoch = eval_period and (final_epoch or
|
575 |
-
step_offset % eval_period == 0)
|
576 |
-
|
577 |
-
# Training Evaluation (i.e., with teacher forcing).
|
578 |
-
if is_eval_epoch and train_eval_datasets:
|
579 |
-
# Maybe less if final step < period.
|
580 |
-
first_run = step_offset // eval_period <= 1
|
581 |
-
_run_training_eval(first_run and not run_eval_before_training)
|
582 |
-
|
583 |
-
# Maybe save a checkpoint.
|
584 |
-
if checkpoint_period and (final_epoch or
|
585 |
-
step_offset % checkpoint_period == 0):
|
586 |
-
# Make sure last train step has completed before starting the clock.
|
587 |
-
train_summary.result()
|
588 |
-
logging.info('Saving checkpoint.')
|
589 |
-
checkpoint_tick = time.time()
|
590 |
-
checkpointer.save(trainer.train_state,
|
591 |
-
checkpoint_cfg.save.state_transformation_fns)
|
592 |
-
checkpoint_tock = time.time()
|
593 |
-
train_metrics.write_scalar('timing/checkpoint_seconds',
|
594 |
-
checkpoint_tock - checkpoint_tick, host_step)
|
595 |
-
|
596 |
-
if hub_model_id:
|
597 |
-
# convert checkpoint to HF Flax model and push to hub
|
598 |
-
checkpoint_step = trainer.train_state.step
|
599 |
-
checkpoint_step = checkpoint_step.get() if isinstance(checkpoint_step, LazyArray) else checkpoint_step
|
600 |
-
checkpoint_step = int(checkpoint_step) # Integer, to avoid side effects in the checkpoint path.
|
601 |
-
config_path = os.path.join(model_dir, 'config.json')
|
602 |
-
subprocess.run(["python3", "convert_t5x_checkpoint_to_flax.py", f"--t5x_checkpoint_path='checkpoint_{checkpoint_step}'/'", f'--config_name="{config_path}"', "--flax_dump_folder_path='./'"])
|
603 |
-
subprocess.run("git lfs prune --verify-remote", shell=True)
|
604 |
-
subprocess.run("git add .", shell=True)
|
605 |
-
subprocess.run(f'git commit -m "Saving weights and logs of step {checkpoint_step}"', shell=True)
|
606 |
-
subprocess.Popen("git push", shell=True)
|
607 |
-
|
608 |
-
# Inference Evaluation (i.e., with decoding or scoring).
|
609 |
-
if evaluator is not None:
|
610 |
-
_run_inference_eval()
|
611 |
-
|
612 |
-
# Wait until computations are done before exiting
|
613 |
-
logging.info('Finished.')
|
614 |
-
trainer.close()
|
615 |
-
if evaluator:
|
616 |
-
evaluator.close()
|
617 |
-
multihost_utils.sync_global_devices('complete')
|
618 |
-
|
619 |
-
return host_step, trainer.train_state
|
620 |
-
|
621 |
-
|
622 |
-
if __name__ == '__main__':
|
623 |
-
# pylint: disable=g-import-not-at-top
|
624 |
-
from absl import app
|
625 |
-
from absl import flags
|
626 |
-
import gin
|
627 |
-
from t5x import gin_utils
|
628 |
-
# pylint: enable=g-import-not-at-top
|
629 |
-
|
630 |
-
FLAGS = flags.FLAGS
|
631 |
-
|
632 |
-
jax.config.parse_flags_with_absl()
|
633 |
-
|
634 |
-
flags.DEFINE_multi_string(
|
635 |
-
'gin_file',
|
636 |
-
default=None,
|
637 |
-
help='Path to gin configuration file. Multiple paths may be passed and '
|
638 |
-
'will be imported in the given order, with later configurations '
|
639 |
-
'overriding earlier ones.')
|
640 |
-
|
641 |
-
flags.DEFINE_multi_string(
|
642 |
-
'gin_bindings', default=[], help='Individual gin bindings.')
|
643 |
-
|
644 |
-
flags.DEFINE_list(
|
645 |
-
'gin_search_paths',
|
646 |
-
default=['.'],
|
647 |
-
help='Comma-separated list of gin config path prefixes to be prepended '
|
648 |
-
'to suffixes given via `--gin_file`. If a file appears in. Only the '
|
649 |
-
'first prefix that produces a valid path for each suffix will be '
|
650 |
-
'used.')
|
651 |
-
|
652 |
-
flags.DEFINE_string(
|
653 |
-
'tfds_data_dir', None,
|
654 |
-
'If set, this directory will be used to store datasets prepared by '
|
655 |
-
'TensorFlow Datasets that are not available in the public TFDS GCS '
|
656 |
-
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
|
657 |
-
'all `Task`s.')
|
658 |
-
|
659 |
-
flags.DEFINE_list(
|
660 |
-
'seqio_additional_cache_dirs', [],
|
661 |
-
'Directories to search for cached Tasks in addition to defaults.')
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
def main(argv: Sequence[str]):
|
666 |
-
"""Wrapper for pdb post mortems."""
|
667 |
-
_main(argv)
|
668 |
-
|
669 |
-
def _main(argv: Sequence[str]):
|
670 |
-
"""True main function."""
|
671 |
-
if len(argv) > 1:
|
672 |
-
raise app.UsageError('Too many command-line arguments.')
|
673 |
-
|
674 |
-
if FLAGS.tfds_data_dir:
|
675 |
-
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
|
676 |
-
|
677 |
-
seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs)
|
678 |
-
|
679 |
-
# Create gin-configurable version of `train`.
|
680 |
-
train_using_gin = gin.configurable(train)
|
681 |
-
|
682 |
-
gin_utils.parse_gin_flags(
|
683 |
-
# User-provided gin paths take precedence if relative paths conflict.
|
684 |
-
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
|
685 |
-
FLAGS.gin_file,
|
686 |
-
FLAGS.gin_bindings)
|
687 |
-
train_using_gin()
|
688 |
-
|
689 |
-
gin_utils.run(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.0.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:224a0411c5fc4e0e882c7a647ff554b58fec3f79dc12f9809b26b3a319225c1d
|
3 |
-
size 7585
|
|
|
|
|
|
|
|
train/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.0.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d5a8424af960443ad6fbb097216b8add4fb9af5298424f440afb56a27ee260b9
|
3 |
-
size 16363
|
|
|
|
|
|
|
|
train/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.0.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.0.v2}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59fc20fda0f88a5e31f18b0ebc9497d4131b6458ac28d6d2a52875d1ba5c5b13
|
3 |
+
size 10402
|
training_eval/pretrain_finnish/events.out.tfevents.1649073594.t1v-n-304587cf-w-0.1316481.1.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7a4fbca2952ba8dbad98a5e752b6bcf0e80ea9d1376380bbbbd669b7fb0897e7
|
3 |
-
size 1431
|
|
|
|
|
|
|
|
training_eval/pretrain_finnish/events.out.tfevents.1649092520.t1v-n-304587cf-w-0.1399566.1.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b3efa0a2b6af0ef032441aaf0e97ac69bab14c84b90bbbfa22001dc094926fb2
|
3 |
-
size 9261
|
|
|
|
|
|
|
|
training_eval/pretrain_finnish/{events.out.tfevents.1649056216.t1v-n-304587cf-w-0.1239745.1.v2 → events.out.tfevents.1649705066.t1v-n-304587cf-w-0.2549834.1.v2}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ec5c452edec8f5036cab8e5f3d67f492e55a15ff004d4ee5847b2d2cd56f2df
|
3 |
+
size 4024
|