|
""" Conv2D w/ SAME padding, CondConv, MixedConv |
|
|
|
A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and |
|
MobileNetV3 models that maintain weight compatibility with original Tensorflow models. |
|
|
|
Copyright 2020 Ross Wightman |
|
""" |
|
import collections.abc |
|
import math |
|
from functools import partial |
|
from itertools import repeat |
|
from typing import Tuple, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .config import * |
|
|
|
|
|
|
|
def _ntuple(n): |
|
def parse(x): |
|
if isinstance(x, collections.abc.Iterable): |
|
return x |
|
return tuple(repeat(x, n)) |
|
return parse |
|
|
|
|
|
_single = _ntuple(1) |
|
_pair = _ntuple(2) |
|
_triple = _ntuple(3) |
|
_quadruple = _ntuple(4) |
|
|
|
|
|
def _is_static_pad(kernel_size, stride=1, dilation=1, **_): |
|
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 |
|
|
|
|
|
def _get_padding(kernel_size, stride=1, dilation=1, **_): |
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 |
|
return padding |
|
|
|
|
|
def _calc_same_pad(i: int, k: int, s: int, d: int): |
|
return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) |
|
|
|
|
|
def _same_pad_arg(input_size, kernel_size, stride, dilation): |
|
ih, iw = input_size |
|
kh, kw = kernel_size |
|
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) |
|
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) |
|
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] |
|
|
|
|
|
def _split_channels(num_chan, num_groups): |
|
split = [num_chan // num_groups for _ in range(num_groups)] |
|
split[0] += num_chan - sum(split) |
|
return split |
|
|
|
|
|
def conv2d_same( |
|
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), |
|
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): |
|
ih, iw = x.size()[-2:] |
|
kh, kw = weight.size()[-2:] |
|
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) |
|
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) |
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) |
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) |
|
|
|
|
|
class Conv2dSame(nn.Conv2d): |
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions |
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True): |
|
super(Conv2dSame, self).__init__( |
|
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) |
|
|
|
def forward(self, x): |
|
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
|
|
|
class Conv2dSameExport(nn.Conv2d): |
|
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions |
|
|
|
NOTE: This does not currently work with torch.jit.script |
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True): |
|
super(Conv2dSameExport, self).__init__( |
|
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) |
|
self.pad = None |
|
self.pad_input_size = (0, 0) |
|
|
|
def forward(self, x): |
|
input_size = x.size()[-2:] |
|
if self.pad is None: |
|
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) |
|
self.pad = nn.ZeroPad2d(pad_arg) |
|
self.pad_input_size = input_size |
|
|
|
if self.pad is not None: |
|
x = self.pad(x) |
|
return F.conv2d( |
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
|
|
|
def get_padding_value(padding, kernel_size, **kwargs): |
|
dynamic = False |
|
if isinstance(padding, str): |
|
|
|
padding = padding.lower() |
|
if padding == 'same': |
|
|
|
if _is_static_pad(kernel_size, **kwargs): |
|
|
|
padding = _get_padding(kernel_size, **kwargs) |
|
else: |
|
|
|
padding = 0 |
|
dynamic = True |
|
elif padding == 'valid': |
|
|
|
padding = 0 |
|
else: |
|
|
|
padding = _get_padding(kernel_size, **kwargs) |
|
return padding, dynamic |
|
|
|
|
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): |
|
padding = kwargs.pop('padding', '') |
|
kwargs.setdefault('bias', False) |
|
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) |
|
if is_dynamic: |
|
if is_exportable(): |
|
assert not is_scriptable() |
|
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) |
|
else: |
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) |
|
else: |
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) |
|
|
|
|
|
class MixedConv2d(nn.ModuleDict): |
|
""" Mixed Grouped Convolution |
|
Based on MDConv and GroupedConv in MixNet impl: |
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, |
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs): |
|
super(MixedConv2d, self).__init__() |
|
|
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] |
|
num_groups = len(kernel_size) |
|
in_splits = _split_channels(in_channels, num_groups) |
|
out_splits = _split_channels(out_channels, num_groups) |
|
self.in_channels = sum(in_splits) |
|
self.out_channels = sum(out_splits) |
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): |
|
conv_groups = out_ch if depthwise else 1 |
|
self.add_module( |
|
str(idx), |
|
create_conv2d_pad( |
|
in_ch, out_ch, k, stride=stride, |
|
padding=padding, dilation=dilation, groups=conv_groups, **kwargs) |
|
) |
|
self.splits = in_splits |
|
|
|
def forward(self, x): |
|
x_split = torch.split(x, self.splits, 1) |
|
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] |
|
x = torch.cat(x_out, 1) |
|
return x |
|
|
|
|
|
def get_condconv_initializer(initializer, num_experts, expert_shape): |
|
def condconv_initializer(weight): |
|
"""CondConv initializer function.""" |
|
num_params = np.prod(expert_shape) |
|
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or |
|
weight.shape[1] != num_params): |
|
raise (ValueError( |
|
'CondConv variables must have shape [num_experts, num_params]')) |
|
for i in range(num_experts): |
|
initializer(weight[i].view(expert_shape)) |
|
return condconv_initializer |
|
|
|
|
|
class CondConv2d(nn.Module): |
|
""" Conditional Convolution |
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py |
|
|
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: |
|
https://github.com/pytorch/pytorch/issues/17983 |
|
""" |
|
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, |
|
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): |
|
super(CondConv2d, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = _pair(kernel_size) |
|
self.stride = _pair(stride) |
|
padding_val, is_padding_dynamic = get_padding_value( |
|
padding, kernel_size, stride=stride, dilation=dilation) |
|
self.dynamic_padding = is_padding_dynamic |
|
self.padding = _pair(padding_val) |
|
self.dilation = _pair(dilation) |
|
self.groups = groups |
|
self.num_experts = num_experts |
|
|
|
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size |
|
weight_num_param = 1 |
|
for wd in self.weight_shape: |
|
weight_num_param *= wd |
|
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) |
|
|
|
if bias: |
|
self.bias_shape = (self.out_channels,) |
|
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) |
|
else: |
|
self.register_parameter('bias', None) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
init_weight = get_condconv_initializer( |
|
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) |
|
init_weight(self.weight) |
|
if self.bias is not None: |
|
fan_in = np.prod(self.weight_shape[1:]) |
|
bound = 1 / math.sqrt(fan_in) |
|
init_bias = get_condconv_initializer( |
|
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) |
|
init_bias(self.bias) |
|
|
|
def forward(self, x, routing_weights): |
|
B, C, H, W = x.shape |
|
weight = torch.matmul(routing_weights, self.weight) |
|
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size |
|
weight = weight.view(new_weight_shape) |
|
bias = None |
|
if self.bias is not None: |
|
bias = torch.matmul(routing_weights, self.bias) |
|
bias = bias.view(B * self.out_channels) |
|
|
|
x = x.view(1, B * C, H, W) |
|
if self.dynamic_padding: |
|
out = conv2d_same( |
|
x, weight, bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * B) |
|
else: |
|
out = F.conv2d( |
|
x, weight, bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * B) |
|
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): |
|
assert 'groups' not in kwargs |
|
if isinstance(kernel_size, list): |
|
assert 'num_experts' not in kwargs |
|
|
|
|
|
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) |
|
else: |
|
depthwise = kwargs.pop('depthwise', False) |
|
groups = out_chs if depthwise else 1 |
|
if 'num_experts' in kwargs and kwargs['num_experts'] > 0: |
|
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) |
|
else: |
|
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) |
|
return m |
|
|