Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
from typing import Tuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.config import configurable | |
from detectron2.data import MetadataCatalog | |
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head | |
from detectron2.modeling.backbone import Backbone | |
from detectron2.modeling.postprocessing import sem_seg_postprocess | |
from detectron2.structures import ImageList | |
from detectron2.utils.memory import _ignore_torch_cuda_oom | |
import numpy as np | |
from einops import rearrange | |
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator | |
class CATSeg(nn.Module): | |
def __init__( | |
self, | |
*, | |
backbone: Backbone, | |
sem_seg_head: nn.Module, | |
size_divisibility: int, | |
pixel_mean: Tuple[float], | |
pixel_std: Tuple[float], | |
clip_pixel_mean: Tuple[float], | |
clip_pixel_std: Tuple[float], | |
train_class_json: str, | |
test_class_json: str, | |
sliding_window: bool, | |
clip_finetune: str, | |
backbone_multiplier: float, | |
clip_pretrained: str, | |
): | |
""" | |
Args: | |
backbone: a backbone module, must follow detectron2's backbone interface | |
sem_seg_head: a module that predicts semantic segmentation from backbone features | |
""" | |
super().__init__() | |
self.backbone = backbone | |
self.sem_seg_head = sem_seg_head | |
if size_divisibility < 0: | |
size_divisibility = self.backbone.size_divisibility | |
self.size_divisibility = size_divisibility | |
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) | |
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | |
self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False) | |
self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False) | |
self.train_class_json = train_class_json | |
self.test_class_json = test_class_json | |
self.clip_finetune = clip_finetune | |
for name, params in self.sem_seg_head.predictor.clip_model.named_parameters(): | |
if "visual" in name: | |
if clip_finetune == "prompt": | |
params.requires_grad = True if "prompt" in name else False | |
elif clip_finetune == "attention": | |
params.requires_grad = True if "attn" in name or "position" in name else False | |
elif clip_finetune == "full": | |
params.requires_grad = True | |
else: | |
params.requires_grad = False | |
else: | |
params.requires_grad = False | |
finetune_backbone = backbone_multiplier > 0. | |
for name, params in self.backbone.named_parameters(): | |
if "norm0" in name: | |
params.requires_grad = False | |
else: | |
params.requires_grad = finetune_backbone | |
self.sliding_window = sliding_window | |
self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336) | |
self.sequential = False | |
self.use_sam = False | |
self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device) | |
amg_kwargs = { | |
"points_per_side": 32, | |
"points_per_batch": None, | |
#"pred_iou_thresh": 0.0, | |
#"stability_score_thresh": 0.0, | |
"stability_score_offset": None, | |
"box_nms_thresh": None, | |
"crop_n_layers": None, | |
"crop_nms_thresh": None, | |
"crop_overlap_ratio": None, | |
"crop_n_points_downscale_factor": None, | |
"min_mask_region_area": None, | |
} | |
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} | |
self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs) | |
self.overlap_threshold = 0.8 | |
self.panoptic_on = False | |
def from_config(cls, cfg): | |
backbone = build_backbone(cfg) | |
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) | |
return { | |
"backbone": backbone, | |
"sem_seg_head": sem_seg_head, | |
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, | |
"pixel_mean": cfg.MODEL.PIXEL_MEAN, | |
"pixel_std": cfg.MODEL.PIXEL_STD, | |
"clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN, | |
"clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD, | |
"train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON, | |
"test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON, | |
"sliding_window": cfg.TEST.SLIDING_WINDOW, | |
"clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE, | |
"backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER, | |
"clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED, | |
} | |
def device(self): | |
return self.pixel_mean.device | |
def forward(self, batched_inputs): | |
""" | |
Args: | |
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. | |
Each item in the list contains the inputs for one image. | |
For now, each item in the list is a dict that contains: | |
* "image": Tensor, image in (C, H, W) format. | |
* "instances": per-region ground truth | |
* Other information that's included in the original dicts, such as: | |
"height", "width" (int): the output resolution of the model (may be different | |
from input resolution), used in inference. | |
Returns: | |
list[dict]: | |
each dict has the results for one image. The dict contains the following keys: | |
* "sem_seg": | |
A Tensor that represents the | |
per-pixel segmentation prediced by the head. | |
The prediction has shape KxHxW that represents the logits of | |
each class for each pixel. | |
""" | |
images = [x["image"].to(self.device) for x in batched_inputs] | |
sam_images = images | |
if not self.training and self.sliding_window: | |
if not self.sequential: | |
with _ignore_torch_cuda_oom(): | |
return self.inference_sliding_window(batched_inputs) | |
self.sequential = True | |
return self.inference_sliding_window(batched_inputs) | |
clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images] | |
clip_images = ImageList.from_tensors(clip_images, self.size_divisibility) | |
images = [(x - self.pixel_mean) / self.pixel_std for x in images] | |
images = ImageList.from_tensors(images, self.size_divisibility) | |
clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, ) | |
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) | |
images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,) | |
features = self.backbone(images_resized) | |
outputs = self.sem_seg_head(clip_features, features) | |
if self.training: | |
targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0) | |
outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False) | |
num_classes = outputs.shape[1] | |
mask = targets != self.sem_seg_head.ignore_value | |
outputs = outputs.permute(0,2,3,1) | |
_targets = torch.zeros(outputs.shape, device=self.device) | |
_onehot = F.one_hot(targets[mask], num_classes=num_classes).float() | |
_targets[mask] = _onehot | |
loss = F.binary_cross_entropy_with_logits(outputs, _targets) | |
losses = {"loss_sem_seg" : loss} | |
return losses | |
else: | |
#outputs = outputs.sigmoid() | |
image_size = images.image_sizes[0] | |
if self.use_sam: | |
masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) | |
outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size) | |
#outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size) | |
#outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text) | |
height = batched_inputs[0].get("height", image_size[0]) | |
width = batched_inputs[0].get("width", image_size[1]) | |
output = sem_seg_postprocess(outputs[0], image_size, height, width) | |
processed_results = [{'sem_seg': output}] | |
return processed_results | |
def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]): | |
images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs] | |
stride = int(kernel * (1 - overlap)) | |
unfold = nn.Unfold(kernel_size=kernel, stride=stride) | |
fold = nn.Fold(out_res, kernel_size=kernel, stride=stride) | |
image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze() | |
sam_images = [image] | |
image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel) | |
global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False) | |
image = torch.cat((image, global_image), dim=0) | |
images = (image - self.pixel_mean) / self.pixel_std | |
clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std | |
clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, ) | |
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) | |
if self.sequential: | |
outputs = [] | |
for clip_feat, image in zip(clip_features, images): | |
feature = self.backbone(image.unsqueeze(0)) | |
output = self.sem_seg_head(clip_feat.unsqueeze(0), feature) | |
outputs.append(output[0]) | |
outputs = torch.stack(outputs, dim=0) | |
else: | |
features = self.backbone(images) | |
outputs = self.sem_seg_head(clip_features, features) | |
outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False) | |
outputs = outputs.sigmoid() | |
global_output = outputs[-1:] | |
global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,) | |
outputs = outputs[:-1] | |
outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device))) | |
outputs = (outputs + global_output) / 2. | |
height = batched_inputs[0].get("height", out_res[0]) | |
width = batched_inputs[0].get("width", out_res[1]) | |
catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width) | |
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) | |
if self.use_sam: | |
outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res) | |
#outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res) | |
output = sem_seg_postprocess(outputs[0], out_res, height, width) | |
ret = [{'sem_seg': output}] | |
if self.panoptic_on: | |
panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:]) | |
ret[0]['panoptic_seg'] = panoptic_r | |
return ret | |
def discrete_semantic_inference(self, outputs, masks, image_size): | |
catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu() | |
sam_outputs = torch.zeros_like(catseg_outputs).cpu() | |
catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
sam_classes = torch.zeros(len(masks)) | |
for i in range(len(masks)): | |
m = masks[i]['segmentation'] | |
s = masks[i]['stability_score'] | |
idx = catseg_outputs[m].bincount().argmax() | |
sam_outputs[0, idx][m] = s | |
sam_classes[i] = idx | |
return sam_outputs, sam_classes | |
def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.): | |
#import pdb; pdb.set_trace() | |
catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
sam_outputs = torch.zeros_like(catseg_outputs) | |
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
sam_classes = torch.zeros(len(masks)) | |
#import pdb; pdb.set_trace() | |
mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W | |
mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N | |
mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) | |
mask_norm = mask_pred.sum(-1).sum(-1) | |
mask_cls = mask_cls / mask_norm[:, None] | |
mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] | |
mask_logits = mask_pred * mask_score[:, None, None] | |
output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) | |
return output.unsqueeze(0), mask_cls | |
def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None): | |
assert img is not None and text is not None | |
import pdb; pdb.set_trace() | |
#catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu() | |
img = img.permute(1, 2, 0) | |
#sam_outputs = torch.zeros_like(catseg_outputs) | |
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() | |
sam_classes = torch.zeros(len(masks)) | |
#import pdb; pdb.set_trace() | |
mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W | |
mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N | |
mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img) | |
mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True) | |
mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu()) | |
mask_cls = mask_cls.softmax(dim=1) | |
#mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) | |
mask_norm = mask_pred.sum(-1).sum(-1) | |
mask_cls = mask_cls / mask_norm[:, None] | |
mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] | |
mask_logits = mask_pred * mask_score[:, None, None] | |
output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) | |
return output.unsqueeze(0), sam_classes | |
def panoptic_inference(self, outputs, masks, sam_classes, size=None): | |
#import pdb; pdb.set_trace() | |
scores = np.asarray([x['predicted_iou'] for x in masks]) | |
mask_pred = np.asarray([x['segmentation'] for x in masks]) | |
#keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) | |
cur_scores = torch.tensor(scores) | |
cur_masks = torch.tensor(mask_pred) | |
cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0] | |
cur_classes = sam_classes.argmax(dim=-1) | |
#cur_mask_cls = mask_cls#[keep] | |
#cur_mask_cls = cur_mask_cls[:, :-1] | |
#import pdb; pdb.set_trace() | |
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks | |
h, w = cur_masks.shape[-2:] | |
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) | |
segments_info = [] | |
current_segment_id = 0 | |
if cur_masks.shape[0] == 0: | |
# We didn't detect any mask :( | |
return panoptic_seg, segments_info | |
else: | |
# take argmax | |
cur_mask_ids = cur_prob_masks.argmax(0) | |
stuff_memory_list = {} | |
for k in range(cur_classes.shape[0]): | |
pred_class = cur_classes[k].item() | |
#isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() | |
isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values() | |
mask = cur_mask_ids == k | |
mask_area = mask.sum().item() | |
original_area = (cur_masks[k] >= 0.5).sum().item() | |
if mask_area > 0 and original_area > 0: | |
if mask_area / original_area < self.overlap_threshold: | |
continue | |
# merge stuff regions | |
if not isthing: | |
if int(pred_class) in stuff_memory_list.keys(): | |
panoptic_seg[mask] = stuff_memory_list[int(pred_class)] | |
continue | |
else: | |
stuff_memory_list[int(pred_class)] = current_segment_id + 1 | |
current_segment_id += 1 | |
panoptic_seg[mask] = current_segment_id | |
segments_info.append( | |
{ | |
"id": current_segment_id, | |
"isthing": bool(isthing), | |
"category_id": int(pred_class), | |
} | |
) | |
return panoptic_seg, segments_info |