Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Evaluate CaR on segmentation benchmarks.""" | |
# pylint: disable=g-importing-member | |
import argparse | |
import numpy as np | |
import torch | |
from torch.utils import tensorboard | |
import torch.utils.data | |
from torch.utils.data import Subset | |
import torchvision.transforms as T | |
# pylint: disable=g-bad-import-order | |
from modeling.model.car import CaR | |
from sam.utils import build_sam_config | |
from utils.utils import Config | |
from utils.utils import load_yaml | |
from utils.utils import MetricLogger | |
from utils.utils import SmoothedValue | |
from utils.inference_pipeline import inference_car | |
from utils.merge_mask import merge_masks_simple | |
# Datasets | |
# pylint: disable=g-multiple-import | |
from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset | |
from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset | |
from data.coco import COCO_OBJECT_CLASSES, COCODataset | |
from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset | |
from data.gres import GReferDataset | |
from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset | |
from data.refcoco import ReferDataset | |
from data.voc import VOC_CLASSES, VOCDataset | |
IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512 | |
# set random seed | |
torch.manual_seed(0) | |
np.random.seed(0) | |
def get_dataset(cfg, ds_name, split, transform, data_root=None): | |
"""Get dataset.""" | |
data_args = dict(root=data_root) if data_root is not None else {} | |
if 'refcoco' in ds_name: | |
splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc' | |
ds = ReferDataset( | |
dataset=ds_name, | |
splitBy=splitby, | |
split=split, | |
image_transforms=transform, | |
target_transforms=transform, | |
eval_mode=True, | |
prompts_augment=cfg.test.prompts_augment, | |
**data_args, | |
) | |
elif ds_name == 'gres': | |
ds = GReferDataset(split=split, transform=transform, **data_args) | |
elif ds_name == 'voc': | |
ds = VOCDataset( | |
year='2012', | |
split=split, | |
transform=transform, | |
target_transform=transform, | |
**data_args, | |
) | |
elif ds_name == 'cocostuff': | |
ds = COCODataset(transform=transform, **data_args) | |
elif ds_name == 'context': | |
ds = CONTEXTDataset( | |
year='2010', transform=transform, split=split, **data_args | |
) | |
elif ds_name == 'ade': | |
ds = ADEDataset(split=split, transform=transform, **data_args) | |
elif ds_name == 'pascal_459': | |
ds = Pascal459Dataset(split=split, transform=transform, **data_args) | |
elif ds_name == 'ade_847': | |
ds = ADE847Dataset(split=split, transform=transform, **data_args) | |
else: | |
raise ValueError(f'Dataset {ds_name} not implemented') | |
return ds | |
def get_transform(): | |
transforms = [ | |
T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)), | |
T.ToTensor(), | |
] | |
return T.Compose(transforms) | |
def assign_label( | |
all_masks, | |
scores, | |
stuff_masks=None, | |
stuff_scores=None, | |
id_mapping=None, | |
stuff_id_mapping=None, | |
): | |
"""Assign labels.""" | |
label_preds = np.zeros_like(all_masks[0]).astype(np.int32) | |
if stuff_masks is not None: | |
sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy()) | |
stuff_masks = stuff_masks[sorted_idxs] | |
stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs] | |
for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores): | |
if score > 0: | |
# convert mask to boolean | |
mask = mask > 0.5 | |
# assign label | |
if stuff_id_mapping is not None: | |
label_preds[mask] = stuff_id_mapping[sorted_idx] + 1 | |
else: | |
label_preds[mask] = sorted_idx + 1 | |
sorted_idxs = np.argsort(scores.detach().cpu().numpy()) | |
all_masks = all_masks[sorted_idxs] | |
scores = scores.detach().cpu().numpy()[sorted_idxs] | |
for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores): | |
if score > 0: | |
# convert mask to boolean | |
mask = mask > 0.5 | |
# assign label | |
if id_mapping is not None: | |
label_preds[mask] = id_mapping[sorted_idx] + 1 | |
else: | |
label_preds[mask] = sorted_idx + 1 | |
return label_preds | |
def eval_semantic( | |
label_space, | |
algo, | |
cfg, | |
model, | |
image_path, | |
stuff_label_space=None, | |
sam_pipeline=None, | |
): | |
"""Semantic segmentation evaluation.""" | |
if label_space is None: | |
raise ValueError( | |
'label_space must be provided for semantic segmentation evaluation' | |
) | |
if algo == 'car': | |
all_masks, scores = inference_car( | |
cfg, model, image_path, label_space, sam_pipeline=sam_pipeline | |
) | |
if stuff_label_space is not None: | |
if cfg.test.ds_name == 'context': | |
thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID | |
stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID | |
elif cfg.test.ds_name == 'ade': | |
thing_id_mapping = ADE_THING_CLASS_ID | |
stuff_id_mapping = ADE_STUFF_CLASS_ID | |
elif cfg.test.ds_name == 'pascal_459': | |
thing_id_mapping = PASCAL_459_THING_CLASS_ID | |
stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID | |
elif cfg.test.ds_name == 'ade_847': | |
thing_id_mapping = ADE_847_THING_CLASS_ID | |
stuff_id_mapping = ADE_847_STUFF_CLASS_ID | |
else: | |
raise ValueError(f'Dataset {cfg.test.ds_name} not supported') | |
model.mask_generator.set_bg_cls(label_space) | |
model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type) | |
model.set_bg_factor(cfg.car.stuff_bg_factor) | |
stuff_masks, stuff_scores = inference_car( | |
cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline | |
) | |
model.mask_generator.set_bg_cls(cfg.car.bg_cls) | |
model.set_visual_prompt_type(cfg.car.visual_prompt_type) | |
model.set_bg_factor(cfg.car.bg_factor) | |
all_masks = all_masks.detach().cpu().numpy() | |
stuff_masks = stuff_masks.detach().cpu().numpy() | |
label_preds = assign_label( | |
all_masks, | |
scores, | |
stuff_masks=stuff_masks, | |
stuff_scores=stuff_scores, | |
id_mapping=thing_id_mapping, | |
stuff_id_mapping=stuff_id_mapping, | |
) | |
else: | |
all_masks = all_masks.detach().cpu().numpy() | |
label_preds = assign_label(all_masks, scores) | |
return label_preds.squeeze() | |
else: | |
raise NotImplementedError(f'algo {algo} not implemented') | |
def _fast_hist(label_true, label_pred, n_class=21): | |
mask = (label_true >= 0) & (label_true < n_class) | |
hist = np.bincount( | |
n_class * label_true[mask].astype(int) + label_pred[mask], | |
minlength=n_class**2, | |
).reshape(n_class, n_class) | |
return hist | |
def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False): | |
"""Semantic segmentation IOU.""" | |
hist = np.zeros((n_class, n_class)) | |
for lt, lp in zip(label_trues, label_preds): | |
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) | |
if ignore_background: | |
hist = hist[1:, 1:] | |
acc = np.diag(hist).sum() / hist.sum() | |
acc_cls = np.diag(hist) / hist.sum(axis=1) | |
acc_cls = np.nanmean(acc_cls) | |
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) | |
valid = hist.sum(axis=1) > 0 # added | |
if valid.sum() == 0: | |
mean_iu = 0 | |
else: | |
mean_iu = np.nanmean(iu[valid]) | |
freq = hist.sum(axis=1) / hist.sum() | |
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() | |
if ignore_background: | |
cls_iu = dict(zip(range(1, n_class), iu)) | |
else: | |
cls_iu = dict(zip(range(n_class), iu)) | |
return { | |
'Pixel Accuracy': acc, | |
'Mean Accuracy': acc_cls, | |
'Frequency Weighted IoU': fwavacc, | |
'mIoU': mean_iu, | |
'Class IoU': cls_iu, | |
} | |
def evaluate( | |
data_loader, | |
cfg, | |
model, | |
test_cfg, | |
label_space=None, | |
stuff_label_space=None, | |
sam_pipeline=None, | |
): | |
"""Run evaluation.""" | |
if ( | |
test_cfg.ds_name | |
not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847'] | |
and test_cfg.seg_mode == 'semantic' | |
): | |
raise ValueError(( | |
'Semantic segmentation evaluation is only implemented for voc, ' | |
'context, coco object, ade, pascal459, ade847 dataset' | |
)) | |
metric_logger = MetricLogger(delimiter=' ') | |
metric_logger.add_meter( | |
'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})') | |
) | |
# evaluation variables | |
cum_i, cum_u = 0, 0 | |
eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9] | |
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) | |
seg_total = 0 | |
mean_iou = [] | |
header = 'Test:' | |
# all_masks = [] | |
label_preds, label_gts = [], [] | |
print(len(data_loader)) | |
cc = 0 | |
use_tensorboard = False | |
if hasattr(cfg.test, 'use_tensorboard'): | |
use_tensorboard = cfg.test.use_tensorboard | |
if use_tensorboard: | |
writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path) | |
for data in metric_logger.log_every(data_loader, 1, header): | |
_, image_paths, target_list, sentences_list = data | |
# print(type(target_lis)) | |
if not isinstance(target_list, list): | |
target_list, sentences_list = [target_list], [sentences_list] | |
for target, sentences in zip(target_list, sentences_list): | |
image_path = image_paths[0] | |
# print(image_path) | |
if test_cfg.seg_mode == 'refer': | |
all_masks, all_scores = inference_car( | |
cfg, model, image_path, sentences, sam_pipeline=sam_pipeline | |
) | |
# final_mask = merge_masks(all_masks, *target.shape[1:]) | |
final_mask = merge_masks_simple( | |
all_masks, *target.shape[1:], scores=all_scores | |
) | |
intersection, union, cur_iou = compute_iou(final_mask, target) | |
# cur_iou = IoU(final_mask, target, 0) | |
metric_logger.update(mIoU=cur_iou) | |
mean_iou.append(cur_iou) | |
if use_tensorboard: | |
writer.add_scalar('Mean IoU', cur_iou, cc) | |
cum_i += intersection | |
cum_u += union | |
for n_eval_iou in range(len(eval_seg_iou_list)): | |
eval_seg_iou = eval_seg_iou_list[n_eval_iou] | |
seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou | |
seg_total += 1 | |
elif test_cfg.seg_mode == 'semantic': | |
# torch.cuda.empty_cache() | |
label_pred = eval_semantic( | |
label_space, | |
test_cfg.algo, | |
cfg, | |
model, | |
image_path, | |
stuff_label_space, | |
) | |
label_gt = target.squeeze().cpu().numpy() | |
cur_iou = semantic_iou( | |
[label_gt], | |
[label_pred], | |
n_class=cfg.test.n_class, | |
ignore_background=cfg.test.ignore_background, | |
)['mIoU'] | |
metric_logger.update(mIoU=cur_iou) | |
label_preds.append(label_pred) | |
label_gts.append(label_gt) | |
cc += 1 | |
if test_cfg.seg_mode == 'refer': | |
mean_iou = np.array(mean_iou) | |
m_iou = np.mean(mean_iou) | |
if use_tensorboard: | |
writer.add_scalar('mIoU', m_iou.item(), len(data_loader)) | |
print('Final results:') | |
print('Mean IoU is %.2f\n' % (m_iou * 100.0)) | |
results_str = '' | |
for n_eval_iou in range(len(eval_seg_iou_list)): | |
results_str += ' precision@%s = %.2f\n' % ( | |
str(eval_seg_iou_list[n_eval_iou]), | |
seg_correct[n_eval_iou] * 100.0 / seg_total, | |
) | |
o_iou = cum_i * 100.0 / cum_u | |
results_str += ' overall IoU = %.2f\n' % o_iou | |
if use_tensorboard: | |
writer.add_scalar('oIoU', o_iou, 0) | |
print(results_str) | |
elif test_cfg.seg_mode == 'semantic': | |
iou_score = semantic_iou( | |
label_gts, | |
label_preds, | |
n_class=cfg.test.n_class, | |
ignore_background=cfg.test.ignore_background, | |
) | |
if use_tensorboard: | |
writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader)) | |
print(iou_score) | |
if use_tensorboard: | |
writer.close() | |
def compute_iou(pred_seg, gd_seg): | |
"""Compute IoU.""" | |
intersection = torch.sum(torch.logical_and(pred_seg, gd_seg)) | |
union = torch.sum(torch.logical_or(pred_seg, gd_seg)) | |
iou = intersection * 1.0 / union | |
if union == 0: | |
iou = 0 | |
return intersection, union, iou | |
def list_of_strings(arg): | |
return [a.strip() for a in arg.split(',')] | |
# pylint: disable=redefined-outer-name | |
def parse_args(): | |
"""Parse arguments.""" | |
parser = argparse.ArgumentParser(description='Training') | |
parser.add_argument( | |
'--cfg-path', | |
default='configs/refcoco_test_prompt.yaml', | |
help='path to configuration file.', | |
) | |
parser.add_argument('--index', default=0, type=int, help='split task') | |
parser.add_argument('--mask_threshold', default=0.0, type=float) | |
parser.add_argument('--confidence_threshold', default=0.0, type=float) | |
parser.add_argument('--clipes_threshold', default=0.0, type=float) | |
parser.add_argument('--stuff_bg_factor', default=0.0, type=float) | |
parser.add_argument('--bg_factor', default=0.0, type=float) | |
parser.add_argument('--output_path', default=None, type=str) | |
parser.add_argument( | |
'--visual_prompt_type', default=None, type=list_of_strings | |
) | |
parser.add_argument( | |
'--stuff_visual_prompt_type', default=None, type=list_of_strings | |
) | |
args = parser.parse_args() | |
return args | |
def main(args): | |
cfg = Config(**load_yaml(args.cfg_path)) | |
if args.mask_threshold > 0: | |
cfg.car.mask_threshold = args.mask_threshold | |
if args.confidence_threshold > 0: | |
cfg.car.confidence_threshold = args.confidence_threshold | |
if args.clipes_threshold > 0: | |
cfg.car.clipes_threshold = args.clipes_threshold | |
if args.bg_factor > 0: | |
cfg.car.bg_factor = args.bg_factor | |
if args.stuff_bg_factor > 0: | |
cfg.car.stuff_bg_factor = args.stuff_bg_factor | |
if args.output_path is not None: | |
cfg.test.output_path = args.output_path | |
if args.visual_prompt_type is not None: | |
cfg.car.visual_prompt_type = args.visual_prompt_type | |
if args.stuff_visual_prompt_type is not None: | |
cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type | |
try: | |
data_root = cfg.test.data_root | |
except ValueError: | |
data_root = None | |
dataset_test = get_dataset( | |
cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root | |
) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
stuff_label_space = None | |
if cfg.test.ds_name == 'voc': | |
label_space = VOC_CLASSES | |
elif cfg.test.ds_name == 'cocostuff': | |
label_space = COCO_OBJECT_CLASSES | |
elif cfg.test.ds_name == 'context': | |
# label_space = PASCAL_CONTEXT_CLASSES | |
label_space = PASCAL_CONTEXT_THING_CLASS | |
stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS | |
elif cfg.test.ds_name == 'ade': | |
label_space = ADE_THING_CLASS | |
stuff_label_space = ADE_STUFF_CLASS | |
elif cfg.test.ds_name == 'pascal_459': | |
label_space = PASCAL_459_THING_CLASS | |
stuff_label_space = PASCAL_459_STUFF_CLASS | |
elif cfg.test.ds_name == 'ade_847': | |
label_space = ADE_847_THING_CLASS | |
stuff_label_space = ADE_847_STUFF_CLASS | |
else: | |
label_space = None | |
num_chunks, chunk_index = 1, 0 | |
if hasattr(cfg.test, 'num_chunks'): | |
num_chunks = cfg.test.num_chunks | |
if hasattr(cfg.test, 'chunk_index'): | |
chunk_index = cfg.test.chunk_index | |
# Size of each chunk | |
chunk_size = len(dataset_test) // num_chunks | |
# Choose which chunk to load (0-indexed) | |
# Define a subset of the dataset | |
subset_indices = range( | |
chunk_index * chunk_size, (chunk_index + 1) * chunk_size | |
) | |
subset_dataset = Subset(dataset_test, indices=subset_indices) | |
data_loader_test = torch.utils.data.DataLoader( | |
subset_dataset, batch_size=1, shuffle=False, num_workers=1 | |
) | |
car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode) | |
car_model = car_model.to(device) | |
if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None: | |
print('Using sam online') | |
# sam_checkpoint, model_type = build_sam_config(cfg) | |
build_sam_config(cfg) | |
evaluate( | |
data_loader_test, | |
cfg, | |
car_model, | |
test_cfg=cfg.test, | |
label_space=label_space, | |
stuff_label_space=stuff_label_space, | |
) | |
if __name__ == '__main__': | |
args = parse_args() | |
main(args) | |