Spaces:
Running
Running
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Team All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Fine-tuning the library models for seq2seq, text to image. | |
Script adapted from run_summarization_flax.py | |
""" | |
import json | |
import logging | |
import os | |
import sys | |
import time | |
from dataclasses import asdict, dataclass, field | |
from pathlib import Path | |
from typing import Callable, Optional | |
import datasets | |
import jax | |
import jax.numpy as jnp | |
import optax | |
import transformers | |
import wandb | |
from datasets import Dataset | |
from flax import jax_utils, traverse_util | |
from flax.jax_utils import unreplicate | |
from flax.serialization import from_bytes, to_bytes | |
from flax.training import train_state | |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, HfArgumentParser | |
from dalle_mini.data import Dataset | |
from dalle_mini.model import DalleBartConfig, DalleBart | |
logger = logging.getLogger(__name__) | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
model_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "The model checkpoint for weights initialization." | |
"Don't set if you want to train a model from scratch." | |
}, | |
) | |
config_name: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "Pretrained config name or path if not the same as model_name" | |
}, | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path" | |
}, | |
) | |
dtype: Optional[str] = field( | |
default="float32", | |
metadata={ | |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." | |
}, | |
) | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
text_column: Optional[str] = field( | |
default="caption", | |
metadata={ | |
"help": "The name of the column in the datasets containing the full texts (for summarization)." | |
}, | |
) | |
encoding_column: Optional[str] = field( | |
default="encoding", | |
metadata={ | |
"help": "The name of the column in the datasets containing the image encodings." | |
}, | |
) | |
dataset_repo_or_path: str = field( | |
default=None, | |
metadata={"help": "The dataset repository containing encoded files."}, | |
) | |
train_file: Optional[str] = field( | |
default=None, | |
metadata={"help": "The input training data file (glob acceptable)."}, | |
) | |
validation_file: Optional[str] = field( | |
default=None, | |
metadata={"help": "An optional input evaluation data file (glob acceptable)."}, | |
) | |
# data loading should not be a bottleneck so we use "streaming" mode by default | |
streaming: bool = field( | |
default=True, | |
metadata={"help": "Whether to stream the dataset."}, | |
) | |
use_auth_token: bool = field( | |
default=False, | |
metadata={ | |
"help": "Whether to use the authentication token for private datasets." | |
}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
}, | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "The number of processes to use for the preprocessing. Not used in streaming mode." | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, | |
metadata={ | |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode." | |
}, | |
) | |
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch | |
seed_dataset: int = field( | |
default=None, | |
metadata={ | |
"help": "Random seed for the dataset that will be set at the beginning of training." | |
}, | |
) | |
def __post_init__(self): | |
if self.dataset_repo_or_path is None: | |
raise ValueError("Need a dataset repository or path.") | |
class TrainingArguments: | |
""" | |
Arguments pertaining to training parameters. | |
""" | |
output_dir: str = field( | |
metadata={ | |
"help": "The output directory where the model predictions and checkpoints will be written." | |
}, | |
) | |
overwrite_output_dir: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Overwrite the content of the output directory. " | |
"Use this to continue training if output_dir points to a checkpoint directory." | |
) | |
}, | |
) | |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) | |
do_eval: bool = field( | |
default=False, metadata={"help": "Whether to run eval on the dev set."} | |
) | |
per_device_train_batch_size: int = field( | |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} | |
) | |
per_device_eval_batch_size: int = field( | |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} | |
) | |
gradient_accumulation_steps: int = field( | |
default=1, | |
metadata={ | |
"help": "Number of updates steps to accumulate before performing a backward/update pass." | |
}, | |
) | |
learning_rate: float = field( | |
default=5e-5, metadata={"help": "The initial learning rate."} | |
) | |
adafactor: bool = field( | |
default=False, | |
metadata={"help": "Whether or not to replace AdamW by Adafactor."}, | |
) | |
weight_decay: float = field( | |
default=None, metadata={"help": "Weight decay if we apply some."} | |
) | |
adam_beta1: float = field( | |
default=0.9, metadata={"help": "Beta1 for AdamW optimizer"} | |
) | |
adam_beta2: float = field( | |
default=0.999, metadata={"help": "Beta2 for AdamW optimizer"} | |
) | |
adam_epsilon: float = field( | |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."} | |
) | |
max_grad_norm: float = field( | |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."} | |
) | |
use_decay: bool = field( | |
default=False, | |
metadata={"help": "Whether to use decay in the learning rate scheduler."}, | |
) | |
num_train_epochs: float = field( | |
default=3.0, metadata={"help": "Total number of training epochs to perform."} | |
) | |
warmup_steps: int = field( | |
default=0, metadata={"help": "Linear warmup over warmup_steps."} | |
) | |
logging_steps: int = field( | |
default=40, metadata={"help": "Log every X updates steps."} | |
) | |
eval_steps: int = field( | |
default=400, metadata={"help": "Run an evaluation every X steps."} | |
) | |
save_steps: int = field( | |
default=4000, metadata={"help": "Save checkpoint every X updates steps."} | |
) | |
log_model: bool = field( | |
default=False, | |
metadata={"help": "Log model to wandb at `save_steps` frequency."}, | |
) | |
seed_model: int = field( | |
default=42, | |
metadata={ | |
"help": "Random seed for the model that will be set at the beginning of training." | |
}, | |
) | |
push_to_hub: bool = field( | |
default=False, | |
metadata={ | |
"help": "Whether or not to upload the trained model to the model hub after training." | |
}, | |
) | |
resume_from_checkpoint: Optional[str] = field( | |
default=None, | |
metadata={"help": "Reference to a wandb artifact for resuming training."}, | |
) | |
class TrainState(train_state.TrainState): | |
dropout_rng: jnp.ndarray = None | |
epoch: int = 0 | |
train_time: float = 0.0 # total time the model trained | |
train_samples: int = 0 # number of samples seen | |
def replicate(self): | |
return jax_utils.replicate(self).replace( | |
dropout_rng=shard_prng_key(self.dropout_rng) | |
) | |
def restore_state(self, artifact_dir): | |
# restore optimizer state | |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f: | |
new_opt_state = from_bytes(self.opt_state, f.read()) | |
# restore other parameters | |
with (Path(artifact_dir) / "training_state.json").open("r") as f: | |
training_state = json.load(f) | |
# replace state | |
return self.replace( | |
opt_state=new_opt_state, | |
step=training_state["step"], | |
train_time=training_state["train_time"], | |
train_samples=training_state["train_samples"], | |
) | |
def create_learning_rate_fn( | |
num_warmup_steps: int, | |
learning_rate: float, | |
use_decay: bool, | |
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs | |
) -> Callable[[int], jnp.array]: | |
"""Returns a linear warmup, linear_decay learning rate function.""" | |
if use_decay: | |
assert ( | |
num_train_steps is not None | |
), "Learning rate with decay requires number of training steps" | |
warmup_fn = optax.linear_schedule( | |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps | |
) | |
if not use_decay: | |
return warmup_fn | |
decay_fn = optax.linear_schedule( | |
init_value=learning_rate, | |
end_value=0, | |
transition_steps=num_train_steps - num_warmup_steps, | |
) | |
schedule_fn = optax.join_schedules( | |
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps] | |
) | |
return schedule_fn | |
def wandb_log(metrics, step=None, prefix=None): | |
if jax.process_index() == 0: | |
log_metrics = { | |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items() | |
} | |
if step is not None: | |
log_metrics["train/step"] = step | |
wandb.log(log_metrics) | |
def main(): | |
# See all possible arguments by passing the --help flag to this script. | |
parser = HfArgumentParser( | |
(ModelArguments, DataTrainingArguments, TrainingArguments) | |
) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
# If we pass only one argument to the script and it's the path to a json file, | |
# let's parse it to get our arguments. | |
model_args, data_args, training_args = parser.parse_json_file( | |
json_file=os.path.abspath(sys.argv[1]) | |
) | |
else: | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
if ( | |
os.path.exists(training_args.output_dir) | |
and os.listdir(training_args.output_dir) | |
and training_args.do_train | |
and not training_args.overwrite_output_dir | |
): | |
raise ValueError( | |
f"Output directory ({training_args.output_dir}) already exists and is not empty." | |
"Use --overwrite_output_dir to overcome." | |
) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
# Setup logging, we only want one process per machine to log things on the screen. | |
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) | |
if jax.process_index() == 0: | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
logger.info(f"TPUs: {jax.device_count()}") | |
assert jax.device_count() == 8, "TPUs in use, please check running processes" | |
# Set the verbosity to info of the Transformers logger (on main process only): | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Load dataset | |
dataset = Dataset( | |
**asdict(data_args), | |
do_train=training_args.do_train, | |
do_eval=training_args.do_eval, | |
) | |
# Set up wandb run | |
wandb.init( | |
entity="dalle-mini", | |
project="dalle-mini", | |
job_type="Seq2Seq", | |
config=parser.parse_args(), | |
) | |
if training_args.resume_from_checkpoint is not None: | |
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint) | |
artifact_dir = artifact.download() | |
# load model | |
model = DalleBart.from_pretrained(artifact_dir) | |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658 | |
print(model.params) | |
# load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
artifact_dir, | |
use_fast=True, | |
) | |
else: | |
# Set up our new model config | |
if model_args.config_name: | |
config = DalleBartConfig.from_pretrained(model_args.config_name) | |
else: | |
config = DalleBartConfig.from_pretrained(model_args.model_name_or_path) | |
# Load or create new model | |
if model_args.model_name_or_path: | |
model = DalleBart.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
seed=training_args.seed_model, | |
dtype=getattr(jnp, model_args.dtype), | |
) | |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658 | |
print(model.params) | |
else: | |
model = DalleBart( | |
config, | |
seed=training_args.seed_model, | |
dtype=getattr(jnp, model_args.dtype), | |
) | |
# Load tokenizer | |
if model_args.tokenizer_name is not None: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.tokenizer_name, use_fast=True | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
use_fast=True, | |
) | |
# Preprocessing the datasets. | |
# We need to normalize and tokenize inputs and targets. | |
dataset.preprocess( | |
tokenizer=tokenizer, | |
decoder_start_token_id=model.config.decoder_start_token_id, | |
normalize_text=model.config.normalize_text, | |
max_length=model.config.max_text_length, | |
) | |
# Initialize our training | |
rng = jax.random.PRNGKey(training_args.seed_model) | |
rng, dropout_rng = jax.random.split(rng) | |
# Store some constant | |
num_epochs = int(training_args.num_train_epochs) | |
train_batch_size = ( | |
int(training_args.per_device_train_batch_size) * jax.device_count() | |
) | |
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps | |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() | |
len_train_dataset, len_eval_dataset = dataset.length | |
steps_per_epoch = ( | |
len_train_dataset // train_batch_size if len_train_dataset is not None else None | |
) | |
num_train_steps = ( | |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None | |
) | |
num_params = model.num_params | |
# Create learning rate schedule | |
learning_rate_fn = create_learning_rate_fn( | |
training_args.warmup_steps, | |
training_args.learning_rate, | |
training_args.use_decay, | |
num_train_steps, | |
) | |
# We use Optax's "masking" functionality to not apply weight decay | |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a | |
# mask boolean with the same structure as the parameters. | |
# The mask is True for parameters that should be decayed. | |
# Note that this mask is specifically adapted for FlaxBart. | |
def decay_mask_fn(params): | |
flat_params = traverse_util.flatten_dict(params) | |
layer_norm_params = [ | |
(name, "scale") | |
for name in [ | |
"self_attn_layer_norm", | |
"layernorm_embedding", | |
"final_layer_norm", | |
] | |
] | |
flat_mask = { | |
path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) | |
for path in flat_params | |
} | |
return traverse_util.unflatten_dict(flat_mask) | |
# create adam optimizer | |
if training_args.adafactor: | |
# We use the default parameters here to initialize adafactor, | |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 | |
optimizer = optax.adafactor( | |
learning_rate=learning_rate_fn, | |
weight_decay_rate=training_args.weight_decay, | |
weight_decay_mask=decay_mask_fn, | |
clipping_threshold=training_args.max_grad_norm, | |
) | |
else: | |
optimizer = optax.adamw( | |
learning_rate=learning_rate_fn, | |
b1=training_args.adam_beta1, | |
b2=training_args.adam_beta2, | |
eps=training_args.adam_epsilon, | |
weight_decay=training_args.weight_decay, | |
mask=decay_mask_fn, | |
) | |
# add gradient accumulation | |
if training_args.gradient_accumulation_steps > 1: | |
optimizer = optax.chain( | |
optax.apply_every(training_args.gradient_accumulation_steps), optimizer | |
) | |
# Setup train state | |
state = TrainState.create( | |
apply_fn=model.__call__, | |
params=model.params, | |
tx=optimizer, | |
dropout_rng=dropout_rng, | |
) | |
if training_args.resume_from_checkpoint is not None: | |
# restore optimizer state and other parameters | |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105 | |
state = state.restore_state(artifact_dir) | |
# label smoothed cross entropy | |
def loss_fn(logits, labels): | |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) | |
loss = loss.mean() | |
return loss | |
# Define gradient update step fn | |
def train_step(state, batch, delta_time): | |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | |
def compute_loss(params, batch): | |
labels = batch.pop("labels") | |
logits = state.apply_fn( | |
**batch, params=params, dropout_rng=dropout_rng, train=True | |
)[0] | |
loss = loss_fn(logits, labels) | |
return loss | |
grad_fn = jax.value_and_grad(compute_loss) | |
loss, grads = grad_fn(state.params, batch) | |
grads = jax.lax.pmean(grads, "batch") | |
state = state.apply_gradients( | |
grads=grads, | |
dropout_rng=new_dropout_rng, | |
train_time=state.train_time + delta_time, | |
train_samples=state.train_samples + train_batch_size, | |
) | |
metrics = { | |
"loss": loss, | |
"learning_rate": learning_rate_fn(state.step), | |
} | |
metrics = jax.lax.pmean(metrics, axis_name="batch") | |
return state, metrics | |
# Define eval fn | |
def eval_step(params, batch): | |
labels = batch.pop("labels") | |
logits = model(**batch, params=params, train=False)[0] | |
loss = loss_fn(logits, labels) | |
# summarize metrics | |
metrics = {"loss": loss} | |
metrics = jax.lax.pmean(metrics, axis_name="batch") | |
return metrics | |
# Create parallel version of the train and eval step | |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) | |
p_eval_step = jax.pmap(eval_step, "batch") | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len_train_dataset}") | |
logger.info(f" Num Epochs = {num_epochs}") | |
logger.info( | |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" | |
) | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}" | |
) | |
logger.info(f" Model parameters = {num_params:,}") | |
epochs = tqdm( | |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0 | |
) | |
# set default x-axis as 'train/step' | |
wandb_log({}, step=state.step) | |
wandb.define_metric("*", step_metric="train/step") | |
# add interesting config parameters | |
wandb.config.update( | |
{ | |
"len_train_dataset": len_train_dataset, | |
"len_eval_dataset": len_eval_dataset, | |
"batch_size_per_update": batch_size_per_update, | |
"num_params": num_params, | |
} | |
) | |
# replicate state on each device | |
state = state.replicate() | |
def run_evaluation(): | |
# ======================== Evaluating ============================== | |
eval_metrics = [] | |
if training_args.do_eval: | |
eval_loader = dataset.dataloader("eval", eval_batch_size) | |
eval_steps = ( | |
len_eval_dataset // eval_batch_size | |
if len_eval_dataset is not None | |
else None | |
) | |
for batch in tqdm( | |
eval_loader, | |
desc="Evaluating...", | |
position=2, | |
leave=False, | |
total=eval_steps, | |
): | |
# Model forward | |
metrics = p_eval_step(state.params, batch) | |
eval_metrics.append(metrics) | |
# normalize eval metrics | |
eval_metrics = get_metrics(eval_metrics) | |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) | |
# log metrics | |
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval") | |
# Print metrics and update progress bar | |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})" | |
epochs.write(desc) | |
epochs.desc = desc | |
return eval_metrics | |
def run_save_model(state, eval_metrics=None): | |
if jax.process_index() == 0: | |
params = jax.device_get(unreplicate(state.params)) | |
# save model locally | |
model.save_pretrained( | |
training_args.output_dir, | |
params=params, | |
) | |
# save tokenizer | |
tokenizer.save_pretrained(training_args.output_dir) | |
# save state | |
opt_state = unreplicate(state.opt_state) | |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f: | |
f.write(to_bytes(opt_state)) | |
state_dict = { | |
k: jax.device_get(unreplicate(getattr(state, k))).item() | |
for k in ["step", "epoch", "train_time", "train_samples"] | |
} | |
with (Path(training_args.output_dir) / "training_state.json").open( | |
"w" | |
) as f: | |
json.dump( | |
state_dict, | |
f, | |
) | |
# save to W&B | |
if training_args.log_model: | |
# save some space | |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache() | |
c.cleanup(wandb.util.from_human_size("10GB")) | |
metadata = dict(state_dict) | |
metadata["num_params"] = num_params | |
if eval_metrics is not None: | |
metadata["eval"] = eval_metrics | |
artifact = wandb.Artifact( | |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata | |
) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "flax_model.msgpack") | |
) | |
artifact.add_file(str(Path(training_args.output_dir) / "config.json")) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "tokenizer.json") | |
) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "tokenizer_config.json") | |
) | |
artifact.add_file(str(Path(training_args.output_dir) / "vocab.json")) | |
artifact.add_file(str(Path(training_args.output_dir) / "merges.txt")) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "special_tokens_map.json") | |
) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "opt_state.msgpack") | |
) | |
artifact.add_file( | |
str(Path(training_args.output_dir) / "training_state.json") | |
) | |
wandb.run.log_artifact(artifact) | |
# save to the hub | |
if training_args.push_to_hub: | |
model.save_pretrained( | |
training_args.output_dir, | |
params=params, | |
push_to_hub=training_args.push_to_hub, | |
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}", | |
temp_dir=True, # avoid issues with being in a repository | |
) | |
# init variables | |
last_time = time.perf_counter() | |
train_metrics = None | |
for epoch in epochs: | |
state.replace(epoch=jax_utils.replicate(epoch)) | |
# ======================== Training ================================ | |
wandb_log({"train/epoch": epoch}, step=unreplicate(state.step)) | |
# Generate an epoch by shuffling sampling indices from the train dataset | |
train_loader = dataset.dataloader("train", train_batch_size) | |
# train | |
for batch in tqdm( | |
train_loader, | |
desc="Training...", | |
position=1, | |
leave=False, | |
total=steps_per_epoch, | |
): | |
# calculate delta time (we have a lag of one step but it's ok) | |
new_time = time.perf_counter() | |
delta_time = new_time - last_time | |
last_time = new_time | |
# train step | |
state, train_metrics = p_train_step( | |
state, batch, jax_utils.replicate(delta_time) | |
) | |
step = unreplicate(state.step) | |
if step % training_args.logging_steps == 0 and jax.process_index() == 0: | |
# log metrics | |
metrics = unreplicate(train_metrics) | |
# log state parameters | |
state_dict = { | |
k.split("_")[-1]: unreplicate(getattr(state, k)) | |
for k in ["epoch", "train_time", "train_samples"] | |
} | |
wandb_log({**metrics, **state_dict}, step=step, prefix="train") | |
eval_metrics = None | |
if training_args.eval_steps and step % training_args.eval_steps == 0: | |
eval_metrics = run_evaluation() | |
if step % training_args.save_steps == 0: | |
run_save_model(state, eval_metrics) | |
# log final train metrics | |
if train_metrics is not None: | |
train_metrics = unreplicate(train_metrics) | |
wandb_log(train_metrics, step=step, prefix="train") | |
epochs.write( | |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})" | |
) | |
# Final evaluation | |
eval_metrics = run_evaluation() | |
# save checkpoint after each epoch | |
run_save_model(state, eval_metrics) | |
if __name__ == "__main__": | |
main() | |