myn0908's picture
init sketch2image
55a3c9a
raw
history blame
12.6 kB
from S2I.samer import SegMent, generate_sam_args
from S2I.logger import logger
from tqdm import tqdm
import gradio as gr
import numpy as np
import os
import shutil
import cv2
import requests
class SAMController:
def __init__(self):
self.current_model_type = None
self.refine_mask = None
@staticmethod
def clean():
return None, None, None, None, None, [[]]
@staticmethod
def save_mask(refined_mask=None, save=False):
if refined_mask is not None and save:
if os.path.exists(os.path.join(os.getcwd(), 'output_render')):
shutil.rmtree(os.path.join(os.getcwd(), 'output_render'))
save_path = os.path.join(os.getcwd(), 'output_render')
os.makedirs(save_path, exist_ok=True)
cv2.imwrite(os.path.join(save_path, f'refined_mask_result.png'), (refined_mask * 255).astype('uint8'))
elif refined_mask is None and save:
return os.path.join(os.path.join(os.getcwd(), 'output_render'), f'refined_mask_result.png')
@staticmethod
def download_models(model_type):
dir_path = os.path.join(os.getcwd(), 'root_model')
sam_models_path = os.path.join(dir_path, 'sam_models')
# Models URLs
models_urls = {
'sam_models': {
'vit_b': 'https://huggingface.co/ybelkada/segment-anything/resolve/main/checkpoints/sam_vit_b_01ec64.pth?download=true',
'vit_l': 'https://huggingface.co/segments-arnaud/sam_vit_l/resolve/main/sam_vit_l_0b3195.pth?download=true',
'vit_h': 'https://huggingface.co/segments-arnaud/sam_vit_h/resolve/main/sam_vit_h_4b8939.pth?download=true'
}
}
# Download specified model type
if model_type in models_urls['sam_models']:
model_url = models_urls['sam_models'][model_type]
os.makedirs(sam_models_path, exist_ok=True)
model_path = os.path.join(sam_models_path, model_type + '.pth')
if not os.path.exists(model_path):
logger.info(f"Downloading {model_type} model...")
response = requests.get(model_url, stream=True)
response.raise_for_status() # Raise an exception for non-2xx status codes
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading {model_type} model") as pbar:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(len(chunk))
logger.info(f"{model_type} model downloaded.")
else:
logger.info(f"{model_type} model already exists.")
return logger.info(f"{model_type} model download complete.")
else:
return logger.info(f"Invalid model type: {model_type}")
@staticmethod
def get_models_path(model_type=None, segment=False):
sam_models_path = os.path.join(os.getcwd(), 'root_model', 'sam_models')
if segment:
sam_args = generate_sam_args(sam_checkpoint=sam_models_path, model_type=model_type)
return sam_args, sam_models_path
@staticmethod
def get_click_prompt(click_stack, point):
click_stack[0].append(point["coord"])
click_stack[1].append(point["mode"]
)
prompt = {
"points_coord": click_stack[0],
"points_mode": click_stack[1],
"multi_mask": "True",
}
return prompt
@staticmethod
def read_temp_file(temp_file_wrapper):
name = temp_file_wrapper.name
with open(temp_file_wrapper.name, 'rb') as f:
# Read the content of the file
file_content = f.read()
return file_content, name
def get_meta_from_image(self, input_img):
file_content, _ = self.read_temp_file(input_img)
np_arr = np.frombuffer(file_content, np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
first_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return first_frame, first_frame
def is_sam_model(self, model_type):
sam_args, sam_models_dir = self.get_models_path(model_type=model_type, segment=True)
model_path = os.path.join(sam_models_dir, model_type + '.pth')
if not os.path.exists(model_path):
self.download_models(model_type=model_type)
return 'Model is downloaded', sam_args
else:
return 'Model is already downloaded', sam_args
@staticmethod
def init_segment(
points_per_side,
origin_frame,
sam_args,
predict_iou_thresh=0.8,
stability_score_thresh=0.9,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=200):
if origin_frame is None:
return None, origin_frame, [[], []]
sam_args["generator_args"]["points_per_side"] = points_per_side
sam_args["generator_args"]["pred_iou_thresh"] = predict_iou_thresh
sam_args["generator_args"]["stability_score_thresh"] = stability_score_thresh
sam_args["generator_args"]["crop_n_layers"] = crop_n_layers
sam_args["generator_args"]["crop_n_points_downscale_factor"] = crop_n_points_downscale_factor
sam_args["generator_args"]["min_mask_region_area"] = min_mask_region_area
segment = SegMent(sam_args)
logger.info(f"Model Init: {sam_args}")
return segment, origin_frame, [[], []]
@staticmethod
def seg_acc_click(segment, prompt, origin_frame):
# seg acc to click
refined_mask, masked_frame = segment.seg_acc_click(
origin_frame=origin_frame,
coords=np.array(prompt["points_coord"]),
modes=np.array(prompt["points_mode"]),
multimask=prompt["multi_mask"],
)
return refined_mask, masked_frame
def undo_click_stack_and_refine_seg(self, segment, origin_frame, click_stack):
if segment is None:
return segment, origin_frame, [[], []]
logger.info("Undo !")
if len(click_stack[0]) > 0:
click_stack[0] = click_stack[0][: -1]
click_stack[1] = click_stack[1][: -1]
if len(click_stack[0]) > 0:
prompt = {
"points_coord": click_stack[0],
"points_mode": click_stack[1],
"multi_mask": "True",
}
_, masked_frame = self.seg_acc_click(segment, prompt, origin_frame)
return segment, masked_frame, click_stack
else:
return segment, origin_frame, [[], []]
def reload_segment(self,
check_sam,
segment,
model_type,
point_per_sides,
origin_frame,
predict_iou_thresh,
stability_score_thresh,
crop_n_layers,
crop_n_points_downscale_factor,
min_mask_region_area):
status, sam_args = check_sam(model_type)
if segment is None or status == 'Model is downloaded':
segment, _, _ = self.init_segment(point_per_sides,
origin_frame,
sam_args,
predict_iou_thresh,
stability_score_thresh,
crop_n_layers,
crop_n_points_downscale_factor,
min_mask_region_area)
self.current_model_type = model_type
return segment, self.current_model_type, status
def sam_click(self,
evt: gr.SelectData,
segment,
origin_frame,
model_type,
point_mode,
click_stack,
point_per_sides,
predict_iou_thresh,
stability_score_thresh,
crop_n_layers,
crop_n_points_downscale_factor,
min_mask_region_area):
logger.info("Click")
if point_mode == "Positive":
point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
else:
point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
click_prompt = self.get_click_prompt(click_stack, point)
segment, self.current_model_type, status = self.reload_segment(
self.is_sam_model,
segment,
model_type,
point_per_sides,
origin_frame,
predict_iou_thresh,
stability_score_thresh,
crop_n_layers,
crop_n_points_downscale_factor,
min_mask_region_area)
if segment is not None and model_type != self.current_model_type:
segment = None
segment, _, status = self.reload_segment(
self.is_sam_model,
segment,
model_type,
point_per_sides,
origin_frame,
predict_iou_thresh,
stability_score_thresh,
crop_n_layers,
crop_n_points_downscale_factor,
min_mask_region_area)
refined_mask, masked_frame = self.seg_acc_click(segment, click_prompt, origin_frame)
self.save_mask(refined_mask, save=True)
self.refine_mask = refined_mask
return segment, masked_frame, click_stack, status
@staticmethod
def normalize_image(image):
# Normalize the image to the range [0, 1]
min_val = image.min()
max_val = image.max()
image = (image - min_val) / (max_val - min_val)
return image
@staticmethod
def compute_probability(masks):
p_max = None
for mask in masks:
p = mask['prob']
if p_max is None:
p_max = p
else:
p_max = np.maximum(p_max, p)
return p_max
@staticmethod
def download_opencv_model(model_url):
opencv_model_path = os.path.join(os.getcwd(), 'edges_detection')
os.makedirs(opencv_model_path, exist_ok=True)
model_path = os.path.join(opencv_model_path, 'edges_detection' + '.yml.gz')
response = requests.get(model_url, stream=True)
response.raise_for_status() # Raise an exception for non-2xx status codes
total_size = int(response.headers.get('content-length', 0)) # Get file size from headers
with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading opencv model") as pbar:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(len(chunk))
return model_path
def automatic_sam2sketch(self,
segment,
image,
origin_frame,
model_type
):
_, sam_args = self.is_sam_model(model_type)
if segment is None or model_type != sam_args['model_type']:
segment, _, _ = self.init_segment(
points_per_side=16,
origin_frame=origin_frame,
sam_args=sam_args,
predict_iou_thresh=0.8,
stability_score_thresh=0.9,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=200)
model_path = self.download_opencv_model(model_url='https://github.com/nipunmanral/Object-Detection-using-OpenCV/raw/master/model.yml.gz')
masks = segment.automatic_generate_mask(image)
p_max = self.compute_probability(masks)
edges = self.normalize_image(p_max)
edge_detection = cv2.ximgproc.createStructuredEdgeDetection(model_path)
orimap = edge_detection.computeOrientation(edges)
edges = edge_detection.edgesNms(edges, orimap)
edges = (edges * 255).astype('uint8')
edges = 255 - edges
edges = np.stack((edges,) * 3, axis=-1)
return edges