Spaces:
Sleeping
Sleeping
File size: 7,291 Bytes
a153c95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from collections import namedtuple
from .modeling.regionspot import build_regionspot_model
import torch.cuda.amp as amp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
import json
from detectron2.modeling import META_ARCH_REGISTRY
from .util.postprocessing import segmentation_postprocess
from detectron2.structures import Boxes, Instances
from .util.preprocessing import prepare_prompt_infer, prepare_prompt_train
__all__ = ["RegionSpot"]
@META_ARCH_REGISTRY.register()
class RegionSpot(nn.Module):
"""
Implement RegionSpot
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.clip_type = cfg.MODEL.CLIP_TYPE
self.inference_box_type = cfg.MODEL.BOX_TYPE
self.clip_input_size = cfg.MODEL.CLIP_INPUT_SIZE
self.clip_target_size = (self.clip_input_size, self.clip_input_size)
self.model, _ = build_regionspot_model(clip_type = self.clip_type, is_training=cfg.MODEL.TRAINING, image_size=self.clip_input_size)
self.model.to(self.device)
if self.inference_box_type != 'GT':
path = './datasets/glip_results/nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json'
with open(path, 'r') as file:
self.pred_results = json.load(file)
else:
self.pred_results = None
@torch.no_grad()
def foward_inference(self, batched_inputs, do_postprocess=True):
with amp.autocast(enabled=True):
with torch.no_grad():
logits_per_image, pred_mask = self.model.forward_eval(batched_inputs, multimask_output=False)
image_sizes = [x["original_size"] for x in batched_inputs]
if self.inference_box_type == 'GT':
boxes = torch.stack([x["instances"].gt_boxes.tensor for x in batched_inputs], dim=0) #n, n_box, n_token, 256
if len(boxes[0]) == 0:
boxes=torch.tensor([[[0,0, image_sizes[0][0], image_sizes[0][1]]]])
else:
boxes = torch.stack([x["pred_boxes"] for x in batched_inputs], dim=0) #n, n_box, n_token, 256
scores = torch.stack([x["scores"] for x in batched_inputs], dim=0)
box_cls = logits_per_image
box_pred = boxes
if self.inference_box_type == 'GT':
results = self.inference_gt_box(box_cls, box_pred, pred_mask, image_sizes)
else:
results = self.inference_pred_box(box_cls, box_pred, scores, pred_mask, image_sizes)
if do_postprocess:
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, image_sizes):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = segmentation_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
else:
return results
def foward_train(self, batched_inputs):
with amp.autocast(enabled=True):
outputs = self.model.forward_train(batched_inputs)
loss = {'loss': outputs}
return loss
def forward(self, batched_inputs, do_postprocess=True):
if not self.training:
# Prepare Prompt.
batched_inputs = prepare_prompt_infer(batched_inputs, pred_results = self.pred_results, target_size=self.clip_target_size)
results = self.foward_inference(batched_inputs)
return results
if self.training:
batched_inputs = prepare_prompt_train(batched_inputs, target_size=self.clip_target_size)
loss_dict = self.foward_train(batched_inputs)
return loss_dict
def inference_gt_box(self, box_cls, box_pred, pred_mask, image_sizes=None):
device = box_cls.device # assuming all tensors are on the same device
results = []
for logits, boxes, masks, img_size in zip(box_cls, box_pred, pred_mask, image_sizes):
# Calculate probabilities and flatten them
probs = F.softmax(logits, dim=-1)
probs_flattened = probs.view(-1)
# Determine number of top predictions to consider
top_num = min(900, len(probs_flattened))
# Get top-k values and indices
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
# Decode the top-k indices to get corresponding labels and boxes
topk_labels = topk_indices % logits.shape[1]
topk_boxes_indices = topk_indices // logits.shape[1]
# Ensure boxes, masks and topk_boxes_indices are on the same device
topk_boxes_indices = topk_boxes_indices.to(device)
boxes = boxes.to(device)
masks = masks.to(device)
# Retrieve predictions using the top-k indices
boxes_for_topk = boxes[topk_boxes_indices]
masks_for_topk = masks[topk_boxes_indices]
scores_for_topk = topk_probs # Modify accordingly if you have another score tensor
# Create Instances object for top-k predictions
result = Instances(img_size)
result.pred_boxes = Boxes(boxes_for_topk)
result.scores = scores_for_topk
result.pred_classes = topk_labels
result.pred_masks = masks_for_topk # Added masks to the result
results.append(result)
return results
def inference_pred_box(self, box_cls, box_pred, box_score, masks, image_sizes=None):
results = []
for i, (logits, box_pred_i, box_score_i, mask_i, img_size) in enumerate(zip(box_cls, box_pred, box_score, masks, image_sizes)):
logits = logits.cuda()
box_pred_i = box_pred_i.cuda()
box_score_i = box_score_i.cuda()
# Calculate probabilities and flatten them
probs = F.softmax(logits, dim=-1)
probs_flattened = probs.view(-1)
# Determine number of top predictions to consider
top_num = min(900, len(probs_flattened))
# Get top-k values and indices
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
# Decode the top-k indices to get corresponding labels and boxes
topk_labels = topk_indices % logits.shape[1]
topk_boxes_indices = topk_indices // logits.shape[1]
# Retrieve predictions using the top-k indices
boxes = box_pred_i[topk_boxes_indices]
masks = mask_i[topk_boxes_indices]
scores = box_score_i[topk_boxes_indices] * topk_probs
# Construct result for the current image
result = Instances(img_size)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = topk_labels
result.pred_masks = masks
results.append(result)
return results
|