|
import uniperceiver.utils.comm as comm |
|
import torch |
|
import numpy as np |
|
from uniperceiver.utils.events import get_event_storage |
|
from typing import Dict |
|
from uniperceiver.datasets import ( |
|
build_standard_valtest_loader, |
|
build_unified_train_loader, |
|
) |
|
import weakref |
|
|
|
def write_metrics(loss_dict: Dict[str, torch.Tensor], |
|
data_time: float, |
|
prefix: str = "", |
|
): |
|
""" |
|
Args: |
|
loss_dict (dict): dict of scalar losses |
|
data_time (float): time taken by the dataloader iteration |
|
""" |
|
metrics_dict = {} |
|
for k, v in loss_dict.items(): |
|
if isinstance(v, torch.Tensor): |
|
metrics_dict.update({k: v.detach().cpu().item()}) |
|
else: |
|
metrics_dict.update({k: v}) |
|
metrics_dict["data_time"] = data_time |
|
|
|
|
|
|
|
|
|
all_metrics_dict = [metrics_dict] |
|
if comm.is_main_process(): |
|
|
|
storage = get_event_storage() |
|
|
|
|
|
|
|
data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) |
|
storage.put_scalar("data_time", data_time) |
|
|
|
metrics_dict = { |
|
k: np.mean([x[k] for x in all_metrics_dict]) |
|
for k in all_metrics_dict[0].keys() |
|
} |
|
total_losses_reduced = sum(metrics_dict.values()) |
|
storage.put_scalar("{}total_loss".format(prefix), |
|
total_losses_reduced) |
|
if len(metrics_dict) > 1: |
|
for k, v in metrics_dict.items(): |
|
if k != 'null_loss': |
|
storage.put_scalar(f'{prefix}{k}', v) |
|
|
|
def build_writers(cfg, max_iter): |
|
from uniperceiver.engine.defaults import default_writers |
|
return default_writers(cfg.OUTPUT_DIR, max_iter) |
|
|
|
def build_train_loader(cfg, task_cfg, model): |
|
loader = dict() |
|
if cfg.DATALOADER.UNIFIED_DATASET: |
|
loader = build_unified_train_loader(cfg, task_cfg, model=weakref.proxy(comm.unwrap_model(model)) if cfg.DATALOADER.LOAD_INLABEL else None) |
|
return loader |
|
else: |
|
raise NotImplementedError('please use unified dataset.') |
|
|
|
def build_test_loader(cfg, task_cfg): |
|
loaders = dict() |
|
|
|
for name, new_cfg in task_cfg.items(): |
|
multi_gpu = name in [ |
|
'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption', |
|
'flickr30k_caption', 'K700_retrieve', 'imagenet_caption' |
|
] |
|
loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='test', multi_gpu_eval=multi_gpu) |
|
return loaders |
|
|
|
def build_val_loader(cfg, task_cfg): |
|
loaders = dict() |
|
for name, new_cfg in task_cfg.items(): |
|
|
|
multi_gpu = name in [ |
|
'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption', |
|
'flickr30k_caption', 'K700_retrieve', 'imagenet_caption' |
|
] |
|
loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='val', multi_gpu_eval=multi_gpu) |
|
return loaders |
|
|
|
def get_batch_data(cfg, train_data_loader_iter, train_data_loader): |
|
if not cfg.DATALOADER.FAKE_DATA: |
|
try: |
|
data = next(train_data_loader_iter) |
|
except StopIteration: |
|
train_data_loader_iter = iter(train_data_loader) |
|
data = next(train_data_loader_iter) |
|
else: |
|
|
|
bs = 32 |
|
return data |
|
|