Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2020-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. | |
""" | |
import contextlib | |
import functools | |
import glob | |
import inspect | |
import math | |
import os | |
import random | |
import re | |
import shutil | |
import sys | |
import time | |
import warnings | |
from collections.abc import Mapping | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | |
from tqdm.auto import tqdm | |
# Integrations must be imported before ML frameworks: | |
# isort: off | |
from transformers.integrations import ( | |
default_hp_search_backend, | |
get_reporting_integration_callbacks, | |
hp_params, | |
is_fairscale_available, | |
is_optuna_available, | |
is_ray_tune_available, | |
is_sigopt_available, | |
is_wandb_available, | |
run_hp_search_optuna, | |
run_hp_search_ray, | |
run_hp_search_sigopt, | |
run_hp_search_wandb, | |
) | |
# isort: on | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from huggingface_hub import Repository, create_repo | |
from packaging import version | |
from torch import nn | |
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler | |
from torch.utils.data.distributed import DistributedSampler | |
from transformers import __version__ | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator | |
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow | |
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled | |
from transformers.dependency_versions_check import dep_version_check | |
from transformers.modelcard import TrainingSummary | |
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model | |
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES | |
from transformers.optimization import Adafactor, get_scheduler | |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.trainer_callback import ( | |
CallbackHandler, | |
DefaultFlowCallback, | |
PrinterCallback, | |
ProgressCallback, | |
TrainerCallback, | |
TrainerControl, | |
TrainerState, | |
) | |
from transformers.trainer_pt_utils import ( | |
DistributedLengthGroupedSampler, | |
DistributedSamplerWithLoop, | |
DistributedTensorGatherer, | |
IterableDatasetShard, | |
LabelSmoother, | |
LengthGroupedSampler, | |
SequentialDistributedSampler, | |
ShardSampler, | |
distributed_broadcast_scalars, | |
distributed_concat, | |
find_batch_size, | |
get_model_param_count, | |
get_module_class_from_name, | |
get_parameter_names, | |
nested_concat, | |
nested_detach, | |
nested_numpify, | |
nested_truncate, | |
nested_xla_mesh_reduce, | |
reissue_pt_warnings, | |
) | |
from transformers.trainer_utils import ( | |
PREFIX_CHECKPOINT_DIR, | |
BestRun, | |
EvalLoopOutput, | |
EvalPrediction, | |
FSDPOption, | |
HPSearchBackend, | |
HubStrategy, | |
IntervalStrategy, | |
PredictionOutput, | |
RemoveColumnsCollator, | |
ShardedDDPOption, | |
TrainerMemoryTracker, | |
TrainOutput, | |
default_compute_objective, | |
default_hp_space, | |
denumpify_detensorize, | |
enable_full_determinism, | |
find_executable_batch_size, | |
get_last_checkpoint, | |
has_length, | |
number_of_arguments, | |
seed_worker, | |
set_seed, | |
speed_metrics, | |
) | |
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments | |
from transformers.utils import ( | |
ADAPTER_SAFE_WEIGHTS_NAME, | |
ADAPTER_WEIGHTS_NAME, | |
CONFIG_NAME, | |
SAFE_WEIGHTS_INDEX_NAME, | |
SAFE_WEIGHTS_NAME, | |
WEIGHTS_INDEX_NAME, | |
WEIGHTS_NAME, | |
can_return_loss, | |
find_labels, | |
get_full_repo_name, | |
is_accelerate_available, | |
is_apex_available, | |
is_datasets_available, | |
is_in_notebook, | |
is_ipex_available, | |
is_peft_available, | |
is_safetensors_available, | |
is_sagemaker_dp_enabled, | |
is_sagemaker_mp_enabled, | |
is_torch_compile_available, | |
is_torch_neuroncore_available, | |
is_torch_tpu_available, | |
logging, | |
strtobool, | |
) | |
from transformers.utils.generic import ContextManagers | |
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 | |
DEFAULT_CALLBACKS = [DefaultFlowCallback] | |
DEFAULT_PROGRESS_CALLBACK = ProgressCallback | |
if is_in_notebook(): | |
from transformers.utils.notebook import NotebookProgressCallback | |
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback | |
if is_apex_available(): | |
from apex import amp | |
if is_datasets_available(): | |
import datasets | |
if is_torch_tpu_available(check_device=False): | |
import torch_xla.core.xla_model as xm | |
import torch_xla.debug.metrics as met | |
import torch_xla.distributed.parallel_loader as pl | |
if is_fairscale_available(): | |
dep_version_check("fairscale") | |
import fairscale | |
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP | |
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP | |
from fairscale.nn.wrap import auto_wrap | |
from fairscale.optim import OSS | |
from fairscale.optim.grad_scaler import ShardedGradScaler | |
if is_sagemaker_mp_enabled(): | |
import smdistributed.modelparallel.torch as smp | |
from smdistributed.modelparallel import __version__ as SMP_VERSION | |
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") | |
from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat | |
else: | |
IS_SAGEMAKER_MP_POST_1_10 = False | |
if is_safetensors_available(): | |
import safetensors.torch | |
if is_peft_available(): | |
from peft import PeftModel | |
skip_first_batches = None | |
if is_accelerate_available(): | |
from accelerate import __version__ as accelerate_version | |
if version.parse(accelerate_version) >= version.parse("0.16"): | |
from accelerate import skip_first_batches | |
from accelerate import Accelerator | |
from accelerate.utils import DistributedDataParallelKwargs | |
if TYPE_CHECKING: | |
import optuna | |
logger = logging.get_logger(__name__) | |
# Name of the files used for checkpointing | |
TRAINING_ARGS_NAME = "training_args.bin" | |
TRAINER_STATE_NAME = "trainer_state.json" | |
OPTIMIZER_NAME = "optimizer.pt" | |
SCHEDULER_NAME = "scheduler.pt" | |
SCALER_NAME = "scaler.pt" | |
class Trainer: | |
""" | |
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. | |
Args: | |
model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): | |
The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. | |
<Tip> | |
[`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use | |
your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers | |
models. | |
</Tip> | |
args ([`TrainingArguments`], *optional*): | |
The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the | |
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. | |
data_collator (`DataCollator`, *optional*): | |
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will | |
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of | |
[`DataCollatorWithPadding`] otherwise. | |
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): | |
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the | |
`model.forward()` method are automatically removed. | |
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a | |
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a | |
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will | |
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally | |
sets the seed of the RNGs used. | |
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): | |
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the | |
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each | |
dataset prepending the dictionary key to the metric name. | |
tokenizer ([`PreTrainedTokenizerBase`], *optional*): | |
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the | |
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an | |
interrupted training or reuse the fine-tuned model. | |
model_init (`Callable[[], PreTrainedModel]`, *optional*): | |
A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start | |
from a new instance of the model as given by this function. | |
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to | |
be able to choose different architectures according to hyper parameters (such as layer count, sizes of | |
inner layers, dropout probabilities etc). | |
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): | |
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return | |
a dictionary string to metric values. | |
callbacks (List of [`TrainerCallback`], *optional*): | |
A list of callbacks to customize the training loop. Will add those to the list of default callbacks | |
detailed in [here](callback). | |
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. | |
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple | |
containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model | |
and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. | |
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): | |
A function that preprocess the logits right before caching them at each evaluation step. Must take two | |
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made | |
by this function will be reflected in the predictions received by `compute_metrics`. | |
Note that the labels (second parameter) will be `None` if the dataset does not have them. | |
Important attributes: | |
- **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] | |
subclass. | |
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the | |
original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, | |
the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner | |
model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. | |
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from | |
data parallelism, this means some of the model layers are split on different GPUs). | |
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set | |
to `False` if model parallel or deepspeed is used, or if the default | |
`TrainingArguments.place_model_on_device` is overridden to return `False` . | |
- **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while | |
in `train`) | |
""" | |
from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state | |
def __init__( | |
self, | |
model: Union[PreTrainedModel, nn.Module] = None, | |
args: TrainingArguments = None, | |
data_collator: Optional[DataCollator] = None, | |
train_dataset: Optional[Dataset] = None, | |
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, | |
tokenizer: Optional[PreTrainedTokenizerBase] = None, | |
model_init: Optional[Callable[[], PreTrainedModel]] = None, | |
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | |
callbacks: Optional[List[TrainerCallback]] = None, | |
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | |
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | |
): | |
if args is None: | |
output_dir = "tmp_trainer" | |
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") | |
args = TrainingArguments(output_dir=output_dir) | |
self.args = args | |
# Seed must be set before instantiating the model when using model | |
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) | |
self.hp_name = None | |
self.is_in_train = False | |
self.create_accelerator_and_postprocess() | |
# memory metrics - must set up as early as possible | |
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) | |
self._memory_tracker.start() | |
# set the correct log level depending on the node | |
log_level = args.get_process_log_level() | |
logging.set_verbosity(log_level) | |
# force device and distributed setup init explicitly | |
args._setup_devices | |
if model is None: | |
if model_init is not None: | |
self.model_init = model_init | |
model = self.call_model_init() | |
else: | |
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") | |
else: | |
if model_init is not None: | |
warnings.warn( | |
"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" | |
" overwrite your model when calling the `train` method. This will become a fatal error in the next" | |
" release.", | |
FutureWarning, | |
) | |
self.model_init = model_init | |
if model.__class__.__name__ in MODEL_MAPPING_NAMES: | |
raise ValueError( | |
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " | |
"computes hidden states and does not accept any labels. You should choose a model with a head " | |
"suitable for your task like any of the `AutoModelForXxx` listed at " | |
"https://huggingface.co/docs/transformers/model_doc/auto." | |
) | |
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: | |
self.is_model_parallel = True | |
else: | |
self.is_model_parallel = False | |
if getattr(model, "hf_device_map", None) is not None: | |
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] | |
if len(devices) > 1: | |
self.is_model_parallel = True | |
else: | |
self.is_model_parallel = self.args.device != torch.device(devices[0]) | |
# warn users | |
logger.info( | |
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" | |
" to `True` to avoid any unexpected behavior such as device placement mismatching." | |
) | |
# At this stage the model is already loaded | |
if getattr(model, "is_quantized", False): | |
if getattr(model, "_is_quantized_training_enabled", False): | |
logger.info( | |
"The model is loaded in 8-bit precision. To train this model you need to add additional modules" | |
" inside the model such as adapters using `peft` library and freeze the model weights. Please" | |
" check " | |
" the examples in https://github.com/huggingface/peft for more details." | |
) | |
else: | |
raise ValueError( | |
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" | |
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. " | |
) | |
# Setup Sharded DDP training | |
self.sharded_ddp = None | |
if len(args.sharded_ddp) > 0: | |
if self.is_deepspeed_enabled: | |
raise ValueError( | |
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." | |
) | |
if len(args.fsdp) > 0: | |
raise ValueError( | |
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." | |
) | |
if args.parallel_mode != ParallelMode.DISTRIBUTED: | |
raise ValueError("Using sharded DDP only works in distributed training.") | |
elif not is_fairscale_available(): | |
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") | |
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: | |
raise ImportError( | |
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " | |
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." | |
) | |
elif ShardedDDPOption.SIMPLE in args.sharded_ddp: | |
self.sharded_ddp = ShardedDDPOption.SIMPLE | |
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: | |
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 | |
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: | |
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 | |
self.fsdp = None | |
if len(args.fsdp) > 0: | |
if self.is_deepspeed_enabled: | |
raise ValueError( | |
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." | |
) | |
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: | |
raise ValueError("Using fsdp only works in distributed training.") | |
# dep_version_check("torch>=1.12.0") | |
# Would have to update setup.py with torch>=1.12.0 | |
# which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 | |
# below is the current alternative. | |
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): | |
raise ValueError("FSDP requires PyTorch >= 1.12.0") | |
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy | |
if FSDPOption.FULL_SHARD in args.fsdp: | |
self.fsdp = ShardingStrategy.FULL_SHARD | |
elif FSDPOption.SHARD_GRAD_OP in args.fsdp: | |
self.fsdp = ShardingStrategy.SHARD_GRAD_OP | |
elif FSDPOption.NO_SHARD in args.fsdp: | |
self.fsdp = ShardingStrategy.NO_SHARD | |
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE | |
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( | |
"backward_prefetch", [] | |
): | |
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST | |
self.forward_prefetch = False | |
if self.args.fsdp_config.get("forward_prefect", False): | |
self.forward_prefetch = True | |
self.limit_all_gathers = False | |
if self.args.fsdp_config.get("limit_all_gathers", False): | |
self.limit_all_gathers = True | |
# one place to sort out whether to place the model on device or not | |
# postpone switching model to cuda when: | |
# 1. MP - since we are trying to fit a much bigger than 1 gpu model | |
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, | |
# and we only use deepspeed for training at the moment | |
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first | |
# 4. Sharded DDP - same as MP | |
# 5. FSDP - same as MP | |
self.place_model_on_device = args.place_model_on_device | |
if ( | |
self.is_model_parallel | |
or self.is_deepspeed_enabled | |
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) | |
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) | |
or (self.fsdp is not None) | |
or self.is_fsdp_enabled | |
): | |
self.place_model_on_device = False | |
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) | |
self.data_collator = data_collator if data_collator is not None else default_collator | |
self.train_dataset = train_dataset | |
self.eval_dataset = eval_dataset | |
self.tokenizer = tokenizer | |
if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): | |
self._move_model_to_device(model, args.device) | |
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs | |
if self.is_model_parallel: | |
self.args._n_gpu = 1 | |
# later use `self.model is self.model_wrapped` to check if it's wrapped or not | |
self.model_wrapped = model | |
self.model = model | |
self.compute_metrics = compute_metrics | |
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics | |
self.optimizer, self.lr_scheduler = optimizers | |
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): | |
raise RuntimeError( | |
"Passing a `model_init` is incompatible with providing the `optimizers` argument. " | |
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | |
) | |
if is_torch_tpu_available() and self.optimizer is not None: | |
for param in self.model.parameters(): | |
model_device = param.device | |
break | |
for param_group in self.optimizer.param_groups: | |
if len(param_group["params"]) > 0: | |
optimizer_device = param_group["params"][0].device | |
break | |
if model_device != optimizer_device: | |
raise ValueError( | |
"The model and the optimizer parameters are not on the same device, which probably means you" | |
" created an optimizer around your model **before** putting on the device and passing it to the" | |
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" | |
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script." | |
) | |
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( | |
self.optimizer is not None or self.lr_scheduler is not None | |
): | |
raise RuntimeError( | |
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." | |
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | |
) | |
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | |
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | |
self.callback_handler = CallbackHandler( | |
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler | |
) | |
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) | |
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`. | |
self._loggers_initialized = False | |
# Create clone of distant repo and output directory if needed | |
if self.args.push_to_hub: | |
self.init_git_repo(at_init=True) | |
# In case of pull, we need to make sure every process has the latest. | |
if is_torch_tpu_available(): | |
xm.rendezvous("init git repo") | |
elif args.parallel_mode == ParallelMode.DISTRIBUTED: | |
dist.barrier() | |
if self.args.should_save: | |
os.makedirs(self.args.output_dir, exist_ok=True) | |
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): | |
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") | |
if args.max_steps > 0: | |
logger.info("max_steps is given, it will override any value given in num_train_epochs") | |
if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: | |
raise ValueError( | |
"The train_dataset does not implement __len__, max_steps has to be specified. " | |
"The number of steps needs to be known in advance for the learning rate scheduler." | |
) | |
if ( | |
train_dataset is not None | |
and isinstance(train_dataset, torch.utils.data.IterableDataset) | |
and args.group_by_length | |
): | |
raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") | |
self._signature_columns = None | |
# Mixed precision setup | |
self.use_apex = False | |
self.use_cuda_amp = False | |
self.use_cpu_amp = False | |
# Mixed precision setup for SageMaker Model Parallel | |
if is_sagemaker_mp_enabled(): | |
# BF16 + model parallelism in SageMaker: currently not supported, raise an error | |
if args.bf16: | |
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") | |
if IS_SAGEMAKER_MP_POST_1_10: | |
# When there's mismatch between SMP config and trainer argument, use SMP config as truth | |
if args.fp16 != smp.state.cfg.fp16: | |
logger.warning( | |
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," | |
f"but FP16 provided in trainer argument is {args.fp16}," | |
f"setting to {smp.state.cfg.fp16}" | |
) | |
args.fp16 = smp.state.cfg.fp16 | |
else: | |
# smp < 1.10 does not support fp16 in trainer. | |
if hasattr(smp.state.cfg, "fp16"): | |
logger.warning( | |
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " | |
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." | |
) | |
if (args.fp16 or args.bf16) and self.sharded_ddp is not None: | |
if args.half_precision_backend == "auto": | |
if args.device == torch.device("cpu"): | |
if args.fp16: | |
raise ValueError("Tried to use `fp16` but it is not supported on cpu") | |
elif _is_native_cpu_amp_available: | |
args.half_precision_backend = "cpu_amp" | |
else: | |
raise ValueError("Tried to use cpu amp but native cpu amp is not available") | |
else: | |
args.half_precision_backend = "cuda_amp" | |
logger.info(f"Using {args.half_precision_backend} half precision backend") | |
self.do_grad_scaling = False | |
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): | |
# deepspeed and SageMaker Model Parallel manage their own half precision | |
if self.sharded_ddp is not None: | |
if args.half_precision_backend == "cuda_amp": | |
self.use_cuda_amp = True | |
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 | |
# bf16 does not need grad scaling | |
self.do_grad_scaling = self.amp_dtype == torch.float16 | |
if self.do_grad_scaling: | |
if self.sharded_ddp is not None: | |
self.scaler = ShardedGradScaler() | |
elif self.fsdp is not None: | |
from torch.distributed.fsdp.sharded_grad_scaler import ( | |
ShardedGradScaler as FSDPShardedGradScaler, | |
) | |
self.scaler = FSDPShardedGradScaler() | |
elif is_torch_tpu_available(): | |
from torch_xla.amp import GradScaler | |
self.scaler = GradScaler() | |
else: | |
self.scaler = torch.cuda.amp.GradScaler() | |
elif args.half_precision_backend == "cpu_amp": | |
self.use_cpu_amp = True | |
self.amp_dtype = torch.bfloat16 | |
elif args.half_precision_backend == "apex": | |
if not is_apex_available(): | |
raise ImportError( | |
"Using FP16 with APEX but APEX is not installed, please refer to" | |
" https://www.github.com/nvidia/apex." | |
) | |
self.use_apex = True | |
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. | |
if ( | |
is_sagemaker_mp_enabled() | |
and self.use_cuda_amp | |
and args.max_grad_norm is not None | |
and args.max_grad_norm > 0 | |
): | |
raise ValueError( | |
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " | |
"along 'max_grad_norm': 0 in your hyperparameters." | |
) | |
# Label smoothing | |
if self.args.label_smoothing_factor != 0: | |
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) | |
else: | |
self.label_smoother = None | |
self.state = TrainerState( | |
is_local_process_zero=self.is_local_process_zero(), | |
is_world_process_zero=self.is_world_process_zero(), | |
) | |
self.control = TrainerControl() | |
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then | |
# returned to 0 every time flos need to be logged | |
self.current_flos = 0 | |
self.hp_search_backend = None | |
self.use_tune_checkpoints = False | |
default_label_names = find_labels(self.model.__class__) | |
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | |
self.can_return_loss = can_return_loss(self.model.__class__) | |
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) | |
# Internal variables to keep track of the original batch size | |
self._train_batch_size = args.train_batch_size | |
# very last | |
self._memory_tracker.stop_and_update_metrics() | |
# torch.compile | |
if args.torch_compile and not is_torch_compile_available(): | |
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") | |
def add_callback(self, callback): | |
""" | |
Add a callback to the current list of [`~transformer.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformer.TrainerCallback`]): | |
A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the | |
first case, will instantiate a member of that class. | |
""" | |
self.callback_handler.add_callback(callback) | |
def pop_callback(self, callback): | |
""" | |
Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. | |
If the callback is not found, returns `None` (and no error is raised). | |
Args: | |
callback (`type` or [`~transformer.TrainerCallback`]): | |
A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the | |
first case, will pop the first member of that class found in the list of callbacks. | |
Returns: | |
[`~transformer.TrainerCallback`]: The callback removed, if found. | |
""" | |
return self.callback_handler.pop_callback(callback) | |
def remove_callback(self, callback): | |
""" | |
Remove a callback from the current list of [`~transformer.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformer.TrainerCallback`]): | |
A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the | |
first case, will remove the first member of that class found in the list of callbacks. | |
""" | |
self.callback_handler.remove_callback(callback) | |
def _move_model_to_device(self, model, device): | |
model = model.to(device) | |
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them. | |
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): | |
model.tie_weights() | |
def _set_signature_columns_if_needed(self): | |
if self._signature_columns is None: | |
# Inspect model forward signature to keep only the arguments it accepts. | |
signature = inspect.signature(self.model.forward) | |
self._signature_columns = list(signature.parameters.keys()) | |
# Labels may be named label or label_ids, the default data collator handles that. | |
self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) | |
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): | |
if not self.args.remove_unused_columns: | |
return dataset | |
self._set_signature_columns_if_needed() | |
signature_columns = self._signature_columns | |
ignored_columns = list(set(dataset.column_names) - set(signature_columns)) | |
if len(ignored_columns) > 0: | |
dset_description = "" if description is None else f"in the {description} set" | |
logger.info( | |
f"The following columns {dset_description} don't have a corresponding argument in " | |
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." | |
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " | |
" you can safely ignore this message." | |
) | |
columns = [k for k in signature_columns if k in dataset.column_names] | |
if version.parse(datasets.__version__) < version.parse("1.4.0"): | |
dataset.set_format( | |
type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] | |
) | |
return dataset | |
else: | |
return dataset.remove_columns(ignored_columns) | |
def _get_collator_with_removed_columns( | |
self, data_collator: Callable, description: Optional[str] = None | |
) -> Callable: | |
"""Wrap the data collator in a callable removing unused columns.""" | |
if not self.args.remove_unused_columns: | |
return data_collator | |
self._set_signature_columns_if_needed() | |
signature_columns = self._signature_columns | |
remove_columns_collator = RemoveColumnsCollator( | |
data_collator=data_collator, | |
signature_columns=signature_columns, | |
logger=logger, | |
description=description, | |
model_name=self.model.__class__.__name__, | |
) | |
return remove_columns_collator | |
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: | |
if self.train_dataset is None or not has_length(self.train_dataset): | |
return None | |
generator = None | |
if self.args.world_size <= 1: | |
generator = torch.Generator() | |
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with | |
# `args.seed`) if data_seed isn't provided. | |
# Further on in this method, we default to `args.seed` instead. | |
if self.args.data_seed is None: | |
seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
else: | |
seed = self.args.data_seed | |
generator.manual_seed(seed) | |
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | |
# Build the sampler. | |
if self.args.group_by_length: | |
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): | |
lengths = ( | |
self.train_dataset[self.args.length_column_name] | |
if self.args.length_column_name in self.train_dataset.column_names | |
else None | |
) | |
else: | |
lengths = None | |
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None | |
if self.args.world_size <= 1: | |
return LengthGroupedSampler( | |
self.args.train_batch_size * self.args.gradient_accumulation_steps, | |
dataset=self.train_dataset, | |
lengths=lengths, | |
model_input_name=model_input_name, | |
generator=generator, | |
) | |
else: | |
return DistributedLengthGroupedSampler( | |
self.args.train_batch_size * self.args.gradient_accumulation_steps, | |
dataset=self.train_dataset, | |
num_replicas=self.args.world_size, | |
rank=self.args.process_index, | |
lengths=lengths, | |
model_input_name=model_input_name, | |
seed=seed, | |
) | |
else: | |
if self.args.world_size <= 1: | |
return RandomSampler(self.train_dataset, generator=generator) | |
elif ( | |
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] | |
and not self.args.dataloader_drop_last | |
): | |
# Use a loop for TPUs when drop_last is False to have all batches have the same size. | |
return DistributedSamplerWithLoop( | |
self.train_dataset, | |
batch_size=self.args.per_device_train_batch_size, | |
num_replicas=self.args.world_size, | |
rank=self.args.process_index, | |
seed=seed, | |
) | |
else: | |
return DistributedSampler( | |
self.train_dataset, | |
num_replicas=self.args.world_size, | |
rank=self.args.process_index, | |
seed=seed, | |
) | |
def get_train_dataloader(self) -> DataLoader: | |
""" | |
Returns the training [`~torch.utils.data.DataLoader`]. | |
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed | |
training if necessary) otherwise. | |
Subclass and override this method if you want to inject some custom behavior. | |
""" | |
if self.train_dataset is None: | |
raise ValueError("Trainer: training requires a train_dataset.") | |
train_dataset = self.train_dataset | |
data_collator = self.data_collator | |
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | |
train_dataset = self._remove_unused_columns(train_dataset, description="training") | |
else: | |
data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | |
if isinstance(train_dataset, torch.utils.data.IterableDataset): | |
if self.args.world_size > 1: | |
train_dataset = IterableDatasetShard( | |
train_dataset, | |
batch_size=self._train_batch_size, | |
drop_last=self.args.dataloader_drop_last, | |
num_processes=self.args.world_size, | |
process_index=self.args.process_index, | |
) | |
return DataLoader( | |
train_dataset, | |
batch_size=self._train_batch_size, | |
collate_fn=data_collator, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
) | |
train_sampler = self._get_train_sampler() | |
return DataLoader( | |
train_dataset, | |
batch_size=self._train_batch_size, | |
sampler=train_sampler, | |
collate_fn=data_collator, | |
drop_last=self.args.dataloader_drop_last, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
worker_init_fn=seed_worker, | |
) | |
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: | |
# Deprecated code | |
if self.args.use_legacy_prediction_loop: | |
if is_torch_tpu_available(): | |
return SequentialDistributedSampler( | |
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() | |
) | |
elif is_sagemaker_mp_enabled(): | |
return SequentialDistributedSampler( | |
eval_dataset, | |
num_replicas=smp.dp_size(), | |
rank=smp.dp_rank(), | |
batch_size=self.args.per_device_eval_batch_size, | |
) | |
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
return SequentialDistributedSampler(eval_dataset) | |
else: | |
return SequentialSampler(eval_dataset) | |
if self.args.world_size <= 1: | |
return SequentialSampler(eval_dataset) | |
else: | |
return ShardSampler( | |
eval_dataset, | |
batch_size=self.args.per_device_eval_batch_size, | |
num_processes=self.args.world_size, | |
process_index=self.args.process_index, | |
) | |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: | |
""" | |
Returns the evaluation [`~torch.utils.data.DataLoader`]. | |
Subclass and override this method if you want to inject some custom behavior. | |
Args: | |
eval_dataset (`torch.utils.data.Dataset`, *optional*): | |
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted | |
by the `model.forward()` method are automatically removed. It must implement `__len__`. | |
""" | |
if eval_dataset is None and self.eval_dataset is None: | |
raise ValueError("Trainer: evaluation requires an eval_dataset.") | |
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | |
data_collator = self.data_collator | |
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): | |
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") | |
else: | |
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") | |
if isinstance(eval_dataset, torch.utils.data.IterableDataset): | |
if self.args.world_size > 1: | |
eval_dataset = IterableDatasetShard( | |
eval_dataset, | |
batch_size=self.args.per_device_eval_batch_size, | |
drop_last=self.args.dataloader_drop_last, | |
num_processes=self.args.world_size, | |
process_index=self.args.process_index, | |
) | |
return DataLoader( | |
eval_dataset, | |
batch_size=self.args.eval_batch_size, | |
collate_fn=data_collator, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
) | |
eval_sampler = self._get_eval_sampler(eval_dataset) | |
return DataLoader( | |
eval_dataset, | |
sampler=eval_sampler, | |
batch_size=self.args.eval_batch_size, | |
collate_fn=data_collator, | |
drop_last=self.args.dataloader_drop_last, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
) | |
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: | |
""" | |
Returns the test [`~torch.utils.data.DataLoader`]. | |
Subclass and override this method if you want to inject some custom behavior. | |
Args: | |
test_dataset (`torch.utils.data.Dataset`, *optional*): | |
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the | |
`model.forward()` method are automatically removed. It must implement `__len__`. | |
""" | |
data_collator = self.data_collator | |
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): | |
test_dataset = self._remove_unused_columns(test_dataset, description="test") | |
else: | |
data_collator = self._get_collator_with_removed_columns(data_collator, description="test") | |
if isinstance(test_dataset, torch.utils.data.IterableDataset): | |
if self.args.world_size > 1: | |
test_dataset = IterableDatasetShard( | |
test_dataset, | |
batch_size=self.args.eval_batch_size, | |
drop_last=self.args.dataloader_drop_last, | |
num_processes=self.args.world_size, | |
process_index=self.args.process_index, | |
) | |
return DataLoader( | |
test_dataset, | |
batch_size=self.args.eval_batch_size, | |
collate_fn=data_collator, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
) | |
test_sampler = self._get_eval_sampler(test_dataset) | |
# We use the same batch_size as for eval. | |
return DataLoader( | |
test_dataset, | |
sampler=test_sampler, | |
batch_size=self.args.eval_batch_size, | |
collate_fn=data_collator, | |
drop_last=self.args.dataloader_drop_last, | |
num_workers=self.args.dataloader_num_workers, | |
pin_memory=self.args.dataloader_pin_memory, | |
) | |
def create_optimizer_and_scheduler(self, num_training_steps: int): | |
""" | |
Setup the optimizer and the learning rate scheduler. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or | |
`create_scheduler`) in a subclass. | |
""" | |
self.create_optimizer() | |
if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: | |
# If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer | |
optimizer = self.optimizer.optimizer | |
else: | |
optimizer = self.optimizer | |
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) | |
def create_optimizer(self): | |
""" | |
Setup the optimizer. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
""" | |
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
if self.optimizer is None: | |
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) | |
decay_parameters = [name for name in decay_parameters if "bias" not in name] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) | |
], | |
"weight_decay": self.args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
self.optimizer = OSS( | |
params=optimizer_grouped_parameters, | |
optim=optimizer_cls, | |
**optimizer_kwargs, | |
) | |
else: | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
if optimizer_cls.__name__ == "Adam8bit": | |
import bitsandbytes | |
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
skipped = 0 | |
for module in opt_model.modules(): | |
if isinstance(module, nn.Embedding): | |
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
logger.info(f"skipped {module}: {skipped/2**20}M params") | |
manager.register_module_override(module, "weight", {"optim_bits": 32}) | |
logger.debug(f"bitsandbytes: will optimize {module} in fp32") | |
logger.info(f"skipped: {skipped/2**20}M params") | |
if is_sagemaker_mp_enabled(): | |
self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
return self.optimizer | |
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | |
""" | |
Returns the optimizer class and optimizer parameters based on the training arguments. | |
Args: | |
args (`transformers.training_args.TrainingArguments`): | |
The training arguments for the training session. | |
""" | |
# parse args.optim_args | |
optim_args = {} | |
if args.optim_args: | |
for mapping in args.optim_args.replace(" ", "").split(","): | |
key, value = mapping.split("=") | |
optim_args[key] = value | |
optimizer_kwargs = {"lr": args.learning_rate} | |
adam_kwargs = { | |
"betas": (args.adam_beta1, args.adam_beta2), | |
"eps": args.adam_epsilon, | |
} | |
if args.optim == OptimizerNames.ADAFACTOR: | |
optimizer_cls = Adafactor | |
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) | |
elif args.optim == OptimizerNames.ADAMW_HF: | |
from transformers.optimization import AdamW | |
optimizer_cls = AdamW | |
optimizer_kwargs.update(adam_kwargs) | |
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: | |
from torch.optim import AdamW | |
optimizer_cls = AdamW | |
optimizer_kwargs.update(adam_kwargs) | |
if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: | |
optimizer_kwargs.update({"fused": True}) | |
elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: | |
try: | |
from torch_xla.amp.syncfree import AdamW | |
optimizer_cls = AdamW | |
optimizer_kwargs.update(adam_kwargs) | |
except ImportError: | |
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") | |
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: | |
try: | |
from apex.optimizers import FusedAdam | |
optimizer_cls = FusedAdam | |
optimizer_kwargs.update(adam_kwargs) | |
except ImportError: | |
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") | |
elif args.optim in [ | |
OptimizerNames.ADAMW_BNB, | |
OptimizerNames.ADAMW_8BIT, | |
OptimizerNames.PAGED_ADAMW, | |
OptimizerNames.PAGED_ADAMW_8BIT, | |
OptimizerNames.LION, | |
OptimizerNames.LION_8BIT, | |
OptimizerNames.PAGED_LION, | |
OptimizerNames.PAGED_LION_8BIT, | |
]: | |
try: | |
from bitsandbytes.optim import AdamW, Lion | |
is_paged = False | |
optim_bits = 32 | |
optimizer_cls = None | |
additional_optim_kwargs = adam_kwargs | |
if "paged" in args.optim: | |
is_paged = True | |
if "8bit" in args.optim: | |
optim_bits = 8 | |
if "adam" in args.optim: | |
optimizer_cls = AdamW | |
elif "lion" in args.optim: | |
optimizer_cls = Lion | |
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} | |
bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} | |
optimizer_kwargs.update(additional_optim_kwargs) | |
optimizer_kwargs.update(bnb_kwargs) | |
except ImportError: | |
raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") | |
elif args.optim == OptimizerNames.ADAMW_BNB: | |
try: | |
from bitsandbytes.optim import Adam8bit | |
optimizer_cls = Adam8bit | |
optimizer_kwargs.update(adam_kwargs) | |
except ImportError: | |
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") | |
elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: | |
try: | |
from torchdistx.optimizers import AnyPrecisionAdamW | |
optimizer_cls = AnyPrecisionAdamW | |
optimizer_kwargs.update(adam_kwargs) | |
# TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. | |
optimizer_kwargs.update( | |
{ | |
"use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), | |
"momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), | |
"variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), | |
"compensation_buffer_dtype": getattr( | |
torch, optim_args.get("compensation_buffer_dtype", "bfloat16") | |
), | |
} | |
) | |
except ImportError: | |
raise ValueError("Please install https://github.com/pytorch/torchdistx") | |
elif args.optim == OptimizerNames.SGD: | |
optimizer_cls = torch.optim.SGD | |
elif args.optim == OptimizerNames.ADAGRAD: | |
optimizer_cls = torch.optim.Adagrad | |
else: | |
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") | |
return optimizer_cls, optimizer_kwargs | |
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): | |
""" | |
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or | |
passed as an argument. | |
Args: | |
num_training_steps (int): The number of training steps to do. | |
""" | |
if self.lr_scheduler is None: | |
self.lr_scheduler = get_scheduler( | |
self.args.lr_scheduler_type, | |
optimizer=self.optimizer if optimizer is None else optimizer, | |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), | |
num_training_steps=num_training_steps, | |
) | |
return self.lr_scheduler | |
def num_examples(self, dataloader: DataLoader) -> int: | |
""" | |
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When | |
dataloader.dataset does not exist or has no length, estimates as best it can | |
""" | |
try: | |
dataset = dataloader.dataset | |
# Special case for IterableDatasetShard, we need to dig deeper | |
if isinstance(dataset, IterableDatasetShard): | |
return len(dataloader.dataset.dataset) | |
return len(dataloader.dataset) | |
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader | |
return len(dataloader) * self.args.per_device_train_batch_size | |
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): | |
"""HP search setup code""" | |
self._trial = trial | |
if self.hp_search_backend is None or trial is None: | |
return | |
if self.hp_search_backend == HPSearchBackend.OPTUNA: | |
params = self.hp_space(trial) | |
elif self.hp_search_backend == HPSearchBackend.RAY: | |
params = trial | |
params.pop("wandb", None) | |
elif self.hp_search_backend == HPSearchBackend.SIGOPT: | |
params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} | |
elif self.hp_search_backend == HPSearchBackend.WANDB: | |
params = trial | |
for key, value in params.items(): | |
if not hasattr(self.args, key): | |
logger.warning( | |
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" | |
" `TrainingArguments`." | |
) | |
continue | |
old_attr = getattr(self.args, key, None) | |
# Casting value to the proper type | |
if old_attr is not None: | |
value = type(old_attr)(value) | |
setattr(self.args, key, value) | |
if self.hp_search_backend == HPSearchBackend.OPTUNA: | |
logger.info(f"Trial: {trial.params}") | |
if self.hp_search_backend == HPSearchBackend.SIGOPT: | |
logger.info(f"SigOpt Assignments: {trial.assignments}") | |
if self.hp_search_backend == HPSearchBackend.WANDB: | |
logger.info(f"W&B Sweep parameters: {trial}") | |
if self.is_deepspeed_enabled: | |
if self.args.deepspeed is None: | |
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") | |
# Rebuild the deepspeed config to reflect the updated training parameters | |
from accelerate.utils import DeepSpeedPlugin | |
from transformers.deepspeed import HfTrainerDeepSpeedConfig | |
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) | |
self.args.hf_deepspeed_config.trainer_config_process(self.args) | |
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) | |
self.create_accelerator_and_postprocess() | |
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): | |
if self.hp_search_backend is None or trial is None: | |
return | |
self.objective = self.compute_objective(metrics.copy()) | |
if self.hp_search_backend == HPSearchBackend.OPTUNA: | |
import optuna | |
trial.report(self.objective, step) | |
if trial.should_prune(): | |
self.callback_handler.on_train_end(self.args, self.state, self.control) | |
raise optuna.TrialPruned() | |
elif self.hp_search_backend == HPSearchBackend.RAY: | |
from ray import tune | |
if self.control.should_save: | |
self._tune_save_checkpoint() | |
tune.report(objective=self.objective, **metrics) | |
def _tune_save_checkpoint(self): | |
from ray import tune | |
if not self.use_tune_checkpoints: | |
return | |
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: | |
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") | |
self.save_model(output_dir, _internal_call=True) | |
if self.args.should_save: | |
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | |
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | |
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | |
def call_model_init(self, trial=None): | |
model_init_argcount = number_of_arguments(self.model_init) | |
if model_init_argcount == 0: | |
model = self.model_init() | |
elif model_init_argcount == 1: | |
model = self.model_init(trial) | |
else: | |
raise RuntimeError("model_init should have 0 or 1 argument.") | |
if model is None: | |
raise RuntimeError("model_init should not return None.") | |
return model | |
def torch_jit_model_eval(self, model, dataloader, training=False): | |
if not training: | |
if dataloader is None: | |
logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") | |
return model | |
example_batch = next(iter(dataloader)) | |
example_batch = self._prepare_inputs(example_batch) | |
try: | |
jit_model = model.eval() | |
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): | |
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): | |
if isinstance(example_batch, dict): | |
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) | |
else: | |
jit_model = torch.jit.trace( | |
jit_model, | |
example_kwarg_inputs={key: example_batch[key] for key in example_batch}, | |
strict=False, | |
) | |
else: | |
jit_inputs = [] | |
for key in example_batch: | |
example_tensor = torch.ones_like(example_batch[key]) | |
jit_inputs.append(example_tensor) | |
jit_inputs = tuple(jit_inputs) | |
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) | |
jit_model = torch.jit.freeze(jit_model) | |
with torch.no_grad(): | |
jit_model(**example_batch) | |
jit_model(**example_batch) | |
model = jit_model | |
self.use_cpu_amp = False | |
self.use_cuda_amp = False | |
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: | |
logger.warning(f"failed to use PyTorch jit mode due to: {e}.") | |
return model | |
def ipex_optimize_model(self, model, training=False, dtype=torch.float32): | |
if not is_ipex_available(): | |
raise ImportError( | |
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" | |
" to https://github.com/intel/intel-extension-for-pytorch." | |
) | |
import intel_extension_for_pytorch as ipex | |
if not training: | |
model.eval() | |
dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype | |
# conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings | |
model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) | |
else: | |
if not model.training: | |
model.train() | |
model, self.optimizer = ipex.optimize( | |
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" | |
) | |
return model | |
def _wrap_model(self, model, training=True, dataloader=None): | |
if self.args.use_ipex: | |
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 | |
model = self.ipex_optimize_model(model, training, dtype=dtype) | |
if is_sagemaker_mp_enabled(): | |
# Wrapping the base model twice in a DistributedModel will raise an error. | |
if isinstance(self.model_wrapped, smp.model.DistributedModel): | |
return self.model_wrapped | |
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) | |
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again | |
if unwrap_model(model) is not model: | |
return model | |
# Mixed precision training with apex (torch < 1.6) | |
if self.use_apex and training: | |
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) | |
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP | |
if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): | |
model = nn.DataParallel(model) | |
if self.args.jit_mode_eval: | |
start_time = time.time() | |
model = self.torch_jit_model_eval(model, dataloader, training) | |
self.jit_compilation_time = round(time.time() - start_time, 4) | |
# Note: in torch.distributed mode, there's no point in wrapping the model | |
# inside a DistributedDataParallel as we'll be under `no_grad` anyways. | |
if not training: | |
return model | |
# Distributed training (should be after apex fp16 initialization) | |
if self.sharded_ddp is not None: | |
# Sharded DDP! | |
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
model = ShardedDDP(model, self.optimizer) | |
else: | |
mixed_precision = self.args.fp16 or self.args.bf16 | |
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp | |
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 | |
# XXX: Breaking the self.model convention but I see no way around it for now. | |
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: | |
model = auto_wrap(model) | |
self.model = model = FullyShardedDDP( | |
model, | |
mixed_precision=mixed_precision, | |
reshard_after_forward=zero_3, | |
cpu_offload=cpu_offload, | |
).to(self.args.device) | |
# Distributed training using PyTorch FSDP | |
elif self.fsdp is not None and self.args.fsdp_config["xla"]: | |
try: | |
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP | |
from torch_xla.distributed.fsdp import checkpoint_module | |
from torch_xla.distributed.fsdp.wrap import ( | |
size_based_auto_wrap_policy, | |
transformer_auto_wrap_policy, | |
) | |
except ImportError: | |
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") | |
auto_wrap_policy = None | |
auto_wrapper_callable = None | |
if self.args.fsdp_config["fsdp_min_num_params"] > 0: | |
auto_wrap_policy = functools.partial( | |
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] | |
) | |
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | |
transformer_cls_to_wrap = set() | |
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: | |
transformer_cls = get_module_class_from_name(model, layer_class) | |
if transformer_cls is None: | |
raise Exception("Could not find the transformer layer class to wrap in the model.") | |
else: | |
transformer_cls_to_wrap.add(transformer_cls) | |
auto_wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, | |
# Transformer layer class to wrap | |
transformer_layer_cls=transformer_cls_to_wrap, | |
) | |
fsdp_kwargs = self.args.xla_fsdp_config | |
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: | |
# Apply gradient checkpointing to auto-wrapped sub-modules if specified | |
def auto_wrapper_callable(m, *args, **kwargs): | |
return FSDP(checkpoint_module(m), *args, **kwargs) | |
# Wrap the base model with an outer FSDP wrapper | |
self.model = model = FSDP( | |
model, | |
auto_wrap_policy=auto_wrap_policy, | |
auto_wrapper_callable=auto_wrapper_callable, | |
**fsdp_kwargs, | |
) | |
# Patch `xm.optimizer_step` should not reduce gradients in this case, | |
# as FSDP does not need gradient reduction over sharded parameters. | |
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): | |
loss = optimizer.step(**optimizer_args) | |
if barrier: | |
xm.mark_step() | |
return loss | |
xm.optimizer_step = patched_optimizer_step | |
elif is_sagemaker_dp_enabled(): | |
model = nn.parallel.DistributedDataParallel( | |
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] | |
) | |
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
if is_torch_neuroncore_available(): | |
return model | |
kwargs = {} | |
if self.args.ddp_find_unused_parameters is not None: | |
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters | |
elif isinstance(model, PreTrainedModel): | |
# find_unused_parameters breaks checkpointing as per | |
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 | |
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing | |
else: | |
kwargs["find_unused_parameters"] = True | |
if self.args.ddp_bucket_cap_mb is not None: | |
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb | |
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) | |
return model | |
def train( | |
self, | |
resume_from_checkpoint: Optional[Union[str, bool]] = None, | |
trial: Union["optuna.Trial", Dict[str, Any]] = None, | |
ignore_keys_for_eval: Optional[List[str]] = None, | |
**kwargs, | |
): | |
""" | |
Main training entry point. | |
Args: | |
resume_from_checkpoint (`str` or `bool`, *optional*): | |
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a | |
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance | |
of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. | |
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
The trial run or the hyperparameter dictionary for hyperparameter search. | |
ignore_keys_for_eval (`List[str]`, *optional*) | |
A list of keys in the output of your model (if it is a dictionary) that should be ignored when | |
gathering predictions for evaluation during the training. | |
kwargs: | |
Additional keyword arguments used to hide deprecated arguments | |
""" | |
if resume_from_checkpoint is False: | |
resume_from_checkpoint = None | |
# memory metrics - must set up as early as possible | |
self._memory_tracker.start() | |
args = self.args | |
self.is_in_train = True | |
# do_train is not a reliable argument, as it might not be set and .train() still called, so | |
# the following is a workaround: | |
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: | |
self._move_model_to_device(self.model, args.device) | |
if "model_path" in kwargs: | |
resume_from_checkpoint = kwargs.pop("model_path") | |
warnings.warn( | |
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " | |
"instead.", | |
FutureWarning, | |
) | |
if len(kwargs) > 0: | |
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") | |
# This might change the seed so needs to run first. | |
self._hp_search_setup(trial) | |
self._train_batch_size = self.args.train_batch_size | |
# Model re-init | |
model_reloaded = False | |
if self.model_init is not None: | |
# Seed must be set before instantiating the model when using model_init. | |
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) | |
self.model = self.call_model_init(trial) | |
model_reloaded = True | |
# Reinitializes optimizer and scheduler | |
self.optimizer, self.lr_scheduler = None, None | |
# Load potential model checkpoint | |
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: | |
resume_from_checkpoint = get_last_checkpoint(args.output_dir) | |
if resume_from_checkpoint is None: | |
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") | |
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: | |
self._load_from_checkpoint(resume_from_checkpoint) | |
# If model was re-initialized, put it on the right device and update self.model_wrapped | |
if model_reloaded: | |
if self.place_model_on_device: | |
self._move_model_to_device(self.model, args.device) | |
self.model_wrapped = self.model | |
inner_training_loop = find_executable_batch_size( | |
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size | |
) | |
return inner_training_loop( | |
args=args, | |
resume_from_checkpoint=resume_from_checkpoint, | |
trial=trial, | |
ignore_keys_for_eval=ignore_keys_for_eval, | |
) | |
def _inner_training_loop( | |
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None | |
): | |
self.accelerator.free_memory() | |
self._train_batch_size = batch_size | |
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") | |
# Data loader and number of training steps | |
train_dataloader = self.get_train_dataloader() | |
# Setting up training control variables: | |
# number of training epochs: num_train_epochs | |
# number of training steps per epoch: num_update_steps_per_epoch | |
# total number of training steps to execute: max_steps | |
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size | |
len_dataloader = None | |
if has_length(train_dataloader): | |
len_dataloader = len(train_dataloader) | |
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps | |
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) | |
num_examples = self.num_examples(train_dataloader) | |
if args.max_steps > 0: | |
max_steps = args.max_steps | |
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( | |
args.max_steps % num_update_steps_per_epoch > 0 | |
) | |
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's | |
# the best we can do. | |
num_train_samples = args.max_steps * total_train_batch_size | |
else: | |
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) | |
num_train_epochs = math.ceil(args.num_train_epochs) | |
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs | |
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size | |
max_steps = args.max_steps | |
# Setting a very large number of epochs so we go as many times as necessary over the iterator. | |
num_train_epochs = sys.maxsize | |
num_update_steps_per_epoch = max_steps | |
num_examples = total_train_batch_size * args.max_steps | |
num_train_samples = args.max_steps * total_train_batch_size | |
else: | |
raise ValueError( | |
"args.max_steps must be set to a positive value if dataloader does not have a length, was" | |
f" {args.max_steps}" | |
) | |
# Compute absolute values for logging, eval, and save if given as ratio | |
if args.logging_steps and args.logging_steps < 1: | |
args.logging_steps = math.ceil(max_steps * args.logging_steps) | |
if args.eval_steps and args.eval_steps < 1: | |
args.eval_steps = math.ceil(max_steps * args.eval_steps) | |
if args.save_steps and args.save_steps < 1: | |
args.save_steps = math.ceil(max_steps * args.save_steps) | |
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: | |
if self.args.n_gpu > 1: | |
# nn.DataParallel(model) replicates the model, creating new variables and module | |
# references registered here no longer work on other gpus, breaking the module | |
raise ValueError( | |
"Currently --debug underflow_overflow is not supported under DP. Please use DDP" | |
" (torch.distributed.launch)." | |
) | |
else: | |
debug_overflow = DebugUnderflowOverflow(self.model) # noqa | |
delay_optimizer_creation = ( | |
self.sharded_ddp is not None | |
and self.sharded_ddp != ShardedDDPOption.SIMPLE | |
or is_sagemaker_mp_enabled() | |
or self.fsdp is not None | |
) | |
if self.is_deepspeed_enabled: | |
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) | |
if not delay_optimizer_creation: | |
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | |
self.state = TrainerState() | |
self.state.is_hyper_param_search = trial is not None | |
# Activate gradient checkpointing if needed | |
if args.gradient_checkpointing: | |
self.model.gradient_checkpointing_enable() | |
model = self._wrap_model(self.model_wrapped) | |
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: | |
self._load_from_checkpoint(resume_from_checkpoint, model) | |
# as the model is wrapped, don't use `accelerator.prepare` | |
# this is for unhandled cases such as | |
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX | |
use_accelerator_prepare = True if model is self.model else False | |
if delay_optimizer_creation: | |
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | |
# prepare using `accelerator` prepare | |
if use_accelerator_prepare: | |
if hasattr(self.lr_scheduler, "step"): | |
if self.use_apex: | |
model = self.accelerator.prepare(self.model) | |
else: | |
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) | |
else: | |
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. | |
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( | |
self.model, self.optimizer, self.lr_scheduler | |
) | |
if self.is_fsdp_enabled: | |
self.model = model | |
# for the rest of this function `model` is the outside model, whether it was wrapped or not | |
if model is not self.model: | |
self.model_wrapped = model | |
# backward compatibility | |
if self.is_deepspeed_enabled: | |
self.deepspeed = self.model_wrapped | |
# deepspeed ckpt loading | |
if resume_from_checkpoint is not None and self.is_deepspeed_enabled: | |
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) | |
# Check if saved optimizer or scheduler states exist | |
self._load_optimizer_and_scheduler(resume_from_checkpoint) | |
# important: at this point: | |
# self.model is the Transformers Model | |
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. | |
# Train! | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {num_examples:,}") | |
logger.info(f" Num Epochs = {num_train_epochs:,}") | |
logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") | |
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {max_steps:,}") | |
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") | |
self.state.epoch = 0 | |
start_time = time.time() | |
epochs_trained = 0 | |
steps_trained_in_current_epoch = 0 | |
steps_trained_progress_bar = None | |
# Check if continuing training from a checkpoint | |
if resume_from_checkpoint is not None and os.path.isfile( | |
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) | |
): | |
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) | |
epochs_trained = self.state.global_step // num_update_steps_per_epoch | |
if not args.ignore_data_skip: | |
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) | |
steps_trained_in_current_epoch *= args.gradient_accumulation_steps | |
else: | |
steps_trained_in_current_epoch = 0 | |
logger.info(" Continuing training from checkpoint, will skip to saved global_step") | |
logger.info(f" Continuing training from epoch {epochs_trained}") | |
logger.info(f" Continuing training from global step {self.state.global_step}") | |
if not args.ignore_data_skip: | |
if skip_first_batches is None: | |
logger.info( | |
f" Will skip the first {epochs_trained} epochs then the first" | |
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," | |
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can" | |
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the" | |
" training on data already seen by your model." | |
) | |
else: | |
logger.info( | |
f" Will skip the first {epochs_trained} epochs then the first" | |
f" {steps_trained_in_current_epoch} batches in the first epoch." | |
) | |
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: | |
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) | |
steps_trained_progress_bar.set_description("Skipping the first batches") | |
# Update the references | |
self.callback_handler.model = self.model | |
self.callback_handler.optimizer = self.optimizer | |
self.callback_handler.lr_scheduler = self.lr_scheduler | |
self.callback_handler.train_dataloader = train_dataloader | |
if self.hp_name is not None and self._trial is not None: | |
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial | |
# parameter to Train when using DDP. | |
self.state.trial_name = self.hp_name(self._trial) | |
if trial is not None: | |
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial | |
self.state.trial_params = hp_params(assignments) | |
else: | |
self.state.trial_params = None | |
# This should be the same if the state has been saved but in case the training arguments changed, it's safer | |
# to set this after the load. | |
self.state.max_steps = max_steps | |
self.state.num_train_epochs = num_train_epochs | |
self.state.is_local_process_zero = self.is_local_process_zero() | |
self.state.is_world_process_zero = self.is_world_process_zero() | |
# tr_loss is a tensor to avoid synchronization of TPUs through .item() | |
tr_loss = torch.tensor(0.0).to(args.device) | |
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses | |
self._total_loss_scalar = 0.0 | |
self._globalstep_last_logged = self.state.global_step | |
model.zero_grad() | |
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) | |
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. | |
if not args.ignore_data_skip: | |
for epoch in range(epochs_trained): | |
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( | |
train_dataloader.sampler, RandomSampler | |
) | |
if is_torch_less_than_1_11 or not is_random_sampler: | |
# We just need to begin an iteration to create the randomization of the sampler. | |
# That was before PyTorch 1.11 however... | |
for _ in train_dataloader: | |
break | |
else: | |
# Otherwise we need to call the whooooole sampler cause there is some random operation added | |
# AT THE VERY END! | |
_ = list(train_dataloader.sampler) | |
total_batched_samples = 0 | |
for epoch in range(epochs_trained, num_train_epochs): | |
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): | |
train_dataloader.sampler.set_epoch(epoch) | |
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): | |
train_dataloader.dataset.set_epoch(epoch) | |
if is_torch_tpu_available(): | |
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) | |
epoch_iterator = parallel_loader | |
else: | |
epoch_iterator = train_dataloader | |
# Reset the past mems state at the beginning of each epoch if necessary. | |
if args.past_index >= 0: | |
self._past = None | |
steps_in_epoch = ( | |
len(epoch_iterator) | |
if len_dataloader is not None | |
else args.max_steps * args.gradient_accumulation_steps | |
) | |
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) | |
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: | |
self._load_rng_state(resume_from_checkpoint) | |
rng_to_sync = False | |
steps_skipped = 0 | |
if skip_first_batches is not None and steps_trained_in_current_epoch > 0: | |
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) | |
steps_skipped = steps_trained_in_current_epoch | |
steps_trained_in_current_epoch = 0 | |
rng_to_sync = True | |
step = -1 | |
for step, inputs in enumerate(epoch_iterator): | |
total_batched_samples += 1 | |
if rng_to_sync: | |
self._load_rng_state(resume_from_checkpoint) | |
rng_to_sync = False | |
# Skip past any already trained steps if resuming training | |
if steps_trained_in_current_epoch > 0: | |
steps_trained_in_current_epoch -= 1 | |
if steps_trained_progress_bar is not None: | |
steps_trained_progress_bar.update(1) | |
if steps_trained_in_current_epoch == 0: | |
self._load_rng_state(resume_from_checkpoint) | |
continue | |
elif steps_trained_progress_bar is not None: | |
steps_trained_progress_bar.close() | |
steps_trained_progress_bar = None | |
if step % args.gradient_accumulation_steps == 0: | |
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | |
with self.accelerator.accumulate(model): | |
tr_loss_step = self.training_step(model, inputs) | |
if ( | |
args.logging_nan_inf_filter | |
and not is_torch_tpu_available() | |
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) | |
): | |
# if loss is nan or inf simply add the average of previous logged losses | |
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) | |
else: | |
tr_loss += tr_loss_step | |
self.current_flos += float(self.floating_point_ops(inputs)) | |
# should this be under the accumulate context manager? | |
# the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered | |
# in accelerate | |
if total_batched_samples % args.gradient_accumulation_steps == 0 or ( | |
# last step in epoch but step is always smaller than gradient_accumulation_steps | |
steps_in_epoch <= args.gradient_accumulation_steps | |
and (step + 1) == steps_in_epoch | |
): | |
# Gradient clipping | |
if args.max_grad_norm is not None and args.max_grad_norm > 0: | |
# deepspeed does its own clipping | |
if self.do_grad_scaling: | |
# Reduce gradients first for XLA | |
if is_torch_tpu_available(): | |
gradients = xm._fetch_gradients(self.optimizer) | |
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) | |
# AMP: gradients need unscaling | |
self.scaler.unscale_(self.optimizer) | |
if is_sagemaker_mp_enabled() and args.fp16: | |
self.optimizer.clip_master_grads(args.max_grad_norm) | |
elif hasattr(self.optimizer, "clip_grad_norm"): | |
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping | |
self.optimizer.clip_grad_norm(args.max_grad_norm) | |
elif hasattr(model, "clip_grad_norm_"): | |
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping | |
model.clip_grad_norm_(args.max_grad_norm) | |
elif self.use_apex: | |
# Revert to normal clipping otherwise, handling Apex or full precision | |
nn.utils.clip_grad_norm_( | |
amp.master_params(self.optimizer), | |
args.max_grad_norm, | |
) | |
else: | |
self.accelerator.clip_grad_norm_( | |
model.parameters(), | |
args.max_grad_norm, | |
) | |
# Optimizer step | |
optimizer_was_run = True | |
if is_torch_tpu_available(): | |
if self.do_grad_scaling: | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
xm.optimizer_step(self.optimizer) | |
elif self.do_grad_scaling: | |
scale_before = self.scaler.get_scale() | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
scale_after = self.scaler.get_scale() | |
optimizer_was_run = scale_before <= scale_after | |
else: | |
self.optimizer.step() | |
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped | |
if optimizer_was_run: | |
# Delay optimizer scheduling until metrics are generated | |
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): | |
self.lr_scheduler.step() | |
model.zero_grad() | |
self.state.global_step += 1 | |
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch | |
self.control = self.callback_handler.on_step_end(args, self.state, self.control) | |
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) | |
else: | |
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) | |
if self.control.should_epoch_stop or self.control.should_training_stop: | |
break | |
if step < 0: | |
logger.warning( | |
"There seems to be not a single sample in your epoch_iterator, stopping training at step" | |
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" | |
f" num_steps ({max_steps}) higher than the number of available samples." | |
) | |
self.control.should_training_stop = True | |
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) | |
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) | |
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: | |
if is_torch_tpu_available(): | |
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) | |
xm.master_print(met.metrics_report()) | |
else: | |
logger.warning( | |
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " | |
"configured. Check your training configuration if this is unexpected." | |
) | |
if self.control.should_training_stop: | |
break | |
if args.past_index and hasattr(self, "_past"): | |
# Clean the state at the end of training | |
delattr(self, "_past") | |
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") | |
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: | |
# Wait for everyone to get here so we are sur the model has been saved by process 0. | |
if is_torch_tpu_available(): | |
xm.rendezvous("load_best_model_at_end") | |
elif args.parallel_mode == ParallelMode.DISTRIBUTED: | |
dist.barrier() | |
elif is_sagemaker_mp_enabled(): | |
smp.barrier() | |
self._load_best_model() | |
# add remaining tr_loss | |
self._total_loss_scalar += tr_loss.item() | |
train_loss = self._total_loss_scalar / self.state.global_step | |
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) | |
self.store_flos() | |
metrics["total_flos"] = self.state.total_flos | |
metrics["train_loss"] = train_loss | |
self.is_in_train = False | |
self._memory_tracker.stop_and_update_metrics(metrics) | |
self.log(metrics) | |
run_dir = self._get_output_dir(trial) | |
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) | |
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. | |
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: | |
for checkpoint in checkpoints_sorted: | |
if checkpoint != self.state.best_model_checkpoint: | |
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | |
shutil.rmtree(checkpoint) | |
self.control = self.callback_handler.on_train_end(args, self.state, self.control) | |
return TrainOutput(self.state.global_step, train_loss, metrics) | |
def _get_output_dir(self, trial): | |
if self.hp_search_backend is not None and trial is not None: | |
if self.hp_search_backend == HPSearchBackend.OPTUNA: | |
run_id = trial.number | |
elif self.hp_search_backend == HPSearchBackend.RAY: | |
from ray import tune | |
run_id = tune.get_trial_id() | |
elif self.hp_search_backend == HPSearchBackend.SIGOPT: | |
run_id = trial.id | |
elif self.hp_search_backend == HPSearchBackend.WANDB: | |
import wandb | |
run_id = wandb.run.id | |
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" | |
run_dir = os.path.join(self.args.output_dir, run_name) | |
else: | |
run_dir = self.args.output_dir | |
return run_dir | |
def _load_from_checkpoint(self, resume_from_checkpoint, model=None): | |
if model is None: | |
model = self.model | |
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) | |
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) | |
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) | |
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) | |
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) | |
if not any( | |
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file] | |
): | |
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") | |
logger.info(f"Loading model from {resume_from_checkpoint}.") | |
if os.path.isfile(config_file): | |
config = PretrainedConfig.from_json_file(config_file) | |
checkpoint_version = config.transformers_version | |
if checkpoint_version is not None and checkpoint_version != __version__: | |
logger.warning( | |
f"You are resuming training from a checkpoint trained with {checkpoint_version} of " | |
f"Transformers but your current version is {__version__}. This is not recommended and could " | |
"yield to errors or unwanted behaviors." | |
) | |
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): | |
# If the model is on the GPU, it still works! | |
if is_sagemaker_mp_enabled(): | |
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): | |
# If the 'user_content.pt' file exists, load with the new smp api. | |
# Checkpoint must have been saved with the new smp api. | |
smp.resume_from_checkpoint( | |
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False | |
) | |
else: | |
# If the 'user_content.pt' file does NOT exist, load with the old smp api. | |
# Checkpoint must have been saved with the old smp api. | |
if hasattr(self.args, "fp16") and self.args.fp16 is True: | |
logger.warning( | |
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." | |
) | |
state_dict = torch.load(weights_file, map_location="cpu") | |
# Required for smp to not auto-translate state_dict from hf to smp (is already smp). | |
state_dict["_smp_is_partial"] = False | |
load_result = model.load_state_dict(state_dict, strict=True) | |
# release memory | |
del state_dict | |
elif self.is_fsdp_enabled: | |
self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint) | |
else: | |
# We load the model state dict on the CPU to avoid an OOM error. | |
if self.args.save_safetensors and os.path.isfile(safe_weights_file): | |
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") | |
else: | |
state_dict = torch.load(weights_file, map_location="cpu") | |
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 | |
# which takes *args instead of **kwargs | |
load_result = model.load_state_dict(state_dict, False) | |
# release memory | |
del state_dict | |
self._issue_warnings_after_load(load_result) | |
else: | |
# We load the sharded checkpoint | |
load_result = load_sharded_checkpoint( | |
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors | |
) | |
if not is_sagemaker_mp_enabled(): | |
self._issue_warnings_after_load(load_result) | |
def _load_best_model(self): | |
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") | |
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) | |
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) | |
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) | |
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) | |
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
if ( | |
os.path.exists(best_model_path) | |
or os.path.exists(best_safe_model_path) | |
or os.path.exists(best_adapter_model_path) | |
or os.path.exists(best_safe_adapter_model_path) | |
): | |
if self.is_deepspeed_enabled: | |
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) | |
else: | |
has_been_loaded = True | |
if is_sagemaker_mp_enabled(): | |
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): | |
# If the 'user_content.pt' file exists, load with the new smp api. | |
# Checkpoint must have been saved with the new smp api. | |
smp.resume_from_checkpoint( | |
path=self.state.best_model_checkpoint, | |
tag=WEIGHTS_NAME, | |
partial=False, | |
load_optimizer=False, | |
) | |
else: | |
# If the 'user_content.pt' file does NOT exist, load with the old smp api. | |
# Checkpoint must have been saved with the old smp api. | |
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): | |
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") | |
else: | |
state_dict = torch.load(best_model_path, map_location="cpu") | |
state_dict["_smp_is_partial"] = False | |
load_result = model.load_state_dict(state_dict, strict=True) | |
elif self.is_fsdp_enabled: | |
self.accelerator.state.fsdp_plugin.load_model( | |
self.accelerator, model, self.state.best_model_checkpoint | |
) | |
else: | |
if is_peft_available() and isinstance(model, PeftModel): | |
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. | |
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): | |
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): | |
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) | |
# Load_adapter has no return value present, modify it when appropriate. | |
from torch.nn.modules.module import _IncompatibleKeys | |
load_result = _IncompatibleKeys([], []) | |
else: | |
logger.warning( | |
"The intermediate checkpoints of PEFT may not be saved correctly, " | |
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, " | |
"here are some examples https://github.com/huggingface/peft/issues/96" | |
) | |
has_been_loaded = False | |
else: | |
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") | |
has_been_loaded = False | |
else: | |
# We load the model state dict on the CPU to avoid an OOM error. | |
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): | |
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") | |
else: | |
state_dict = torch.load(best_model_path, map_location="cpu") | |
# If the model is on the GPU, it still works! | |
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 | |
# which takes *args instead of **kwargs | |
load_result = model.load_state_dict(state_dict, False) | |
if not is_sagemaker_mp_enabled() and has_been_loaded: | |
self._issue_warnings_after_load(load_result) | |
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): | |
load_result = load_sharded_checkpoint( | |
model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() | |
) | |
if not is_sagemaker_mp_enabled(): | |
self._issue_warnings_after_load(load_result) | |
else: | |
logger.warning( | |
f"Could not locate the best model at {best_model_path}, if you are running a distributed training " | |
"on multiple nodes, you should activate `--save_on_each_node`." | |
) | |
def _issue_warnings_after_load(self, load_result): | |
if len(load_result.missing_keys) != 0: | |
if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( | |
self.model._keys_to_ignore_on_save | |
): | |
self.model.tie_weights() | |
else: | |
logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") | |
if len(load_result.unexpected_keys) != 0: | |
logger.warning( | |
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." | |
) | |
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): | |
if self.control.should_log: | |
if is_torch_tpu_available(): | |
xm.mark_step() | |
logs: Dict[str, float] = {} | |
# all_gather + mean() to get average loss over all processes | |
tr_loss_scalar = self._nested_gather(tr_loss).mean().item() | |
# reset tr_loss to zero | |
tr_loss -= tr_loss | |
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) | |
logs["learning_rate"] = self._get_learning_rate() | |
self._total_loss_scalar += tr_loss_scalar | |
self._globalstep_last_logged = self.state.global_step | |
self.store_flos() | |
self.log(logs) | |
metrics = None | |
if self.control.should_evaluate: | |
if isinstance(self.eval_dataset, dict): | |
metrics = {} | |
for eval_dataset_name, eval_dataset in self.eval_dataset.items(): | |
dataset_metrics = self.evaluate( | |
eval_dataset=eval_dataset, | |
ignore_keys=ignore_keys_for_eval, | |
metric_key_prefix=f"eval_{eval_dataset_name}", | |
) | |
metrics.update(dataset_metrics) | |
else: | |
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) | |
self._report_to_hp_search(trial, self.state.global_step, metrics) | |
# Run delayed LR scheduler now that metrics are populated | |
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): | |
metric_to_check = self.args.metric_for_best_model | |
if not metric_to_check.startswith("eval_"): | |
metric_to_check = f"eval_{metric_to_check}" | |
self.lr_scheduler.step(metrics[metric_to_check]) | |
if self.control.should_save: | |
self._save_checkpoint(model, trial, metrics=metrics) | |
self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |
def _load_rng_state(self, checkpoint): | |
# Load RNG states from `checkpoint` | |
if checkpoint is None: | |
return | |
if self.args.world_size > 1: | |
process_index = self.args.process_index | |
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") | |
if not os.path.isfile(rng_file): | |
logger.info( | |
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " | |
"wasn't launched in a distributed fashion, reproducibility is not guaranteed." | |
) | |
return | |
else: | |
rng_file = os.path.join(checkpoint, "rng_state.pth") | |
if not os.path.isfile(rng_file): | |
logger.info( | |
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed " | |
"fashion, reproducibility is not guaranteed." | |
) | |
return | |
checkpoint_rng_state = torch.load(rng_file) | |
random.setstate(checkpoint_rng_state["python"]) | |
np.random.set_state(checkpoint_rng_state["numpy"]) | |
torch.random.set_rng_state(checkpoint_rng_state["cpu"]) | |
if torch.cuda.is_available(): | |
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) | |
else: | |
try: | |
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) | |
except Exception as e: | |
logger.info( | |
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" | |
"\nThis won't yield the same results as if the training had not been interrupted." | |
) | |
if is_torch_tpu_available(): | |
xm.set_rng_state(checkpoint_rng_state["xla"]) | |
def _save_checkpoint(self, model, trial, metrics=None): | |
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we | |
# want to save except FullyShardedDDP. | |
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" | |
# Save model checkpoint | |
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}" # changed by homeway, 20230711 | |
if self.hp_search_backend is None and trial is None: | |
self.store_flos() | |
run_dir = self._get_output_dir(trial=trial) | |
output_dir = os.path.join(run_dir, checkpoint_folder) | |
self.save_model(output_dir, _internal_call=True) | |
if self.is_deepspeed_enabled: | |
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed | |
# config `stage3_gather_16bit_weights_on_model_save` is True | |
self.model_wrapped.save_checkpoint(output_dir) | |
# Save optimizer and scheduler | |
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | |
self.optimizer.consolidate_state_dict() | |
if self.fsdp: | |
# FSDP has a different interface for saving optimizer states. | |
# Needs to be called on all ranks to gather all states. | |
# full_optim_state_dict will be deprecated after Pytorch 2.2! | |
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) | |
if is_torch_tpu_available(): | |
xm.rendezvous("saving_optimizer_states") | |
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | |
reissue_pt_warnings(caught_warnings) | |
elif is_sagemaker_mp_enabled(): | |
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) | |
smp.barrier() | |
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: | |
smp.save( | |
opt_state_dict, | |
os.path.join(output_dir, OPTIMIZER_NAME), | |
partial=True, | |
v3=smp.state.cfg.shard_optimizer_state, | |
) | |
if self.args.should_save: | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | |
reissue_pt_warnings(caught_warnings) | |
if self.do_grad_scaling: | |
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | |
elif self.args.should_save and not self.is_deepspeed_enabled: | |
# deepspeed.save_checkpoint above saves model/optim/sched | |
if self.fsdp: | |
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) | |
else: | |
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | |
reissue_pt_warnings(caught_warnings) | |
if self.do_grad_scaling: | |
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | |
# Determine the new best metric / best model checkpoint | |
if metrics is not None and self.args.metric_for_best_model is not None: | |
metric_to_check = self.args.metric_for_best_model | |
if not metric_to_check.startswith("eval_"): | |
metric_to_check = f"eval_{metric_to_check}" | |
metric_value = metrics[metric_to_check] | |
operator = np.greater if self.args.greater_is_better else np.less | |
if ( | |
self.state.best_metric is None | |
or self.state.best_model_checkpoint is None | |
or operator(metric_value, self.state.best_metric) | |
): | |
self.state.best_metric = metric_value | |
self.state.best_model_checkpoint = output_dir | |
# Save the Trainer state | |
if self.args.should_save: | |
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | |
# Save RNG state in non-distributed training | |
rng_states = { | |
"python": random.getstate(), | |
"numpy": np.random.get_state(), | |
"cpu": torch.random.get_rng_state(), | |
} | |
if torch.cuda.is_available(): | |
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel) | |
rng_states["cuda"] = torch.cuda.random.get_rng_state_all() | |
else: | |
rng_states["cuda"] = torch.cuda.random.get_rng_state() | |
if is_torch_tpu_available(): | |
rng_states["xla"] = xm.get_rng_state() | |
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may | |
# not yet exist. | |
os.makedirs(output_dir, exist_ok=True) | |
if self.args.world_size <= 1: | |
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) | |
else: | |
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) | |
if self.args.push_to_hub: | |
self._push_from_checkpoint(output_dir) | |
# Maybe delete some older checkpoints. | |
if self.args.should_save: | |
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) | |
def _load_optimizer_and_scheduler(self, checkpoint): | |
"""If optimizer and scheduler states exist, load them.""" | |
if checkpoint is None: | |
return | |
if self.is_deepspeed_enabled: | |
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init | |
return | |
checkpoint_file_exists = ( | |
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") | |
if is_sagemaker_mp_enabled() | |
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) | |
) | |
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): | |
# Load in optimizer and scheduler states | |
if is_torch_tpu_available(): | |
# On TPU we have to take some extra precautions to properly load the states on the right device. | |
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") | |
reissue_pt_warnings(caught_warnings) | |
xm.send_cpu_data_to_device(optimizer_state, self.args.device) | |
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) | |
self.optimizer.load_state_dict(optimizer_state) | |
self.lr_scheduler.load_state_dict(lr_scheduler_state) | |
else: | |
if is_sagemaker_mp_enabled(): | |
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): | |
# Optimizer checkpoint was saved with smp >= 1.10 | |
def opt_load_hook(mod, opt): | |
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | |
else: | |
# Optimizer checkpoint was saved with smp < 1.10 | |
def opt_load_hook(mod, opt): | |
if IS_SAGEMAKER_MP_POST_1_10: | |
opt.load_state_dict( | |
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) | |
) | |
else: | |
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | |
self.model_wrapped.register_post_step_hook(opt_load_hook) | |
else: | |
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. | |
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more | |
# likely to get OOM on CPU (since we load num_gpu times the optimizer state | |
map_location = self.args.device if self.args.world_size > 1 else "cpu" | |
if self.fsdp: | |
full_osd = None | |
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it | |
if self.args.process_index == 0: | |
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) | |
# call scatter_full_optim_state_dict on all ranks | |
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) | |
self.optimizer.load_state_dict(sharded_osd) | |
else: | |
self.optimizer.load_state_dict( | |
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | |
) | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) | |
reissue_pt_warnings(caught_warnings) | |
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): | |
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) | |
def hyperparameter_search( | |
self, | |
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, | |
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, | |
n_trials: int = 20, | |
direction: str = "minimize", | |
backend: Optional[Union["str", HPSearchBackend]] = None, | |
hp_name: Optional[Callable[["optuna.Trial"], str]] = None, | |
**kwargs, | |
) -> BestRun: | |
""" | |
Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined | |
by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, | |
the sum of all metrics otherwise. | |
<Tip warning={true}> | |
To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to | |
reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to | |
subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom | |
optimizer/scheduler. | |
</Tip> | |
Args: | |
hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): | |
A function that defines the hyperparameter search space. Will default to | |
[`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or | |
[`~trainer_utils.default_hp_space_sigopt`] depending on your backend. | |
compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): | |
A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` | |
method. Will default to [`~trainer_utils.default_compute_objective`]. | |
n_trials (`int`, *optional*, defaults to 100): | |
The number of trial runs to test. | |
direction (`str`, *optional*, defaults to `"minimize"`): | |
Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick | |
`"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. | |
backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): | |
The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending | |
on which one is installed. If all are installed, will default to optuna. | |
hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): | |
A function that defines the trial/run name. Will default to None. | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more | |
information see: | |
- the documentation of | |
[optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) | |
- the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) | |
- the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) | |
Returns: | |
[`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in | |
`run_summary` attribute for Ray backend. | |
""" | |
if backend is None: | |
backend = default_hp_search_backend() | |
if backend is None: | |
raise RuntimeError( | |
"At least one of optuna or ray should be installed. " | |
"To install optuna run `pip install optuna`. " | |
"To install ray run `pip install ray[tune]`. " | |
"To install sigopt run `pip install sigopt`." | |
) | |
backend = HPSearchBackend(backend) | |
if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): | |
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") | |
if backend == HPSearchBackend.RAY and not is_ray_tune_available(): | |
raise RuntimeError( | |
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." | |
) | |
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): | |
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") | |
if backend == HPSearchBackend.WANDB and not is_wandb_available(): | |
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") | |
self.hp_search_backend = backend | |
if self.model_init is None: | |
raise RuntimeError( | |
"To use hyperparameter search, you need to pass your model through a model_init function." | |
) | |
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space | |
self.hp_name = hp_name | |
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective | |
backend_dict = { | |
HPSearchBackend.OPTUNA: run_hp_search_optuna, | |
HPSearchBackend.RAY: run_hp_search_ray, | |
HPSearchBackend.SIGOPT: run_hp_search_sigopt, | |
HPSearchBackend.WANDB: run_hp_search_wandb, | |
} | |
best_run = backend_dict[backend](self, n_trials, direction, **kwargs) | |
self.hp_search_backend = None | |
return best_run | |
def log(self, logs: Dict[str, float]) -> None: | |
""" | |
Log `logs` on the various objects watching training. | |
Subclass and override this method to inject custom behavior. | |
Args: | |
logs (`Dict[str, float]`): | |
The values to log. | |
""" | |
if self.state.epoch is not None: | |
logs["epoch"] = round(self.state.epoch, 2) | |
output = {**logs, **{"step": self.state.global_step}} | |
self.state.log_history.append(output) | |
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) | |
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: | |
""" | |
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. | |
""" | |
if isinstance(data, Mapping): | |
return type(data)({k: self._prepare_input(v) for k, v in data.items()}) | |
elif isinstance(data, (tuple, list)): | |
return type(data)(self._prepare_input(v) for v in data) | |
elif isinstance(data, torch.Tensor): | |
kwargs = {"device": self.args.device} | |
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): | |
# NLP models inputs are int/uint and those get adjusted to the right dtype of the | |
# embedding. Other models such as wav2vec2's inputs are already float and thus | |
# may need special handling to match the dtypes of the model | |
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) | |
return data.to(**kwargs) | |
return data | |
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: | |
""" | |
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and | |
handling potential state. | |
""" | |
inputs = self._prepare_input(inputs) | |
if len(inputs) == 0: | |
raise ValueError( | |
"The batch received was empty, your model won't be able to train on it. Double-check that your " | |
f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." | |
) | |
if self.args.past_index >= 0 and self._past is not None: | |
inputs["mems"] = self._past | |
return inputs | |
def compute_loss_context_manager(self): | |
""" | |
A helper wrapper to group together context managers. | |
""" | |
return self.autocast_smart_context_manager() | |
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): | |
""" | |
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired | |
arguments, depending on the situation. | |
""" | |
if self.use_cuda_amp or self.use_cpu_amp: | |
if is_torch_greater_or_equal_than_1_10: | |
ctx_manager = ( | |
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) | |
if self.use_cpu_amp | |
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) | |
) | |
else: | |
ctx_manager = torch.cuda.amp.autocast() | |
else: | |
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() | |
return ctx_manager | |
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | |
""" | |
Perform a training step on a batch of inputs. | |
Subclass and override to inject custom behavior. | |
Args: | |
model (`nn.Module`): | |
The model to train. | |
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
The inputs and targets of the model. | |
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
argument `labels`. Check your model's documentation for all accepted arguments. | |
Return: | |
`torch.Tensor`: The tensor with training loss on this batch. | |
""" | |
model.train() | |
inputs = self._prepare_inputs(inputs) | |
if is_sagemaker_mp_enabled(): | |
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) | |
return loss_mb.reduce_mean().detach().to(self.args.device) | |
with self.compute_loss_context_manager(): | |
loss = self.compute_loss(model, inputs) | |
if self.args.n_gpu > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if self.do_grad_scaling: | |
self.scaler.scale(loss).backward() | |
elif self.use_apex: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
self.accelerator.backward(loss) | |
return loss.detach() / self.args.gradient_accumulation_steps | |
def compute_loss(self, model, inputs, return_outputs=False): | |
""" | |
How the loss is computed by Trainer. By default, all models return the loss in the first element. | |
Subclass and override for custom behavior. | |
""" | |
if self.label_smoother is not None and "labels" in inputs: | |
labels = inputs.pop("labels") | |
else: | |
labels = None | |
outputs = model(**inputs) | |
# Save past state if it exists | |
# TODO: this needs to be fixed and made cleaner later. | |
if self.args.past_index >= 0: | |
self._past = outputs[self.args.past_index] | |
if labels is not None: | |
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | |
loss = self.label_smoother(outputs, labels, shift_labels=True) | |
else: | |
loss = self.label_smoother(outputs, labels) | |
else: | |
if isinstance(outputs, dict) and "loss" not in outputs: | |
raise ValueError( | |
"The model did not return a loss from the inputs, only the following keys: " | |
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." | |
) | |
# We don't use .loss here since the model may return tuples instead of ModelOutput. | |
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] | |
return (loss, outputs) if return_outputs else loss | |
def is_local_process_zero(self) -> bool: | |
""" | |
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several | |
machines) main process. | |
""" | |
return self.args.local_process_index == 0 | |
def is_world_process_zero(self) -> bool: | |
""" | |
Whether or not this process is the global main process (when training in a distributed fashion on several | |
machines, this is only going to be `True` for one process). | |
""" | |
# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global | |
# process index. | |
if is_sagemaker_mp_enabled(): | |
return smp.rank() == 0 | |
else: | |
return self.args.process_index == 0 | |
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): | |
""" | |
Will save the model, so you can reload it using `from_pretrained()`. | |
Will only save from the main process. | |
""" | |
if output_dir is None: | |
output_dir = self.args.output_dir | |
if is_torch_tpu_available(): | |
self._save_tpu(output_dir) | |
elif is_sagemaker_mp_enabled(): | |
# Calling the state_dict needs to be done on the wrapped model and on all processes. | |
os.makedirs(output_dir, exist_ok=True) | |
state_dict = self.model_wrapped.state_dict() | |
if self.args.should_save: | |
self._save(output_dir, state_dict=state_dict) | |
if IS_SAGEMAKER_MP_POST_1_10: | |
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10 | |
Path(os.path.join(output_dir, "user_content.pt")).touch() | |
elif ( | |
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp | |
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp | |
or self.fsdp is not None | |
or self.is_fsdp_enabled | |
): | |
if self.is_fsdp_enabled: | |
os.makedirs(output_dir, exist_ok=True) | |
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir) | |
else: | |
state_dict = self.model.state_dict() | |
if self.args.should_save: | |
self._save(output_dir, state_dict=state_dict) | |
elif self.is_deepspeed_enabled: | |
# this takes care of everything as long as we aren't under zero3 | |
if self.args.should_save: | |
self._save(output_dir) | |
if is_deepspeed_zero3_enabled(): | |
# It's too complicated to try to override different places where the weights dump gets | |
# saved, so since under zero3 the file is bogus, simply delete it. The user should | |
# either user deepspeed checkpoint to resume or to recover full weights use | |
# zero_to_fp32.py stored in the checkpoint. | |
if self.args.should_save: | |
file = os.path.join(output_dir, WEIGHTS_NAME) | |
if os.path.isfile(file): | |
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") | |
os.remove(file) | |
# now save the real model if stage3_gather_16bit_weights_on_model_save=True | |
# if false it will not be saved. | |
# This must be called on all ranks | |
if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME): | |
logger.warning( | |
"deepspeed.save_16bit_model didn't save the model, since" | |
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" | |
" zero_to_fp32.py to recover weights" | |
) | |
self.model_wrapped.save_checkpoint(output_dir) | |
elif self.args.should_save: | |
self._save(output_dir) | |
# Push to the Hub when `save_model` is called by the user. | |
if self.args.push_to_hub and not _internal_call: | |
self.push_to_hub(commit_message="Model save") | |
def _save_tpu(self, output_dir: Optional[str] = None): | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
if xm.is_master_ordinal(): | |
os.makedirs(output_dir, exist_ok=True) | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
# Save a trained model and configuration using `save_pretrained()`. | |
# They can then be reloaded using `from_pretrained()` | |
xm.rendezvous("saving_checkpoint") | |
if not isinstance(self.model, PreTrainedModel): | |
if isinstance(unwrap_model(self.model), PreTrainedModel): | |
unwrap_model(self.model).save_pretrained( | |
output_dir, | |
is_main_process=self.args.should_save, | |
state_dict=self.model.state_dict(), | |
save_function=xm.save, | |
) | |
else: | |
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | |
state_dict = self.model.state_dict() | |
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | |
else: | |
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) | |
if self.tokenizer is not None and self.args.should_save: | |
self.tokenizer.save_pretrained(output_dir) | |
def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
# If we are executing this function, we are the process zero, so we don't check for that. | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) | |
# Save a trained model and configuration using `save_pretrained()`. | |
# They can then be reloaded using `from_pretrained()` | |
if not isinstance(self.model, supported_classes): | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
if isinstance(unwrap_model(self.model), supported_classes): | |
unwrap_model(self.model).save_pretrained( | |
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors | |
) | |
else: | |
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | |
if self.args.save_safetensors: | |
safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) | |
else: | |
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | |
else: | |
self.model.save_pretrained( | |
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors | |
) | |
if self.tokenizer is not None: | |
self.tokenizer.save_pretrained(output_dir) | |
# Good practice: save your training arguments together with the trained model | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
def store_flos(self): | |
# Storing the number of floating-point operations that went into the model | |
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
self.state.total_flos += ( | |
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() | |
) | |
self.current_flos = 0 | |
else: | |
self.state.total_flos += self.current_flos | |
self.current_flos = 0 | |
def _sorted_checkpoints( | |
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False | |
) -> List[str]: | |
ordering_and_checkpoint_path = [] | |
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] | |
for path in glob_checkpoints: | |
if use_mtime: | |
ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) | |
else: | |
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) | |
if regex_match is not None and regex_match.groups() is not None: | |
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) | |
checkpoints_sorted = sorted(ordering_and_checkpoint_path) | |
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] | |
# Make sure we don't delete the best model. | |
if self.state.best_model_checkpoint is not None: | |
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) | |
for i in range(best_model_index, len(checkpoints_sorted) - 2): | |
checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] | |
return checkpoints_sorted | |
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: | |
if self.args.save_total_limit is None or self.args.save_total_limit <= 0: | |
return | |
# Check if we should delete older checkpoint(s) | |
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) | |
if len(checkpoints_sorted) <= self.args.save_total_limit: | |
return | |
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which | |
# we don't do to allow resuming. | |
save_total_limit = self.args.save_total_limit | |
if ( | |
self.state.best_model_checkpoint is not None | |
and self.args.save_total_limit == 1 | |
and checkpoints_sorted[-1] != self.state.best_model_checkpoint | |
): | |
save_total_limit = 2 | |
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) | |
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] | |
for checkpoint in checkpoints_to_be_deleted: | |
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | |
shutil.rmtree(checkpoint, ignore_errors=True) | |
def evaluate( | |
self, | |
eval_dataset: Optional[Dataset] = None, | |
ignore_keys: Optional[List[str]] = None, | |
metric_key_prefix: str = "eval", | |
) -> Dict[str, float]: | |
""" | |
Run evaluation and returns metrics. | |
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent | |
(pass it to the init `compute_metrics` argument). | |
You can also subclass and override this method to inject custom behavior. | |
Args: | |
eval_dataset (`Dataset`, *optional*): | |
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns | |
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` | |
method. | |
ignore_keys (`List[str]`, *optional*): | |
A list of keys in the output of your model (if it is a dictionary) that should be ignored when | |
gathering predictions. | |
metric_key_prefix (`str`, *optional*, defaults to `"eval"`): | |
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named | |
"eval_bleu" if the prefix is "eval" (default) | |
Returns: | |
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The | |
dictionary also contains the epoch number which comes from the training state. | |
""" | |
# memory metrics - must set up as early as possible | |
self._memory_tracker.start() | |
eval_dataloader = self.get_eval_dataloader(eval_dataset) | |
start_time = time.time() | |
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop | |
output = eval_loop( | |
eval_dataloader, | |
description="Evaluation", | |
# No point gathering the predictions if there are no metrics, otherwise we defer to | |
# self.args.prediction_loss_only | |
prediction_loss_only=True if self.compute_metrics is None else None, | |
ignore_keys=ignore_keys, | |
metric_key_prefix=metric_key_prefix, | |
) | |
total_batch_size = self.args.eval_batch_size * self.args.world_size | |
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: | |
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] | |
output.metrics.update( | |
speed_metrics( | |
metric_key_prefix, | |
start_time, | |
num_samples=output.num_samples, | |
num_steps=math.ceil(output.num_samples / total_batch_size), | |
) | |
) | |
self.log(output.metrics) | |
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: | |
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) | |
xm.master_print(met.metrics_report()) | |
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) | |
self._memory_tracker.stop_and_update_metrics(output.metrics) | |
return output.metrics | |
def predict( | |
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" | |
) -> PredictionOutput: | |
""" | |
Run prediction and returns predictions and potential metrics. | |
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method | |
will also return metrics, like in `evaluate()`. | |
Args: | |
test_dataset (`Dataset`): | |
Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the | |
`model.forward()` method are automatically removed. Has to implement the method `__len__` | |
ignore_keys (`List[str]`, *optional*): | |
A list of keys in the output of your model (if it is a dictionary) that should be ignored when | |
gathering predictions. | |
metric_key_prefix (`str`, *optional*, defaults to `"test"`): | |
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named | |
"test_bleu" if the prefix is "test" (default) | |
<Tip> | |
If your predictions or labels have different sequence length (for instance because you're doing dynamic padding | |
in a token classification task) the predictions will be padded (on the right) to allow for concatenation into | |
one array. The padding index is -100. | |
</Tip> | |
Returns: *NamedTuple* A namedtuple with the following keys: | |
- predictions (`np.ndarray`): The predictions on `test_dataset`. | |
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). | |
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained | |
labels). | |
""" | |
# memory metrics - must set up as early as possible | |
self._memory_tracker.start() | |
test_dataloader = self.get_test_dataloader(test_dataset) | |
start_time = time.time() | |
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop | |
output = eval_loop( | |
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix | |
) | |
total_batch_size = self.args.eval_batch_size * self.args.world_size | |
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: | |
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] | |
output.metrics.update( | |
speed_metrics( | |
metric_key_prefix, | |
start_time, | |
num_samples=output.num_samples, | |
num_steps=math.ceil(output.num_samples / total_batch_size), | |
) | |
) | |
self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) | |
self._memory_tracker.stop_and_update_metrics(output.metrics) | |
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) | |
def evaluation_loop( | |
self, | |
dataloader: DataLoader, | |
description: str, | |
prediction_loss_only: Optional[bool] = None, | |
ignore_keys: Optional[List[str]] = None, | |
metric_key_prefix: str = "eval", | |
) -> EvalLoopOutput: | |
""" | |
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. | |
Works both with or without labels. | |
""" | |
args = self.args | |
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only | |
# if eval is called w/o train, handle model prep here | |
if self.is_deepspeed_enabled and self.model_wrapped is self.model: | |
_, _ = deepspeed_init(self, num_training_steps=0, inference=True) | |
model = self._wrap_model(self.model, training=False, dataloader=dataloader) | |
if len(self.accelerator._models) == 0 and model is self.model: | |
model = ( | |
self.accelerator.prepare(model) | |
if self.is_deepspeed_enabled | |
else self.accelerator.prepare_model(model, evaluation_mode=True) | |
) | |
if self.is_fsdp_enabled: | |
self.model = model | |
# for the rest of this function `model` is the outside model, whether it was wrapped or not | |
if model is not self.model: | |
self.model_wrapped = model | |
# backward compatibility | |
if self.is_deepspeed_enabled: | |
self.deepspeed = self.model_wrapped | |
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called | |
# while ``train`` is running, cast it to the right dtype first and then put on device | |
if not self.is_in_train: | |
if args.fp16_full_eval: | |
model = model.to(dtype=torch.float16, device=args.device) | |
elif args.bf16_full_eval: | |
model = model.to(dtype=torch.bfloat16, device=args.device) | |
batch_size = self.args.eval_batch_size | |
logger.info(f"***** Running {description} *****") | |
if has_length(dataloader): | |
logger.info(f" Num examples = {self.num_examples(dataloader)}") | |
else: | |
logger.info(" Num examples: Unknown") | |
logger.info(f" Batch size = {batch_size}") | |
model.eval() | |
self.callback_handler.eval_dataloader = dataloader | |
# Do this before wrapping. | |
eval_dataset = getattr(dataloader, "dataset", None) | |
if is_torch_tpu_available(): | |
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) | |
if args.past_index >= 0: | |
self._past = None | |
# Initialize containers | |
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) | |
losses_host = None | |
preds_host = None | |
labels_host = None | |
inputs_host = None | |
# losses/preds/labels on CPU (final containers) | |
all_losses = None | |
all_preds = None | |
all_labels = None | |
all_inputs = None | |
# Will be useful when we have an iterable dataset so don't know its length. | |
observed_num_examples = 0 | |
# Main evaluation loop | |
for step, inputs in enumerate(dataloader): | |
# Update the observed num examples | |
observed_batch_size = find_batch_size(inputs) | |
if observed_batch_size is not None: | |
observed_num_examples += observed_batch_size | |
# For batch samplers, batch_size is not known by the dataloader in advance. | |
if batch_size is None: | |
batch_size = observed_batch_size | |
# Prediction step | |
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) | |
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None | |
if is_torch_tpu_available(): | |
xm.mark_step() | |
# Update containers on host | |
if loss is not None: | |
losses = self._nested_gather(loss.repeat(batch_size)) | |
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) | |
if labels is not None: | |
labels = self._pad_across_processes(labels) | |
if inputs_decode is not None: | |
inputs_decode = self._pad_across_processes(inputs_decode) | |
inputs_decode = self._nested_gather(inputs_decode) | |
inputs_host = ( | |
inputs_decode | |
if inputs_host is None | |
else nested_concat(inputs_host, inputs_decode, padding_index=-100) | |
) | |
if logits is not None: | |
logits = self._pad_across_processes(logits) | |
if self.preprocess_logits_for_metrics is not None: | |
logits = self.preprocess_logits_for_metrics(logits, labels) | |
logits = self._nested_gather(logits) | |
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) | |
if labels is not None: | |
labels = self._nested_gather(labels) | |
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) | |
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) | |
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. | |
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: | |
if losses_host is not None: | |
losses = nested_numpify(losses_host) | |
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) | |
if preds_host is not None: | |
logits = nested_numpify(preds_host) | |
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) | |
if inputs_host is not None: | |
inputs_decode = nested_numpify(inputs_host) | |
all_inputs = ( | |
inputs_decode | |
if all_inputs is None | |
else nested_concat(all_inputs, inputs_decode, padding_index=-100) | |
) | |
if labels_host is not None: | |
labels = nested_numpify(labels_host) | |
all_labels = ( | |
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) | |
) | |
# Set back to None to begin a new accumulation | |
losses_host, preds_host, inputs_host, labels_host = None, None, None, None | |
if args.past_index and hasattr(self, "_past"): | |
# Clean the state at the end of the evaluation loop | |
delattr(self, "_past") | |
# Gather all remaining tensors and put them back on the CPU | |
if losses_host is not None: | |
losses = nested_numpify(losses_host) | |
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) | |
if preds_host is not None: | |
logits = nested_numpify(preds_host) | |
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) | |
if inputs_host is not None: | |
inputs_decode = nested_numpify(inputs_host) | |
all_inputs = ( | |
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) | |
) | |
if labels_host is not None: | |
labels = nested_numpify(labels_host) | |
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) | |
# Number of samples | |
if has_length(eval_dataset): | |
num_samples = len(eval_dataset) | |
# The instance check is weird and does not actually check for the type, but whether the dataset has the right | |
# methods. Therefore we need to make sure it also has the attribute. | |
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: | |
num_samples = eval_dataset.num_examples | |
else: | |
if has_length(dataloader): | |
num_samples = self.num_examples(dataloader) | |
else: # both len(dataloader.dataset) and len(dataloader) fail | |
num_samples = observed_num_examples | |
if num_samples == 0 and observed_num_examples > 0: | |
num_samples = observed_num_examples | |
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of | |
# samplers has been rounded to a multiple of batch_size, so we truncate. | |
if all_losses is not None: | |
all_losses = all_losses[:num_samples] | |
if all_preds is not None: | |
all_preds = nested_truncate(all_preds, num_samples) | |
if all_labels is not None: | |
all_labels = nested_truncate(all_labels, num_samples) | |
if all_inputs is not None: | |
all_inputs = nested_truncate(all_inputs, num_samples) | |
# Metrics! | |
if self.compute_metrics is not None and all_preds is not None and all_labels is not None: | |
if args.include_inputs_for_metrics: | |
metrics = self.compute_metrics( | |
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) | |
) | |
else: | |
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) | |
else: | |
metrics = {} | |
# To be JSON-serializable, we need to remove numpy types or zero-d tensors | |
metrics = denumpify_detensorize(metrics) | |
if all_losses is not None: | |
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() | |
if hasattr(self, "jit_compilation_time"): | |
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time | |
# Prefix all keys with metric_key_prefix + '_' | |
for key in list(metrics.keys()): | |
if not key.startswith(f"{metric_key_prefix}_"): | |
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) | |
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) | |
def _nested_gather(self, tensors, name=None): | |
""" | |
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before | |
concatenating them to `gathered` | |
""" | |
if tensors is None: | |
return | |
if is_torch_tpu_available(): | |
if name is None: | |
name = "nested_gather" | |
tensors = nested_xla_mesh_reduce(tensors, name) | |
elif is_sagemaker_mp_enabled(): | |
tensors = smp_gather(tensors) | |
elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( | |
self.args.distributed_state is None and self.local_rank != -1 | |
): | |
tensors = distributed_concat(tensors) | |
return tensors | |
# Copied from Accelerate. | |
def _pad_across_processes(self, tensor, pad_index=-100): | |
""" | |
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so | |
they can safely be gathered. | |
""" | |
if isinstance(tensor, (list, tuple)): | |
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) | |
elif isinstance(tensor, dict): | |
return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) | |
elif not isinstance(tensor, torch.Tensor): | |
raise TypeError( | |
f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." | |
) | |
if len(tensor.shape) < 2: | |
return tensor | |
# Gather all sizes | |
size = torch.tensor(tensor.shape, device=tensor.device)[None] | |
sizes = self._nested_gather(size).cpu() | |
max_size = max(s[1] for s in sizes) | |
# When extracting XLA graphs for compilation, max_size is 0, | |
# so use inequality to avoid errors. | |
if tensor.shape[1] >= max_size: | |
return tensor | |
# Then pad to the maximum size | |
old_size = tensor.shape | |
new_size = list(old_size) | |
new_size[1] = max_size | |
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index | |
new_tensor[:, : old_size[1]] = tensor | |
return new_tensor | |
def prediction_step( | |
self, | |
model: nn.Module, | |
inputs: Dict[str, Union[torch.Tensor, Any]], | |
prediction_loss_only: bool, | |
ignore_keys: Optional[List[str]] = None, | |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: | |
""" | |
Perform an evaluation step on `model` using `inputs`. | |
Subclass and override to inject custom behavior. | |
Args: | |
model (`nn.Module`): | |
The model to evaluate. | |
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
The inputs and targets of the model. | |
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
argument `labels`. Check your model's documentation for all accepted arguments. | |
prediction_loss_only (`bool`): | |
Whether or not to return the loss only. | |
ignore_keys (`List[str]`, *optional*): | |
A list of keys in the output of your model (if it is a dictionary) that should be ignored when | |
gathering predictions. | |
Return: | |
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, | |
logits and labels (each being optional). | |
""" | |
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) | |
# For CLIP-like models capable of returning loss values. | |
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` | |
# is `True` in `model.forward`. | |
return_loss = inputs.get("return_loss", None) | |
if return_loss is None: | |
return_loss = self.can_return_loss | |
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False | |
inputs = self._prepare_inputs(inputs) | |
if ignore_keys is None: | |
if hasattr(self.model, "config"): | |
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) | |
else: | |
ignore_keys = [] | |
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first. | |
if has_labels or loss_without_labels: | |
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) | |
if len(labels) == 1: | |
labels = labels[0] | |
else: | |
labels = None | |
with torch.no_grad(): | |
if is_sagemaker_mp_enabled(): | |
raw_outputs = smp_forward_only(model, inputs) | |
if has_labels or loss_without_labels: | |
if isinstance(raw_outputs, dict): | |
loss_mb = raw_outputs["loss"] | |
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) | |
else: | |
loss_mb = raw_outputs[0] | |
logits_mb = raw_outputs[1:] | |
loss = loss_mb.reduce_mean().detach().cpu() | |
logits = smp_nested_concat(logits_mb) | |
else: | |
loss = None | |
if isinstance(raw_outputs, dict): | |
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) | |
else: | |
logits_mb = raw_outputs | |
logits = smp_nested_concat(logits_mb) | |
else: | |
if has_labels or loss_without_labels: | |
with self.compute_loss_context_manager(): | |
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) | |
loss = loss.mean().detach() | |
if isinstance(outputs, dict): | |
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) | |
else: | |
logits = outputs[1:] | |
else: | |
loss = None | |
with self.compute_loss_context_manager(): | |
outputs = model(**inputs) | |
if isinstance(outputs, dict): | |
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) | |
else: | |
logits = outputs | |
# TODO: this needs to be fixed and made cleaner later. | |
if self.args.past_index >= 0: | |
self._past = outputs[self.args.past_index - 1] | |
if prediction_loss_only: | |
return (loss, None, None) | |
logits = nested_detach(logits) | |
if len(logits) == 1: | |
logits = logits[0] | |
return (loss, logits, labels) | |
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): | |
""" | |
For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point | |
operations for every backward + forward pass. If using another model, either implement such a method in the | |
model or subclass and override this method. | |
Args: | |
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
The inputs and targets of the model. | |
Returns: | |
`int`: The number of floating-point operations. | |
""" | |
if hasattr(self.model, "floating_point_ops"): | |
return self.model.floating_point_ops(inputs) | |
else: | |
return 0 | |
def init_git_repo(self, at_init: bool = False): | |
""" | |
Initializes a git repo in `self.args.hub_model_id`. | |
Args: | |
at_init (`bool`, *optional*, defaults to `False`): | |
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is | |
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped | |
out. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
if self.args.hub_model_id is None: | |
repo_name = Path(self.args.output_dir).absolute().name | |
else: | |
repo_name = self.args.hub_model_id | |
if "/" not in repo_name: | |
repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) | |
# Make sure the repo exists. | |
create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) | |
try: | |
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) | |
except EnvironmentError: | |
if self.args.overwrite_output_dir and at_init: | |
# Try again after wiping output_dir | |
shutil.rmtree(self.args.output_dir) | |
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) | |
else: | |
raise | |
self.repo.git_pull() | |
# By default, ignore the checkpoint folders | |
if ( | |
not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) | |
and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS | |
): | |
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: | |
writer.writelines(["checkpoint-*/"]) | |
# Add "*.sagemaker" to .gitignore if using SageMaker | |
if os.environ.get("SM_TRAINING_ENV"): | |
self._add_sm_patterns_to_gitignore() | |
self.push_in_progress = None | |
def create_model_card( | |
self, | |
language: Optional[str] = None, | |
license: Optional[str] = None, | |
tags: Union[str, List[str], None] = None, | |
model_name: Optional[str] = None, | |
finetuned_from: Optional[str] = None, | |
tasks: Union[str, List[str], None] = None, | |
dataset_tags: Union[str, List[str], None] = None, | |
dataset: Union[str, List[str], None] = None, | |
dataset_args: Union[str, List[str], None] = None, | |
): | |
""" | |
Creates a draft of a model card using the information available to the `Trainer`. | |
Args: | |
language (`str`, *optional*): | |
The language of the model (if applicable) | |
license (`str`, *optional*): | |
The license of the model. Will default to the license of the pretrained model used, if the original | |
model given to the `Trainer` comes from a repo on the Hub. | |
tags (`str` or `List[str]`, *optional*): | |
Some tags to be included in the metadata of the model card. | |
model_name (`str`, *optional*): | |
The name of the model. | |
finetuned_from (`str`, *optional*): | |
The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo | |
of the original model given to the `Trainer` (if it comes from the Hub). | |
tasks (`str` or `List[str]`, *optional*): | |
One or several task identifiers, to be included in the metadata of the model card. | |
dataset_tags (`str` or `List[str]`, *optional*): | |
One or several dataset tags, to be included in the metadata of the model card. | |
dataset (`str` or `List[str]`, *optional*): | |
One or several dataset identifiers, to be included in the metadata of the model card. | |
dataset_args (`str` or `List[str]`, *optional*): | |
One or several dataset arguments, to be included in the metadata of the model card. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
training_summary = TrainingSummary.from_trainer( | |
self, | |
language=language, | |
license=license, | |
tags=tags, | |
model_name=model_name, | |
finetuned_from=finetuned_from, | |
tasks=tasks, | |
dataset_tags=dataset_tags, | |
dataset=dataset, | |
dataset_args=dataset_args, | |
) | |
model_card = training_summary.to_model_card() | |
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: | |
f.write(model_card) | |
def _push_from_checkpoint(self, checkpoint_folder): | |
# Only push from one node. | |
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: | |
return | |
# If we haven't finished the last push, we don't do this one. | |
if self.push_in_progress is not None and not self.push_in_progress.is_done: | |
return | |
output_dir = self.args.output_dir | |
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder | |
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] | |
for modeling_file in modeling_files: | |
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): | |
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) | |
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. | |
if self.tokenizer is not None: | |
self.tokenizer.save_pretrained(output_dir) | |
# Same for the training arguments | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
try: | |
if self.args.hub_strategy == HubStrategy.CHECKPOINT: | |
# Temporarily move the checkpoint just saved for the push | |
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") | |
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a | |
# subfolder. | |
if os.path.isdir(tmp_checkpoint): | |
shutil.rmtree(tmp_checkpoint) | |
shutil.move(checkpoint_folder, tmp_checkpoint) | |
if self.args.save_strategy == IntervalStrategy.STEPS: | |
commit_message = f"Training in progress, step {self.state.global_step}" | |
else: | |
commit_message = f"Training in progress, epoch {int(self.state.epoch)}" | |
push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True) | |
# Return type of `Repository.push_to_hub` is either None or a tuple. | |
if push_work is not None: | |
self.push_in_progress = push_work[1] | |
except Exception as e: | |
logger.error(f"Error when pushing to hub: {e}") | |
finally: | |
if self.args.hub_strategy == HubStrategy.CHECKPOINT: | |
# Move back the checkpoint to its place | |
shutil.move(tmp_checkpoint, checkpoint_folder) | |
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: | |
""" | |
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. | |
Parameters: | |
commit_message (`str`, *optional*, defaults to `"End of training"`): | |
Message to commit while pushing. | |
blocking (`bool`, *optional*, defaults to `True`): | |
Whether the function should return only when the `git push` has finished. | |
kwargs: | |
Additional keyword arguments passed along to [`~Trainer.create_model_card`]. | |
Returns: | |
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of | |
the commit and an object to track the progress of the commit if `blocking=True` | |
""" | |
# If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but | |
# it might fail. | |
if not hasattr(self, "repo"): | |
self.init_git_repo() | |
model_name = kwargs.pop("model_name", None) | |
if model_name is None and self.args.should_save: | |
if self.args.hub_model_id is None: | |
model_name = Path(self.args.output_dir).name | |
else: | |
model_name = self.args.hub_model_id.split("/")[-1] | |
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by | |
# self.args.should_save. | |
self.save_model(_internal_call=True) | |
# Only push from one node. | |
if not self.is_world_process_zero(): | |
return | |
# Cancel any async push in progress if blocking=True. The commits will all be pushed together. | |
if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: | |
self.push_in_progress._process.kill() | |
self.push_in_progress = None | |
git_head_commit_url = self.repo.push_to_hub( | |
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True | |
) | |
# push separately the model card to be independant from the rest of the model | |
if self.args.should_save: | |
self.create_model_card(model_name=model_name, **kwargs) | |
try: | |
self.repo.push_to_hub( | |
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True | |
) | |
except EnvironmentError as exc: | |
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") | |
return git_head_commit_url | |
# | |
# Deprecated code | |
# | |
def prediction_loop( | |
self, | |
dataloader: DataLoader, | |
description: str, | |
prediction_loss_only: Optional[bool] = None, | |
ignore_keys: Optional[List[str]] = None, | |
metric_key_prefix: str = "eval", | |
) -> EvalLoopOutput: | |
""" | |
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. | |
Works both with or without labels. | |
""" | |
args = self.args | |
if not has_length(dataloader): | |
raise ValueError("dataloader must implement a working __len__") | |
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only | |
# if eval is called w/o train, handle model prep here | |
if self.is_deepspeed_enabled and self.model_wrapped is self.model: | |
_, _ = deepspeed_init(self, num_training_steps=0, inference=True) | |
model = self._wrap_model(self.model, training=False, dataloader=dataloader) | |
if len(self.accelerator._models) == 0 and model is self.model: | |
model = ( | |
self.accelerator.prepare(model) | |
if self.is_deepspeed_enabled | |
else self.accelerator.prepare_model(model, evaluation_mode=True) | |
) | |
if self.is_fsdp_enabled: | |
self.model = model | |
# for the rest of this function `model` is the outside model, whether it was wrapped or not | |
if model is not self.model: | |
self.model_wrapped = model | |
# backward compatibility | |
if self.is_deepspeed_enabled: | |
self.deepspeed = self.model_wrapped | |
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called | |
# while ``train`` is running, cast it to the right dtype first and then put on device | |
if not self.is_in_train: | |
if args.fp16_full_eval: | |
model = model.to(dtype=torch.float16, device=args.device) | |
elif args.bf16_full_eval: | |
model = model.to(dtype=torch.bfloat16, device=args.device) | |
batch_size = dataloader.batch_size | |
num_examples = self.num_examples(dataloader) | |
logger.info(f"***** Running {description} *****") | |
logger.info(f" Num examples = {num_examples}") | |
logger.info(f" Batch size = {batch_size}") | |
losses_host: torch.Tensor = None | |
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None | |
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None | |
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None | |
world_size = max(1, args.world_size) | |
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) | |
if not prediction_loss_only: | |
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass | |
# a batch size to the sampler) | |
make_multiple_of = None | |
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): | |
make_multiple_of = dataloader.sampler.batch_size | |
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) | |
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) | |
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) | |
model.eval() | |
if is_torch_tpu_available(): | |
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) | |
if args.past_index >= 0: | |
self._past = None | |
self.callback_handler.eval_dataloader = dataloader | |
for step, inputs in enumerate(dataloader): | |
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) | |
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None | |
if loss is not None: | |
losses = loss.repeat(batch_size) | |
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) | |
if logits is not None: | |
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) | |
if labels is not None: | |
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) | |
if inputs_decode is not None: | |
inputs_host = ( | |
inputs_decode | |
if inputs_host is None | |
else nested_concat(inputs_host, inputs_decode, padding_index=-100) | |
) | |
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) | |
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps. | |
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: | |
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) | |
if not prediction_loss_only: | |
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) | |
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) | |
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) | |
# Set back to None to begin a new accumulation | |
losses_host, preds_host, labels_host, inputs_host = None, None, None, None | |
if args.past_index and hasattr(self, "_past"): | |
# Clean the state at the end of the evaluation loop | |
delattr(self, "_past") | |
# Gather all remaining tensors and put them back on the CPU | |
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) | |
if not prediction_loss_only: | |
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) | |
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) | |
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) | |
eval_loss = eval_losses_gatherer.finalize() | |
preds = preds_gatherer.finalize() if not prediction_loss_only else None | |
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None | |
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None | |
if self.compute_metrics is not None and preds is not None and label_ids is not None: | |
if args.include_inputs_for_metrics: | |
metrics = self.compute_metrics( | |
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) | |
) | |
else: | |
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) | |
else: | |
metrics = {} | |
# To be JSON-serializable, we need to remove numpy types or zero-d tensors | |
metrics = denumpify_detensorize(metrics) | |
if eval_loss is not None: | |
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() | |
# Prefix all keys with metric_key_prefix + '_' | |
for key in list(metrics.keys()): | |
if not key.startswith(f"{metric_key_prefix}_"): | |
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) | |
return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) | |
def _gather_and_numpify(self, tensors, name): | |
""" | |
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before | |
concatenating them to `gathered` | |
""" | |
if tensors is None: | |
return | |
if is_torch_tpu_available(): | |
tensors = nested_xla_mesh_reduce(tensors, name) | |
elif is_sagemaker_mp_enabled(): | |
tensors = smp_gather(tensors) | |
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: | |
tensors = distributed_concat(tensors) | |
return nested_numpify(tensors) | |
def _add_sm_patterns_to_gitignore(self) -> None: | |
"""Add SageMaker Checkpointing patterns to .gitignore file.""" | |
# Make sure we only do this on the main process | |
if not self.is_world_process_zero(): | |
return | |
patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] | |
# Get current .gitignore content | |
if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): | |
with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: | |
current_content = f.read() | |
else: | |
current_content = "" | |
# Add the patterns to .gitignore | |
content = current_content | |
for pattern in patterns: | |
if pattern not in content: | |
if content.endswith("\n"): | |
content += pattern | |
else: | |
content += f"\n{pattern}" | |
# Write the .gitignore file if it has changed | |
if content != current_content: | |
with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: | |
logger.debug(f"Writing .gitignore file. Content: {content}") | |
f.write(content) | |
self.repo.git_add(".gitignore") | |
# avoid race condition with git status | |
time.sleep(0.5) | |
if not self.repo.is_repo_clean(): | |
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") | |
self.repo.git_push() | |
def create_accelerator_and_postprocess(self): | |
# create accelerator object | |
self.accelerator = Accelerator( | |
deepspeed_plugin=self.args.deepspeed_plugin, | |
gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
) | |
# deepspeed and accelerate flags covering both trainer args and accelerate launcher | |
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None | |
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None | |
# post accelerator creation setup | |
if self.is_fsdp_enabled: | |
fsdp_plugin = self.accelerator.state.fsdp_plugin | |
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) | |
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) | |
if self.is_deepspeed_enabled: | |
if getattr(self.args, "hf_deepspeed_config", None) is None: | |
from transformers.deepspeed import HfTrainerDeepSpeedConfig | |
ds_plugin = self.accelerator.state.deepspeed_plugin | |
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) | |
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config | |
ds_plugin.hf_ds_config.trainer_config_process(self.args) | |