Spaces:
Running
on
Zero
Running
on
Zero
from S2I.samer import SegMent, generate_sam_args | |
from S2I.logger import logger | |
from tqdm import tqdm | |
import gradio as gr | |
import numpy as np | |
import os | |
import shutil | |
import cv2 | |
import requests | |
class SAMController: | |
def __init__(self): | |
self.current_model_type = None | |
self.refine_mask = None | |
def clean(): | |
return None, None, None, None, None, [[]] | |
def save_mask(refined_mask=None, save=False): | |
if refined_mask is not None and save: | |
if os.path.exists(os.path.join(os.getcwd(), 'output_render')): | |
shutil.rmtree(os.path.join(os.getcwd(), 'output_render')) | |
save_path = os.path.join(os.getcwd(), 'output_render') | |
os.makedirs(save_path, exist_ok=True) | |
cv2.imwrite(os.path.join(save_path, f'refined_mask_result.png'), (refined_mask * 255).astype('uint8')) | |
elif refined_mask is None and save: | |
return os.path.join(os.path.join(os.getcwd(), 'output_render'), f'refined_mask_result.png') | |
def download_models(model_type): | |
dir_path = os.path.join(os.getcwd(), 'root_model') | |
sam_models_path = os.path.join(dir_path, 'sam_models') | |
# Models URLs | |
models_urls = { | |
'sam_models': { | |
'vit_b': 'https://huggingface.co/ybelkada/segment-anything/resolve/main/checkpoints/sam_vit_b_01ec64.pth?download=true', | |
'vit_l': 'https://huggingface.co/segments-arnaud/sam_vit_l/resolve/main/sam_vit_l_0b3195.pth?download=true', | |
'vit_h': 'https://huggingface.co/segments-arnaud/sam_vit_h/resolve/main/sam_vit_h_4b8939.pth?download=true' | |
} | |
} | |
# Download specified model type | |
if model_type in models_urls['sam_models']: | |
model_url = models_urls['sam_models'][model_type] | |
os.makedirs(sam_models_path, exist_ok=True) | |
model_path = os.path.join(sam_models_path, model_type + '.pth') | |
if not os.path.exists(model_path): | |
logger.info(f"Downloading {model_type} model...") | |
response = requests.get(model_url, stream=True) | |
response.raise_for_status() # Raise an exception for non-2xx status codes | |
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers | |
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading {model_type} model") as pbar: | |
with open(model_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
f.write(chunk) | |
pbar.update(len(chunk)) | |
logger.info(f"{model_type} model downloaded.") | |
else: | |
logger.info(f"{model_type} model already exists.") | |
return logger.info(f"{model_type} model download complete.") | |
else: | |
return logger.info(f"Invalid model type: {model_type}") | |
def get_models_path(model_type=None, segment=False): | |
sam_models_path = os.path.join(os.getcwd(), 'root_model', 'sam_models') | |
if segment: | |
sam_args = generate_sam_args(sam_checkpoint=sam_models_path, model_type=model_type) | |
return sam_args, sam_models_path | |
def get_click_prompt(click_stack, point): | |
click_stack[0].append(point["coord"]) | |
click_stack[1].append(point["mode"] | |
) | |
prompt = { | |
"points_coord": click_stack[0], | |
"points_mode": click_stack[1], | |
"multi_mask": "True", | |
} | |
return prompt | |
def read_temp_file(temp_file_wrapper): | |
name = temp_file_wrapper.name | |
with open(temp_file_wrapper.name, 'rb') as f: | |
# Read the content of the file | |
file_content = f.read() | |
return file_content, name | |
def get_meta_from_image(self, input_img): | |
file_content, _ = self.read_temp_file(input_img) | |
np_arr = np.frombuffer(file_content, np.uint8) | |
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) | |
first_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return first_frame, first_frame | |
def is_sam_model(self, model_type): | |
sam_args, sam_models_dir = self.get_models_path(model_type=model_type, segment=True) | |
model_path = os.path.join(sam_models_dir, model_type + '.pth') | |
if not os.path.exists(model_path): | |
self.download_models(model_type=model_type) | |
return 'Model is downloaded', sam_args | |
else: | |
return 'Model is already downloaded', sam_args | |
def init_segment( | |
points_per_side, | |
origin_frame, | |
sam_args, | |
predict_iou_thresh=0.8, | |
stability_score_thresh=0.9, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=200): | |
if origin_frame is None: | |
return None, origin_frame, [[], []] | |
sam_args["generator_args"]["points_per_side"] = points_per_side | |
sam_args["generator_args"]["pred_iou_thresh"] = predict_iou_thresh | |
sam_args["generator_args"]["stability_score_thresh"] = stability_score_thresh | |
sam_args["generator_args"]["crop_n_layers"] = crop_n_layers | |
sam_args["generator_args"]["crop_n_points_downscale_factor"] = crop_n_points_downscale_factor | |
sam_args["generator_args"]["min_mask_region_area"] = min_mask_region_area | |
segment = SegMent(sam_args) | |
logger.info(f"Model Init: {sam_args}") | |
return segment, origin_frame, [[], []] | |
def seg_acc_click(segment, prompt, origin_frame): | |
# seg acc to click | |
refined_mask, masked_frame = segment.seg_acc_click( | |
origin_frame=origin_frame, | |
coords=np.array(prompt["points_coord"]), | |
modes=np.array(prompt["points_mode"]), | |
multimask=prompt["multi_mask"], | |
) | |
return refined_mask, masked_frame | |
def undo_click_stack_and_refine_seg(self, segment, origin_frame, click_stack): | |
if segment is None: | |
return segment, origin_frame, [[], []] | |
logger.info("Undo !") | |
if len(click_stack[0]) > 0: | |
click_stack[0] = click_stack[0][: -1] | |
click_stack[1] = click_stack[1][: -1] | |
if len(click_stack[0]) > 0: | |
prompt = { | |
"points_coord": click_stack[0], | |
"points_mode": click_stack[1], | |
"multi_mask": "True", | |
} | |
_, masked_frame = self.seg_acc_click(segment, prompt, origin_frame) | |
return segment, masked_frame, click_stack | |
else: | |
return segment, origin_frame, [[], []] | |
def reload_segment(self, | |
check_sam, | |
segment, | |
model_type, | |
point_per_sides, | |
origin_frame, | |
predict_iou_thresh, | |
stability_score_thresh, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
min_mask_region_area): | |
status, sam_args = check_sam(model_type) | |
if segment is None or status == 'Model is downloaded': | |
segment, _, _ = self.init_segment(point_per_sides, | |
origin_frame, | |
sam_args, | |
predict_iou_thresh, | |
stability_score_thresh, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
min_mask_region_area) | |
self.current_model_type = model_type | |
return segment, self.current_model_type, status | |
def sam_click(self, | |
evt: gr.SelectData, | |
segment, | |
origin_frame, | |
model_type, | |
point_mode, | |
click_stack, | |
point_per_sides, | |
predict_iou_thresh, | |
stability_score_thresh, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
min_mask_region_area): | |
logger.info("Click") | |
if point_mode == "Positive": | |
point = {"coord": [evt.index[0], evt.index[1]], "mode": 1} | |
else: | |
point = {"coord": [evt.index[0], evt.index[1]], "mode": 0} | |
click_prompt = self.get_click_prompt(click_stack, point) | |
segment, self.current_model_type, status = self.reload_segment( | |
self.is_sam_model, | |
segment, | |
model_type, | |
point_per_sides, | |
origin_frame, | |
predict_iou_thresh, | |
stability_score_thresh, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
min_mask_region_area) | |
if segment is not None and model_type != self.current_model_type: | |
segment = None | |
segment, _, status = self.reload_segment( | |
self.is_sam_model, | |
segment, | |
model_type, | |
point_per_sides, | |
origin_frame, | |
predict_iou_thresh, | |
stability_score_thresh, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
min_mask_region_area) | |
refined_mask, masked_frame = self.seg_acc_click(segment, click_prompt, origin_frame) | |
self.save_mask(refined_mask, save=True) | |
self.refine_mask = refined_mask | |
return segment, masked_frame, click_stack, status | |
def normalize_image(image): | |
# Normalize the image to the range [0, 1] | |
min_val = image.min() | |
max_val = image.max() | |
image = (image - min_val) / (max_val - min_val) | |
return image | |
def compute_probability(masks): | |
p_max = None | |
for mask in masks: | |
p = mask['prob'] | |
if p_max is None: | |
p_max = p | |
else: | |
p_max = np.maximum(p_max, p) | |
return p_max | |
def download_opencv_model(model_url): | |
opencv_model_path = os.path.join(os.getcwd(), 'edges_detection') | |
os.makedirs(opencv_model_path, exist_ok=True) | |
model_path = os.path.join(opencv_model_path, 'edges_detection' + '.yml.gz') | |
response = requests.get(model_url, stream=True) | |
response.raise_for_status() # Raise an exception for non-2xx status codes | |
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers | |
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading opencv model") as pbar: | |
with open(model_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
f.write(chunk) | |
pbar.update(len(chunk)) | |
return model_path | |
def automatic_sam2sketch(self, | |
segment, | |
image, | |
origin_frame, | |
model_type | |
): | |
_, sam_args = self.is_sam_model(model_type) | |
if segment is None or model_type != sam_args['model_type']: | |
segment, _, _ = self.init_segment( | |
points_per_side=16, | |
origin_frame=origin_frame, | |
sam_args=sam_args, | |
predict_iou_thresh=0.8, | |
stability_score_thresh=0.9, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=200) | |
model_path = self.download_opencv_model(model_url='https://github.com/nipunmanral/Object-Detection-using-OpenCV/raw/master/model.yml.gz') | |
masks = segment.automatic_generate_mask(image) | |
p_max = self.compute_probability(masks) | |
edges = self.normalize_image(p_max) | |
edge_detection = cv2.ximgproc.createStructuredEdgeDetection(model_path) | |
orimap = edge_detection.computeOrientation(edges) | |
edges = edge_detection.edgesNms(edges, orimap) | |
edges = (edges * 255).astype('uint8') | |
edges = 255 - edges | |
edges = np.stack((edges,) * 3, axis=-1) | |
return edges | |