import argparse import torch from timm.models import create_model from models.CoAt import * try: from mmcv.cnn import get_model_complexity_info from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string except ImportError: raise ImportError('Please upgrade mmcv to >0.6.2') def parse_args(): parser = argparse.ArgumentParser(description='Get FLOPS of a classification model') parser.add_argument('model', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[224,], help='input image size') args = parser.parse_args() return args def get_flops(model, input_shape): flops, params = get_model_complexity_info(model, input_shape, as_strings=False) return flops_to_string(flops), params_to_string(params) def main(): args = parse_args() if len(args.shape) == 1: input_shape = (3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = (3,) + tuple(args.shape) else: raise ValueError('invalid input shape') model = create_model( args.model, pretrained=False, num_classes=1000, img_size=args.shape[0], ) model.name = args.model if torch.cuda.is_available(): model.cuda() model.eval() flops, params = get_flops(model, input_shape) split_line = '=' * 30 print(f'{split_line}\nInput shape: {input_shape}\n' f'Flops: {flops}\nParams: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') if __name__ == '__main__': main()