File size: 25,918 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 |
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import gc
import logging
import sys
import time
from distutils import util
from typing import Any, Callable, Dict, Tuple
import deepspeed
import numpy as np
import pandas as pd
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.deepspeed import HfDeepSpeedConfig
from llm_studio.src.loggers import MainLogger
from llm_studio.src.utils.config_utils import (
load_config_py,
load_config_yaml,
save_config_yaml,
)
from llm_studio.src.utils.data_utils import (
get_data,
get_inference_batch_size,
get_train_dataloader,
get_train_dataset,
get_val_dataloader,
get_val_dataset,
)
from llm_studio.src.utils.exceptions import LLMTrainingException
from llm_studio.src.utils.export_utils import save_prediction_outputs
from llm_studio.src.utils.gpu_utils import sync_across_processes
from llm_studio.src.utils.logging_utils import (
TqdmToLogger,
initialize_logging,
log_plot,
write_flag,
)
from llm_studio.src.utils.modeling_utils import (
activate_neftune,
check_disk_space,
get_ds_config,
get_number_of_validation_epochs,
get_optimizer,
get_scheduler,
get_torch_dtype,
load_checkpoint,
run_inference,
save_checkpoint,
save_predictions,
wrap_model_distributed,
)
from llm_studio.src.utils.utils import kill_ddp_processes, set_environment, set_seed
logger = logging.getLogger(__name__)
def run_eval(
cfg,
model: torch.nn.Module,
val_dataloader: DataLoader,
val_df: pd.DataFrame,
mode: str = "validation",
) -> Tuple:
"""Runs the evaluation loop.
Args:
cfg: config object
model: trained model
val_dataloader: validation Dataloader
val_df: validation DataFrame
mode: validation
Returns:
Validation loss
"""
with torch.no_grad():
is_training = model.training
model.eval()
val_data: Dict[str, Any] = run_inference(
cfg, model, val_dataloader, mode
) # type: ignore
model.train(is_training)
# Sync validation predictions across GPUs
if cfg.environment._distributed and cfg.environment._distributed_inference:
for key, value in val_data.items():
val_data[key] = sync_across_processes(
value, cfg.environment._world_size, group=cfg.environment._cpu_comm
)
if cfg.environment._local_rank != 0:
# data has been synced, so we can return early on other ranks
if cfg.environment._distributed:
torch.distributed.barrier()
return 0, 0
# Drop any extra observations
for k, v in val_data.items():
val_data[k] = v[: len(val_dataloader.dataset)] # type: ignore
val_data = val_dataloader.dataset.postprocess_output( # type: ignore
cfg=cfg, df=val_df, output=val_data
)
val_loss = np.mean(val_data.get("loss", torch.tensor(0)).float().cpu().numpy())
# postprocess_output only runs on rank 0 to save time/memory
val_metric = np.mean(val_data["metrics"])
logger.info(f"{mode.capitalize()} {cfg.prediction.metric}: {val_metric:.5f}")
for key in val_data:
if key.startswith("additional_log_") or key == "loss":
value = np.mean(val_data[key].float().cpu().numpy())
key = key.replace("additional_log_", "")
logger.info(f"Mean {mode} {key}: {value:.5f}")
cfg.logging._logger.log(
mode,
key,
value,
step=cfg.environment._curr_step,
)
cfg.logging._logger.log(
mode, cfg.prediction.metric, val_metric, step=cfg.environment._curr_step
)
# Log plots
if val_df is not None:
plot = cfg.logging.plots_class.plot_validation_predictions(
val_outputs=val_data, cfg=cfg, val_df=val_df, mode="validation"
)
log_plot(cfg, plot, "validation_predictions")
save_predictions(cfg, val_data, val_dataloader, val_df, mode)
if cfg.environment._distributed:
torch.distributed.barrier()
return val_loss, val_metric
def run_train(
cfg: Any,
model: torch.nn.Module,
optimizer,
scheduler,
epoch_steps,
train_dataloader,
val_dataloader,
val_df: pd.DataFrame,
):
"""Runs the training loop.
Args:
cfg: config object
model: model
train_dataloader: custom training Dataloader
train_df: train DataFrame
val_dataloader: custom validation Dataloader
val_df: validation DataFrame
Returns:
Validation prediction output
Validation loss
Validation metric
Last train batch
"""
if (
hasattr(cfg.augmentation, "neftune_noise_alpha")
and cfg.augmentation.neftune_noise_alpha > 0
):
activate_neftune(model, cfg.augmentation.neftune_noise_alpha)
scaler: GradScaler | None = None
if cfg.environment.mixed_precision:
scaler = GradScaler(
enabled=(cfg.environment.mixed_precision_dtype == "float16")
)
optimizer.zero_grad(set_to_none=True)
# Prepare NLP Augmentation
nlp_augment = None
if hasattr(cfg.augmentation, "nlp_augmentations_class"):
nlp_augment = cfg.augmentation.nlp_augmentations_class(cfg=cfg)
start_epoch = 0
_, metric_mode, _ = cfg.prediction.metric_class.get(cfg.prediction.metric)
objective_op: Callable[[float, float], bool]
if metric_mode == "max":
best_val_metric = -np.inf
objective_op = np.greater
else:
best_val_metric = np.inf
objective_op = np.less
if cfg.training.evaluate_before_training:
val_loss, val_metric = run_eval(
cfg=cfg, model=model, val_dataloader=val_dataloader, val_df=val_df
)
for epoch in range(start_epoch, cfg.training.epochs):
set_seed(
cfg.environment._seed
+ epoch * cfg.environment._world_size * cfg.environment.number_of_workers
+ cfg.environment._local_rank * cfg.environment.number_of_workers
)
if cfg.environment._local_rank == 0:
logger.info(f"Training Epoch: {epoch + 1} / {cfg.training.epochs}")
if (
cfg.environment._distributed
and not cfg.environment.use_deepspeed
and hasattr(train_dataloader.sampler, "set_epoch")
):
train_dataloader.sampler.set_epoch(epoch) # type: ignore
tqdm_out = TqdmToLogger(logger, level=logging.INFO)
progress_bar = tqdm(
total=epoch_steps,
disable=cfg.environment._local_rank != 0,
file=tqdm_out,
ascii=True,
desc="train loss",
mininterval=0,
)
tr_it = iter(train_dataloader)
losses = []
model.train()
log_update_steps = max(epoch_steps // 20, 1)
evaluation_step = max(int(epoch_steps * cfg.training.evaluation_epochs), 1)
logger.info(f"Evaluation step: {evaluation_step}")
for itr, data in enumerate(tr_it):
cfg.environment._curr_step += (
cfg.training.batch_size * cfg.environment._world_size
)
# Batch to device
batch = cfg.dataset.dataset_class.batch_to_device(
data, cfg.environment._device
)
# NLP augmentation
if nlp_augment is not None:
batch = nlp_augment(batch)
# Plot first batch
if epoch == 0 and itr == 0 and cfg.environment._local_rank == 0:
plot = cfg.logging.plots_class.plot_batch(batch=batch, cfg=cfg)
log_plot(cfg, plot, "train_data")
# only need to sync gradients at last step of grad accumulation
model.require_backward_grad_sync = itr % cfg.training.grad_accumulation == 0
# Forward pass
with autocast(
enabled=cfg.environment.mixed_precision,
dtype=get_torch_dtype(cfg.environment.mixed_precision_dtype),
):
output_dict = model.forward(batch)
loss = output_dict["loss"]
if ~np.isfinite(loss.item()) and (epoch > start_epoch or itr > 20):
raise LLMTrainingException(
"NaN caught in loss during training. "
"Please, reduce learning rate, change dtype, "
"or disable mixed precision. Alternatively, "
"gradient clipping may help to stabilize training."
)
losses.append(loss.item())
# loss is a mean loss per batch/sample
# as grad_accumulations sums up the gradients, this loss must be scaled
# by the number of grad_accumulations, to have similar behavior for
# BS * grad_accumulations = const.
if cfg.training.grad_accumulation != 1:
loss = loss / cfg.training.grad_accumulation
# Backward pass
if (
cfg.environment.mixed_precision
and len(cfg.environment.gpus)
and not cfg.environment.use_deepspeed
):
scaler.scale(loss).backward() # type: ignore
if itr % cfg.training.grad_accumulation == 0:
if cfg.training.gradient_clip > 0:
scaler.unscale_(optimizer) # type: ignore
torch.nn.utils.clip_grad_norm_(
model.parameters(), cfg.training.gradient_clip
)
scaler.step(optimizer) # type: ignore
scaler.update()
optimizer.zero_grad(set_to_none=True)
else:
if cfg.environment.use_deepspeed:
model.backward(loss) # type: ignore[operator]
else:
loss.backward()
if itr % cfg.training.grad_accumulation == 0:
if cfg.training.gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), cfg.training.gradient_clip
)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if cfg.environment._distributed:
torch.cuda.synchronize(device=cfg.environment._local_rank)
if scheduler is not None:
scheduler.step()
if cfg.environment._local_rank == 0:
cfg.logging._logger.log(
"train", "loss", losses[-1], step=cfg.environment._curr_step
)
cfg.logging._logger.log(
"meta",
"lr",
optimizer.param_groups[0]["lr"],
step=cfg.environment._curr_step,
)
if cfg.training.differential_learning_rate_layers:
cfg.logging._logger.log(
"meta",
"lr_diff",
optimizer.param_groups[2]["lr"],
step=cfg.environment._curr_step,
)
cfg.logging._logger.log(
"internal",
"current_step",
cfg.environment._curr_step,
step=cfg.environment._curr_step,
)
for key in output_dict:
if key.startswith("additional_log_"):
cfg.logging._logger.log(
"train",
key.replace("additional_log_", ""),
output_dict[key].item(),
step=cfg.environment._curr_step,
)
# Show logs each 5% of the epoch (only if doing per epoch evaluation)
if (itr + 1) % log_update_steps == 0 or itr == epoch_steps - 1:
progress_bar.set_description(
f"train loss: {np.mean(losses[-10:]):.2f}", refresh=False
)
if (itr + 1) % log_update_steps == 0:
progress_bar.update(log_update_steps)
else:
progress_bar.update(epoch_steps % log_update_steps)
del output_dict
# Validation loop
if (itr + 1) % evaluation_step == 0:
if cfg.training.evaluation_epochs == 1:
progress_bar.close()
# TODO: Move back after fixing slow generation of deepspeed.
if not cfg.training.save_best_checkpoint:
checkpoint_path = cfg.output_directory
if cfg.environment._local_rank == 0:
logger.info(
f"Saving last model checkpoint to {checkpoint_path}"
)
save_checkpoint(model=model, path=checkpoint_path, cfg=cfg)
val_loss, val_metric = run_eval(
cfg=cfg, model=model, val_dataloader=val_dataloader, val_df=val_df
)
if cfg.training.save_best_checkpoint:
if objective_op(val_metric, best_val_metric):
checkpoint_path = cfg.output_directory
if cfg.environment._local_rank == 0:
logger.info(
f"Saving best model checkpoint: "
f"val_{cfg.prediction.metric} {best_val_metric:.5} -> "
f"{val_metric:.5} to {checkpoint_path}"
)
save_checkpoint(model=model, path=checkpoint_path, cfg=cfg)
best_val_metric = val_metric
model.train()
progress_bar.close()
del progress_bar
if cfg.environment._distributed:
torch.cuda.synchronize(device=cfg.environment._local_rank)
torch.distributed.barrier()
if cfg.environment._local_rank == 0:
cfg.logging._logger.log(
"internal", "epoch", epoch + 1, step=cfg.environment._curr_step
)
if cfg.environment._distributed:
torch.distributed.barrier()
return val_loss, val_metric
def run(cfg: Any) -> None:
"""Runs the routine.
Args:
cfg: config object with all the hyperparameters
"""
if cfg.problem_type == "text_rlhf_language_modeling":
raise DeprecationWarning(
"text_rlhf_language_modeling is deprecated. "
"Please use DPO Modeling instead."
)
os.makedirs(cfg.output_directory, exist_ok=True)
# Force evaluation if user trains 0 epochs
cfg.training.evaluate_before_training = (
cfg.training.evaluate_before_training or cfg.training.epochs == 0
)
# Set the random seed for reproducibility
# either random seed when user set it -1 or deterministic user chosen seed
if cfg.environment.seed < 0:
cfg.environment._seed = np.random.randint(1_000_000)
else:
cfg.environment._seed = cfg.environment.seed
if (
cfg.architecture.backbone_dtype in ["int8", "int4"]
and cfg.environment.use_deepspeed
):
raise ValueError(
f"Deepspeed do not support backbone type {cfg.architecture.backbone_dtype}."
+ " Please set backbone type to float16 or bfloat16 for using deepspeed."
)
# Prepare environment
if "WORLD_SIZE" in os.environ:
cfg.environment._distributed = int(os.environ["WORLD_SIZE"]) > 1
else:
cfg.environment._distributed = False
if cfg.environment._distributed:
cfg.environment._local_rank = int(os.environ["LOCAL_RANK"])
cfg.environment._device = "cuda:%d" % cfg.environment._local_rank
if cfg.environment.use_deepspeed:
deepspeed.init_distributed()
else:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
cfg.environment._cpu_comm = torch.distributed.new_group(backend="gloo")
cfg.environment._world_size = torch.distributed.get_world_size()
cfg.environment._rank = torch.distributed.get_rank()
torch.cuda.set_device(cfg.environment._rank)
logger.info(
f"Training in distributed mode with multiple processes, "
f"1 GPU per process. Process {cfg.environment._rank}, "
f"total: {cfg.environment._world_size} "
f"local rank: {cfg.environment._local_rank}."
)
# Sync the random seed
cfg.environment._seed = int(
sync_across_processes(
np.array([cfg.environment._seed]),
cfg.environment._world_size,
group=cfg.environment._cpu_comm,
)[0]
)
else:
cfg.environment._local_rank = 0
cfg.environment._device = (
"cuda:0"
if (torch.cuda.is_available() and len(cfg.environment.gpus) > 0)
else "cpu"
)
if cfg.environment._device == "cpu":
logger.warning("Training on CPU. This will be slow.")
set_seed(cfg.environment._seed)
if cfg.environment._local_rank == 0:
logger.info(f"Problem Type: {cfg.problem_type}")
logger.info(f"Global random seed: {cfg.environment._seed}")
cfg = set_environment(cfg)
# we need to get train dataframe and number of labels if not set or in training mode
if cfg.environment._local_rank == 0:
logger.info("Preparing the data...")
train_df, val_df = get_data(cfg)
if (
len(val_df) > int(os.getenv("GPT_EVAL_MAX", 100))
and "GPT" in cfg.prediction.metric
):
logger.warning(
f"More than {os.getenv('GPT_EVAL_MAX', 100)} validation records. "
"Safeguarding against OpenAI API costs. Setting metric to BLEU. "
"Change GPT_EVAL_MAX to run GPT validation."
)
cfg.prediction.metric = "BLEU"
# prepare data
if cfg.environment._local_rank == 0:
logger.info("Preparing train and validation data")
train_dataset = get_train_dataset(train_df=train_df, cfg=cfg)
val_dataset = get_val_dataset(val_df=val_df, cfg=cfg)
train_dataloader = get_train_dataloader(train_ds=train_dataset, cfg=cfg)
val_dataloader = get_val_dataloader(val_ds=val_dataset, cfg=cfg)
if cfg.environment._local_rank == 0:
total_training_steps = (
cfg.training.epochs
* len(train_dataloader)
* cfg.training.batch_size
* cfg.environment._world_size
)
num_eval_epochs = get_number_of_validation_epochs(
training_epochs=cfg.training.epochs,
evaluation_epochs=cfg.training.evaluation_epochs,
)
val_batch_size = get_inference_batch_size(cfg)
# if zero shot, validate once before training
total_validation_steps = (
len(val_dataloader)
* (num_eval_epochs + int(cfg.training.evaluate_before_training))
* val_batch_size
* cfg.environment._world_size
)
# Prepare model and optimizer
if cfg.environment.use_deepspeed:
ds_config = get_ds_config(cfg)
# keep this object alive.
dschf = HfDeepSpeedConfig(ds_config) # noqa: F841
with torch.device(cfg.environment._device):
model = cfg.architecture.model_class(cfg)
check_disk_space(model, cfg.output_directory)
# load model weights
if cfg.architecture.pretrained_weights != "":
# Do not load strictly if continue training from the previous experiment
load_checkpoint(cfg, model, strict=cfg.training.epochs == -1)
model.to(cfg.environment._device)
epoch_steps = len(train_dataloader)
optimizer = get_optimizer(model=model, cfg=cfg)
scheduler = get_scheduler(cfg=cfg, optimizer=optimizer, epoch_steps=epoch_steps)
if getattr(cfg.architecture, "force_embedding_gradients"):
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
for param in module.parameters():
param.requires_grad = True
param.data = param.data.float()
if cfg.environment._distributed:
(
model,
optimizer,
train_dataloader,
val_dataloader,
scheduler,
) = wrap_model_distributed(
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
cfg=cfg,
)
if cfg.environment.compile_model:
# deepspeed do not support torch.compile
if cfg.environment.use_deepspeed:
logger.warning(
"Deepspeed is active, but it doesn't support torch.compile."
"Skipping compilation for this experiment."
)
else:
if cfg.environment._distributed:
model.module.backbone = torch.compile(model.module.backbone)
else:
model.backbone = torch.compile(model.backbone)
# Force settings when saving best checkpoint
if cfg.training.save_best_checkpoint:
cfg.training.train_validation_data = False
# reset steps
cfg.environment._curr_step = 0
cfg.environment._curr_val_step = 0
gc.collect()
global_start_time = time.time()
if cfg.environment._local_rank == 0:
# re-save cfg
save_config_yaml(f"{cfg.output_directory}/cfg.yaml", cfg)
cfg.logging._logger = MainLogger(cfg)
cfg.logging._logger.log(
"internal", "total_training_steps", total_training_steps, step=0
)
cfg.logging._logger.log(
"internal", "total_validation_steps", total_validation_steps, step=0
)
cfg.logging._logger.log(
"internal",
"global_start_time",
global_start_time,
step=cfg.environment._curr_step,
)
# re-save config
save_config_yaml(f"{cfg.output_directory}/cfg.yaml", cfg)
val_loss, val_metric = run_train(
cfg=cfg,
model=model,
optimizer=optimizer,
scheduler=scheduler,
epoch_steps=epoch_steps,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
val_df=val_df,
)
# reset external logging
if cfg.environment._local_rank == 0:
cfg.logging._logger.reset_external()
experiment_path = f"{cfg.output_directory}"
if cfg.training.epochs == 0:
checkpoint_path = cfg.output_directory
if cfg.environment._local_rank == 0:
logger.info(f"Saving last model checkpoint to {checkpoint_path}")
save_checkpoint(model=model, path=checkpoint_path, cfg=cfg)
if cfg.environment._local_rank == 0:
save_config_yaml(f"{cfg.output_directory}/cfg.yaml", cfg)
save_prediction_outputs(cfg.experiment_name, experiment_path)
flag_path = os.path.join(cfg.output_directory, "flags.json")
write_flag(flag_path, "status", "finished")
time_took = time.time() - global_start_time
if time_took > 86400:
# if more than one day, show days
# need to subtract 1 day from time_took since strftime shows day of year
# which starts counting at 1
time_took_formatted = time.strftime(
"%-jd %H:%M:%S", time.gmtime(float(time_took - 86400))
)
else:
time_took_formatted = time.strftime(
"%H:%M:%S", time.gmtime(float(time_took))
)
write_flag(flag_path, "info", f"Runtime: {time_took_formatted}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-C", "--config", help="config filename", default=argparse.SUPPRESS
)
parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
parser_args, unknown = parser.parse_known_args(sys.argv)
if "config" in parser_args:
cfg = load_config_py(parser_args.config)
elif "yaml" in parser_args:
cfg = load_config_yaml(parser_args.yaml)
else:
raise ValueError("Please, provide a configuration file")
extra_args = []
for arg_orig in unknown:
if arg_orig.startswith(("-", "--")):
arg = arg_orig.replace("-", "").split(".")
try:
arg_type = getattr(cfg, arg[0]).get_annotations()[arg[1]]
except (AttributeError, KeyError):
continue
if arg_type == bool:
parser.add_argument(arg_orig, type=util.strtobool)
else:
parser.add_argument(arg_orig, type=arg_type)
extra_args.append(arg)
args = parser.parse_args()
for arg in extra_args:
value = getattr(args, ".".join(arg))
setattr(getattr(cfg, arg[0]), arg[1], value)
out_dir = cfg.output_directory
os.makedirs(out_dir, exist_ok=True)
initialize_logging(cfg)
try:
run(cfg=cfg)
except Exception:
logging.error("Exception occurred during the run:", exc_info=True)
if ("WORLD_SIZE" in os.environ) and (int(os.environ["WORLD_SIZE"]) > 1):
kill_ddp_processes()
|