Spaces:
Runtime error
Runtime error
import json | |
import cv2 | |
import numpy as np | |
from loguru import logger | |
from lama_cleaner.helper import download_model | |
from lama_cleaner.plugins.base_plugin import BasePlugin | |
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry | |
# 从小到大 | |
SEGMENT_ANYTHING_MODELS = { | |
"vit_b": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", | |
"md5": "01ec64d29a2fca3f0661936605ae66f8", | |
}, | |
"vit_l": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", | |
"md5": "0b3195507c641ddb6910d2bb5adee89c", | |
}, | |
"vit_h": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", | |
"md5": "4b8939a88964f0f4ff5f5b2642c598a6", | |
}, | |
} | |
class InteractiveSeg(BasePlugin): | |
name = "InteractiveSeg" | |
def __init__(self, model_name, device): | |
super().__init__() | |
model_path = download_model( | |
SEGMENT_ANYTHING_MODELS[model_name]["url"], | |
SEGMENT_ANYTHING_MODELS[model_name]["md5"], | |
) | |
logger.info(f"SegmentAnything model path: {model_path}") | |
self.predictor = SamPredictor( | |
sam_model_registry[model_name](checkpoint=model_path).to(device) | |
) | |
self.prev_img_md5 = None | |
def __call__(self, rgb_np_img, files, form): | |
clicks = json.loads(form["clicks"]) | |
return self.forward(rgb_np_img, clicks, form["img_md5"]) | |
def forward(self, rgb_np_img, clicks, img_md5): | |
input_point = [] | |
input_label = [] | |
for click in clicks: | |
x = click[0] | |
y = click[1] | |
input_point.append([x, y]) | |
input_label.append(click[2]) | |
if img_md5 and img_md5 != self.prev_img_md5: | |
self.prev_img_md5 = img_md5 | |
self.predictor.set_image(rgb_np_img) | |
masks, scores, _ = self.predictor.predict( | |
point_coords=np.array(input_point), | |
point_labels=np.array(input_label), | |
multimask_output=False, | |
) | |
mask = masks[0].astype(np.uint8) * 255 | |
# TODO: how to set kernel size? | |
kernel_size = 9 | |
mask = cv2.dilate( | |
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1 | |
) | |
# fronted brush color "ffcc00bb" | |
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) | |
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)] | |
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) | |
return res_mask | |