Spaces:
Build error
Build error
import sys | |
import random | |
import torch | |
import torch.nn as nn | |
from .point import Point | |
from .polygon import Polygon | |
from .scribble import Scribble | |
from .circle import Circle | |
from modeling.utils import configurable | |
class ShapeSampler(nn.Module): | |
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True): | |
super().__init__() | |
self.max_candidate = max_candidate | |
self.shape_prob = shape_prob | |
self.shape_candidate = shape_candidate | |
self.is_train = is_train | |
def from_config(cls, cfg, is_train=True, mode=None): | |
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE'] | |
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS'] | |
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES'] | |
if mode == 'hack_train': | |
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names] | |
else: | |
# overwrite condidate_prob | |
if not is_train: | |
candidate_probs = [0.0 for x in range(len(candidate_names))] | |
candidate_probs[candidate_names.index(mode)] = 1.0 | |
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names] | |
# Build augmentation | |
return { | |
"max_candidate": max_candidate, | |
"shape_prob": candidate_probs, | |
"shape_candidate": candidate_classes, | |
"is_train": is_train, | |
} | |
def forward(self, instances): | |
masks = instances.gt_masks.tensor | |
boxes = instances.gt_boxes.tensor | |
if len(masks) == 0: | |
gt_masks = torch.zeros(masks.shape[-2:]).bool() | |
rand_masks = torch.zeros(masks.shape[-2:]).bool() | |
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']} | |
indices = [x for x in range(len(masks))] | |
if self.is_train: | |
random.shuffle(indices) | |
candidate_mask = masks[indices[:self.max_candidate]] | |
candidate_box = boxes[indices[:self.max_candidate]] | |
else: | |
candidate_mask = masks | |
candidate_box = boxes | |
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask)) | |
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)] | |
types = [repr(x) for x in draw_funcs] | |
for i in range(0, len(rand_shapes)): | |
if rand_shapes[i].sum() == 0: | |
candidate_mask[i] = candidate_mask[i] * 0 | |
types[i] = 'none' | |
# candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c) | |
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self} | |
def build_shape_sampler(cfg, **kwargs): | |
return ShapeSampler(cfg, **kwargs) |