Spaces:
Build error
Build error
update to torch2
Browse files- app.py +35 -0
- config.py +4 -1
- data/build.py +6 -15
- data/dataset_fg.py +52 -9
- inference.py +106 -29
- lr_scheduler.py +0 -1
- main.py +74 -19
- models/MetaFG.py +2 -1
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference import Inference
|
2 |
+
import argparse
|
3 |
+
import gradio as gr
|
4 |
+
import glob
|
5 |
+
|
6 |
+
def parse_option():
|
7 |
+
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
8 |
+
parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
|
9 |
+
# easy config modification
|
10 |
+
parser.add_argument('--model-path', type=str, help="path to model data", default="./ckpt_4_mf2.pth")
|
11 |
+
parser.add_argument('--img-size', type=int, default=384, help='path to image')
|
12 |
+
parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
|
13 |
+
parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
|
14 |
+
args = parser.parse_args()
|
15 |
+
return args
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
args = parse_option()
|
19 |
+
|
20 |
+
model = Inference(config_path=args.cfg,
|
21 |
+
model_path=args.model_path,
|
22 |
+
names_path=args.names_path)
|
23 |
+
|
24 |
+
def classify(image):
|
25 |
+
preds = model.infer(img_path=image, meta_data_path="meta.txt").squeeze()
|
26 |
+
print(len(model.classes))
|
27 |
+
print(model.classes)
|
28 |
+
confidences = {c: float(preds[i]) for i,c in enumerate(model.classes)}
|
29 |
+
|
30 |
+
return confidences
|
31 |
+
|
32 |
+
gr.Interface(pfn=classify,
|
33 |
+
inputs=gr.Image(shape=(args.img_size, args.img_size), type="pil"),
|
34 |
+
outputs=gr.Label(num_top_classes=10),
|
35 |
+
examples=glob.glob("./example_images/*")).launch()
|
config.py
CHANGED
@@ -24,6 +24,8 @@ _C.DATA.BATCH_SIZE = 32
|
|
24 |
_C.DATA.DATA_PATH = ''
|
25 |
# Dataset name
|
26 |
_C.DATA.DATASET = 'imagenet'
|
|
|
|
|
27 |
# Input image size
|
28 |
_C.DATA.IMG_SIZE = 224
|
29 |
# Interpolation to resize image (random, bilinear, bicubic)
|
@@ -74,6 +76,7 @@ _C.MODEL.LABEL_SMOOTHING = 0.1
|
|
74 |
_C.MODEL.PRETRAINED = None
|
75 |
_C.MODEL.DORP_HEAD = True
|
76 |
_C.MODEL.DORP_META = True
|
|
|
77 |
|
78 |
_C.MODEL.ONLY_LAST_CLS = False
|
79 |
_C.MODEL.EXTRA_TOKEN_NUM = 1
|
@@ -255,7 +258,7 @@ def update_config(config, args):
|
|
255 |
config.MODEL.PRETRAINED = args.pretrain
|
256 |
|
257 |
# set local rank for distributed training
|
258 |
-
config.LOCAL_RANK =
|
259 |
|
260 |
# output folder
|
261 |
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
|
|
24 |
_C.DATA.DATA_PATH = ''
|
25 |
# Dataset name
|
26 |
_C.DATA.DATASET = 'imagenet'
|
27 |
+
# Dataset root folder
|
28 |
+
_C.DATA.DATASET_ROOT = None
|
29 |
# Input image size
|
30 |
_C.DATA.IMG_SIZE = 224
|
31 |
# Interpolation to resize image (random, bilinear, bicubic)
|
|
|
76 |
_C.MODEL.PRETRAINED = None
|
77 |
_C.MODEL.DORP_HEAD = True
|
78 |
_C.MODEL.DORP_META = True
|
79 |
+
_C.MODEL.FREEZE_BACKBONE = True
|
80 |
|
81 |
_C.MODEL.ONLY_LAST_CLS = False
|
82 |
_C.MODEL.EXTRA_TOKEN_NUM = 1
|
|
|
258 |
config.MODEL.PRETRAINED = args.pretrain
|
259 |
|
260 |
# set local rank for distributed training
|
261 |
+
config.LOCAL_RANK = os.environ['LOCAL_RANK']
|
262 |
|
263 |
# output folder
|
264 |
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
data/build.py
CHANGED
@@ -13,7 +13,7 @@ from torchvision import datasets, transforms
|
|
13 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
14 |
from timm.data import Mixup
|
15 |
from timm.data import create_transform
|
16 |
-
from timm.data.transforms import
|
17 |
|
18 |
from .cached_image_folder import CachedImageFolder
|
19 |
from .samplers import SubsetRandomSampler
|
@@ -81,50 +81,41 @@ def build_dataset(is_train, config):
|
|
81 |
# root = os.path.join(config.DATA.DATA_PATH, prefix)
|
82 |
root = './datasets/imagenet'
|
83 |
dataset = datasets.ImageFolder(root, transform=transform)
|
84 |
-
nb_classes = 1000
|
85 |
elif config.DATA.DATASET == 'inaturelist2021':
|
86 |
root = './datasets/inaturelist2021'
|
87 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
88 |
-
nb_classes = 10000
|
89 |
elif config.DATA.DATASET == 'inaturelist2021_mini':
|
90 |
root = './datasets/inaturelist2021_mini'
|
91 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
92 |
-
nb_classes = 10000
|
93 |
elif config.DATA.DATASET == 'inaturelist2017':
|
94 |
root = './datasets/inaturelist2017'
|
95 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
96 |
-
nb_classes = 5089
|
97 |
elif config.DATA.DATASET == 'inaturelist2018':
|
98 |
root = './datasets/inaturelist2018'
|
99 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
100 |
-
nb_classes = 8142
|
101 |
elif config.DATA.DATASET == 'cub-200':
|
102 |
root = './datasets/cub-200'
|
103 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
104 |
-
nb_classes = 200
|
105 |
elif config.DATA.DATASET == 'stanfordcars':
|
106 |
root = './datasets/stanfordcars'
|
107 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
108 |
-
nb_classes = 196
|
109 |
elif config.DATA.DATASET == 'oxfordflower':
|
110 |
root = './datasets/oxfordflower'
|
111 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
112 |
-
nb_classes = 102
|
113 |
elif config.DATA.DATASET == 'stanforddogs':
|
114 |
root = './datasets/stanforddogs'
|
115 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
116 |
-
nb_classes = 120
|
117 |
elif config.DATA.DATASET == 'nabirds':
|
118 |
root = './datasets/nabirds'
|
119 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
120 |
-
nb_classes = 555
|
121 |
elif config.DATA.DATASET == 'aircraft':
|
122 |
root = './datasets/aircraft'
|
123 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
124 |
-
nb_classes = 100
|
125 |
else:
|
126 |
-
|
|
|
127 |
|
|
|
128 |
return dataset, nb_classes
|
129 |
|
130 |
|
@@ -153,14 +144,14 @@ def build_transform(is_train, config):
|
|
153 |
if config.TEST.CROP:
|
154 |
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
155 |
t.append(
|
156 |
-
transforms.Resize(size, interpolation=
|
157 |
# to maintain same ratio w.r.t. 224 images
|
158 |
)
|
159 |
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
160 |
else:
|
161 |
t.append(
|
162 |
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
163 |
-
interpolation=
|
164 |
)
|
165 |
|
166 |
t.append(transforms.ToTensor())
|
|
|
13 |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
14 |
from timm.data import Mixup
|
15 |
from timm.data import create_transform
|
16 |
+
from timm.data.transforms import str_to_interp_mode
|
17 |
|
18 |
from .cached_image_folder import CachedImageFolder
|
19 |
from .samplers import SubsetRandomSampler
|
|
|
81 |
# root = os.path.join(config.DATA.DATA_PATH, prefix)
|
82 |
root = './datasets/imagenet'
|
83 |
dataset = datasets.ImageFolder(root, transform=transform)
|
|
|
84 |
elif config.DATA.DATASET == 'inaturelist2021':
|
85 |
root = './datasets/inaturelist2021'
|
86 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
87 |
elif config.DATA.DATASET == 'inaturelist2021_mini':
|
88 |
root = './datasets/inaturelist2021_mini'
|
89 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
90 |
elif config.DATA.DATASET == 'inaturelist2017':
|
91 |
root = './datasets/inaturelist2017'
|
92 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
93 |
elif config.DATA.DATASET == 'inaturelist2018':
|
94 |
root = './datasets/inaturelist2018'
|
95 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
96 |
elif config.DATA.DATASET == 'cub-200':
|
97 |
root = './datasets/cub-200'
|
98 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
99 |
elif config.DATA.DATASET == 'stanfordcars':
|
100 |
root = './datasets/stanfordcars'
|
101 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
102 |
elif config.DATA.DATASET == 'oxfordflower':
|
103 |
root = './datasets/oxfordflower'
|
104 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
105 |
elif config.DATA.DATASET == 'stanforddogs':
|
106 |
root = './datasets/stanforddogs'
|
107 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
108 |
elif config.DATA.DATASET == 'nabirds':
|
109 |
root = './datasets/nabirds'
|
110 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
111 |
elif config.DATA.DATASET == 'aircraft':
|
112 |
root = './datasets/aircraft'
|
113 |
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
|
|
114 |
else:
|
115 |
+
root = config.DATA.DATASET_ROOT
|
116 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
117 |
|
118 |
+
nb_classes = len(dataset.class_to_idx)
|
119 |
return dataset, nb_classes
|
120 |
|
121 |
|
|
|
144 |
if config.TEST.CROP:
|
145 |
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
146 |
t.append(
|
147 |
+
transforms.Resize(size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION)),
|
148 |
# to maintain same ratio w.r.t. 224 images
|
149 |
)
|
150 |
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
151 |
else:
|
152 |
t.append(
|
153 |
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
154 |
+
interpolation=str_to_interp_mode(config.DATA.INTERPOLATION))
|
155 |
)
|
156 |
|
157 |
t.append(transforms.ToTensor())
|
data/dataset_fg.py
CHANGED
@@ -10,6 +10,7 @@ import pickle
|
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
12 |
import random
|
|
|
13 |
random.seed(2021)
|
14 |
from PIL import Image
|
15 |
from scipy import io as scio
|
@@ -335,7 +336,7 @@ def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False)
|
|
335 |
else:
|
336 |
images_and_targets.append((file_path,target))
|
337 |
return images_and_targets,class_to_idx,images_info
|
338 |
-
def find_images_and_targets(root,istrain=False,aux_info=False):
|
339 |
if os.path.exists(os.path.join(root,'train.json')):
|
340 |
with open(os.path.join(root,'train.json'),'r') as f:
|
341 |
train_class_info = json.load(f)
|
@@ -343,24 +344,59 @@ def find_images_and_targets(root,istrain=False,aux_info=False):
|
|
343 |
with open(os.path.join(root,'train_mini.json'),'r') as f:
|
344 |
train_class_info = json.load(f)
|
345 |
else:
|
346 |
-
raise ValueError(f'
|
|
|
347 |
with open(os.path.join(root,'val.json'),'r') as f:
|
348 |
val_class_info = json.load(f)
|
349 |
-
|
350 |
-
|
|
|
351 |
id2label = dict()
|
352 |
for categorie in train_class_info['categories']:
|
353 |
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
354 |
class_info = train_class_info if istrain else val_class_info
|
355 |
-
|
356 |
images_and_targets = []
|
357 |
images_info = []
|
358 |
if aux_info:
|
359 |
temporal_info = []
|
360 |
spatial_info = []
|
361 |
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
id_name = id2label[int(annotation['category_id'])]
|
365 |
target = class_to_idx[id_name]
|
366 |
date = image['date']
|
@@ -389,13 +425,15 @@ class DatasetMeta(data.Dataset):
|
|
389 |
transform=None,
|
390 |
train=False,
|
391 |
aux_info=False,
|
392 |
-
dataset='
|
393 |
class_ratio=1.0,
|
394 |
per_sample=1.0):
|
395 |
self.aux_info = aux_info
|
396 |
self.dataset = dataset
|
397 |
if dataset in ['inaturelist2021','inaturelist2021_mini']:
|
398 |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
|
|
|
|
399 |
elif dataset in ['inaturelist2017','inaturelist2018']:
|
400 |
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
|
401 |
elif dataset == 'cub-200':
|
@@ -427,7 +465,12 @@ class DatasetMeta(data.Dataset):
|
|
427 |
path, target,aux_info = self.samples[index]
|
428 |
else:
|
429 |
path, target = self.samples[index]
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
431 |
if self.transform is not None:
|
432 |
img = self.transform(img)
|
433 |
if self.aux_info:
|
|
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
12 |
import random
|
13 |
+
from tqdm import tqdm
|
14 |
random.seed(2021)
|
15 |
from PIL import Image
|
16 |
from scipy import io as scio
|
|
|
336 |
else:
|
337 |
images_and_targets.append((file_path,target))
|
338 |
return images_and_targets,class_to_idx,images_info
|
339 |
+
def find_images_and_targets(root,istrain=False,aux_info=False, integrity_check=False):
|
340 |
if os.path.exists(os.path.join(root,'train.json')):
|
341 |
with open(os.path.join(root,'train.json'),'r') as f:
|
342 |
train_class_info = json.load(f)
|
|
|
344 |
with open(os.path.join(root,'train_mini.json'),'r') as f:
|
345 |
train_class_info = json.load(f)
|
346 |
else:
|
347 |
+
raise ValueError(f'{root}/train.json or {root}/train_mini.json doesn\'t exist')
|
348 |
+
|
349 |
with open(os.path.join(root,'val.json'),'r') as f:
|
350 |
val_class_info = json.load(f)
|
351 |
+
|
352 |
+
categories = [x['name'].strip().lower() for x in val_class_info['categories']]
|
353 |
+
class_to_idx = {c: idx for idx, c in enumerate(categories)}
|
354 |
id2label = dict()
|
355 |
for categorie in train_class_info['categories']:
|
356 |
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
357 |
class_info = train_class_info if istrain else val_class_info
|
358 |
+
image_subdir = "train" if istrain else "val"
|
359 |
images_and_targets = []
|
360 |
images_info = []
|
361 |
if aux_info:
|
362 |
temporal_info = []
|
363 |
spatial_info = []
|
364 |
|
365 |
+
ann2im = {}
|
366 |
+
for ann in class_info['annotations']:
|
367 |
+
ann2im[ann['id']] = ann['image_id']
|
368 |
+
|
369 |
+
ims = {}
|
370 |
+
for image in class_info['images']:
|
371 |
+
ims[image['id']] = image
|
372 |
+
|
373 |
+
print("Found", len(train_class_info['categories']))
|
374 |
+
print("Loading images and targets, checking image integrity")
|
375 |
+
|
376 |
+
for annotation in tqdm(class_info['annotations']):
|
377 |
+
|
378 |
+
image = ims[annotation['image_id']]
|
379 |
+
dir = train_class_info['categories'][annotation['category_id']]['image_dir_name']
|
380 |
+
|
381 |
+
file_path = os.path.join(root,image_subdir,dir,image['file_name'])
|
382 |
+
|
383 |
+
if not os.path.exists(file_path):
|
384 |
+
|
385 |
+
continue
|
386 |
+
|
387 |
+
print(f"Download {file_path}")
|
388 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
389 |
+
import requests
|
390 |
+
with open(file_path, 'wb') as fp:
|
391 |
+
fp.write(requests.get(image['inaturalist_url']).content)
|
392 |
+
|
393 |
+
if integrity_check:
|
394 |
+
try:
|
395 |
+
_ = np.array(Image.open(file_path))
|
396 |
+
except:
|
397 |
+
print(f"Failed to open {file_path}")
|
398 |
+
continue
|
399 |
+
|
400 |
id_name = id2label[int(annotation['category_id'])]
|
401 |
target = class_to_idx[id_name]
|
402 |
date = image['date']
|
|
|
425 |
transform=None,
|
426 |
train=False,
|
427 |
aux_info=False,
|
428 |
+
dataset='coco_generic',
|
429 |
class_ratio=1.0,
|
430 |
per_sample=1.0):
|
431 |
self.aux_info = aux_info
|
432 |
self.dataset = dataset
|
433 |
if dataset in ['inaturelist2021','inaturelist2021_mini']:
|
434 |
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
435 |
+
elif dataset in ['coco_generic']:
|
436 |
+
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
437 |
elif dataset in ['inaturelist2017','inaturelist2018']:
|
438 |
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
|
439 |
elif dataset == 'cub-200':
|
|
|
465 |
path, target,aux_info = self.samples[index]
|
466 |
else:
|
467 |
path, target = self.samples[index]
|
468 |
+
|
469 |
+
try:
|
470 |
+
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
471 |
+
except:
|
472 |
+
img = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8))
|
473 |
+
|
474 |
if self.transform is not None:
|
475 |
img = self.transform(img)
|
476 |
if self.aux_info:
|
inference.py
CHANGED
@@ -7,6 +7,10 @@ from torch.autograd import Variable
|
|
7 |
from torchvision.transforms import transforms
|
8 |
import numpy as np
|
9 |
import argparse
|
|
|
|
|
|
|
|
|
10 |
|
11 |
try:
|
12 |
from apex import amp
|
@@ -34,24 +38,32 @@ def read_class_names(file_path):
|
|
34 |
class_list = []
|
35 |
|
36 |
for l in lines:
|
37 |
-
line = l.strip()
|
38 |
# class_list.append(line[0])
|
39 |
-
class_list.append(line
|
40 |
|
41 |
classes = tuple(class_list)
|
42 |
return classes
|
43 |
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
|
|
49 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
50 |
self.model = AutoModel.from_pretrained("bert-base-uncased")
|
51 |
|
52 |
-
def generate(self):
|
53 |
text_list = []
|
54 |
-
with open(
|
55 |
for line in f_text:
|
56 |
line = line.encode(encoding='UTF-8', errors='strict')
|
57 |
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
|
@@ -69,57 +81,122 @@ class GenerateEmbedding:
|
|
69 |
|
70 |
|
71 |
class Inference:
|
72 |
-
def __init__(self, config_path, model_path):
|
|
|
73 |
self.config_path = config_path
|
74 |
self.model_path = model_path
|
75 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
76 |
-
|
77 |
-
self.classes = read_class_names(r"D:\dataset\CUB_200_2011\CUB_200_2011\classes_custom.txt")
|
78 |
|
79 |
self.config = model_config(self.config_path)
|
|
|
80 |
self.model = build_model(self.config)
|
81 |
self.checkpoint = torch.load(self.model_path, map_location='cpu')
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
83 |
self.model.eval()
|
84 |
-
self.model.
|
|
|
|
|
85 |
|
86 |
self.transform_img = transforms.Compose([
|
87 |
-
transforms.Resize((
|
88 |
transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
89 |
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
90 |
])
|
91 |
|
92 |
-
def infer(self, img_path, meta_data_path):
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
img = self.transform_img(img)
|
97 |
img.unsqueeze_(0)
|
98 |
-
img = img.
|
99 |
img = Variable(img).to(self.device)
|
100 |
out = self.model(img, meta)
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
def parse_option():
|
109 |
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
110 |
-
parser.add_argument('--cfg', type=str,
|
111 |
# easy config modification
|
112 |
-
parser.add_argument('--model-path',
|
113 |
-
parser.add_argument('--img-path',
|
114 |
-
parser.add_argument('--
|
|
|
|
|
115 |
args = parser.parse_args()
|
116 |
return args
|
117 |
|
118 |
|
119 |
if __name__ == '__main__':
|
120 |
args = parse_option()
|
121 |
-
|
122 |
-
model_path=args.model_path
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
# Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
|
|
|
7 |
from torchvision.transforms import transforms
|
8 |
import numpy as np
|
9 |
import argparse
|
10 |
+
from pycocotools.coco import COCO
|
11 |
+
import requests
|
12 |
+
import os
|
13 |
+
from tqdm.auto import tqdm
|
14 |
|
15 |
try:
|
16 |
from apex import amp
|
|
|
38 |
class_list = []
|
39 |
|
40 |
for l in lines:
|
41 |
+
line = l.strip()
|
42 |
# class_list.append(line[0])
|
43 |
+
class_list.append(line)
|
44 |
|
45 |
classes = tuple(class_list)
|
46 |
return classes
|
47 |
|
48 |
|
49 |
+
def read_class_names_coco(file_path):
|
50 |
+
dataset = COCO(file_path)
|
51 |
+
classes = [dataset.cats[k]['name'] for k in sorted(dataset.cats.keys())]
|
52 |
+
|
53 |
+
with open("names.txt", 'w') as fp:
|
54 |
+
for c in classes:
|
55 |
+
fp.write(f"{c}\n")
|
56 |
|
57 |
+
return classes
|
58 |
+
|
59 |
+
class GenerateEmbedding:
|
60 |
+
def __init__(self):
|
61 |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
62 |
self.model = AutoModel.from_pretrained("bert-base-uncased")
|
63 |
|
64 |
+
def generate(self, text_file):
|
65 |
text_list = []
|
66 |
+
with open(text_file, 'r') as f_text:
|
67 |
for line in f_text:
|
68 |
line = line.encode(encoding='UTF-8', errors='strict')
|
69 |
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
|
|
|
81 |
|
82 |
|
83 |
class Inference:
|
84 |
+
def __init__(self, config_path, model_path, names_path):
|
85 |
+
|
86 |
self.config_path = config_path
|
87 |
self.model_path = model_path
|
88 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
89 |
+
self.classes = read_class_names(names_path)
|
|
|
90 |
|
91 |
self.config = model_config(self.config_path)
|
92 |
+
|
93 |
self.model = build_model(self.config)
|
94 |
self.checkpoint = torch.load(self.model_path, map_location='cpu')
|
95 |
+
|
96 |
+
if 'model' in self.checkpoint:
|
97 |
+
self.model.load_state_dict(self.checkpoint['model'], strict=False)
|
98 |
+
else:
|
99 |
+
self.model.load_state_dict(self.checkpoint, strict=False)
|
100 |
+
|
101 |
self.model.eval()
|
102 |
+
self.model.to(self.device)
|
103 |
+
self.topk = 10
|
104 |
+
self.embedding_gen = GenerateEmbedding()
|
105 |
|
106 |
self.transform_img = transforms.Compose([
|
107 |
+
transforms.Resize((self.config.DATA.IMG_SIZE, self.config.DATA.IMG_SIZE), interpolation=Image.BILINEAR),
|
108 |
transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
109 |
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
110 |
])
|
111 |
|
112 |
+
def infer(self, img_path, meta_data_path, topk=None):
|
113 |
+
|
114 |
+
if isinstance(img_path, str):
|
115 |
+
if img_path.startswith("http"):
|
116 |
+
img = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
|
117 |
+
else:
|
118 |
+
img = Image.open(img_path).convert('RGB')
|
119 |
+
else:
|
120 |
+
img = img_path
|
121 |
+
|
122 |
+
"""
|
123 |
+
_, _, meta = self.embedding_gen(meta_data_path)
|
124 |
+
meta = meta.to(self.device)
|
125 |
+
"""
|
126 |
+
meta = None
|
127 |
+
|
128 |
img = self.transform_img(img)
|
129 |
img.unsqueeze_(0)
|
130 |
+
img = img.to(self.device)
|
131 |
img = Variable(img).to(self.device)
|
132 |
out = self.model(img, meta)
|
133 |
|
134 |
+
f = torch.nn.Softmax(dim=1)
|
135 |
+
y_pred = f(out)
|
136 |
+
indices = reversed(torch.argsort(y_pred, dim=1).squeeze().tolist())
|
137 |
+
|
138 |
+
if topk is not None:
|
139 |
+
predict = [{self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices[:topk]}]
|
140 |
+
return predict
|
141 |
+
else:
|
142 |
+
return {self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices}
|
143 |
|
144 |
|
145 |
def parse_option():
|
146 |
parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
|
147 |
+
parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
|
148 |
# easy config modification
|
149 |
+
parser.add_argument('--model-path', type=str, help="path to model data", default="ckpt_epoch_12.pth")
|
150 |
+
parser.add_argument('--img-path', type=str, help='path to image')
|
151 |
+
parser.add_argument('--img-folder', type=str, help='path to image')
|
152 |
+
parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
|
153 |
+
parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
|
154 |
args = parser.parse_args()
|
155 |
return args
|
156 |
|
157 |
|
158 |
if __name__ == '__main__':
|
159 |
args = parse_option()
|
160 |
+
model = Inference(config_path=args.cfg,
|
161 |
+
model_path=args.model_path,
|
162 |
+
names_path=args.names_path)
|
163 |
+
|
164 |
+
from glob import glob
|
165 |
+
glob_imgs = glob(os.path.join(args.img_folder, "*.jpg"))
|
166 |
+
out_dir = f"results_{os.path.splitext(os.path.basename(args.model_path))[0]}"
|
167 |
+
os.makedirs(out_dir, exist_ok=True)
|
168 |
+
|
169 |
+
for img in tqdm(glob_imgs):
|
170 |
+
try:
|
171 |
+
res = model.infer(img_path=img, meta_data_path=args.meta_path)
|
172 |
+
except KeyboardInterrupt:
|
173 |
+
break
|
174 |
+
except Exception as e:
|
175 |
+
print(e)
|
176 |
+
continue
|
177 |
+
|
178 |
+
out = {}
|
179 |
+
out['preds'] = res
|
180 |
+
|
181 |
+
"""
|
182 |
+
# Out is a list of (class, score). Return true/false if the top1 class is correct
|
183 |
+
out['top1_correct'] = '_'.join(res[0][1].split(' ')).lower() in os.path.basename(img).lower()
|
184 |
+
|
185 |
+
out['top5_correct'] = False
|
186 |
+
print(os.path.basename(img).lower())
|
187 |
+
for i in range(5):
|
188 |
+
out['top5_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
|
189 |
+
print('_'.join(res[i][1].split(' ')).lower())
|
190 |
+
|
191 |
+
out['top10_correct'] = False
|
192 |
+
for i in range(10):
|
193 |
+
out['top10_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()
|
194 |
+
"""
|
195 |
+
|
196 |
+
# output json with inference results, use image basename
|
197 |
+
# as filename
|
198 |
+
import json
|
199 |
+
with open(os.path.join(out_dir, os.path.splitext(os.path.basename(img))[0]+".json"), 'w') as fp:
|
200 |
+
json.dump(out, fp, indent=1)
|
201 |
|
202 |
# Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'
|
lr_scheduler.py
CHANGED
@@ -21,7 +21,6 @@ def build_scheduler(config, optimizer, n_iter_per_epoch):
|
|
21 |
lr_scheduler = CosineLRScheduler(
|
22 |
optimizer,
|
23 |
t_initial=num_steps,
|
24 |
-
t_mul=1.,
|
25 |
lr_min=config.TRAIN.MIN_LR,
|
26 |
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
27 |
warmup_t=warmup_steps,
|
|
|
21 |
lr_scheduler = CosineLRScheduler(
|
22 |
optimizer,
|
23 |
t_initial=num_steps,
|
|
|
24 |
lr_min=config.TRAIN.MIN_LR,
|
25 |
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
26 |
warmup_t=warmup_steps,
|
main.py
CHANGED
@@ -2,7 +2,9 @@ import os
|
|
2 |
import time
|
3 |
import argparse
|
4 |
import datetime
|
|
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
import torch
|
8 |
import torch.backends.cudnn as cudnn
|
@@ -18,13 +20,23 @@ from lr_scheduler import build_scheduler
|
|
18 |
from optimizer import build_optimizer
|
19 |
from logger import create_logger
|
20 |
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
try:
|
23 |
# noinspection PyUnresolvedReferences
|
24 |
from apex import amp
|
25 |
except ImportError:
|
26 |
amp = None
|
27 |
|
|
|
|
|
28 |
|
29 |
def parse_option():
|
30 |
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
|
@@ -77,20 +89,19 @@ def parse_option():
|
|
77 |
help='dataset')
|
78 |
parser.add_argument('--lr-scheduler-name', type=str,
|
79 |
help='lr scheduler name,cosin linear,step')
|
80 |
-
|
81 |
parser.add_argument('--pretrain', type=str,
|
82 |
help='pretrain')
|
83 |
|
84 |
-
parser.add_argument('--
|
85 |
-
|
86 |
|
87 |
-
# distributed training
|
88 |
-
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
89 |
-
|
90 |
args, unparsed = parser.parse_known_args()
|
91 |
|
92 |
config = get_config(args)
|
93 |
|
|
|
|
|
|
|
94 |
return args, config
|
95 |
|
96 |
|
@@ -98,14 +109,20 @@ def main(config):
|
|
98 |
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
99 |
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
100 |
model = build_model(config)
|
|
|
|
|
101 |
model.cuda()
|
102 |
logger.info(str(model))
|
103 |
|
104 |
optimizer = build_optimizer(config, model)
|
105 |
if config.AMP_OPT_LEVEL != "O0":
|
106 |
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
107 |
-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
|
108 |
model_without_ddp = model.module
|
|
|
|
|
|
|
|
|
109 |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
110 |
logger.info(f"number of params: {n_parameters}")
|
111 |
if hasattr(model_without_ddp, 'flops'):
|
@@ -123,10 +140,15 @@ def main(config):
|
|
123 |
max_accuracy = 0.0
|
124 |
if config.MODEL.PRETRAINED:
|
125 |
load_pretained(config,model_without_ddp,logger)
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
if config.TRAIN.AUTO_RESUME:
|
132 |
resume_file = auto_resume_helper(config.OUTPUT)
|
@@ -143,11 +165,11 @@ def main(config):
|
|
143 |
if config.MODEL.RESUME:
|
144 |
logger.info(f"**********normal test***********")
|
145 |
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
146 |
-
acc1, acc5, loss = validate(config, data_loader_val, model)
|
147 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
148 |
if config.DATA.ADD_META:
|
149 |
logger.info(f"**********mask meta test***********")
|
150 |
-
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
151 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
152 |
if config.EVAL_MODE:
|
153 |
return
|
@@ -165,18 +187,37 @@ def main(config):
|
|
165 |
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
166 |
|
167 |
logger.info(f"**********normal test***********")
|
168 |
-
acc1, acc5, loss = validate(config, data_loader_val, model)
|
|
|
|
|
|
|
|
|
169 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
170 |
max_accuracy = max(max_accuracy, acc1)
|
171 |
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
172 |
if config.DATA.ADD_META:
|
173 |
logger.info(f"**********mask meta test***********")
|
174 |
-
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
|
|
|
|
|
|
|
|
175 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
176 |
# data_loader_train.terminate()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
total_time = time.time() - start_time
|
178 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
179 |
logger.info('Training time {}'.format(total_time_str))
|
|
|
180 |
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
|
181 |
model.train()
|
182 |
if hasattr(model.module,'cur_epoch'):
|
@@ -261,6 +302,8 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
|
|
261 |
lr = optimizer.param_groups[0]['lr']
|
262 |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
263 |
etas = batch_time.avg * (num_steps - idx)
|
|
|
|
|
264 |
logger.info(
|
265 |
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
266 |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
@@ -271,7 +314,7 @@ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer,
|
|
271 |
epoch_time = time.time() - start
|
272 |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
273 |
@torch.no_grad()
|
274 |
-
def validate(config, data_loader, model, mask_meta=False):
|
275 |
criterion = torch.nn.CrossEntropyLoss()
|
276 |
model.eval()
|
277 |
|
@@ -280,8 +323,16 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
280 |
acc1_meter = AverageMeter()
|
281 |
acc5_meter = AverageMeter()
|
282 |
|
|
|
|
|
283 |
end = time.time()
|
|
|
284 |
for idx, data in enumerate(data_loader):
|
|
|
|
|
|
|
|
|
|
|
285 |
if config.DATA.ADD_META:
|
286 |
images,target,meta = data
|
287 |
meta = [m.float() for m in meta]
|
@@ -314,6 +365,9 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
314 |
acc1_meter.update(acc1.item(), target.size(0))
|
315 |
acc5_meter.update(acc5.item(), target.size(0))
|
316 |
|
|
|
|
|
|
|
317 |
# measure elapsed time
|
318 |
batch_time.update(time.time() - end)
|
319 |
end = time.time()
|
@@ -328,7 +382,8 @@ def validate(config, data_loader, model, mask_meta=False):
|
|
328 |
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
329 |
f'Mem {memory_used:.0f}MB')
|
330 |
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
331 |
-
|
|
|
332 |
|
333 |
|
334 |
@torch.no_grad()
|
@@ -364,7 +419,7 @@ if __name__ == '__main__':
|
|
364 |
else:
|
365 |
rank = -1
|
366 |
world_size = -1
|
367 |
-
torch.cuda.set_device(config.LOCAL_RANK)
|
368 |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
369 |
torch.distributed.barrier()
|
370 |
|
|
|
2 |
import time
|
3 |
import argparse
|
4 |
import datetime
|
5 |
+
import json
|
6 |
import numpy as np
|
7 |
+
from collections import defaultdict
|
8 |
|
9 |
import torch
|
10 |
import torch.backends.cudnn as cudnn
|
|
|
20 |
from optimizer import build_optimizer
|
21 |
from logger import create_logger
|
22 |
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
|
23 |
+
|
24 |
+
have_wandb = False
|
25 |
+
try:
|
26 |
+
import wandb
|
27 |
+
have_wandb = True
|
28 |
+
except:
|
29 |
+
pass
|
30 |
+
|
31 |
+
# TODO use torch.amp
|
32 |
try:
|
33 |
# noinspection PyUnresolvedReferences
|
34 |
from apex import amp
|
35 |
except ImportError:
|
36 |
amp = None
|
37 |
|
38 |
+
import logging
|
39 |
+
logging.basicConfig(level=logging.INFO)
|
40 |
|
41 |
def parse_option():
|
42 |
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
|
|
|
89 |
help='dataset')
|
90 |
parser.add_argument('--lr-scheduler-name', type=str,
|
91 |
help='lr scheduler name,cosin linear,step')
|
92 |
+
|
93 |
parser.add_argument('--pretrain', type=str,
|
94 |
help='pretrain')
|
95 |
|
96 |
+
parser.add_argument('--wandb_job', type=str)
|
|
|
97 |
|
|
|
|
|
|
|
98 |
args, unparsed = parser.parse_known_args()
|
99 |
|
100 |
config = get_config(args)
|
101 |
|
102 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
103 |
+
wandb.init(name = args.wandb_job, config=args)
|
104 |
+
|
105 |
return args, config
|
106 |
|
107 |
|
|
|
109 |
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
110 |
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
111 |
model = build_model(config)
|
112 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
113 |
+
wandb.config['model_config'] = config
|
114 |
model.cuda()
|
115 |
logger.info(str(model))
|
116 |
|
117 |
optimizer = build_optimizer(config, model)
|
118 |
if config.AMP_OPT_LEVEL != "O0":
|
119 |
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
120 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(config.LOCAL_RANK)], broadcast_buffers=False)
|
121 |
model_without_ddp = model.module
|
122 |
+
|
123 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
124 |
+
wandb.watch(model, log_freq=100)
|
125 |
+
|
126 |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
127 |
logger.info(f"number of params: {n_parameters}")
|
128 |
if hasattr(model_without_ddp, 'flops'):
|
|
|
140 |
max_accuracy = 0.0
|
141 |
if config.MODEL.PRETRAINED:
|
142 |
load_pretained(config,model_without_ddp,logger)
|
143 |
+
|
144 |
+
# Run initial validation
|
145 |
+
logger.info("Start validation (on init)")
|
146 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model, limit=10)
|
147 |
+
|
148 |
+
with open(os.path.join(config.OUTPUT, f'val_init.json'), 'w') as fp:
|
149 |
+
json.dump(stats, fp, indent=1)
|
150 |
+
|
151 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
152 |
|
153 |
if config.TRAIN.AUTO_RESUME:
|
154 |
resume_file = auto_resume_helper(config.OUTPUT)
|
|
|
165 |
if config.MODEL.RESUME:
|
166 |
logger.info(f"**********normal test***********")
|
167 |
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
168 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model)
|
169 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
170 |
if config.DATA.ADD_META:
|
171 |
logger.info(f"**********mask meta test***********")
|
172 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
|
173 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
174 |
if config.EVAL_MODE:
|
175 |
return
|
|
|
187 |
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
188 |
|
189 |
logger.info(f"**********normal test***********")
|
190 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model)
|
191 |
+
|
192 |
+
with open(os.path.join(config.OUTPUT, f'val_{epoch}.json'), 'w') as fp:
|
193 |
+
json.dump(stats, fp, indent=1)
|
194 |
+
|
195 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
196 |
max_accuracy = max(max_accuracy, acc1)
|
197 |
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
198 |
if config.DATA.ADD_META:
|
199 |
logger.info(f"**********mask meta test***********")
|
200 |
+
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True)
|
201 |
+
|
202 |
+
with open(os.path.join(config.OUTPUT, f'val_{epoch}_meta.json'), 'w') as fp:
|
203 |
+
json.dump(stats, fp, indent=1)
|
204 |
+
|
205 |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
206 |
# data_loader_train.terminate()
|
207 |
+
|
208 |
+
if have_wandb and int(config.LOCAL_RANK) == 0:
|
209 |
+
wandb.run.summary["acc_top_1"] = acc1
|
210 |
+
wandb.run.summary["acc_top_5"] = acc5
|
211 |
+
wandb.run.summary["val_loss"] = loss
|
212 |
+
|
213 |
+
wandb.log({'val/acc1': acc1})
|
214 |
+
wandb.log({'val/acc5': acc5})
|
215 |
+
wandb.log({'val/loss': acc5})
|
216 |
+
|
217 |
total_time = time.time() - start_time
|
218 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
219 |
logger.info('Training time {}'.format(total_time_str))
|
220 |
+
|
221 |
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
|
222 |
model.train()
|
223 |
if hasattr(model.module,'cur_epoch'):
|
|
|
302 |
lr = optimizer.param_groups[0]['lr']
|
303 |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
304 |
etas = batch_time.avg * (num_steps - idx)
|
305 |
+
if have_wandb and int(config.LOCAL_RANK) == 0 and idx % 100 == 0:
|
306 |
+
wandb.log({"train/loss": loss_meter.val})
|
307 |
logger.info(
|
308 |
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
309 |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
|
|
314 |
epoch_time = time.time() - start
|
315 |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
316 |
@torch.no_grad()
|
317 |
+
def validate(config, data_loader, model, mask_meta=False, limit=None):
|
318 |
criterion = torch.nn.CrossEntropyLoss()
|
319 |
model.eval()
|
320 |
|
|
|
323 |
acc1_meter = AverageMeter()
|
324 |
acc5_meter = AverageMeter()
|
325 |
|
326 |
+
stats = defaultdict(list)
|
327 |
+
|
328 |
end = time.time()
|
329 |
+
|
330 |
for idx, data in enumerate(data_loader):
|
331 |
+
|
332 |
+
if limit:
|
333 |
+
if idx > limit:
|
334 |
+
break
|
335 |
+
|
336 |
if config.DATA.ADD_META:
|
337 |
images,target,meta = data
|
338 |
meta = [m.float() for m in meta]
|
|
|
365 |
acc1_meter.update(acc1.item(), target.size(0))
|
366 |
acc5_meter.update(acc5.item(), target.size(0))
|
367 |
|
368 |
+
for t in target:
|
369 |
+
stats[int(t.item())].append((acc1.item(), acc5.item(), loss.item()))
|
370 |
+
|
371 |
# measure elapsed time
|
372 |
batch_time.update(time.time() - end)
|
373 |
end = time.time()
|
|
|
382 |
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
383 |
f'Mem {memory_used:.0f}MB')
|
384 |
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
385 |
+
|
386 |
+
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg, stats
|
387 |
|
388 |
|
389 |
@torch.no_grad()
|
|
|
419 |
else:
|
420 |
rank = -1
|
421 |
world_size = -1
|
422 |
+
torch.cuda.set_device(f'cuda:{config.LOCAL_RANK}')
|
423 |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
424 |
torch.distributed.barrier()
|
425 |
|
models/MetaFG.py
CHANGED
@@ -54,7 +54,8 @@ class MetaFG(nn.Module):
|
|
54 |
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
55 |
meta_dims=[],
|
56 |
only_last_cls=False,
|
57 |
-
use_checkpoint=False
|
|
|
58 |
super().__init__()
|
59 |
self.only_last_cls = only_last_cls
|
60 |
self.img_size = img_size
|
|
|
54 |
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
55 |
meta_dims=[],
|
56 |
only_last_cls=False,
|
57 |
+
use_checkpoint=False,
|
58 |
+
**kwargs):
|
59 |
super().__init__()
|
60 |
self.only_last_cls = only_last_cls
|
61 |
self.img_size = img_size
|