joshvm commited on
Commit
8f243be
·
1 Parent(s): d170aa9

update to torch2

Browse files
Files changed (8) hide show
  1. app.py +35 -0
  2. config.py +4 -1
  3. data/build.py +6 -15
  4. data/dataset_fg.py +52 -9
  5. inference.py +106 -29
  6. lr_scheduler.py +0 -1
  7. main.py +74 -19
  8. models/MetaFG.py +2 -1
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference import Inference
2
+ import argparse
3
+ import gradio as gr
4
+ import glob
5
+
6
+ def parse_option():
7
+ parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
8
+ parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
9
+ # easy config modification
10
+ parser.add_argument('--model-path', type=str, help="path to model data", default="./ckpt_4_mf2.pth")
11
+ parser.add_argument('--img-size', type=int, default=384, help='path to image')
12
+ parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
13
+ parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
14
+ args = parser.parse_args()
15
+ return args
16
+
17
+ if __name__ == '__main__':
18
+ args = parse_option()
19
+
20
+ model = Inference(config_path=args.cfg,
21
+ model_path=args.model_path,
22
+ names_path=args.names_path)
23
+
24
+ def classify(image):
25
+ preds = model.infer(img_path=image, meta_data_path="meta.txt").squeeze()
26
+ print(len(model.classes))
27
+ print(model.classes)
28
+ confidences = {c: float(preds[i]) for i,c in enumerate(model.classes)}
29
+
30
+ return confidences
31
+
32
+ gr.Interface(pfn=classify,
33
+ inputs=gr.Image(shape=(args.img_size, args.img_size), type="pil"),
34
+ outputs=gr.Label(num_top_classes=10),
35
+ examples=glob.glob("./example_images/*")).launch()
config.py CHANGED
@@ -24,6 +24,8 @@ _C.DATA.BATCH_SIZE = 32
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
26
  _C.DATA.DATASET = 'imagenet'
 
 
27
  # Input image size
28
  _C.DATA.IMG_SIZE = 224
29
  # Interpolation to resize image (random, bilinear, bicubic)
@@ -74,6 +76,7 @@ _C.MODEL.LABEL_SMOOTHING = 0.1
74
  _C.MODEL.PRETRAINED = None
75
  _C.MODEL.DORP_HEAD = True
76
  _C.MODEL.DORP_META = True
 
77
 
78
  _C.MODEL.ONLY_LAST_CLS = False
79
  _C.MODEL.EXTRA_TOKEN_NUM = 1
@@ -255,7 +258,7 @@ def update_config(config, args):
255
  config.MODEL.PRETRAINED = args.pretrain
256
 
257
  # set local rank for distributed training
258
- config.LOCAL_RANK = args.local_rank
259
 
260
  # output folder
261
  config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
 
24
  _C.DATA.DATA_PATH = ''
25
  # Dataset name
26
  _C.DATA.DATASET = 'imagenet'
27
+ # Dataset root folder
28
+ _C.DATA.DATASET_ROOT = None
29
  # Input image size
30
  _C.DATA.IMG_SIZE = 224
31
  # Interpolation to resize image (random, bilinear, bicubic)
 
76
  _C.MODEL.PRETRAINED = None
77
  _C.MODEL.DORP_HEAD = True
78
  _C.MODEL.DORP_META = True
79
+ _C.MODEL.FREEZE_BACKBONE = True
80
 
81
  _C.MODEL.ONLY_LAST_CLS = False
82
  _C.MODEL.EXTRA_TOKEN_NUM = 1
 
258
  config.MODEL.PRETRAINED = args.pretrain
259
 
260
  # set local rank for distributed training
261
+ config.LOCAL_RANK = os.environ['LOCAL_RANK']
262
 
263
  # output folder
264
  config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
data/build.py CHANGED
@@ -13,7 +13,7 @@ from torchvision import datasets, transforms
13
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14
  from timm.data import Mixup
15
  from timm.data import create_transform
16
- from timm.data.transforms import _pil_interp
17
 
18
  from .cached_image_folder import CachedImageFolder
19
  from .samplers import SubsetRandomSampler
@@ -81,50 +81,41 @@ def build_dataset(is_train, config):
81
  # root = os.path.join(config.DATA.DATA_PATH, prefix)
82
  root = './datasets/imagenet'
83
  dataset = datasets.ImageFolder(root, transform=transform)
84
- nb_classes = 1000
85
  elif config.DATA.DATASET == 'inaturelist2021':
86
  root = './datasets/inaturelist2021'
87
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
88
- nb_classes = 10000
89
  elif config.DATA.DATASET == 'inaturelist2021_mini':
90
  root = './datasets/inaturelist2021_mini'
91
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
92
- nb_classes = 10000
93
  elif config.DATA.DATASET == 'inaturelist2017':
94
  root = './datasets/inaturelist2017'
95
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
96
- nb_classes = 5089
97
  elif config.DATA.DATASET == 'inaturelist2018':
98
  root = './datasets/inaturelist2018'
99
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
100
- nb_classes = 8142
101
  elif config.DATA.DATASET == 'cub-200':
102
  root = './datasets/cub-200'
103
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
104
- nb_classes = 200
105
  elif config.DATA.DATASET == 'stanfordcars':
106
  root = './datasets/stanfordcars'
107
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
108
- nb_classes = 196
109
  elif config.DATA.DATASET == 'oxfordflower':
110
  root = './datasets/oxfordflower'
111
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
112
- nb_classes = 102
113
  elif config.DATA.DATASET == 'stanforddogs':
114
  root = './datasets/stanforddogs'
115
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
116
- nb_classes = 120
117
  elif config.DATA.DATASET == 'nabirds':
118
  root = './datasets/nabirds'
119
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
120
- nb_classes = 555
121
  elif config.DATA.DATASET == 'aircraft':
122
  root = './datasets/aircraft'
123
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
124
- nb_classes = 100
125
  else:
126
- raise NotImplementedError("We only support ImageNet and inaturelist.")
 
127
 
 
128
  return dataset, nb_classes
129
 
130
 
@@ -153,14 +144,14 @@ def build_transform(is_train, config):
153
  if config.TEST.CROP:
154
  size = int((256 / 224) * config.DATA.IMG_SIZE)
155
  t.append(
156
- transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
157
  # to maintain same ratio w.r.t. 224 images
158
  )
159
  t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
160
  else:
161
  t.append(
162
  transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
163
- interpolation=_pil_interp(config.DATA.INTERPOLATION))
164
  )
165
 
166
  t.append(transforms.ToTensor())
 
13
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14
  from timm.data import Mixup
15
  from timm.data import create_transform
16
+ from timm.data.transforms import str_to_interp_mode
17
 
18
  from .cached_image_folder import CachedImageFolder
19
  from .samplers import SubsetRandomSampler
 
81
  # root = os.path.join(config.DATA.DATA_PATH, prefix)
82
  root = './datasets/imagenet'
83
  dataset = datasets.ImageFolder(root, transform=transform)
 
84
  elif config.DATA.DATASET == 'inaturelist2021':
85
  root = './datasets/inaturelist2021'
86
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
87
  elif config.DATA.DATASET == 'inaturelist2021_mini':
88
  root = './datasets/inaturelist2021_mini'
89
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
90
  elif config.DATA.DATASET == 'inaturelist2017':
91
  root = './datasets/inaturelist2017'
92
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
93
  elif config.DATA.DATASET == 'inaturelist2018':
94
  root = './datasets/inaturelist2018'
95
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
96
  elif config.DATA.DATASET == 'cub-200':
97
  root = './datasets/cub-200'
98
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
99
  elif config.DATA.DATASET == 'stanfordcars':
100
  root = './datasets/stanfordcars'
101
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
102
  elif config.DATA.DATASET == 'oxfordflower':
103
  root = './datasets/oxfordflower'
104
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
105
  elif config.DATA.DATASET == 'stanforddogs':
106
  root = './datasets/stanforddogs'
107
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
108
  elif config.DATA.DATASET == 'nabirds':
109
  root = './datasets/nabirds'
110
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
111
  elif config.DATA.DATASET == 'aircraft':
112
  root = './datasets/aircraft'
113
  dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
 
114
  else:
115
+ root = config.DATA.DATASET_ROOT
116
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
117
 
118
+ nb_classes = len(dataset.class_to_idx)
119
  return dataset, nb_classes
120
 
121
 
 
144
  if config.TEST.CROP:
145
  size = int((256 / 224) * config.DATA.IMG_SIZE)
146
  t.append(
147
+ transforms.Resize(size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)),
148
  # to maintain same ratio w.r.t. 224 images
149
  )
150
  t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
151
  else:
152
  t.append(
153
  transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
154
+ interpolation=str_to_interp_mode(config.DATA.INTERPOLATION))
155
  )
156
 
157
  t.append(transforms.ToTensor())
data/dataset_fg.py CHANGED
@@ -10,6 +10,7 @@ import pickle
10
  import numpy as np
11
  import pandas as pd
12
  import random
 
13
  random.seed(2021)
14
  from PIL import Image
15
  from scipy import io as scio
@@ -335,7 +336,7 @@ def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False)
335
  else:
336
  images_and_targets.append((file_path,target))
337
  return images_and_targets,class_to_idx,images_info
338
- def find_images_and_targets(root,istrain=False,aux_info=False):
339
  if os.path.exists(os.path.join(root,'train.json')):
340
  with open(os.path.join(root,'train.json'),'r') as f:
341
  train_class_info = json.load(f)
@@ -343,24 +344,59 @@ def find_images_and_targets(root,istrain=False,aux_info=False):
343
  with open(os.path.join(root,'train_mini.json'),'r') as f:
344
  train_class_info = json.load(f)
345
  else:
346
- raise ValueError(f'not eixst file {root}/train.json or {root}/train_mini.json')
 
347
  with open(os.path.join(root,'val.json'),'r') as f:
348
  val_class_info = json.load(f)
349
- categories_2021 = [x['name'].strip().lower() for x in val_class_info['categories']]
350
- class_to_idx = {c: idx for idx, c in enumerate(categories_2021)}
 
351
  id2label = dict()
352
  for categorie in train_class_info['categories']:
353
  id2label[int(categorie['id'])] = categorie['name'].strip().lower()
354
  class_info = train_class_info if istrain else val_class_info
355
-
356
  images_and_targets = []
357
  images_info = []
358
  if aux_info:
359
  temporal_info = []
360
  spatial_info = []
361
 
362
- for image,annotation in zip(class_info['images'],class_info['annotations']):
363
- file_path = os.path.join(root,image['file_name'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  id_name = id2label[int(annotation['category_id'])]
365
  target = class_to_idx[id_name]
366
  date = image['date']
@@ -389,13 +425,15 @@ class DatasetMeta(data.Dataset):
389
  transform=None,
390
  train=False,
391
  aux_info=False,
392
- dataset='inaturelist2021',
393
  class_ratio=1.0,
394
  per_sample=1.0):
395
  self.aux_info = aux_info
396
  self.dataset = dataset
397
  if dataset in ['inaturelist2021','inaturelist2021_mini']:
398
  images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
 
 
399
  elif dataset in ['inaturelist2017','inaturelist2018']:
400
  images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
401
  elif dataset == 'cub-200':
@@ -427,7 +465,12 @@ class DatasetMeta(data.Dataset):
427
  path, target,aux_info = self.samples[index]
428
  else:
429
  path, target = self.samples[index]
430
- img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
 
 
 
 
 
431
  if self.transform is not None:
432
  img = self.transform(img)
433
  if self.aux_info:
 
10
  import numpy as np
11
  import pandas as pd
12
  import random
13
+ from tqdm import tqdm
14
  random.seed(2021)
15
  from PIL import Image
16
  from scipy import io as scio
 
336
  else:
337
  images_and_targets.append((file_path,target))
338
  return images_and_targets,class_to_idx,images_info
339
+ def find_images_and_targets(root,istrain=False,aux_info=False, integrity_check=False):
340
  if os.path.exists(os.path.join(root,'train.json')):
341
  with open(os.path.join(root,'train.json'),'r') as f:
342
  train_class_info = json.load(f)
 
344
  with open(os.path.join(root,'train_mini.json'),'r') as f:
345
  train_class_info = json.load(f)
346
  else:
347
+ raise ValueError(f'{root}/train.json or {root}/train_mini.json doesn\'t exist')
348
+
349
  with open(os.path.join(root,'val.json'),'r') as f:
350
  val_class_info = json.load(f)
351
+
352
+ categories = [x['name'].strip().lower() for x in val_class_info['categories']]
353
+ class_to_idx = {c: idx for idx, c in enumerate(categories)}
354
  id2label = dict()
355
  for categorie in train_class_info['categories']:
356
  id2label[int(categorie['id'])] = categorie['name'].strip().lower()
357
  class_info = train_class_info if istrain else val_class_info
358
+ image_subdir = "train" if istrain else "val"
359
  images_and_targets = []
360
  images_info = []
361
  if aux_info:
362
  temporal_info = []
363
  spatial_info = []
364
 
365
+ ann2im = {}
366
+ for ann in class_info['annotations']:
367
+ ann2im[ann['id']] = ann['image_id']
368
+
369
+ ims = {}
370
+ for image in class_info['images']:
371
+ ims[image['id']] = image
372
+
373
+ print("Found", len(train_class_info['categories']))
374
+ print("Loading images and targets, checking image integrity")
375
+
376
+ for annotation in tqdm(class_info['annotations']):
377
+
378
+ image = ims[annotation['image_id']]
379
+ dir = train_class_info['categories'][annotation['category_id']]['image_dir_name']
380
+
381
+ file_path = os.path.join(root,image_subdir,dir,image['file_name'])
382
+
383
+ if not os.path.exists(file_path):
384
+
385
+ continue
386
+
387
+ print(f"Download {file_path}")
388
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
389
+ import requests
390
+ with open(file_path, 'wb') as fp:
391
+ fp.write(requests.get(image['inaturalist_url']).content)
392
+
393
+ if integrity_check:
394
+ try:
395
+ _ = np.array(Image.open(file_path))
396
+ except:
397
+ print(f"Failed to open {file_path}")
398
+ continue
399
+
400
  id_name = id2label[int(annotation['category_id'])]
401
  target = class_to_idx[id_name]
402
  date = image['date']
 
425
  transform=None,
426
  train=False,
427
  aux_info=False,
428
+ dataset='coco_generic',
429
  class_ratio=1.0,
430
  per_sample=1.0):
431
  self.aux_info = aux_info
432
  self.dataset = dataset
433
  if dataset in ['inaturelist2021','inaturelist2021_mini']:
434
  images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
435
+ elif dataset in ['coco_generic']:
436
+ images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
437
  elif dataset in ['inaturelist2017','inaturelist2018']:
438
  images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
439
  elif dataset == 'cub-200':
 
465
  path, target,aux_info = self.samples[index]
466
  else:
467
  path, target = self.samples[index]
468
+
469
+ try:
470
+ img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
471
+ except:
472
+ img = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8))
473
+
474
  if self.transform is not None:
475
  img = self.transform(img)
476
  if self.aux_info:
inference.py CHANGED
@@ -7,6 +7,10 @@ from torch.autograd import Variable
7
  from torchvision.transforms import transforms
8
  import numpy as np
9
  import argparse
 
 
 
 
10
 
11
  try:
12
  from apex import amp
@@ -34,24 +38,32 @@ def read_class_names(file_path):
34
  class_list = []
35
 
36
  for l in lines:
37
- line = l.strip().split()
38
  # class_list.append(line[0])
39
- class_list.append(line[1][4:])
40
 
41
  classes = tuple(class_list)
42
  return classes
43
 
44
 
45
- class GenerateEmbedding:
46
- def __init__(self, text_file):
47
- self.text_file = text_file
 
 
 
 
48
 
 
 
 
 
49
  self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
50
  self.model = AutoModel.from_pretrained("bert-base-uncased")
51
 
52
- def generate(self):
53
  text_list = []
54
- with open(self.text_file, 'r') as f_text:
55
  for line in f_text:
56
  line = line.encode(encoding='UTF-8', errors='strict')
57
  line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
@@ -69,57 +81,122 @@ class GenerateEmbedding:
69
 
70
 
71
  class Inference:
72
- def __init__(self, config_path, model_path):
 
73
  self.config_path = config_path
74
  self.model_path = model_path
75
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76
- # self.classes = ("cat", "dog")
77
- self.classes = read_class_names(r"D:\dataset\CUB_200_2011\CUB_200_2011\classes_custom.txt")
78
 
79
  self.config = model_config(self.config_path)
 
80
  self.model = build_model(self.config)
81
  self.checkpoint = torch.load(self.model_path, map_location='cpu')
82
- self.model.load_state_dict(self.checkpoint['model'], strict=False)
 
 
 
 
 
83
  self.model.eval()
84
- self.model.cuda()
 
 
85
 
86
  self.transform_img = transforms.Compose([
87
- transforms.Resize((224, 224), interpolation=Image.BILINEAR),
88
  transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
89
  transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
90
  ])
91
 
92
- def infer(self, img_path, meta_data_path):
93
- _, _, meta = GenerateEmbedding(meta_data_path).generate()
94
- meta = meta.cuda()
95
- img = Image.open(img_path).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
96
  img = self.transform_img(img)
97
  img.unsqueeze_(0)
98
- img = img.cuda()
99
  img = Variable(img).to(self.device)
100
  out = self.model(img, meta)
101
 
102
- _, pred = torch.max(out.data, 1)
103
- predict = self.classes[pred.data.item()]
104
- # print(Fore.MAGENTA + f"The Prediction is: {predict}")
105
- return predict
 
 
 
 
 
106
 
107
 
108
  def parse_option():
109
  parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
110
- parser.add_argument('--cfg', type=str, default='D:/pycharmprojects/MetaFormer/configs/MetaFG_meta_bert_1_224.yaml', metavar="FILE", help='path to config file', )
111
  # easy config modification
112
- parser.add_argument('--model-path', default='D:\pycharmprojects\MetaFormer\output\MetaFG_meta_1\cub_200\ckpt_epoch_92.pth', type=str, help="path to model data")
113
- parser.add_argument('--img-path', default=r"D:\dataset\CUB_200_2011\CUB_200_2011\images\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.jpg", type=str, help='path to image')
114
- parser.add_argument('--meta-path', default=r"D:\dataset\CUB_200_2011\text_c10\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.txt", type=str, help='path to meta data')
 
 
115
  args = parser.parse_args()
116
  return args
117
 
118
 
119
  if __name__ == '__main__':
120
  args = parse_option()
121
- result = Inference(config_path=args.cfg,
122
- model_path=args.model_path).infer(img_path=args.img_path, meta_data_path=args.meta_path)
123
- print("Predicted: ", result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
 
7
  from torchvision.transforms import transforms
8
  import numpy as np
9
  import argparse
10
+ from pycocotools.coco import COCO
11
+ import requests
12
+ import os
13
+ from tqdm.auto import tqdm
14
 
15
  try:
16
  from apex import amp
 
38
  class_list = []
39
 
40
  for l in lines:
41
+ line = l.strip()
42
  # class_list.append(line[0])
43
+ class_list.append(line)
44
 
45
  classes = tuple(class_list)
46
  return classes
47
 
48
 
49
+ def read_class_names_coco(file_path):
50
+ dataset = COCO(file_path)
51
+ classes = [dataset.cats[k]['name'] for k in sorted(dataset.cats.keys())]
52
+
53
+ with open("names.txt", 'w') as fp:
54
+ for c in classes:
55
+ fp.write(f"{c}\n")
56
 
57
+ return classes
58
+
59
+ class GenerateEmbedding:
60
+ def __init__(self):
61
  self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
62
  self.model = AutoModel.from_pretrained("bert-base-uncased")
63
 
64
+ def generate(self, text_file):
65
  text_list = []
66
+ with open(text_file, 'r') as f_text:
67
  for line in f_text:
68
  line = line.encode(encoding='UTF-8', errors='strict')
69
  line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
 
81
 
82
 
83
  class Inference:
84
+ def __init__(self, config_path, model_path, names_path):
85
+
86
  self.config_path = config_path
87
  self.model_path = model_path
88
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89
+ self.classes = read_class_names(names_path)
 
90
 
91
  self.config = model_config(self.config_path)
92
+
93
  self.model = build_model(self.config)
94
  self.checkpoint = torch.load(self.model_path, map_location='cpu')
95
+
96
+ if 'model' in self.checkpoint:
97
+ self.model.load_state_dict(self.checkpoint['model'], strict=False)
98
+ else:
99
+ self.model.load_state_dict(self.checkpoint, strict=False)
100
+
101
  self.model.eval()
102
+ self.model.to(self.device)
103
+ self.topk = 10
104
+ self.embedding_gen = GenerateEmbedding()
105
 
106
  self.transform_img = transforms.Compose([
107
+ transforms.Resize((self.config.DATA.IMG_SIZE, self.config.DATA.IMG_SIZE), interpolation=Image.BILINEAR),
108
  transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
109
  transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
110
  ])
111
 
112
+ def infer(self, img_path, meta_data_path, topk=None):
113
+
114
+ if isinstance(img_path, str):
115
+ if img_path.startswith("http"):
116
+ img = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
117
+ else:
118
+ img = Image.open(img_path).convert('RGB')
119
+ else:
120
+ img = img_path
121
+
122
+ """
123
+ _, _, meta = self.embedding_gen(meta_data_path)
124
+ meta = meta.to(self.device)
125
+ """
126
+ meta = None
127
+
128
  img = self.transform_img(img)
129
  img.unsqueeze_(0)
130
+ img = img.to(self.device)
131
  img = Variable(img).to(self.device)
132
  out = self.model(img, meta)
133
 
134
+ f = torch.nn.Softmax(dim=1)
135
+ y_pred = f(out)
136
+ indices = reversed(torch.argsort(y_pred, dim=1).squeeze().tolist())
137
+
138
+ if topk is not None:
139
+ predict = [{self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices[:topk]}]
140
+ return predict
141
+ else:
142
+ return {self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices}
143
 
144
 
145
  def parse_option():
146
  parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
147
+ parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
148
  # easy config modification
149
+ parser.add_argument('--model-path', type=str, help="path to model data", default="ckpt_epoch_12.pth")
150
+ parser.add_argument('--img-path', type=str, help='path to image')
151
+ parser.add_argument('--img-folder', type=str, help='path to image')
152
+ parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
153
+ parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
154
  args = parser.parse_args()
155
  return args
156
 
157
 
158
  if __name__ == '__main__':
159
  args = parse_option()
160
+ model = Inference(config_path=args.cfg,
161
+ model_path=args.model_path,
162
+ names_path=args.names_path)
163
+
164
+ from glob import glob
165
+ glob_imgs = glob(os.path.join(args.img_folder, "*.jpg"))
166
+ out_dir = f"results_{os.path.splitext(os.path.basename(args.model_path))[0]}"
167
+ os.makedirs(out_dir, exist_ok=True)
168
+
169
+ for img in tqdm(glob_imgs):
170
+ try:
171
+ res = model.infer(img_path=img, meta_data_path=args.meta_path)
172
+ except KeyboardInterrupt:
173
+ break
174
+ except Exception as e:
175
+ print(e)
176
+ continue
177
+
178
+ out = {}
179
+ out['preds'] = res
180
+
181
+ """
182
+ # Out is a list of (class, score). Return true/false if the top1 class is correct
183
+ out['top1_correct'] = '_'.join(res[0][1].split(' ')).lower() in os.path.basename(img).lower()
184
+
185
+ out['top5_correct'] = False
186
+ print(os.path.basename(img).lower())
187
+ for i in range(5):
188
+ out['top5_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
189
+ print('_'.join(res[i][1].split(' ')).lower())
190
+
191
+ out['top10_correct'] = False
192
+ for i in range(10):
193
+ out['top10_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
194
+ """
195
+
196
+ # output json with inference results, use image basename
197
+ # as filename
198
+ import json
199
+ with open(os.path.join(out_dir, os.path.splitext(os.path.basename(img))[0]+".json"), 'w') as fp:
200
+ json.dump(out, fp, indent=1)
201
 
202
  # Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
lr_scheduler.py CHANGED
@@ -21,7 +21,6 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
21
  lr_scheduler = CosineLRScheduler(
22
  optimizer,
23
  t_initial=num_steps,
24
- t_mul=1.,
25
  lr_min=config.TRAIN.MIN_LR,
26
  warmup_lr_init=config.TRAIN.WARMUP_LR,
27
  warmup_t=warmup_steps,
 
21
  lr_scheduler = CosineLRScheduler(
22
  optimizer,
23
  t_initial=num_steps,
 
24
  lr_min=config.TRAIN.MIN_LR,
25
  warmup_lr_init=config.TRAIN.WARMUP_LR,
26
  warmup_t=warmup_steps,
main.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  import time
3
  import argparse
4
  import datetime
 
5
  import numpy as np
 
6
 
7
  import torch
8
  import torch.backends.cudnn as cudnn
@@ -18,13 +20,23 @@ from lr_scheduler import build_scheduler
18
  from optimizer import build_optimizer
19
  from logger import create_logger
20
  from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
21
- from torch.utils.tensorboard import SummaryWriter
 
 
 
 
 
 
 
 
22
  try:
23
  # noinspection PyUnresolvedReferences
24
  from apex import amp
25
  except ImportError:
26
  amp = None
27
 
 
 
28
 
29
  def parse_option():
30
  parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
@@ -77,20 +89,19 @@ def parse_option():
77
  help='dataset')
78
  parser.add_argument('--lr-scheduler-name', type=str,
79
  help='lr scheduler name,cosin linear,step')
80
-
81
  parser.add_argument('--pretrain', type=str,
82
  help='pretrain')
83
 
84
- parser.add_argument('--tensorboard', action='store_true', help='using tensorboard')
85
-
86
 
87
- # distributed training
88
- parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
89
-
90
  args, unparsed = parser.parse_known_args()
91
 
92
  config = get_config(args)
93
 
 
 
 
94
  return args, config
95
 
96
 
@@ -98,14 +109,20 @@ def main(config):
98
  dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
99
  logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
100
  model = build_model(config)
 
 
101
  model.cuda()
102
  logger.info(str(model))
103
 
104
  optimizer = build_optimizer(config, model)
105
  if config.AMP_OPT_LEVEL != "O0":
106
  model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
107
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
108
  model_without_ddp = model.module
 
 
 
 
109
  n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
110
  logger.info(f"number of params: {n_parameters}")
111
  if hasattr(model_without_ddp, 'flops'):
@@ -123,10 +140,15 @@ def main(config):
123
  max_accuracy = 0.0
124
  if config.MODEL.PRETRAINED:
125
  load_pretained(config,model_without_ddp,logger)
126
- if config.EVAL_MODE:
127
- acc1, acc5, loss = validate(config, data_loader_val, model)
128
- logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
129
- return
 
 
 
 
 
130
 
131
  if config.TRAIN.AUTO_RESUME:
132
  resume_file = auto_resume_helper(config.OUTPUT)
@@ -143,11 +165,11 @@ def main(config):
143
  if config.MODEL.RESUME:
144
  logger.info(f"**********normal test***********")
145
  max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
146
- acc1, acc5, loss = validate(config, data_loader_val, model)
147
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
148
  if config.DATA.ADD_META:
149
  logger.info(f"**********mask meta test***********")
150
- acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
151
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
152
  if config.EVAL_MODE:
153
  return
@@ -165,18 +187,37 @@ def main(config):
165
  save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
166
 
167
  logger.info(f"**********normal test***********")
168
- acc1, acc5, loss = validate(config, data_loader_val, model)
 
 
 
 
169
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
170
  max_accuracy = max(max_accuracy, acc1)
171
  logger.info(f'Max accuracy: {max_accuracy:.2f}%')
172
  if config.DATA.ADD_META:
173
  logger.info(f"**********mask meta test***********")
174
- acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
 
 
 
 
175
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
176
  # data_loader_train.terminate()
 
 
 
 
 
 
 
 
 
 
177
  total_time = time.time() - start_time
178
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
179
  logger.info('Training time {}'.format(total_time_str))
 
180
  def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
181
  model.train()
182
  if hasattr(model.module,'cur_epoch'):
@@ -261,6 +302,8 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
261
  lr = optimizer.param_groups[0]['lr']
262
  memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
263
  etas = batch_time.avg * (num_steps - idx)
 
 
264
  logger.info(
265
  f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
266
  f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
@@ -271,7 +314,7 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
271
  epoch_time = time.time() - start
272
  logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
273
  @torch.no_grad()
274
- def validate(config, data_loader, model, mask_meta=False):
275
  criterion = torch.nn.CrossEntropyLoss()
276
  model.eval()
277
 
@@ -280,8 +323,16 @@ def validate(config, data_loader, model, mask_meta=False):
280
  acc1_meter = AverageMeter()
281
  acc5_meter = AverageMeter()
282
 
 
 
283
  end = time.time()
 
284
  for idx, data in enumerate(data_loader):
 
 
 
 
 
285
  if config.DATA.ADD_META:
286
  images,target,meta = data
287
  meta = [m.float() for m in meta]
@@ -314,6 +365,9 @@ def validate(config, data_loader, model, mask_meta=False):
314
  acc1_meter.update(acc1.item(), target.size(0))
315
  acc5_meter.update(acc5.item(), target.size(0))
316
 
 
 
 
317
  # measure elapsed time
318
  batch_time.update(time.time() - end)
319
  end = time.time()
@@ -328,7 +382,8 @@ def validate(config, data_loader, model, mask_meta=False):
328
  f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
329
  f'Mem {memory_used:.0f}MB')
330
  logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
331
- return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
 
332
 
333
 
334
  @torch.no_grad()
@@ -364,7 +419,7 @@ if __name__ == '__main__':
364
  else:
365
  rank = -1
366
  world_size = -1
367
- torch.cuda.set_device(config.LOCAL_RANK)
368
  torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
369
  torch.distributed.barrier()
370
 
 
2
  import time
3
  import argparse
4
  import datetime
5
+ import json
6
  import numpy as np
7
+ from collections import defaultdict
8
 
9
  import torch
10
  import torch.backends.cudnn as cudnn
 
20
  from optimizer import build_optimizer
21
  from logger import create_logger
22
  from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
23
+
24
+ have_wandb = False
25
+ try:
26
+ import wandb
27
+ have_wandb = True
28
+ except:
29
+ pass
30
+
31
+ # TODO use torch.amp
32
  try:
33
  # noinspection PyUnresolvedReferences
34
  from apex import amp
35
  except ImportError:
36
  amp = None
37
 
38
+ import logging
39
+ logging.basicConfig(level=logging.INFO)
40
 
41
  def parse_option():
42
  parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
 
89
  help='dataset')
90
  parser.add_argument('--lr-scheduler-name', type=str,
91
  help='lr scheduler name,cosin linear,step')
92
+
93
  parser.add_argument('--pretrain', type=str,
94
  help='pretrain')
95
 
96
+ parser.add_argument('--wandb_job', type=str)
 
97
 
 
 
 
98
  args, unparsed = parser.parse_known_args()
99
 
100
  config = get_config(args)
101
 
102
+ if have_wandb and int(config.LOCAL_RANK) == 0:
103
+ wandb.init(name = args.wandb_job, config=args)
104
+
105
  return args, config
106
 
107
 
 
109
  dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
110
  logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
111
  model = build_model(config)
112
+ if have_wandb and int(config.LOCAL_RANK) == 0:
113
+ wandb.config['model_config'] = config
114
  model.cuda()
115
  logger.info(str(model))
116
 
117
  optimizer = build_optimizer(config, model)
118
  if config.AMP_OPT_LEVEL != "O0":
119
  model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
120
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(config.LOCAL_RANK)], broadcast_buffers=False)
121
  model_without_ddp = model.module
122
+
123
+ if have_wandb and int(config.LOCAL_RANK) == 0:
124
+ wandb.watch(model, log_freq=100)
125
+
126
  n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
127
  logger.info(f"number of params: {n_parameters}")
128
  if hasattr(model_without_ddp, 'flops'):
 
140
  max_accuracy = 0.0
141
  if config.MODEL.PRETRAINED:
142
  load_pretained(config,model_without_ddp,logger)
143
+
144
+ # Run initial validation
145
+ logger.info("Start validation (on init)")
146
+ acc1, acc5, loss, stats = validate(config, data_loader_val, model, limit=10)
147
+
148
+ with open(os.path.join(config.OUTPUT, f'val_init.json'), 'w') as fp:
149
+ json.dump(stats, fp, indent=1)
150
+
151
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
152
 
153
  if config.TRAIN.AUTO_RESUME:
154
  resume_file = auto_resume_helper(config.OUTPUT)
 
165
  if config.MODEL.RESUME:
166
  logger.info(f"**********normal test***********")
167
  max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
168
+ acc1, acc5, loss, stats = validate(config, data_loader_val, model)
169
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
170
  if config.DATA.ADD_META:
171
  logger.info(f"**********mask meta test***********")
172
+ acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
173
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
174
  if config.EVAL_MODE:
175
  return
 
187
  save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
188
 
189
  logger.info(f"**********normal test***********")
190
+ acc1, acc5, loss, stats = validate(config, data_loader_val, model)
191
+
192
+ with open(os.path.join(config.OUTPUT, f'val_{epoch}.json'), 'w') as fp:
193
+ json.dump(stats, fp, indent=1)
194
+
195
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
196
  max_accuracy = max(max_accuracy, acc1)
197
  logger.info(f'Max accuracy: {max_accuracy:.2f}%')
198
  if config.DATA.ADD_META:
199
  logger.info(f"**********mask meta test***********")
200
+ acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
201
+
202
+ with open(os.path.join(config.OUTPUT, f'val_{epoch}_meta.json'), 'w') as fp:
203
+ json.dump(stats, fp, indent=1)
204
+
205
  logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
206
  # data_loader_train.terminate()
207
+
208
+ if have_wandb and int(config.LOCAL_RANK) == 0:
209
+ wandb.run.summary["acc_top_1"] = acc1
210
+ wandb.run.summary["acc_top_5"] = acc5
211
+ wandb.run.summary["val_loss"] = loss
212
+
213
+ wandb.log({'val/acc1': acc1})
214
+ wandb.log({'val/acc5': acc5})
215
+ wandb.log({'val/loss': acc5})
216
+
217
  total_time = time.time() - start_time
218
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219
  logger.info('Training time {}'.format(total_time_str))
220
+
221
  def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
222
  model.train()
223
  if hasattr(model.module,'cur_epoch'):
 
302
  lr = optimizer.param_groups[0]['lr']
303
  memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
304
  etas = batch_time.avg * (num_steps - idx)
305
+ if have_wandb and int(config.LOCAL_RANK) == 0 and idx % 100 == 0:
306
+ wandb.log({"train/loss": loss_meter.val})
307
  logger.info(
308
  f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
309
  f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
 
314
  epoch_time = time.time() - start
315
  logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
316
  @torch.no_grad()
317
+ def validate(config, data_loader, model, mask_meta=False, limit=None):
318
  criterion = torch.nn.CrossEntropyLoss()
319
  model.eval()
320
 
 
323
  acc1_meter = AverageMeter()
324
  acc5_meter = AverageMeter()
325
 
326
+ stats = defaultdict(list)
327
+
328
  end = time.time()
329
+
330
  for idx, data in enumerate(data_loader):
331
+
332
+ if limit:
333
+ if idx > limit:
334
+ break
335
+
336
  if config.DATA.ADD_META:
337
  images,target,meta = data
338
  meta = [m.float() for m in meta]
 
365
  acc1_meter.update(acc1.item(), target.size(0))
366
  acc5_meter.update(acc5.item(), target.size(0))
367
 
368
+ for t in target:
369
+ stats[int(t.item())].append((acc1.item(), acc5.item(), loss.item()))
370
+
371
  # measure elapsed time
372
  batch_time.update(time.time() - end)
373
  end = time.time()
 
382
  f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
383
  f'Mem {memory_used:.0f}MB')
384
  logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
385
+
386
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg, stats
387
 
388
 
389
  @torch.no_grad()
 
419
  else:
420
  rank = -1
421
  world_size = -1
422
+ torch.cuda.set_device(f'cuda:{config.LOCAL_RANK}')
423
  torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
424
  torch.distributed.barrier()
425
 
models/MetaFG.py CHANGED
@@ -54,7 +54,8 @@ class MetaFG(nn.Module):
54
  qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
55
  meta_dims=[],
56
  only_last_cls=False,
57
- use_checkpoint=False):
 
58
  super().__init__()
59
  self.only_last_cls = only_last_cls
60
  self.img_size = img_size
 
54
  qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
55
  meta_dims=[],
56
  only_last_cls=False,
57
+ use_checkpoint=False,
58
+ **kwargs):
59
  super().__init__()
60
  self.only_last_cls = only_last_cls
61
  self.img_size = img_size