RegionSpot / regionspot /detector.py
bklg's picture
Upload 114 files
a153c95
raw
history blame
7.29 kB
from collections import namedtuple
from .modeling.regionspot import build_regionspot_model
import torch.cuda.amp as amp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import json
from detectron2.modeling import META_ARCH_REGISTRY
from .util.postprocessing import segmentation_postprocess
from detectron2.structures import Boxes, Instances
from .util.preprocessing import prepare_prompt_infer, prepare_prompt_train
__all__ = ["RegionSpot"]
@META_ARCH_REGISTRY.register()
class RegionSpot(nn.Module):
"""
Implement RegionSpot
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.clip_type = cfg.MODEL.CLIP_TYPE
self.inference_box_type = cfg.MODEL.BOX_TYPE
self.clip_input_size = cfg.MODEL.CLIP_INPUT_SIZE
self.clip_target_size = (self.clip_input_size, self.clip_input_size)
self.model, _ = build_regionspot_model(clip_type = self.clip_type, is_training=cfg.MODEL.TRAINING, image_size=self.clip_input_size)
self.model.to(self.device)
if self.inference_box_type != 'GT':
path = './datasets/glip_results/nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json'
with open(path, 'r') as file:
self.pred_results = json.load(file)
else:
self.pred_results = None
@torch.no_grad()
def foward_inference(self, batched_inputs, do_postprocess=True):
with amp.autocast(enabled=True):
with torch.no_grad():
logits_per_image, pred_mask = self.model.forward_eval(batched_inputs, multimask_output=False)
image_sizes = [x["original_size"] for x in batched_inputs]
if self.inference_box_type == 'GT':
boxes = torch.stack([x["instances"].gt_boxes.tensor for x in batched_inputs], dim=0) #n, n_box, n_token, 256
if len(boxes[0]) == 0:
boxes=torch.tensor([[[0,0, image_sizes[0][0], image_sizes[0][1]]]])
else:
boxes = torch.stack([x["pred_boxes"] for x in batched_inputs], dim=0) #n, n_box, n_token, 256
scores = torch.stack([x["scores"] for x in batched_inputs], dim=0)
box_cls = logits_per_image
box_pred = boxes
if self.inference_box_type == 'GT':
results = self.inference_gt_box(box_cls, box_pred, pred_mask, image_sizes)
else:
results = self.inference_pred_box(box_cls, box_pred, scores, pred_mask, image_sizes)
if do_postprocess:
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, image_sizes):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = segmentation_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
else:
return results
def foward_train(self, batched_inputs):
with amp.autocast(enabled=True):
outputs = self.model.forward_train(batched_inputs)
loss = {'loss': outputs}
return loss
def forward(self, batched_inputs, do_postprocess=True):
if not self.training:
# Prepare Prompt.
batched_inputs = prepare_prompt_infer(batched_inputs, pred_results = self.pred_results, target_size=self.clip_target_size)
results = self.foward_inference(batched_inputs)
return results
if self.training:
batched_inputs = prepare_prompt_train(batched_inputs, target_size=self.clip_target_size)
loss_dict = self.foward_train(batched_inputs)
return loss_dict
def inference_gt_box(self, box_cls, box_pred, pred_mask, image_sizes=None):
device = box_cls.device # assuming all tensors are on the same device
results = []
for logits, boxes, masks, img_size in zip(box_cls, box_pred, pred_mask, image_sizes):
# Calculate probabilities and flatten them
probs = F.softmax(logits, dim=-1)
probs_flattened = probs.view(-1)
# Determine number of top predictions to consider
top_num = min(900, len(probs_flattened))
# Get top-k values and indices
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
# Decode the top-k indices to get corresponding labels and boxes
topk_labels = topk_indices % logits.shape[1]
topk_boxes_indices = topk_indices // logits.shape[1]
# Ensure boxes, masks and topk_boxes_indices are on the same device
topk_boxes_indices = topk_boxes_indices.to(device)
boxes = boxes.to(device)
masks = masks.to(device)
# Retrieve predictions using the top-k indices
boxes_for_topk = boxes[topk_boxes_indices]
masks_for_topk = masks[topk_boxes_indices]
scores_for_topk = topk_probs # Modify accordingly if you have another score tensor
# Create Instances object for top-k predictions
result = Instances(img_size)
result.pred_boxes = Boxes(boxes_for_topk)
result.scores = scores_for_topk
result.pred_classes = topk_labels
result.pred_masks = masks_for_topk # Added masks to the result
results.append(result)
return results
def inference_pred_box(self, box_cls, box_pred, box_score, masks, image_sizes=None):
results = []
for i, (logits, box_pred_i, box_score_i, mask_i, img_size) in enumerate(zip(box_cls, box_pred, box_score, masks, image_sizes)):
logits = logits.cuda()
box_pred_i = box_pred_i.cuda()
box_score_i = box_score_i.cuda()
# Calculate probabilities and flatten them
probs = F.softmax(logits, dim=-1)
probs_flattened = probs.view(-1)
# Determine number of top predictions to consider
top_num = min(900, len(probs_flattened))
# Get top-k values and indices
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
# Decode the top-k indices to get corresponding labels and boxes
topk_labels = topk_indices % logits.shape[1]
topk_boxes_indices = topk_indices // logits.shape[1]
# Retrieve predictions using the top-k indices
boxes = box_pred_i[topk_boxes_indices]
masks = mask_i[topk_boxes_indices]
scores = box_score_i[topk_boxes_indices] * topk_probs
# Construct result for the current image
result = Instances(img_size)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = topk_labels
result.pred_masks = masks
results.append(result)
return results