Spaces:
Build error
Build error
import argparse | |
import logging | |
import torch | |
import torch.nn as nn | |
import timeit | |
from maskrcnn_benchmark.layers import * | |
from maskrcnn_benchmark.modeling.backbone.resnet_big import StdConv2d | |
from maskrcnn_benchmark.modeling.backbone.fpn import * | |
from maskrcnn_benchmark.modeling.rpn.inference import * | |
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import PostProcessor | |
from maskrcnn_benchmark.modeling.rpn.anchor_generator import BufferList | |
def profile(model, input_size, custom_ops={}, device="cpu", verbose=False, extra_args={}, return_time=False): | |
handler_collection = [] | |
def add_hooks(m): | |
if len(list(m.children())) > 0: | |
return | |
m.register_buffer('total_ops', torch.zeros(1)) | |
m.register_buffer('total_params', torch.zeros(1)) | |
for p in m.parameters(): | |
m.total_params += torch.Tensor([p.numel()]) | |
m_type = type(m) | |
fn = None | |
if m_type in custom_ops: | |
fn = custom_ops[m_type] | |
elif m_type in register_hooks: | |
fn = register_hooks[m_type] | |
else: | |
print("Not implemented for ", m) | |
if fn is not None: | |
if verbose: | |
print("Register FLOP counter for module %s" % str(m)) | |
handler = m.register_forward_hook(fn) | |
handler_collection.append(handler) | |
original_device = model.parameters().__next__().device | |
training = model.training | |
model.eval().to(device) | |
model.apply(add_hooks) | |
x = torch.zeros(input_size).to(device) | |
with torch.no_grad(): | |
tic = timeit.time.perf_counter() | |
model(x, **extra_args) | |
toc = timeit.time.perf_counter() | |
total_time = toc-tic | |
total_ops = 0 | |
total_params = 0 | |
for m in model.modules(): | |
if len(list(m.children())) > 0: # skip for non-leaf module | |
continue | |
total_ops += m.total_ops | |
total_params += m.total_params | |
total_ops = total_ops.item() | |
total_params = total_params.item() | |
model.train(training).to(original_device) | |
for handler in handler_collection: | |
handler.remove() | |
if return_time: | |
return total_ops, total_params, total_time | |
else: | |
return total_ops, total_params | |
multiply_adds = 1 | |
def count_conv2d(m, x, y): | |
x = x[0] | |
cin = m.in_channels | |
cout = m.out_channels | |
kh, kw = m.kernel_size | |
batch_size = x.size()[0] | |
out_h = y.size(2) | |
out_w = y.size(3) | |
# ops per output element | |
# kernel_mul = kh * kw * cin | |
# kernel_add = kh * kw * cin - 1 | |
kernel_ops = multiply_adds * kh * kw * cin // m.groups | |
bias_ops = 1 if m.bias is not None else 0 | |
ops_per_element = kernel_ops + bias_ops | |
# total ops | |
# num_out_elements = y.numel() | |
output_elements = batch_size * out_w * out_h * cout | |
total_ops = output_elements * ops_per_element | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_convtranspose2d(m, x, y): | |
x = x[0] | |
cin = m.in_channels | |
cout = m.out_channels | |
kh, kw = m.kernel_size | |
batch_size = x.size()[0] | |
out_h = y.size(2) | |
out_w = y.size(3) | |
# ops per output element | |
# kernel_mul = kh * kw * cin | |
# kernel_add = kh * kw * cin - 1 | |
kernel_ops = multiply_adds * kh * kw * cin // m.groups | |
bias_ops = 1 if m.bias is not None else 0 | |
ops_per_element = kernel_ops + bias_ops | |
# total ops | |
# num_out_elements = y.numel() | |
# output_elements = batch_size * out_w * out_h * cout | |
ops_per_element = m.weight.nelement() | |
output_elements = y.nelement() | |
total_ops = output_elements * ops_per_element | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_bn(m, x, y): | |
x = x[0] | |
nelements = x.numel() | |
# subtract, divide, gamma, beta | |
total_ops = 4*nelements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_relu(m, x, y): | |
x = x[0] | |
nelements = x.numel() | |
total_ops = nelements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_softmax(m, x, y): | |
x = x[0] | |
batch_size, nfeatures = x.size() | |
total_exp = nfeatures | |
total_add = nfeatures - 1 | |
total_div = nfeatures | |
total_ops = batch_size * (total_exp + total_add + total_div) | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_maxpool(m, x, y): | |
kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) | |
num_elements = y.numel() | |
total_ops = kernel_ops * num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_adap_maxpool(m, x, y): | |
kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze() | |
kernel_ops = torch.prod(kernel) | |
num_elements = y.numel() | |
total_ops = kernel_ops * num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_avgpool(m, x, y): | |
total_add = torch.prod(torch.Tensor([m.kernel_size])) | |
total_div = 1 | |
kernel_ops = total_add + total_div | |
num_elements = y.numel() | |
total_ops = kernel_ops * num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_adap_avgpool(m, x, y): | |
kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze() | |
total_add = torch.prod(kernel) | |
total_div = 1 | |
kernel_ops = total_add + total_div | |
num_elements = y.numel() | |
total_ops = kernel_ops * num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_linear(m, x, y): | |
# per output element | |
total_mul = m.in_features | |
total_add = m.in_features - 1 | |
num_elements = y.numel() | |
total_ops = (total_mul + total_add) * num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_LastLevelMaxPool(m, x, y): | |
num_elements = y[-1].numel() | |
total_ops = num_elements | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
def count_ROIAlign(m, x, y): | |
num_elements = y.numel() | |
total_ops = num_elements*4 | |
m.total_ops = torch.Tensor([int(total_ops)]) | |
register_hooks = { | |
Scale: None, | |
Conv2d: count_conv2d, | |
nn.Conv2d: count_conv2d, | |
ModulatedDeformConv: count_conv2d, | |
StdConv2d: count_conv2d, | |
nn.BatchNorm1d: count_bn, | |
nn.BatchNorm2d: count_bn, | |
nn.BatchNorm3d: count_bn, | |
FrozenBatchNorm2d: count_bn, | |
nn.GroupNorm: count_bn, | |
NaiveSyncBatchNorm2d: count_bn, | |
nn.ReLU: count_relu, | |
nn.ReLU6: count_relu, | |
swish: None, | |
nn.ConstantPad2d: None, | |
SPPLayer: count_LastLevelMaxPool, | |
LastLevelMaxPool: count_LastLevelMaxPool, | |
nn.MaxPool1d: count_maxpool, | |
nn.MaxPool2d: count_maxpool, | |
nn.MaxPool3d: count_maxpool, | |
nn.AdaptiveMaxPool1d: count_adap_maxpool, | |
nn.AdaptiveMaxPool2d: count_adap_maxpool, | |
nn.AdaptiveMaxPool3d: count_adap_maxpool, | |
nn.AvgPool1d: count_avgpool, | |
nn.AvgPool2d: count_avgpool, | |
nn.AvgPool3d: count_avgpool, | |
nn.AdaptiveAvgPool1d: count_adap_avgpool, | |
nn.AdaptiveAvgPool2d: count_adap_avgpool, | |
nn.AdaptiveAvgPool3d: count_adap_avgpool, | |
nn.Linear: count_linear, | |
nn.Upsample: None, | |
nn.Dropout: None, | |
nn.Sigmoid: None, | |
DropBlock2D: None, | |
ROIAlign: count_ROIAlign, | |
RPNPostProcessor: None, | |
PostProcessor: None, | |
BufferList: None, | |
RetinaPostProcessor: None, | |
FCOSPostProcessor: None, | |
ATSSPostProcessor: None, | |
} |