|
from modules.utils import * |
|
|
|
class MaskFormer: |
|
def __init__(self, device, pretrained_model_dir): |
|
print("Initializing MaskFormer to %s" % device) |
|
self.device = device |
|
self.processor = CLIPSegProcessor.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined") |
|
self.model = CLIPSegForImageSegmentation.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined").to(device) |
|
|
|
def inference(self, image_path, text): |
|
threshold = 0.5 |
|
min_area = 0.02 |
|
padding = 20 |
|
original_image = Image.open(image_path) |
|
image = original_image.resize((512, 512)) |
|
inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold |
|
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1]) |
|
if area_ratio < min_area: |
|
return None |
|
true_indices = np.argwhere(mask) |
|
mask_array = np.zeros_like(mask, dtype=bool) |
|
for idx in true_indices: |
|
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx) |
|
mask_array[padded_slice] = True |
|
visual_mask = (mask_array * 255).astype(np.uint8) |
|
image_mask = Image.fromarray(visual_mask) |
|
return image_mask.resize(original_image.size) |