Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
from functools import partial | |
import torch | |
TORCH_VERSION = torch.__version__ | |
def is_rocm_pytorch() -> bool: | |
is_rocm = False | |
if TORCH_VERSION != 'parrots': | |
try: | |
from torch.utils.cpp_extension import ROCM_HOME | |
is_rocm = True if ((torch.version.hip is not None) and | |
(ROCM_HOME is not None)) else False | |
except ImportError: | |
pass | |
return is_rocm | |
def _get_cuda_home(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.utils.build_extension import CUDA_HOME | |
else: | |
if is_rocm_pytorch(): | |
from torch.utils.cpp_extension import ROCM_HOME | |
CUDA_HOME = ROCM_HOME | |
else: | |
from torch.utils.cpp_extension import CUDA_HOME | |
return CUDA_HOME | |
def get_build_config(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.config import get_build_info | |
return get_build_info() | |
else: | |
return torch.__config__.show() | |
def _get_conv(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin | |
else: | |
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin | |
return _ConvNd, _ConvTransposeMixin | |
def _get_dataloader(): | |
if TORCH_VERSION == 'parrots': | |
from torch.utils.data import DataLoader, PoolDataLoader | |
else: | |
from torch.utils.data import DataLoader | |
PoolDataLoader = DataLoader | |
return DataLoader, PoolDataLoader | |
def _get_extension(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.utils.build_extension import BuildExtension, Extension | |
CppExtension = partial(Extension, cuda=False) | |
CUDAExtension = partial(Extension, cuda=True) | |
else: | |
from torch.utils.cpp_extension import (BuildExtension, CppExtension, | |
CUDAExtension) | |
return BuildExtension, CppExtension, CUDAExtension | |
def _get_pool(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, | |
_AdaptiveMaxPoolNd, _AvgPoolNd, | |
_MaxPoolNd) | |
else: | |
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, | |
_AdaptiveMaxPoolNd, _AvgPoolNd, | |
_MaxPoolNd) | |
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd | |
def _get_norm(): | |
if TORCH_VERSION == 'parrots': | |
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm | |
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d | |
else: | |
from torch.nn.modules.instancenorm import _InstanceNorm | |
from torch.nn.modules.batchnorm import _BatchNorm | |
SyncBatchNorm_ = torch.nn.SyncBatchNorm | |
return _BatchNorm, _InstanceNorm, SyncBatchNorm_ | |
_ConvNd, _ConvTransposeMixin = _get_conv() | |
DataLoader, PoolDataLoader = _get_dataloader() | |
BuildExtension, CppExtension, CUDAExtension = _get_extension() | |
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() | |
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() | |
class SyncBatchNorm(SyncBatchNorm_): | |
def _check_input_dim(self, input): | |
if TORCH_VERSION == 'parrots': | |
if input.dim() < 2: | |
raise ValueError( | |
f'expected at least 2D input (got {input.dim()}D input)') | |
else: | |
super()._check_input_dim(input) | |