Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
from .modeling.regionspot import build_regionspot_model | |
import torch.cuda.amp as amp | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from einops import rearrange | |
import json | |
from detectron2.modeling import META_ARCH_REGISTRY | |
from .util.postprocessing import segmentation_postprocess | |
from detectron2.structures import Boxes, Instances | |
from .util.preprocessing import prepare_prompt_infer, prepare_prompt_train | |
__all__ = ["RegionSpot"] | |
class RegionSpot(nn.Module): | |
""" | |
Implement RegionSpot | |
""" | |
def __init__(self, cfg): | |
super().__init__() | |
self.device = torch.device(cfg.MODEL.DEVICE) | |
self.clip_type = cfg.MODEL.CLIP_TYPE | |
self.inference_box_type = cfg.MODEL.BOX_TYPE | |
self.clip_input_size = cfg.MODEL.CLIP_INPUT_SIZE | |
self.clip_target_size = (self.clip_input_size, self.clip_input_size) | |
self.model, _ = build_regionspot_model(clip_type = self.clip_type, is_training=cfg.MODEL.TRAINING, image_size=self.clip_input_size) | |
self.model.to(self.device) | |
if self.inference_box_type != 'GT': | |
path = './datasets/glip_results/nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json' | |
with open(path, 'r') as file: | |
self.pred_results = json.load(file) | |
else: | |
self.pred_results = None | |
def foward_inference(self, batched_inputs, do_postprocess=True): | |
with amp.autocast(enabled=True): | |
with torch.no_grad(): | |
logits_per_image, pred_mask = self.model.forward_eval(batched_inputs, multimask_output=False) | |
image_sizes = [x["original_size"] for x in batched_inputs] | |
if self.inference_box_type == 'GT': | |
boxes = torch.stack([x["instances"].gt_boxes.tensor for x in batched_inputs], dim=0) #n, n_box, n_token, 256 | |
if len(boxes[0]) == 0: | |
boxes=torch.tensor([[[0,0, image_sizes[0][0], image_sizes[0][1]]]]) | |
else: | |
boxes = torch.stack([x["pred_boxes"] for x in batched_inputs], dim=0) #n, n_box, n_token, 256 | |
scores = torch.stack([x["scores"] for x in batched_inputs], dim=0) | |
box_cls = logits_per_image | |
box_pred = boxes | |
if self.inference_box_type == 'GT': | |
results = self.inference_gt_box(box_cls, box_pred, pred_mask, image_sizes) | |
else: | |
results = self.inference_pred_box(box_cls, box_pred, scores, pred_mask, image_sizes) | |
if do_postprocess: | |
processed_results = [] | |
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, image_sizes): | |
height = input_per_image.get("height", image_size[0]) | |
width = input_per_image.get("width", image_size[1]) | |
r = segmentation_postprocess(results_per_image, height, width) | |
processed_results.append({"instances": r}) | |
return processed_results | |
else: | |
return results | |
def foward_train(self, batched_inputs): | |
with amp.autocast(enabled=True): | |
outputs = self.model.forward_train(batched_inputs) | |
loss = {'loss': outputs} | |
return loss | |
def forward(self, batched_inputs, do_postprocess=True): | |
if not self.training: | |
# Prepare Prompt. | |
batched_inputs = prepare_prompt_infer(batched_inputs, pred_results = self.pred_results, target_size=self.clip_target_size) | |
results = self.foward_inference(batched_inputs) | |
return results | |
if self.training: | |
batched_inputs = prepare_prompt_train(batched_inputs, target_size=self.clip_target_size) | |
loss_dict = self.foward_train(batched_inputs) | |
return loss_dict | |
def inference_gt_box(self, box_cls, box_pred, pred_mask, image_sizes=None): | |
device = box_cls.device # assuming all tensors are on the same device | |
results = [] | |
for logits, boxes, masks, img_size in zip(box_cls, box_pred, pred_mask, image_sizes): | |
# Calculate probabilities and flatten them | |
probs = F.softmax(logits, dim=-1) | |
probs_flattened = probs.view(-1) | |
# Determine number of top predictions to consider | |
top_num = min(900, len(probs_flattened)) | |
# Get top-k values and indices | |
topk_probs, topk_indices = torch.topk(probs_flattened, top_num) | |
# Decode the top-k indices to get corresponding labels and boxes | |
topk_labels = topk_indices % logits.shape[1] | |
topk_boxes_indices = topk_indices // logits.shape[1] | |
# Ensure boxes, masks and topk_boxes_indices are on the same device | |
topk_boxes_indices = topk_boxes_indices.to(device) | |
boxes = boxes.to(device) | |
masks = masks.to(device) | |
# Retrieve predictions using the top-k indices | |
boxes_for_topk = boxes[topk_boxes_indices] | |
masks_for_topk = masks[topk_boxes_indices] | |
scores_for_topk = topk_probs # Modify accordingly if you have another score tensor | |
# Create Instances object for top-k predictions | |
result = Instances(img_size) | |
result.pred_boxes = Boxes(boxes_for_topk) | |
result.scores = scores_for_topk | |
result.pred_classes = topk_labels | |
result.pred_masks = masks_for_topk # Added masks to the result | |
results.append(result) | |
return results | |
def inference_pred_box(self, box_cls, box_pred, box_score, masks, image_sizes=None): | |
results = [] | |
for i, (logits, box_pred_i, box_score_i, mask_i, img_size) in enumerate(zip(box_cls, box_pred, box_score, masks, image_sizes)): | |
logits = logits.cuda() | |
box_pred_i = box_pred_i.cuda() | |
box_score_i = box_score_i.cuda() | |
# Calculate probabilities and flatten them | |
probs = F.softmax(logits, dim=-1) | |
probs_flattened = probs.view(-1) | |
# Determine number of top predictions to consider | |
top_num = min(900, len(probs_flattened)) | |
# Get top-k values and indices | |
topk_probs, topk_indices = torch.topk(probs_flattened, top_num) | |
# Decode the top-k indices to get corresponding labels and boxes | |
topk_labels = topk_indices % logits.shape[1] | |
topk_boxes_indices = topk_indices // logits.shape[1] | |
# Retrieve predictions using the top-k indices | |
boxes = box_pred_i[topk_boxes_indices] | |
masks = mask_i[topk_boxes_indices] | |
scores = box_score_i[topk_boxes_indices] * topk_probs | |
# Construct result for the current image | |
result = Instances(img_size) | |
result.pred_boxes = Boxes(boxes) | |
result.scores = scores | |
result.pred_classes = topk_labels | |
result.pred_masks = masks | |
results.append(result) | |
return results | |