ttengwang
clean up code, add langchain for chatbox
9a84ec8
raw
history blame
17.6 kB
import os
import cv2
import requests
import numpy as np
from PIL import Image
import time
import sys
import urllib
from tqdm import tqdm
import hashlib
def is_platform_win():
return sys.platform == "win32"
def colormap(rgb=True):
color_list = np.array(
[
0.000, 0.000, 0.000,
1.000, 1.000, 1.000,
1.000, 0.498, 0.313,
0.392, 0.581, 0.929,
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.000, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
color_list = colormap()
color_list = color_list.astype('uint8').tolist()
def vis_add_mask(image, mask, color, alpha, kernel_size):
color = np.array(color)
mask = mask.astype('float').copy()
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
for i in range(3):
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
return image
def vis_add_mask_wo_blur(image, mask, color, alpha):
color = np.array(color)
mask = mask.astype('float').copy()
for i in range(3):
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
return image
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
background_color = np.array(background_color)
contour_color = np.array(contour_color)
# background_mask = 1 - background_mask
# contour_mask = 1 - contour_mask
for i in range(3):
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
+ background_color[i] * (background_alpha-background_mask*background_alpha)
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
return image.astype('uint8')
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
"""
add color mask to the background/foreground area
input_image: numpy array (w, h, C)
input_mask: numpy array (w, h)
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
background_color: color index of the background (area with input_mask == False)
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
paint_foreground: True for paint on foreground, False for background. Default: Flase
Output:
painted_image: numpy array
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape'
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
# 0: background, 1: foreground
input_mask[input_mask>0] = 255
if paint_foreground:
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
else:
# mask background
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
# mask contour
contour_mask = input_mask.copy()
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
# widden contour
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
contour_mask = cv2.dilate(contour_mask, kernel)
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
return painted_image
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
"""
paint color mask on the all foreground area
input_image: numpy array with shape (w, h, C)
input_mask: list of masks, each mask is a numpy array with shape (w,h)
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
background_color: color index of the background (area with input_mask == False)
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
Output:
painted_image: numpy array
"""
for i, input_mask in enumerate(input_masks):
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
return input_image
def mask_generator_00(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
return mask, contour_mask
def mask_generator_01(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return mask, contour_mask
def mask_generator_10(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask>0.5] = 1.
return background_mask, contour_mask
def mask_generator_11(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return background_mask, contour_mask
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
"""
Input:
input_image: numpy array
input_mask: numpy array
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
Output:
painted_image: numpy array
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape'
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
# downsample input image and mask
width, height = input_image.shape[0], input_image.shape[1]
res = 1024
ratio = min(1.0 * res / max(width, height), 1.0)
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
# 0: background, 1: foreground
msk = np.clip(input_mask, 0, 1)
# generate masks for background and contour pixels
background_radius = (background_blur_radius - 1) // 2
contour_radius = (contour_width - 1) // 2
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
# paint
painted_image = vis_add_mask_wo_gaussian \
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
return painted_image
if __name__ == '__main__':
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
background_blur_radius = 31 # radius of background blur, must be odd number
contour_width = 11 # contour width, must be odd number
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
contour_alpha = 1 # transparency of background, 0: no contour highlighted
# load input image and mask
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
# paint
overall_time_1 = 0
overall_time_2 = 0
overall_time_3 = 0
overall_time_4 = 0
overall_time_5 = 0
for i in range(50):
t2 = time.time()
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
e2 = time.time()
t3 = time.time()
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
e3 = time.time()
t1 = time.time()
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
e1 = time.time()
t4 = time.time()
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
e4 = time.time()
t5 = time.time()
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
e5 = time.time()
overall_time_1 += (e1 - t1)
overall_time_2 += (e2 - t2)
overall_time_3 += (e3 - t3)
overall_time_4 += (e4 - t4)
overall_time_5 += (e5 - t5)
print(f'average time w gaussian: {overall_time_1/50}')
print(f'average time w/o gaussian00: {overall_time_2/50}')
print(f'average time w/o gaussian10: {overall_time_3/50}')
print(f'average time w/o gaussian01: {overall_time_4/50}')
print(f'average time w/o gaussian11: {overall_time_5/50}')
# save
painted_image_00 = Image.fromarray(painted_image_00)
painted_image_00.save('./test_images/painter_output_image_00.png')
painted_image_10 = Image.fromarray(painted_image_10)
painted_image_10.save('./test_images/painter_output_image_10.png')
painted_image_01 = Image.fromarray(painted_image_01)
painted_image_01.save('./test_images/painter_output_image_01.png')
painted_image_11 = Image.fromarray(painted_image_11)
painted_image_11.save('./test_images/painter_output_image_11.png')
seg_model_map = {
'base': 'vit_b',
'large': 'vit_l',
'huge': 'vit_h'
}
ckpt_url_map = {
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
}
expected_sha256_map = {
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
}
def prepare_segmenter(segmenter = "huge", download_root: str = None):
"""
Prepare segmenter model and download checkpoint if necessary.
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
"""
os.makedirs('result', exist_ok=True)
seg_model_name = seg_model_map[segmenter]
checkpoint_url = ckpt_url_map[seg_model_name]
folder = download_root or os.path.expanduser("~/.cache/SAM")
filename = os.path.basename(checkpoint_url)
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
return seg_model_name, segmenter_checkpoint
def download_checkpoint(url, folder, filename, expected_sha256):
os.makedirs(folder, exist_ok=True)
download_target = os.path.join(folder, filename)
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
for data in response.iter_content(chunk_size=1024):
size = output.write(data)
progress.update(size)
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target