Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,525 Bytes
3d4d894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""Preprocessing methods"""
import logging
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageFilter
import streamlit as st
from config import COLOR_RGB, WIDTH, HEIGHT
# from enhance_config import ENHANCE_SETTINGS
LOGGING = logging.getLogger(__name__)
def preprocess_seg_mask(canvas_seg, real_seg: Image.Image = None) -> Tuple[np.ndarray, np.ndarray]:
"""Preprocess the segmentation mask.
Args:
canvas_seg: segmentation canvas
real_seg (Image.Image, optional): segmentation mask. Defaults to None.
Returns:
Tuple[np.ndarray, np.ndarray]: segmentation mask, segmentation mask with overlay
"""
# get unique colors in the segmentation
image_seg = canvas_seg.image_data.copy()[:, :, :3]
# average the colors of the segmentation masks
average_color = np.mean(image_seg, axis=(2))
mask = average_color[:, :] > 0
if mask.sum() > 0:
mask = mask * 1
unique_colors = np.unique(image_seg.reshape(-1, image_seg.shape[-1]), axis=0)
unique_colors = [tuple(color) for color in unique_colors]
unique_colors = [color for color in unique_colors if np.sum(
np.all(image_seg == color, axis=-1)) > 100]
unique_colors_exact = [color for color in unique_colors if color in COLOR_RGB]
if real_seg is not None:
overlay_seg = np.array(real_seg)
unique_colors = np.unique(overlay_seg.reshape(-1, overlay_seg.shape[-1]), axis=0)
unique_colors = [tuple(color) for color in unique_colors]
for color in unique_colors_exact:
if color != (255, 255, 255) and color != (0, 0, 0):
overlay_seg[np.all(image_seg == color, axis=-1)] = color
image_seg = overlay_seg
return mask, image_seg
def get_mask(image_mask: np.ndarray) -> np.ndarray:
"""Get the mask from the segmentation mask.
Args:
image_mask (np.ndarray): segmentation mask
Returns:
np.ndarray: mask
"""
# average the colors of the segmentation masks
average_color = np.mean(image_mask, axis=(2))
mask = average_color[:, :] > 0
if mask.sum() > 0:
mask = mask * 1
return mask
def get_image() -> np.ndarray:
"""Get the image from the session state.
Returns:
np.ndarray: image
"""
if 'initial_image' in st.session_state and st.session_state['initial_image'] is not None:
initial_image = st.session_state['initial_image']
if isinstance(initial_image, Image.Image):
return np.array(initial_image.resize((WIDTH, HEIGHT)))
else:
return np.array(Image.fromarray(initial_image).resize((WIDTH, HEIGHT)))
else:
return None
# def make_enhance_config(segmentation, objects=None):
"""Make the enhance config for the segmentation image.
"""
info = ENHANCE_SETTINGS[objects]
segmentation = np.array(segmentation)
if 'replace' in info:
replace_color = info['replace']
mask = np.zeros(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
segmentation[np.all(segmentation == color, axis=-1)] = replace_color
if info['inverse'] is False:
mask = np.zeros(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
else:
mask = np.ones(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [0, 0, 0]
st.session_state['positive_prompt'] = info['positive_prompt']
st.session_state['negative_prompt'] = info['negative_prompt']
if info['inpainting'] is True:
mask = mask.astype(np.uint8)
mask = Image.fromarray(mask)
mask = mask.filter(ImageFilter.GaussianBlur(radius=13))
mask = mask.filter(ImageFilter.MaxFilter(size=9))
mask = np.array(mask)
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
mask = mask.astype(np.uint8)
conditioning = dict(
mask_image=mask,
positive_prompt=info['positive_prompt'],
negative_prompt=info['negative_prompt'],
)
else:
conditioning = dict(
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=info['positive_prompt'],
negative_prompt=info['negative_prompt'],
strength=info['strength']
)
return conditioning, info['inpainting'] |