|
from segment_anything import sam_model_registry |
|
import torch.nn as nn |
|
import torch |
|
import argparse |
|
import os |
|
from utils import select_random_points, FocalDiceloss_IoULoss, generate_point, setting_prompt_none, save_masks |
|
from torch.utils.data import DataLoader |
|
from DataLoader import TestingDataset |
|
from metrics import SegMetrics |
|
import time |
|
from tqdm import tqdm |
|
import numpy as np |
|
from torch.nn import functional as F |
|
import logging |
|
import datetime |
|
import cv2 |
|
import random |
|
import csv |
|
import json |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--work_dir", type=str, default="workdir", help="work dir") |
|
parser.add_argument("--run_name", type=str, default="sammed", help="run model name") |
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
parser.add_argument("--image_size", type=int, default=256, help="image_size") |
|
parser.add_argument('--device', type=str, default='cuda') |
|
parser.add_argument("--data_path", type=str, default="data_demo", help="train data path") |
|
parser.add_argument("--metrics", nargs='+', default=['iou', 'dice'], help="metrics") |
|
parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") |
|
parser.add_argument("--sam_checkpoint", type=str, default="pretrain_model/sam-med2d_b.pth", help="sam checkpoint") |
|
parser.add_argument("--boxes_prompt", type=bool, default=True, help="use boxes prompt") |
|
parser.add_argument("--point_num", type=int, default=1, help="point num") |
|
parser.add_argument("--iter_point", type=int, default=1, help="iter num") |
|
parser.add_argument("--multimask", type=bool, default=True, help="ouput multimask") |
|
parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter") |
|
parser.add_argument("--prompt_path", type=str, default=None, help="fix prompt path") |
|
parser.add_argument("--save_pred", type=bool, default=False, help="save reslut") |
|
args = parser.parse_args() |
|
if args.iter_point > 1: |
|
args.point_num = 1 |
|
return args |
|
|
|
|
|
def to_device(batch_input, device): |
|
device_input = {} |
|
for key, value in batch_input.items(): |
|
if value is not None: |
|
if key=='image' or key=='label': |
|
device_input[key] = value.float().to(device) |
|
elif type(value) is list or type(value) is torch.Size: |
|
device_input[key] = value |
|
else: |
|
device_input[key] = value.to(device) |
|
else: |
|
device_input[key] = value |
|
return device_input |
|
|
|
|
|
def postprocess_masks(low_res_masks, image_size, original_size): |
|
ori_h, ori_w = original_size |
|
masks = F.interpolate( |
|
low_res_masks, |
|
(image_size, image_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
if ori_h < image_size and ori_w < image_size: |
|
top = torch.div((image_size - ori_h), 2, rounding_mode='trunc') |
|
left = torch.div((image_size - ori_w), 2, rounding_mode='trunc') |
|
masks = masks[..., top : ori_h + top, left : ori_w + left] |
|
pad = (top, left) |
|
else: |
|
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) |
|
pad = None |
|
return masks, pad |
|
|
|
|
|
def prompt_and_decoder(args, batched_input, ddp_model, image_embeddings): |
|
if batched_input["point_coords"] is not None: |
|
points = (batched_input["point_coords"], batched_input["point_labels"]) |
|
else: |
|
points = None |
|
|
|
with torch.no_grad(): |
|
sparse_embeddings, dense_embeddings = ddp_model.prompt_encoder( |
|
points=points, |
|
boxes=batched_input.get("boxes", None), |
|
masks=batched_input.get("mask_inputs", None), |
|
) |
|
|
|
low_res_masks, iou_predictions = ddp_model.mask_decoder( |
|
image_embeddings = image_embeddings, |
|
image_pe = ddp_model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=args.multimask, |
|
) |
|
|
|
if args.multimask: |
|
max_values, max_indexs = torch.max(iou_predictions, dim=1) |
|
max_values = max_values.unsqueeze(1) |
|
iou_predictions = max_values |
|
low_res = [] |
|
for i, idx in enumerate(max_indexs): |
|
low_res.append(low_res_masks[i:i+1, idx]) |
|
low_res_masks = torch.stack(low_res, 0) |
|
masks = F.interpolate(low_res_masks,(args.image_size, args.image_size), mode="bilinear", align_corners=False,) |
|
return masks, low_res_masks, iou_predictions |
|
|
|
|
|
def is_not_saved(save_path, mask_name): |
|
masks_path = os.path.join(save_path, f"{mask_name}") |
|
if os.path.exists(masks_path): |
|
return False |
|
else: |
|
return True |
|
|
|
def main(args): |
|
print('*'*100) |
|
for key, value in vars(args).items(): |
|
print(key + ': ' + str(value)) |
|
print('*'*100) |
|
|
|
model = sam_model_registry[args.model_type](args).to(args.device) |
|
|
|
criterion = FocalDiceloss_IoULoss() |
|
test_dataset = TestingDataset(data_path=args.data_path, image_size=args.image_size, mode='test', requires_name=True, point_num=args.point_num, return_ori_mask=True, prompt_path=args.prompt_path) |
|
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4) |
|
print('Test data:', len(test_loader)) |
|
|
|
test_pbar = tqdm(test_loader) |
|
l = len(test_loader) |
|
|
|
model.eval() |
|
test_loss = [] |
|
test_iter_metrics = [0] * len(args.metrics) |
|
test_metrics = {} |
|
prompt_dict = {} |
|
|
|
for i, batched_input in enumerate(test_pbar): |
|
batched_input = to_device(batched_input, args.device) |
|
ori_labels = batched_input["ori_label"] |
|
original_size = batched_input["original_size"] |
|
labels = batched_input["label"] |
|
img_name = batched_input['name'][0] |
|
if args.prompt_path is None: |
|
prompt_dict[img_name] = { |
|
"boxes": batched_input["boxes"].squeeze(1).cpu().numpy().tolist(), |
|
"point_coords": batched_input["point_coords"].squeeze(1).cpu().numpy().tolist(), |
|
"point_labels": batched_input["point_labels"].squeeze(1).cpu().numpy().tolist() |
|
} |
|
|
|
with torch.no_grad(): |
|
image_embeddings = model.image_encoder(batched_input["image"]) |
|
|
|
if args.boxes_prompt: |
|
save_path = os.path.join(args.work_dir, args.run_name, "boxes_prompt") |
|
batched_input["point_coords"], batched_input["point_labels"] = None, None |
|
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings) |
|
points_show = None |
|
|
|
else: |
|
save_path = os.path.join(f"{args.work_dir}", args.run_name, f"iter{args.iter_point if args.iter_point > 1 else args.point_num}_prompt") |
|
batched_input["boxes"] = None |
|
point_coords, point_labels = [batched_input["point_coords"]], [batched_input["point_labels"]] |
|
|
|
for iter in range(args.iter_point): |
|
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings) |
|
if iter != args.iter_point-1: |
|
batched_input = generate_point(masks, labels, low_res_masks, batched_input, args.point_num) |
|
batched_input = to_device(batched_input, args.device) |
|
point_coords.append(batched_input["point_coords"]) |
|
point_labels.append(batched_input["point_labels"]) |
|
batched_input["point_coords"] = torch.concat(point_coords,dim=1) |
|
batched_input["point_labels"] = torch.concat(point_labels, dim=1) |
|
|
|
points_show = (torch.concat(point_coords, dim=1), torch.concat(point_labels, dim=1)) |
|
|
|
masks, pad = postprocess_masks(low_res_masks, args.image_size, original_size) |
|
if args.save_pred: |
|
save_masks(masks, save_path, img_name, args.image_size, original_size, pad, batched_input.get("boxes", None), points_show) |
|
|
|
loss = criterion(masks, ori_labels, iou_predictions) |
|
test_loss.append(loss.item()) |
|
|
|
test_batch_metrics = SegMetrics(masks, ori_labels, args.metrics) |
|
test_batch_metrics = [float('{:.4f}'.format(metric)) for metric in test_batch_metrics] |
|
|
|
for j in range(len(args.metrics)): |
|
test_iter_metrics[j] += test_batch_metrics[j] |
|
|
|
test_iter_metrics = [metric / l for metric in test_iter_metrics] |
|
test_metrics = {args.metrics[i]: '{:.4f}'.format(test_iter_metrics[i]) for i in range(len(test_iter_metrics))} |
|
|
|
average_loss = np.mean(test_loss) |
|
if args.prompt_path is None: |
|
with open(os.path.join(args.work_dir,f'{args.image_size}_prompt.json'), 'w') as f: |
|
json.dump(prompt_dict, f, indent=2) |
|
print(f"Test loss: {average_loss:.4f}, metrics: {test_metrics}") |
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args) |
|
|