FrankZxShen's picture
Upload 55 files
aa69275
raw
history blame contribute delete
No virus
1.45 kB
from modules.utils import *
class MaskFormer:
def __init__(self, device, pretrained_model_dir):
print("Initializing MaskFormer to %s" % device)
self.device = device
self.processor = CLIPSegProcessor.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained(f"{pretrained_model_dir}/clipseg-rd64-refined").to(device)
def inference(self, image_path, text):
threshold = 0.5
min_area = 0.02
padding = 20
original_image = Image.open(image_path)
image = original_image.resize((512, 512))
inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
if area_ratio < min_area:
return None
true_indices = np.argwhere(mask)
mask_array = np.zeros_like(mask, dtype=bool)
for idx in true_indices:
padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
mask_array[padded_slice] = True
visual_mask = (mask_array * 255).astype(np.uint8)
image_mask = Image.fromarray(visual_mask)
return image_mask.resize(original_image.size)