Jiading Fang
add define
fc16538
raw
history blame
7.58 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import random
from collections import OrderedDict
from inspect import signature
import numpy as np
import torch
from vidar.utils.decorators import iterate1, iterate2
from vidar.utils.types import is_list, is_double_list, is_tuple, is_tensor, is_dict, is_seq
KEYS_IMAGE = [
'rgb', 'mask',
'input_depth', 'depth',
'bwd_optical_flow', 'fwd_optical_flow',
]
KEYS_MATRIX = [
'intrinsics', 'extrinsics', 'pose', 'semantic',
]
def modrem(v, n):
"""Return round division and remainder"""
return v // n, v % n
def flatten(lst):
"""Flatten a list of lists into a list"""
return [l for ls in lst for l in ls] if is_double_list(lst) else lst
def keys_with(dic, string, without=()):
"""Return keys from a dictionary that contain a certain string"""
return [key for key in dic if string in key and not any(w in key for w in make_list(without))]
def keys_startswith(dic, string):
"""Return keys from a dictionary that contain a certain string"""
return [key for key in dic if key.startswith(string)]
def keys_in(dic, keys):
"""Return only keys in a dictionary"""
return [key for key in keys if key in dic]
def str_not_in(string, keys):
for key in keys:
if key in string:
return False
return True
def make_list(var, n=None):
"""Wraps the input into a list, and optionally repeats it to be size n"""
var = var if is_list(var) or is_tuple(var) else [var]
if n is None:
return var
else:
assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
return var * n if len(var) == 1 else var
def filter_args(func, keys):
"""Filters a dictionary, so it only contains keys that are arguments of a function"""
filtered = {}
sign = list(signature(func).parameters.keys())
for k, v in {**keys}.items():
if k in sign:
filtered[k] = v
return filtered
def dict_remove_nones(dic):
"""Filters dictionary to remove keys with None values"""
return {key: val for key, val in dic.items() if val is not None}
@iterate1
def matmul1(v1, v2):
"""Iteratively multiply tensors"""
return v1 @ v2
@iterate2
def matmul2(v1, v2):
"""Iteratively multiply tensors"""
return v1 @ v2
@iterate1
def unsqueeze(x):
"""Iteratively unsqueeze tensors to batch size 1"""
return x.unsqueeze(0) if is_tensor(x) else x
@iterate1
def fold(data, n):
"""Iteratively folds first and second dimensions into one"""
shape = list(data.shape)
if len(shape) == n + 1:
shape = [shape[0] * shape[1]] + shape[2:]
return data.view(*shape)
else:
return data
@iterate1
def expand(data, n, d):
"""Iteratively folds first and second dimensions into one"""
shape = list(data.shape)
if len(shape) == n:
return data.unsqueeze(d)
else:
return data
def fold_batch(batch, device=None):
"""Combine the first (batch) and second (camera) dimensions of a batch"""
if is_seq(batch):
return [fold_batch(b, device=device) for b in batch]
for key in keys_in(batch, KEYS_IMAGE):
batch[key] = fold(batch[key], 4)
for key in keys_in(batch, KEYS_MATRIX):
batch[key] = fold(batch[key], 3)
if device is not None:
batch = batch_to_device(batch, device)
return batch
def expand_batch(batch, d, device=None):
"""Expand the batch to include an additional dimension (0 for batch, 1 for camera)"""
if is_seq(batch):
return [expand_batch(b, d, device=device) for b in batch]
d = {'batch': 0, 'camera': 1}[d]
for key in keys_in(batch, KEYS_IMAGE):
batch[key] = expand(batch[key], 4, d)
for key in keys_in(batch, KEYS_MATRIX):
batch[key] = expand(batch[key], 3, d)
if device is not None:
batch = batch_to_device(batch, device)
return batch
def batch_to_device(batch, device):
"""Copy batch information to device"""
if is_dict(batch):
return {key: batch_to_device(val, device) for key, val in batch.items()}
if is_list(batch):
return [batch_to_device(val, device) for val in batch]
if is_tensor(batch):
return batch.to(device)
return batch
def num_trainable_params(model):
"""Return number of trainable parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def set_random_seed(seed):
if seed >= 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_batch(batch, device=None):
"""Transforms a sample into a batch"""
for key in batch.keys():
if is_dict(batch[key]):
batch[key] = make_batch(batch[key])
elif is_tensor(batch[key]):
batch[key] = batch[key].unsqueeze(0)
if device is not None:
batch = batch_to_device(batch, device)
return batch
def break_key(sample, n=None):
"""Break a multi-camera sample key, so different cameras have their own entries (context, camera)"""
if sample is None:
return sample
new_sample = OrderedDict()
for ctx in sample.keys():
if is_dict(sample[ctx]):
for key2, val in sample[ctx].items():
if val.dim() == 1:
val = val.unsqueeze(1)
for i in range(val.shape[1]):
if (ctx, i) not in new_sample.keys():
new_sample[(ctx, i)] = {}
new_sample[(ctx, i)][key2] = val[:, [i]]
elif sample[ctx].dim() == n + 1:
for i in range(sample[ctx].shape[1]):
new_sample[(ctx, i)] = sample[ctx][:, i]
return new_sample
def break_batch(batch):
"""Break a multi-camera batch, so different cameras have their own entries (context, camera)"""
for key in keys_in(batch, KEYS_IMAGE):
for ctx in list(batch[key].keys()):
if batch[key][ctx].dim() == 5:
for n in range(batch[key][ctx].shape[1]):
batch[key][(ctx,n)] = batch[key][ctx][:, n]
batch[key].pop(ctx)
for key in keys_in(batch, KEYS_MATRIX):
for ctx in list(batch[key].keys()):
if batch[key][ctx].dim() == 4:
for n in range(batch[key][ctx].shape[1]):
batch[key][(ctx,n)] = batch[key][ctx][:, n]
batch[key].pop(ctx)
return batch
def dict_has(dic, key):
"""Check if a dictionary has a certain key"""
return key in dic
def get_from_dict(dic, key):
"""Get value from a dictionary (return None if key is not in dictionary)"""
return None if key not in dic else dic[key]
def get_mask_from_list(mask, i, return_ones=None):
"""Retrieve mask from a list (if it's not a list, return the mask itself, and create one if requested)"""
if return_ones is None:
return None if mask is None else mask[i] if is_list(mask) else mask
else:
mask = torch.ones_like(return_ones[i] if is_list(return_ones) else return_ones).bool() if mask is None \
else mask[i].clone().bool() if is_list(mask) else mask.clone().bool()
if mask.dim() == 4:
return mask[:, [0]]
elif mask.dim() == 3:
return mask[..., [0]]
def get_from_list(lst, i):
"""Get information from a list (return None if input is None, and return value directly if it's not a list)"""
return None if lst is None else lst[i] if is_seq(lst) else lst