Spaces:
Paused
Paused
File size: 2,021 Bytes
508927a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os
import pkgutil
import warnings
from collections import namedtuple
import torch
if torch.__version__ != 'parrots':
def load_ext(name, funcs):
ext = importlib.import_module('mmcv.' + name)
for fun in funcs:
assert hasattr(ext, fun), f'{fun} miss in module {name}'
return ext
else:
from parrots import extension
from parrots.base import ParrotsException
has_return_value_ops = [
'nms',
'softnms',
'nms_match',
'nms_rotated',
'top_pool_forward',
'top_pool_backward',
'bottom_pool_forward',
'bottom_pool_backward',
'left_pool_forward',
'left_pool_backward',
'right_pool_forward',
'right_pool_backward',
'fused_bias_leakyrelu',
'upfirdn2d',
'ms_deform_attn_forward',
'pixel_group',
'contour_expand',
]
def get_fake_func(name, e):
def fake_func(*args, **kwargs):
warnings.warn(f'{name} is not supported in parrots now')
raise e
return fake_func
def load_ext(name, funcs):
ExtModule = namedtuple('ExtModule', funcs)
ext_list = []
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
for fun in funcs:
try:
ext_fun = extension.load(fun, name, lib_dir=lib_root)
except ParrotsException as e:
if 'No element registered' not in e.message:
warnings.warn(e.message)
ext_fun = get_fake_func(fun, e)
ext_list.append(ext_fun)
else:
if fun in has_return_value_ops:
ext_list.append(ext_fun.op)
else:
ext_list.append(ext_fun.op_)
return ExtModule(*ext_list)
def check_ops_exist():
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None
|