|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import datetime |
|
import json |
|
import logging |
|
import os |
|
import time |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import webdataset as wds |
|
from medomni.common.dist_utils import ( |
|
download_cached_file, |
|
get_rank, |
|
get_world_size, |
|
is_main_process, |
|
main_process, |
|
) |
|
from medomni.common.registry import registry |
|
from medomni.common.utils import is_url |
|
from medomni.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset |
|
from medomni.datasets.datasets.dataloader_utils import ( |
|
IterLoader, |
|
MultiIterLoader, |
|
PrefetchLoader, |
|
) |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
import ipdb |
|
|
|
@registry.register_runner("runner_base") |
|
class RunnerBase: |
|
""" |
|
A runner class to train and evaluate a model given a task and datasets. |
|
|
|
The runner uses pytorch distributed data parallel by default. Future release |
|
will support other distributed frameworks. |
|
""" |
|
|
|
def __init__(self, cfg, task, model, datasets, job_id): |
|
self.config = cfg |
|
self.job_id = job_id |
|
|
|
self.task = task |
|
self.datasets = datasets |
|
|
|
self._model = model |
|
|
|
self._wrapped_model = None |
|
self._device = None |
|
self._optimizer = None |
|
self._scaler = None |
|
self._dataloaders = None |
|
self._lr_sched = None |
|
|
|
self.start_epoch = 0 |
|
|
|
|
|
self.setup_output_dir() |
|
|
|
@property |
|
def device(self): |
|
if self._device is None: |
|
self._device = torch.device(self.config.run_cfg.device) |
|
|
|
return self._device |
|
|
|
@property |
|
def use_distributed(self): |
|
return self.config.run_cfg.distributed |
|
|
|
@property |
|
def model(self): |
|
""" |
|
A property to get the DDP-wrapped model on the device. |
|
""" |
|
|
|
if self._model.device != self.device: |
|
self._model = self._model.to(self.device) |
|
|
|
|
|
if self.use_distributed: |
|
if self._wrapped_model is None: |
|
self._wrapped_model = DDP( |
|
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True |
|
) |
|
else: |
|
self._wrapped_model = self._model |
|
|
|
return self._wrapped_model |
|
|
|
@property |
|
def optimizer(self): |
|
|
|
if self._optimizer is None: |
|
num_parameters = 0 |
|
p_wd, p_non_wd = [], [] |
|
for n, p in self.model.named_parameters(): |
|
if not p.requires_grad: |
|
continue |
|
print(n) |
|
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: |
|
p_non_wd.append(p) |
|
else: |
|
p_wd.append(p) |
|
num_parameters += p.data.nelement() |
|
logging.info("number of trainable parameters: %d" % num_parameters) |
|
optim_params = [ |
|
{ |
|
"params": p_wd, |
|
"weight_decay": float(self.config.run_cfg.weight_decay), |
|
}, |
|
{"params": p_non_wd, "weight_decay": 0}, |
|
] |
|
beta2 = self.config.run_cfg.get("beta2", 0.999) |
|
self._optimizer = torch.optim.AdamW( |
|
optim_params, |
|
lr=float(self.config.run_cfg.init_lr), |
|
weight_decay=float(self.config.run_cfg.weight_decay), |
|
betas=(0.9, beta2), |
|
) |
|
|
|
return self._optimizer |
|
|
|
@property |
|
def scaler(self): |
|
amp = self.config.run_cfg.get("amp", False) |
|
|
|
if amp: |
|
if self._scaler is None: |
|
self._scaler = torch.cuda.amp.GradScaler() |
|
|
|
return self._scaler |
|
|
|
@property |
|
def lr_scheduler(self): |
|
""" |
|
A property to get and create learning rate scheduler by split just in need. |
|
""" |
|
if self._lr_sched is None: |
|
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) |
|
|
|
|
|
max_epoch = self.max_epoch |
|
|
|
min_lr = self.min_lr |
|
|
|
init_lr = self.init_lr |
|
|
|
|
|
decay_rate = self.config.run_cfg.get("lr_decay_rate", None) |
|
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) |
|
warmup_steps = self.config.run_cfg.get("warmup_steps", 0) |
|
iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) |
|
|
|
if iters_per_epoch is None: |
|
try: |
|
iters_per_epoch = len(self.dataloaders['train']) |
|
except (AttributeError, TypeError): |
|
iters_per_epoch = 10000 |
|
|
|
self._lr_sched = lr_sched_cls( |
|
optimizer=self.optimizer, |
|
max_epoch=max_epoch, |
|
iters_per_epoch=iters_per_epoch, |
|
min_lr=min_lr, |
|
init_lr=init_lr, |
|
decay_rate=decay_rate, |
|
warmup_start_lr=warmup_start_lr, |
|
warmup_steps=warmup_steps, |
|
) |
|
|
|
return self._lr_sched |
|
|
|
@property |
|
def dataloaders(self) -> dict: |
|
""" |
|
A property to get and create dataloaders by split just in need. |
|
|
|
If no train_dataset_ratio is provided, concatenate map-style datasets and |
|
chain wds.DataPipe datasets separately. Training set becomes a tuple |
|
(ConcatDataset, ChainDataset), both are optional but at least one of them is |
|
required. The resultant ConcatDataset and ChainDataset will be sampled evenly. |
|
|
|
If train_dataset_ratio is provided, create a MultiIterLoader to sample |
|
each dataset by ratios during training. |
|
|
|
Currently do not support multiple datasets for validation and test. |
|
|
|
Returns: |
|
dict: {split_name: (tuples of) dataloader} |
|
""" |
|
if self._dataloaders is None: |
|
|
|
|
|
|
|
|
|
|
|
logging.info( |
|
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." |
|
) |
|
|
|
datasets = reorg_datasets_by_split(self.datasets) |
|
self.datasets = datasets |
|
|
|
|
|
|
|
for split_name in self.datasets: |
|
if isinstance(self.datasets[split_name], tuple) or isinstance( |
|
self.datasets[split_name], list |
|
): |
|
|
|
num_records = sum( |
|
[ |
|
len(d) |
|
if not type(d) in [wds.DataPipeline, ChainDataset] |
|
else 0 |
|
for d in self.datasets[split_name] |
|
] |
|
) |
|
|
|
else: |
|
if hasattr(self.datasets[split_name], "__len__"): |
|
|
|
num_records = len(self.datasets[split_name]) |
|
else: |
|
|
|
num_records = -1 |
|
logging.info( |
|
"Only a single wds.DataPipeline dataset, no __len__ attribute." |
|
) |
|
|
|
if num_records >= 0: |
|
logging.info( |
|
"Loaded {} records for {} split from the dataset.".format( |
|
num_records, split_name |
|
) |
|
) |
|
|
|
|
|
split_names = sorted(self.datasets.keys()) |
|
|
|
datasets = [self.datasets[split] for split in split_names] |
|
is_trains = [split in self.train_splits for split in split_names] |
|
|
|
batch_sizes = [ |
|
self.config.run_cfg.batch_size_train |
|
if split == "train" |
|
else self.config.run_cfg.batch_size_eval |
|
for split in split_names |
|
] |
|
|
|
collate_fns = [] |
|
for dataset in datasets: |
|
if isinstance(dataset, tuple) or isinstance(dataset, list): |
|
collate_fns.append([getattr(d, "collater", None) for d in dataset]) |
|
else: |
|
collate_fns.append(getattr(dataset, "collater", None)) |
|
|
|
dataloaders = self.create_loaders( |
|
datasets=datasets, |
|
num_workers=self.config.run_cfg.num_workers, |
|
batch_sizes=batch_sizes, |
|
is_trains=is_trains, |
|
collate_fns=collate_fns, |
|
) |
|
|
|
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} |
|
|
|
return self._dataloaders |
|
|
|
@property |
|
def cuda_enabled(self): |
|
return self.device.type == "cuda" |
|
|
|
@property |
|
def max_epoch(self): |
|
return int(self.config.run_cfg.max_epoch) |
|
|
|
@property |
|
def log_freq(self): |
|
log_freq = self.config.run_cfg.get("log_freq", 50) |
|
return int(log_freq) |
|
|
|
@property |
|
def init_lr(self): |
|
return float(self.config.run_cfg.init_lr) |
|
|
|
@property |
|
def min_lr(self): |
|
return float(self.config.run_cfg.min_lr) |
|
|
|
@property |
|
def accum_grad_iters(self): |
|
return int(self.config.run_cfg.get("accum_grad_iters", 1)) |
|
|
|
@property |
|
def valid_splits(self): |
|
valid_splits = self.config.run_cfg.get("valid_splits", []) |
|
|
|
if len(valid_splits) == 0: |
|
logging.info("No validation splits found.") |
|
|
|
return valid_splits |
|
|
|
@property |
|
def test_splits(self): |
|
test_splits = self.config.run_cfg.get("test_splits", []) |
|
|
|
return test_splits |
|
|
|
@property |
|
def train_splits(self): |
|
train_splits = self.config.run_cfg.get("train_splits", []) |
|
|
|
if len(train_splits) == 0: |
|
logging.info("Empty train splits.") |
|
|
|
return train_splits |
|
|
|
@property |
|
def evaluate_only(self): |
|
""" |
|
Set to True to skip training. |
|
""" |
|
return self.config.run_cfg.evaluate |
|
|
|
@property |
|
def use_dist_eval_sampler(self): |
|
return self.config.run_cfg.get("use_dist_eval_sampler", True) |
|
|
|
@property |
|
def resume_ckpt_path(self): |
|
return self.config.run_cfg.get("resume_ckpt_path", None) |
|
|
|
@property |
|
def train_loader(self): |
|
train_dataloader = self.dataloaders["train"] |
|
|
|
return train_dataloader |
|
|
|
def setup_output_dir(self): |
|
lib_root = Path(registry.get_path("library_root")) |
|
|
|
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id |
|
result_dir = output_dir / "result" |
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
result_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
registry.register_path("result_dir", str(result_dir)) |
|
registry.register_path("output_dir", str(output_dir)) |
|
|
|
self.result_dir = result_dir |
|
self.output_dir = output_dir |
|
|
|
def train(self): |
|
start_time = time.time() |
|
best_agg_metric = 0 |
|
best_epoch = 0 |
|
|
|
self.log_config() |
|
|
|
|
|
if not self.evaluate_only and self.resume_ckpt_path is not None: |
|
self._load_checkpoint(self.resume_ckpt_path) |
|
|
|
for cur_epoch in range(self.start_epoch, self.max_epoch): |
|
|
|
if not self.evaluate_only: |
|
logging.info("Start training") |
|
train_stats = self.train_epoch(cur_epoch) |
|
self.log_stats(split_name="train", stats=train_stats) |
|
|
|
|
|
if len(self.valid_splits) > 0: |
|
for split_name in self.valid_splits: |
|
logging.info("Evaluating on {}.".format(split_name)) |
|
|
|
val_log = self.eval_epoch( |
|
split_name=split_name, cur_epoch=cur_epoch |
|
) |
|
if val_log is not None: |
|
if is_main_process(): |
|
assert ( |
|
"agg_metrics" in val_log |
|
), "No agg_metrics found in validation log." |
|
|
|
agg_metrics = val_log["agg_metrics"] |
|
if agg_metrics > best_agg_metric and split_name == "val": |
|
best_epoch, best_agg_metric = cur_epoch, agg_metrics |
|
|
|
self._save_checkpoint(cur_epoch, is_best=True) |
|
|
|
val_log.update({"best_epoch": best_epoch}) |
|
self.log_stats(val_log, split_name) |
|
|
|
else: |
|
|
|
if not self.evaluate_only: |
|
self._save_checkpoint(cur_epoch, is_best=False) |
|
|
|
if self.evaluate_only: |
|
break |
|
|
|
if self.config.run_cfg.distributed: |
|
dist.barrier() |
|
|
|
|
|
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch |
|
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
logging.info("Training time {}".format(total_time_str)) |
|
|
|
def evaluate(self, cur_epoch="best", skip_reload=False): |
|
test_logs = dict() |
|
|
|
if len(self.test_splits) > 0: |
|
for split_name in self.test_splits: |
|
test_logs[split_name] = self.eval_epoch( |
|
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload |
|
) |
|
|
|
return test_logs |
|
|
|
def train_epoch(self, epoch): |
|
|
|
self.model.train() |
|
|
|
return self.task.train_epoch( |
|
epoch=epoch, |
|
model=self.model, |
|
data_loader=self.train_loader, |
|
optimizer=self.optimizer, |
|
scaler=self.scaler, |
|
lr_scheduler=self.lr_scheduler, |
|
cuda_enabled=self.cuda_enabled, |
|
log_freq=self.log_freq, |
|
accum_grad_iters=self.accum_grad_iters, |
|
) |
|
|
|
@torch.no_grad() |
|
def eval_epoch(self, split_name, cur_epoch, skip_reload=False): |
|
""" |
|
Evaluate the model on a given split. |
|
|
|
Args: |
|
split_name (str): name of the split to evaluate on. |
|
cur_epoch (int): current epoch. |
|
skip_reload_best (bool): whether to skip reloading the best checkpoint. |
|
During training, we will reload the best checkpoint for validation. |
|
During testing, we will use provided weights and skip reloading the best checkpoint . |
|
""" |
|
data_loader = self.dataloaders.get(split_name, None) |
|
assert data_loader, "data_loader for split {} is None.".format(split_name) |
|
|
|
|
|
|
|
model = self.unwrap_dist_model(self.model) |
|
if not skip_reload and cur_epoch == "best": |
|
model = self._reload_best_model(model) |
|
model.eval() |
|
|
|
self.task.before_evaluation( |
|
model=model, |
|
dataset=self.datasets[split_name], |
|
) |
|
results = self.task.evaluation(model, data_loader) |
|
|
|
if results is not None: |
|
return self.task.after_evaluation( |
|
val_result=results, |
|
split_name=split_name, |
|
epoch=cur_epoch, |
|
) |
|
|
|
def unwrap_dist_model(self, model): |
|
if self.use_distributed: |
|
return model.module |
|
else: |
|
return model |
|
|
|
def create_loaders( |
|
self, |
|
datasets, |
|
num_workers, |
|
batch_sizes, |
|
is_trains, |
|
collate_fns, |
|
dataset_ratios=None, |
|
): |
|
""" |
|
Create dataloaders for training and validation. |
|
""" |
|
|
|
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): |
|
|
|
if isinstance(dataset, ChainDataset) or isinstance( |
|
dataset, wds.DataPipeline |
|
): |
|
|
|
|
|
loader = iter( |
|
DataLoader( |
|
dataset, |
|
batch_size=bsz, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
) |
|
) |
|
else: |
|
|
|
|
|
if self.use_distributed: |
|
sampler = DistributedSampler( |
|
dataset, |
|
shuffle=is_train, |
|
num_replicas=get_world_size(), |
|
rank=get_rank(), |
|
) |
|
if not self.use_dist_eval_sampler: |
|
|
|
sampler = sampler if is_train else None |
|
else: |
|
sampler = None |
|
|
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bsz, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=sampler is None and is_train, |
|
collate_fn=collate_fn, |
|
drop_last=True if is_train else False, |
|
) |
|
loader = PrefetchLoader(loader) |
|
|
|
if is_train: |
|
loader = IterLoader(loader, use_distributed=self.use_distributed) |
|
else: |
|
loader = IterLoader(loader, use_distributed=self.use_distributed) |
|
|
|
|
|
return loader |
|
|
|
loaders = [] |
|
|
|
for dataset, bsz, is_train, collate_fn in zip( |
|
datasets, batch_sizes, is_trains, collate_fns |
|
): |
|
if isinstance(dataset, list) or isinstance(dataset, tuple): |
|
if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: |
|
dataset_ratios = [d.sample_ratio for d in dataset] |
|
loader = MultiIterLoader( |
|
loaders=[ |
|
_create_loader(d, num_workers, bsz, is_train, collate_fn[i]) |
|
for i, d in enumerate(dataset) |
|
], |
|
ratios=dataset_ratios, |
|
) |
|
else: |
|
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) |
|
|
|
loaders.append(loader) |
|
|
|
return loaders |
|
|
|
@main_process |
|
def _save_checkpoint(self, cur_epoch, is_best=False): |
|
""" |
|
Save the checkpoint at the current epoch. |
|
""" |
|
model_no_ddp = self.unwrap_dist_model(self.model) |
|
param_grad_dic = { |
|
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() |
|
} |
|
state_dict = model_no_ddp.state_dict() |
|
for k in list(state_dict.keys()): |
|
if k in param_grad_dic.keys() and not param_grad_dic[k]: |
|
|
|
del state_dict[k] |
|
save_obj = { |
|
"model": state_dict, |
|
"optimizer": self.optimizer.state_dict(), |
|
"config": self.config.to_dict(), |
|
"scaler": self.scaler.state_dict() if self.scaler else None, |
|
"epoch": cur_epoch, |
|
} |
|
save_to = os.path.join( |
|
self.output_dir, |
|
"checkpoint_{}.pth".format("best" if is_best else cur_epoch), |
|
) |
|
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) |
|
torch.save(save_obj, save_to) |
|
|
|
def _reload_best_model(self, model): |
|
""" |
|
Load the best checkpoint for evaluation. |
|
""" |
|
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") |
|
|
|
logging.info("Loading checkpoint from {}.".format(checkpoint_path)) |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
try: |
|
model.load_state_dict(checkpoint["model"]) |
|
except RuntimeError as e: |
|
logging.warning( |
|
""" |
|
Key mismatch when loading checkpoint. This is expected if only part of the model is saved. |
|
Trying to load the model with strict=False. |
|
""" |
|
) |
|
model.load_state_dict(checkpoint["model"], strict=False) |
|
return model |
|
|
|
def _load_checkpoint(self, url_or_filename): |
|
""" |
|
Resume from a checkpoint. |
|
""" |
|
if is_url(url_or_filename): |
|
cached_file = download_cached_file( |
|
url_or_filename, check_hash=False, progress=True |
|
) |
|
checkpoint = torch.load(cached_file, map_location=self.device, strict=False) |
|
elif os.path.isfile(url_or_filename): |
|
checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False) |
|
else: |
|
raise RuntimeError("checkpoint url or path is invalid") |
|
|
|
state_dict = checkpoint["model"] |
|
self.unwrap_dist_model(self.model).load_state_dict(state_dict) |
|
|
|
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if self.scaler and "scaler" in checkpoint: |
|
self.scaler.load_state_dict(checkpoint["scaler"]) |
|
|
|
self.start_epoch = checkpoint["epoch"] + 1 |
|
logging.info("Resume checkpoint from {}".format(url_or_filename)) |
|
|
|
@main_process |
|
def log_stats(self, stats, split_name): |
|
if isinstance(stats, dict): |
|
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} |
|
with open(os.path.join(self.output_dir, "log.txt"), "a") as f: |
|
f.write(json.dumps(log_stats) + "\n") |
|
elif isinstance(stats, list): |
|
pass |
|
|
|
@main_process |
|
def log_config(self): |
|
with open(os.path.join(self.output_dir, "log.txt"), "a") as f: |
|
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") |
|
|