|
|
|
import functools
|
|
|
|
import torch
|
|
|
|
|
|
def assert_tensor_type(func):
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if not isinstance(args[0].data, torch.Tensor):
|
|
raise AttributeError(
|
|
f'{args[0].__class__.__name__} has no attribute '
|
|
f'{func.__name__} for type {args[0].datatype}')
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class DataContainer:
|
|
"""A container for any type of objects.
|
|
|
|
Typically tensors will be stacked in the collate function and sliced along
|
|
some dimension in the scatter function. This behavior has some limitations.
|
|
1. All tensors have to be the same size.
|
|
2. Types are limited (numpy array or Tensor).
|
|
|
|
We design `DataContainer` and `MMDataParallel` to overcome these
|
|
limitations. The behavior can be either of the following.
|
|
|
|
- copy to GPU, pad all tensors to the same size and stack them
|
|
- copy to GPU without stacking
|
|
- leave the objects as is and pass it to the model
|
|
- pad_dims specifies the number of last few dimensions to do padding
|
|
"""
|
|
|
|
def __init__(self,
|
|
data,
|
|
stack=False,
|
|
padding_value=0,
|
|
cpu_only=False,
|
|
pad_dims=2):
|
|
self._data = data
|
|
self._cpu_only = cpu_only
|
|
self._stack = stack
|
|
self._padding_value = padding_value
|
|
assert pad_dims in [None, 1, 2, 3]
|
|
self._pad_dims = pad_dims
|
|
|
|
def __repr__(self):
|
|
return f'{self.__class__.__name__}({repr(self.data)})'
|
|
|
|
def __len__(self):
|
|
return len(self._data)
|
|
|
|
@property
|
|
def data(self):
|
|
return self._data
|
|
|
|
@property
|
|
def datatype(self):
|
|
if isinstance(self.data, torch.Tensor):
|
|
return self.data.type()
|
|
else:
|
|
return type(self.data)
|
|
|
|
@property
|
|
def cpu_only(self):
|
|
return self._cpu_only
|
|
|
|
@property
|
|
def stack(self):
|
|
return self._stack
|
|
|
|
@property
|
|
def padding_value(self):
|
|
return self._padding_value
|
|
|
|
@property
|
|
def pad_dims(self):
|
|
return self._pad_dims
|
|
|
|
@assert_tensor_type
|
|
def size(self, *args, **kwargs):
|
|
return self.data.size(*args, **kwargs)
|
|
|
|
@assert_tensor_type
|
|
def dim(self):
|
|
return self.data.dim()
|
|
|