|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import logging |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from medomni.common.dist_utils import download_cached_file, is_dist_avail_and_initialized |
|
from medomni.common.utils import get_abs_path, is_url |
|
from omegaconf import OmegaConf |
|
import ipdb |
|
|
|
|
|
class BaseModel(nn.Module): |
|
"""Base class for models.""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
@property |
|
def device(self): |
|
return torch.cuda.current_device() |
|
|
|
|
|
|
|
def load_checkpoint(self, url_or_filename): |
|
""" |
|
Load from a finetuned checkpoint. |
|
|
|
This should expect no mismatch in the model keys and the checkpoint keys. |
|
""" |
|
|
|
if is_url(url_or_filename): |
|
cached_file = download_cached_file( |
|
url_or_filename, check_hash=False, progress=True |
|
) |
|
checkpoint = torch.load(cached_file, map_location="cpu") |
|
elif os.path.isfile(url_or_filename): |
|
checkpoint = torch.load(url_or_filename, map_location="cpu") |
|
else: |
|
raise RuntimeError("checkpoint url or path is invalid") |
|
|
|
if "model" in checkpoint.keys(): |
|
state_dict = checkpoint["model"] |
|
else: |
|
state_dict = checkpoint |
|
|
|
msg = self.load_state_dict(state_dict, strict=False) |
|
|
|
logging.info("Missing keys {}".format(msg.missing_keys)) |
|
logging.info("load checkpoint from %s" % url_or_filename) |
|
|
|
return msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def default_config_path(cls, model_type): |
|
assert ( |
|
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT |
|
), "Unknown model type {}".format(model_type) |
|
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) |
|
|
|
def load_checkpoint_from_config(self, cfg, **kwargs): |
|
""" |
|
Load checkpoint as specified in the config file. |
|
|
|
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. |
|
When loading the pretrained model, each task-specific architecture may define their |
|
own load_from_pretrained() method. |
|
""" |
|
load_finetuned = cfg.get("load_finetuned", True) |
|
if load_finetuned: |
|
finetune_path = cfg.get("finetuned", None) |
|
assert ( |
|
finetune_path is not None |
|
), "Found load_finetuned is True, but finetune_path is None." |
|
self.load_checkpoint(url_or_filename=finetune_path) |
|
else: |
|
|
|
pretrain_path = cfg.get("pretrained", None) |
|
assert "Found load_finetuned is False, but pretrain_path is None." |
|
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) |
|
|
|
def before_evaluation(self, **kwargs): |
|
pass |
|
|
|
def show_n_params(self, return_str=True): |
|
tot = 0 |
|
for p in self.parameters(): |
|
w = 1 |
|
for x in p.shape: |
|
w *= x |
|
tot += w |
|
if return_str: |
|
if tot >= 1e6: |
|
return "{:.1f}M".format(tot / 1e6) |
|
else: |
|
return "{:.1f}K".format(tot / 1e3) |
|
else: |
|
return tot |
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
""" |
|
Base class for primitive encoders, such as ViT, TimeSformer, etc. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward_features(self, samples, **kwargs): |
|
raise NotImplementedError |
|
|
|
@property |
|
def device(self): |
|
|
|
return next(self.parameters()).device |
|
|
|
class SharedQueueMixin: |
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): |
|
|
|
image_feats = concat_all_gather(image_feat) |
|
text_feats = concat_all_gather(text_feat) |
|
|
|
batch_size = image_feats.shape[0] |
|
|
|
ptr = int(self.queue_ptr) |
|
assert self.queue_size % batch_size == 0 |
|
|
|
|
|
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T |
|
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T |
|
|
|
if idxs is not None: |
|
idxs = concat_all_gather(idxs) |
|
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T |
|
|
|
ptr = (ptr + batch_size) % self.queue_size |
|
self.queue_ptr[0] = ptr |
|
|
|
|
|
class MomentumDistilationMixin: |
|
@torch.no_grad() |
|
def copy_params(self): |
|
for model_pair in self.model_pairs: |
|
for param, param_m in zip( |
|
model_pair[0].parameters(), model_pair[1].parameters() |
|
): |
|
param_m.data.copy_(param.data) |
|
param_m.requires_grad = False |
|
|
|
@torch.no_grad() |
|
def _momentum_update(self): |
|
for model_pair in self.model_pairs: |
|
for param, param_m in zip( |
|
model_pair[0].parameters(), model_pair[1].parameters() |
|
): |
|
param_m.data = param_m.data * self.momentum + param.data * ( |
|
1.0 - self.momentum |
|
) |
|
|
|
|
|
class GatherLayer(torch.autograd.Function): |
|
""" |
|
Gather tensors from all workers with support for backward propagation: |
|
This implementation does not cut the gradients as torch.distributed.all_gather does. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
output = [ |
|
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(output, x) |
|
return tuple(output) |
|
|
|
@staticmethod |
|
def backward(ctx, *grads): |
|
all_gradients = torch.stack(grads) |
|
torch.distributed.all_reduce(all_gradients) |
|
return all_gradients[torch.distributed.get_rank()] |
|
|
|
|
|
def all_gather_with_grad(tensors): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
Graph remains connected for backward grad computation. |
|
""" |
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
if world_size == 1: |
|
return tensors |
|
|
|
|
|
tensor_all = GatherLayer.apply(tensors) |
|
|
|
return torch.cat(tensor_all, dim=0) |
|
|
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
|
|
if not is_dist_avail_and_initialized(): |
|
return tensor |
|
|
|
tensors_gather = [ |
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|
|
|
|
def tile(x, dim, n_tile): |
|
init_dim = x.size(dim) |
|
repeat_idx = [1] * x.dim() |
|
repeat_idx[dim] = n_tile |
|
x = x.repeat(*(repeat_idx)) |
|
order_index = torch.LongTensor( |
|
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) |
|
) |
|
return torch.index_select(x, dim, order_index.to(x.device)) |
|
|