Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import random | |
from collections import OrderedDict | |
from inspect import signature | |
import numpy as np | |
import torch | |
from vidar.utils.decorators import iterate1, iterate2 | |
from vidar.utils.types import is_list, is_double_list, is_tuple, is_tensor, is_dict, is_seq | |
KEYS_IMAGE = [ | |
'rgb', 'mask', | |
'input_depth', 'depth', | |
'bwd_optical_flow', 'fwd_optical_flow', | |
] | |
KEYS_MATRIX = [ | |
'intrinsics', 'extrinsics', 'pose', 'semantic', | |
] | |
def modrem(v, n): | |
"""Return round division and remainder""" | |
return v // n, v % n | |
def flatten(lst): | |
"""Flatten a list of lists into a list""" | |
return [l for ls in lst for l in ls] if is_double_list(lst) else lst | |
def keys_with(dic, string, without=()): | |
"""Return keys from a dictionary that contain a certain string""" | |
return [key for key in dic if string in key and not any(w in key for w in make_list(without))] | |
def keys_startswith(dic, string): | |
"""Return keys from a dictionary that contain a certain string""" | |
return [key for key in dic if key.startswith(string)] | |
def keys_in(dic, keys): | |
"""Return only keys in a dictionary""" | |
return [key for key in keys if key in dic] | |
def str_not_in(string, keys): | |
for key in keys: | |
if key in string: | |
return False | |
return True | |
def make_list(var, n=None): | |
"""Wraps the input into a list, and optionally repeats it to be size n""" | |
var = var if is_list(var) or is_tuple(var) else [var] | |
if n is None: | |
return var | |
else: | |
assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list' | |
return var * n if len(var) == 1 else var | |
def filter_args(func, keys): | |
"""Filters a dictionary, so it only contains keys that are arguments of a function""" | |
filtered = {} | |
sign = list(signature(func).parameters.keys()) | |
for k, v in {**keys}.items(): | |
if k in sign: | |
filtered[k] = v | |
return filtered | |
def dict_remove_nones(dic): | |
"""Filters dictionary to remove keys with None values""" | |
return {key: val for key, val in dic.items() if val is not None} | |
def matmul1(v1, v2): | |
"""Iteratively multiply tensors""" | |
return v1 @ v2 | |
def matmul2(v1, v2): | |
"""Iteratively multiply tensors""" | |
return v1 @ v2 | |
def unsqueeze(x): | |
"""Iteratively unsqueeze tensors to batch size 1""" | |
return x.unsqueeze(0) if is_tensor(x) else x | |
def fold(data, n): | |
"""Iteratively folds first and second dimensions into one""" | |
shape = list(data.shape) | |
if len(shape) == n + 1: | |
shape = [shape[0] * shape[1]] + shape[2:] | |
return data.view(*shape) | |
else: | |
return data | |
def expand(data, n, d): | |
"""Iteratively folds first and second dimensions into one""" | |
shape = list(data.shape) | |
if len(shape) == n: | |
return data.unsqueeze(d) | |
else: | |
return data | |
def fold_batch(batch, device=None): | |
"""Combine the first (batch) and second (camera) dimensions of a batch""" | |
if is_seq(batch): | |
return [fold_batch(b, device=device) for b in batch] | |
for key in keys_in(batch, KEYS_IMAGE): | |
batch[key] = fold(batch[key], 4) | |
for key in keys_in(batch, KEYS_MATRIX): | |
batch[key] = fold(batch[key], 3) | |
if device is not None: | |
batch = batch_to_device(batch, device) | |
return batch | |
def expand_batch(batch, d, device=None): | |
"""Expand the batch to include an additional dimension (0 for batch, 1 for camera)""" | |
if is_seq(batch): | |
return [expand_batch(b, d, device=device) for b in batch] | |
d = {'batch': 0, 'camera': 1}[d] | |
for key in keys_in(batch, KEYS_IMAGE): | |
batch[key] = expand(batch[key], 4, d) | |
for key in keys_in(batch, KEYS_MATRIX): | |
batch[key] = expand(batch[key], 3, d) | |
if device is not None: | |
batch = batch_to_device(batch, device) | |
return batch | |
def batch_to_device(batch, device): | |
"""Copy batch information to device""" | |
if is_dict(batch): | |
return {key: batch_to_device(val, device) for key, val in batch.items()} | |
if is_list(batch): | |
return [batch_to_device(val, device) for val in batch] | |
if is_tensor(batch): | |
return batch.to(device) | |
return batch | |
def num_trainable_params(model): | |
"""Return number of trainable parameters""" | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def set_random_seed(seed): | |
if seed >= 0: | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def make_batch(batch, device=None): | |
"""Transforms a sample into a batch""" | |
for key in batch.keys(): | |
if is_dict(batch[key]): | |
batch[key] = make_batch(batch[key]) | |
elif is_tensor(batch[key]): | |
batch[key] = batch[key].unsqueeze(0) | |
if device is not None: | |
batch = batch_to_device(batch, device) | |
return batch | |
def break_key(sample, n=None): | |
"""Break a multi-camera sample key, so different cameras have their own entries (context, camera)""" | |
if sample is None: | |
return sample | |
new_sample = OrderedDict() | |
for ctx in sample.keys(): | |
if is_dict(sample[ctx]): | |
for key2, val in sample[ctx].items(): | |
if val.dim() == 1: | |
val = val.unsqueeze(1) | |
for i in range(val.shape[1]): | |
if (ctx, i) not in new_sample.keys(): | |
new_sample[(ctx, i)] = {} | |
new_sample[(ctx, i)][key2] = val[:, [i]] | |
elif sample[ctx].dim() == n + 1: | |
for i in range(sample[ctx].shape[1]): | |
new_sample[(ctx, i)] = sample[ctx][:, i] | |
return new_sample | |
def break_batch(batch): | |
"""Break a multi-camera batch, so different cameras have their own entries (context, camera)""" | |
for key in keys_in(batch, KEYS_IMAGE): | |
for ctx in list(batch[key].keys()): | |
if batch[key][ctx].dim() == 5: | |
for n in range(batch[key][ctx].shape[1]): | |
batch[key][(ctx,n)] = batch[key][ctx][:, n] | |
batch[key].pop(ctx) | |
for key in keys_in(batch, KEYS_MATRIX): | |
for ctx in list(batch[key].keys()): | |
if batch[key][ctx].dim() == 4: | |
for n in range(batch[key][ctx].shape[1]): | |
batch[key][(ctx,n)] = batch[key][ctx][:, n] | |
batch[key].pop(ctx) | |
return batch | |
def dict_has(dic, key): | |
"""Check if a dictionary has a certain key""" | |
return key in dic | |
def get_from_dict(dic, key): | |
"""Get value from a dictionary (return None if key is not in dictionary)""" | |
return None if key not in dic else dic[key] | |
def get_mask_from_list(mask, i, return_ones=None): | |
"""Retrieve mask from a list (if it's not a list, return the mask itself, and create one if requested)""" | |
if return_ones is None: | |
return None if mask is None else mask[i] if is_list(mask) else mask | |
else: | |
mask = torch.ones_like(return_ones[i] if is_list(return_ones) else return_ones).bool() if mask is None \ | |
else mask[i].clone().bool() if is_list(mask) else mask.clone().bool() | |
if mask.dim() == 4: | |
return mask[:, [0]] | |
elif mask.dim() == 3: | |
return mask[..., [0]] | |
def get_from_list(lst, i): | |
"""Get information from a list (return None if input is None, and return value directly if it's not a list)""" | |
return None if lst is None else lst[i] if is_seq(lst) else lst | |