SAMMed-2d / test.py
Arya6967's picture
Upload 49 files
df2c70a
from segment_anything import sam_model_registry
import torch.nn as nn
import torch
import argparse
import os
from utils import select_random_points, FocalDiceloss_IoULoss, generate_point, setting_prompt_none, save_masks
from torch.utils.data import DataLoader
from DataLoader import TestingDataset
from metrics import SegMetrics
import time
from tqdm import tqdm
import numpy as np
from torch.nn import functional as F
import logging
import datetime
import cv2
import random
import csv
import json
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--work_dir", type=str, default="workdir", help="work dir")
parser.add_argument("--run_name", type=str, default="sammed", help="run model name")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--image_size", type=int, default=256, help="image_size")
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument("--data_path", type=str, default="data_demo", help="train data path")
parser.add_argument("--metrics", nargs='+', default=['iou', 'dice'], help="metrics")
parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type")
parser.add_argument("--sam_checkpoint", type=str, default="pretrain_model/sam-med2d_b.pth", help="sam checkpoint")
parser.add_argument("--boxes_prompt", type=bool, default=True, help="use boxes prompt")
parser.add_argument("--point_num", type=int, default=1, help="point num")
parser.add_argument("--iter_point", type=int, default=1, help="iter num")
parser.add_argument("--multimask", type=bool, default=True, help="ouput multimask")
parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter")
parser.add_argument("--prompt_path", type=str, default=None, help="fix prompt path")
parser.add_argument("--save_pred", type=bool, default=False, help="save reslut")
args = parser.parse_args()
if args.iter_point > 1:
args.point_num = 1
return args
def to_device(batch_input, device):
device_input = {}
for key, value in batch_input.items():
if value is not None:
if key=='image' or key=='label':
device_input[key] = value.float().to(device)
elif type(value) is list or type(value) is torch.Size:
device_input[key] = value
else:
device_input[key] = value.to(device)
else:
device_input[key] = value
return device_input
def postprocess_masks(low_res_masks, image_size, original_size):
ori_h, ori_w = original_size
masks = F.interpolate(
low_res_masks,
(image_size, image_size),
mode="bilinear",
align_corners=False,
)
if ori_h < image_size and ori_w < image_size:
top = torch.div((image_size - ori_h), 2, rounding_mode='trunc') #(image_size - ori_h) // 2
left = torch.div((image_size - ori_w), 2, rounding_mode='trunc') #(image_size - ori_w) // 2
masks = masks[..., top : ori_h + top, left : ori_w + left]
pad = (top, left)
else:
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
pad = None
return masks, pad
def prompt_and_decoder(args, batched_input, ddp_model, image_embeddings):
if batched_input["point_coords"] is not None:
points = (batched_input["point_coords"], batched_input["point_labels"])
else:
points = None
with torch.no_grad():
sparse_embeddings, dense_embeddings = ddp_model.prompt_encoder(
points=points,
boxes=batched_input.get("boxes", None),
masks=batched_input.get("mask_inputs", None),
)
low_res_masks, iou_predictions = ddp_model.mask_decoder(
image_embeddings = image_embeddings,
image_pe = ddp_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=args.multimask,
)
if args.multimask:
max_values, max_indexs = torch.max(iou_predictions, dim=1)
max_values = max_values.unsqueeze(1)
iou_predictions = max_values
low_res = []
for i, idx in enumerate(max_indexs):
low_res.append(low_res_masks[i:i+1, idx])
low_res_masks = torch.stack(low_res, 0)
masks = F.interpolate(low_res_masks,(args.image_size, args.image_size), mode="bilinear", align_corners=False,)
return masks, low_res_masks, iou_predictions
def is_not_saved(save_path, mask_name):
masks_path = os.path.join(save_path, f"{mask_name}")
if os.path.exists(masks_path):
return False
else:
return True
def main(args):
print('*'*100)
for key, value in vars(args).items():
print(key + ': ' + str(value))
print('*'*100)
model = sam_model_registry[args.model_type](args).to(args.device)
criterion = FocalDiceloss_IoULoss()
test_dataset = TestingDataset(data_path=args.data_path, image_size=args.image_size, mode='test', requires_name=True, point_num=args.point_num, return_ori_mask=True, prompt_path=args.prompt_path)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4)
print('Test data:', len(test_loader))
test_pbar = tqdm(test_loader)
l = len(test_loader)
model.eval()
test_loss = []
test_iter_metrics = [0] * len(args.metrics)
test_metrics = {}
prompt_dict = {}
for i, batched_input in enumerate(test_pbar):
batched_input = to_device(batched_input, args.device)
ori_labels = batched_input["ori_label"]
original_size = batched_input["original_size"]
labels = batched_input["label"]
img_name = batched_input['name'][0]
if args.prompt_path is None:
prompt_dict[img_name] = {
"boxes": batched_input["boxes"].squeeze(1).cpu().numpy().tolist(),
"point_coords": batched_input["point_coords"].squeeze(1).cpu().numpy().tolist(),
"point_labels": batched_input["point_labels"].squeeze(1).cpu().numpy().tolist()
}
with torch.no_grad():
image_embeddings = model.image_encoder(batched_input["image"])
if args.boxes_prompt:
save_path = os.path.join(args.work_dir, args.run_name, "boxes_prompt")
batched_input["point_coords"], batched_input["point_labels"] = None, None
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings)
points_show = None
else:
save_path = os.path.join(f"{args.work_dir}", args.run_name, f"iter{args.iter_point if args.iter_point > 1 else args.point_num}_prompt")
batched_input["boxes"] = None
point_coords, point_labels = [batched_input["point_coords"]], [batched_input["point_labels"]]
for iter in range(args.iter_point):
masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings)
if iter != args.iter_point-1:
batched_input = generate_point(masks, labels, low_res_masks, batched_input, args.point_num)
batched_input = to_device(batched_input, args.device)
point_coords.append(batched_input["point_coords"])
point_labels.append(batched_input["point_labels"])
batched_input["point_coords"] = torch.concat(point_coords,dim=1)
batched_input["point_labels"] = torch.concat(point_labels, dim=1)
points_show = (torch.concat(point_coords, dim=1), torch.concat(point_labels, dim=1))
masks, pad = postprocess_masks(low_res_masks, args.image_size, original_size)
if args.save_pred:
save_masks(masks, save_path, img_name, args.image_size, original_size, pad, batched_input.get("boxes", None), points_show)
loss = criterion(masks, ori_labels, iou_predictions)
test_loss.append(loss.item())
test_batch_metrics = SegMetrics(masks, ori_labels, args.metrics)
test_batch_metrics = [float('{:.4f}'.format(metric)) for metric in test_batch_metrics]
for j in range(len(args.metrics)):
test_iter_metrics[j] += test_batch_metrics[j]
test_iter_metrics = [metric / l for metric in test_iter_metrics]
test_metrics = {args.metrics[i]: '{:.4f}'.format(test_iter_metrics[i]) for i in range(len(test_iter_metrics))}
average_loss = np.mean(test_loss)
if args.prompt_path is None:
with open(os.path.join(args.work_dir,f'{args.image_size}_prompt.json'), 'w') as f:
json.dump(prompt_dict, f, indent=2)
print(f"Test loss: {average_loss:.4f}, metrics: {test_metrics}")
if __name__ == '__main__':
args = parse_args()
main(args)