Fabrice-TIERCELIN commited on
Commit
8e10f60
·
verified ·
1 Parent(s): 91e2303

Delete clipseg/score.py

Browse files
Files changed (1) hide show
  1. clipseg/score.py +0 -453
clipseg/score.py DELETED
@@ -1,453 +0,0 @@
1
- from torch.functional import Tensor
2
-
3
- import torch
4
- import inspect
5
- import json
6
- import yaml
7
- import time
8
- import sys
9
-
10
- from general_utils import log
11
-
12
- import numpy as np
13
- from os.path import expanduser, join, isfile, realpath
14
-
15
- from torch.utils.data import DataLoader
16
-
17
- from metrics import FixedIntervalMetrics
18
-
19
- from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
20
-
21
-
22
- DATASET_CACHE = dict()
23
-
24
- def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
25
-
26
- config = json.load(open(join('logs', checkpoint_id, 'config.json')))
27
-
28
- if model_args != 'from_config' and type(model_args) != dict:
29
- raise ValueError('model_args must either be "from_config" or a dictionary of values')
30
-
31
- model_cls = get_attribute(config['model'])
32
-
33
- # load model
34
- if model_args == 'from_config':
35
- _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
36
-
37
- model = model_cls(**model_args)
38
-
39
- if weights_file is None:
40
- weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
41
- else:
42
- weights_file = realpath(join('logs', checkpoint_id, weights_file))
43
-
44
- if isfile(weights_file) and not ignore_weights:
45
- weights = torch.load(weights_file)
46
- for _, w in weights.items():
47
- assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
48
- model.load_state_dict(weights, strict=strict)
49
- else:
50
- if not ignore_weights:
51
- raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
52
-
53
- if with_config:
54
- return model, config
55
-
56
- return model
57
-
58
-
59
- def compute_shift2(model, datasets, seed=123, repetitions=1):
60
- """ computes shift """
61
-
62
- model.eval()
63
- model.cuda()
64
-
65
- import random
66
- random.seed(seed)
67
-
68
- preds, gts = [], []
69
- for i_dataset, dataset in enumerate(datasets):
70
-
71
- loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
72
-
73
- max_iterations = int(repetitions * len(dataset.dataset.data_list))
74
-
75
- with torch.no_grad():
76
-
77
- i, losses = 0, []
78
- for i_all, (data_x, data_y) in enumerate(loader):
79
-
80
- data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
81
- data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
82
-
83
- pred, = model(data_x[0], data_x[1], data_x[2])
84
- preds += [pred.detach()]
85
- gts += [data_y]
86
-
87
- i += 1
88
- if max_iterations and i >= max_iterations:
89
- break
90
-
91
- from metrics import FixedIntervalMetrics
92
- n_values = 51
93
- thresholds = np.linspace(0, 1, n_values)[1:-1]
94
- metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
95
-
96
- for p, y in zip(preds, gts):
97
- metric.add(p.unsqueeze(1), y)
98
-
99
- best_idx = np.argmax(metric.value()['fgiou_scores'])
100
- best_thresh = thresholds[best_idx]
101
-
102
- return best_thresh
103
-
104
-
105
- def get_cached_pascal_pfe(split, config):
106
- from datasets.pfe_dataset import PFEPascalWrapper
107
- try:
108
- dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
109
- except KeyError:
110
- dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
111
- DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
112
- return dataset
113
-
114
-
115
-
116
-
117
- def main():
118
- config, train_checkpoint_id = score_config_from_cli_args()
119
-
120
- metrics = score(config, train_checkpoint_id, None)
121
-
122
- for dataset in metrics.keys():
123
- for k in metrics[dataset]:
124
- if type(metrics[dataset][k]) in {float, int}:
125
- print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
126
-
127
-
128
- def score(config, train_checkpoint_id, train_config):
129
-
130
- config = AttributeDict(config)
131
-
132
- print(config)
133
-
134
- # use training dataset and loss
135
- train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
136
-
137
- cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
138
-
139
-
140
- model_cls = get_attribute(train_config['model'])
141
-
142
- _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
143
-
144
- model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
145
-
146
- strict_models = {'ConditionBase4', 'PFENetWrapper'}
147
- model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
148
- weights_file=f'weights{cp_str}.pth', )
149
-
150
-
151
- model.eval()
152
- model.cuda()
153
-
154
- metric_args = dict()
155
-
156
- if 'threshold' in config:
157
- if config.metric.split('.')[-1] == 'SkLearnMetrics':
158
- metric_args['threshold'] = config.threshold
159
-
160
- if 'resize_to' in config:
161
- metric_args['resize_to'] = config.resize_to
162
-
163
- if 'sigmoid' in config:
164
- metric_args['sigmoid'] = config.sigmoid
165
-
166
- if 'custom_threshold' in config:
167
- metric_args['custom_threshold'] = config.custom_threshold
168
-
169
- if config.test_dataset == 'pascal':
170
-
171
- loss_fn = get_attribute(train_config.loss)
172
- # assume that if no split is specified in train_config, test on all splits,
173
-
174
- if 'splits' in config:
175
- splits = config.splits
176
- else:
177
- if 'split' in train_config and type(train_config.split) == int:
178
- # unless train_config has a split set, in that case assume train mode in training
179
- splits = [train_config.split]
180
- assert train_config.mode == 'train'
181
- else:
182
- splits = [0,1,2,3]
183
-
184
- log.info('Test on these splits', splits)
185
-
186
- scores = dict()
187
- for split in splits:
188
-
189
- shift = config.shift if 'shift' in config else 0
190
-
191
- # automatic shift
192
- if shift == 'auto':
193
- shift_compute_t = time.time()
194
- shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
195
- log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
196
-
197
- dataset = get_cached_pascal_pfe(split, config)
198
-
199
- eval_start_t = time.time()
200
-
201
- loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
202
-
203
- assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
204
-
205
- metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
206
-
207
- with torch.no_grad():
208
-
209
- i, losses = 0, []
210
- for i_all, (data_x, data_y) in enumerate(loader):
211
-
212
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
213
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
214
-
215
- if config.mask == 'separate': # for old CondBase model
216
- pred, = model(data_x[0], data_x[1], data_x[2])
217
- else:
218
- # assert config.mask in {'text', 'highlight'}
219
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
220
-
221
- # loss = loss_fn(pred, data_y[0])
222
- metric.add(pred.unsqueeze(1) + shift, data_y)
223
-
224
- # losses += [float(loss)]
225
-
226
- i += 1
227
- if config.max_iterations and i >= config.max_iterations:
228
- break
229
-
230
- #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
231
-
232
- log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
233
-
234
- print(metric.value()['mean_iou_scores'])
235
-
236
- scores[split] = metric.scores()
237
-
238
- log.info(f'Completed split {split}')
239
-
240
- key_prefix = config['name'] if 'name' in config else 'pas'
241
-
242
- all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
243
-
244
- valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
245
-
246
- return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
247
-
248
-
249
- if config.test_dataset == 'coco':
250
- from datasets.coco_wrapper import COCOWrapper
251
-
252
- coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
253
- with_class_label=True)
254
-
255
- log.info('Dataset length', len(coco_dataset))
256
- loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
257
-
258
- metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
259
-
260
- shift = config.shift if 'shift' in config else 0
261
-
262
- with torch.no_grad():
263
-
264
- i, losses = 0, []
265
- for i_all, (data_x, data_y) in enumerate(loader):
266
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
267
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
268
-
269
- if config.mask == 'separate': # for old CondBase model
270
- pred, = model(data_x[0], data_x[1], data_x[2])
271
- else:
272
- # assert config.mask in {'text', 'highlight'}
273
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
274
-
275
- metric.add([pred + shift], data_y)
276
-
277
- i += 1
278
- if config.max_iterations and i >= config.max_iterations:
279
- break
280
-
281
- key_prefix = config['name'] if 'name' in config else 'coco'
282
- return {key_prefix: metric.scores()}
283
- #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
284
-
285
-
286
- if config.test_dataset == 'phrasecut':
287
- from datasets.phrasecut import PhraseCut
288
-
289
- only_visual = config.only_visual is not None and config.only_visual
290
- with_visual = config.with_visual is not None and config.with_visual
291
-
292
- dataset = PhraseCut('test',
293
- image_size=train_config.image_size,
294
- mask=config.mask,
295
- with_visual=with_visual, only_visual=only_visual, aug_crop=False,
296
- aug_color=False)
297
-
298
- loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
299
- metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
300
-
301
- shift = config.shift if 'shift' in config else 0
302
-
303
-
304
- with torch.no_grad():
305
-
306
- i, losses = 0, []
307
- for i_all, (data_x, data_y) in enumerate(loader):
308
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
309
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
310
-
311
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
312
- metric.add([pred + shift], data_y)
313
-
314
- i += 1
315
- if config.max_iterations and i >= config.max_iterations:
316
- break
317
-
318
- key_prefix = config['name'] if 'name' in config else 'phrasecut'
319
- return {key_prefix: metric.scores()}
320
- #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
321
-
322
- if config.test_dataset == 'pascal_zs':
323
- from third_party.JoEm.model.metric import Evaluator
324
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
325
- from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
326
-
327
- from models.clipseg import CLIPSegMultiLabel
328
-
329
- n_unseen = train_config.remove_classes[1]
330
-
331
- pz = PascalZeroShot('val', n_unseen, image_size=352)
332
- m = CLIPSegMultiLabel(model=train_config.name).cuda()
333
- m.eval();
334
-
335
- print(len(pz), n_unseen)
336
- print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
337
-
338
- print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
339
- print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
340
-
341
- loader = DataLoader(pz, batch_size=8)
342
- evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
343
-
344
- for i, (data_x, data_y) in enumerate(loader):
345
- pred = m(data_x[0].cuda())
346
- evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
347
-
348
- if config.max_iter is not None and i > config.max_iter:
349
- break
350
-
351
- scores = evaluator.Mean_Intersection_over_Union()
352
- key_prefix = config['name'] if 'name' in config else 'pas_zs'
353
-
354
- return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
355
-
356
- elif config.test_dataset in {'same_as_training', 'affordance'}:
357
- loss_fn = get_attribute(train_config.loss)
358
-
359
- metric_cls = get_attribute(config.metric)
360
- metric = metric_cls(**metric_args)
361
-
362
- if config.test_dataset == 'same_as_training':
363
- dataset_cls = get_attribute(train_config.dataset)
364
- elif config.test_dataset == 'affordance':
365
- dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
366
- dataset_name = 'aff'
367
- else:
368
- dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
369
- dataset_name = 'lvis'
370
-
371
- _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
372
-
373
- dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
374
-
375
- if model.__class__.__name__ == 'PFENetWrapper':
376
- dataset_args['image_size'] = config.image_size
377
-
378
- log.info('init dataset', str(dataset_cls))
379
- dataset = dataset_cls(**dataset_args)
380
-
381
- log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
382
-
383
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
384
-
385
- # explicitly set prompts
386
- if config.prompt == 'plain':
387
- model.prompt_list = ['{}']
388
- elif config.prompt == 'fixed':
389
- model.prompt_list = ['a photo of a {}.']
390
- elif config.prompt == 'shuffle':
391
- model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
392
- elif config.prompt == 'shuffle_clip':
393
- from models.clip_prompts import imagenet_templates
394
- model.prompt_list = imagenet_templates
395
-
396
- config.assume_no_unused_keys(exceptions=['max_iterations'])
397
-
398
- t_start = time.time()
399
-
400
- with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
401
- i, losses = 0, []
402
- for data_x, data_y in data_loader:
403
-
404
- data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
405
- data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
406
-
407
- if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
408
- pred, = model(data_x[0], data_x[1], data_x[2])
409
- visual_q = None
410
- else:
411
- pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
412
-
413
- loss = loss_fn(pred, data_y[0])
414
-
415
- metric.add([pred], data_y)
416
-
417
- losses += [float(loss)]
418
-
419
- i += 1
420
- if config.max_iterations and i >= config.max_iterations:
421
- break
422
-
423
- # scores = {m: s for m, s in zip(metric.names(), metric.value())}
424
- scores = metric.scores()
425
-
426
- keys = set(scores.keys())
427
- if dataset.negative_prob > 0 and 'mIoU' in keys:
428
- keys.remove('mIoU')
429
-
430
- name_mask = dataset.mask.replace('text_label', 'txt')[:3]
431
- name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
432
-
433
- score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
434
-
435
- scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
436
- scores[score_name].update({'test_loss': np.mean(losses)})
437
-
438
- log.info(f'Evaluation took {time.time() - t_start:.1f}s')
439
-
440
- return scores
441
- else:
442
- raise ValueError('invalid test dataset')
443
-
444
-
445
-
446
-
447
-
448
-
449
-
450
-
451
-
452
- if __name__ == '__main__':
453
- main()