Spaces:
Sleeping
Sleeping
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
import mediapipe as mp | |
from transformers import pipeline | |
from skimage.measure import label, regionprops | |
import gradio as gr | |
import torch | |
import diffusers | |
import tqdm as notebook_tqdm | |
from diffusers import StableDiffusionInpaintPipeline | |
from diffusers import StableDiffusion3Pipeline | |
import cv2 | |
import math | |
import gradio as gr | |
import numpy as np | |
import os | |
import mediapipe as mp | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
from mediapipe.tasks.python.components import containers | |
from skimage.measure import label, regionprops | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
from skimage.measure import label | |
from skimage.measure import regionprops | |
from PIL import Image | |
import torch | |
import requests | |
import tensorflow as tf | |
def _normalized_to_pixel_coordinates( | |
normalized_x: float, normalized_y: float, image_width: int, | |
image_height: int): | |
"""Converts normalized value pair to pixel coordinates.""" | |
# Checks if the float value is between 0 and 1. | |
def is_valid_normalized_value(value: float) -> bool: | |
return (value > 0 or math.isclose(0, value)) and (value < 1 or | |
math.isclose(1, value)) | |
if not (is_valid_normalized_value(normalized_x) and | |
is_valid_normalized_value(normalized_y)): | |
# TODO: Draw coordinates even if it's outside of the image bounds. | |
return None | |
x_px = min(math.floor(normalized_x * image_width), image_width - 1) | |
y_px = min(math.floor(normalized_y * image_height), image_height - 1) | |
return x_px, y_px | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
).to(device) | |
BG_COLOR = (192, 192, 192) # gray | |
MASK_COLOR = (255, 255, 255) # white | |
RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest | |
NormalizedKeypoint = containers.keypoint.NormalizedKeypoint | |
# Create the options that will be used for InteractiveSegmenter | |
base_options = python.BaseOptions(model_asset_path='model.tflite') | |
options = vision.ImageSegmenterOptions(base_options=base_options, | |
output_category_mask=True) | |
def get_bounding_box(mask): | |
"""Generate bounding box coordinates from a binary mask.""" | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
x, y, w, h = cv2.boundingRect(contours[0]) | |
return x, y, x + w, y + h | |
return 0, 0, mask.shape[1], mask.shape[0] | |
def example_segmentation_function(image_file_path, x, y): | |
OVERLAY_COLOR = (255, 105, 180) # Rose | |
base_options = python.BaseOptions(model_asset_path='model.tflite') | |
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True) | |
with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter: | |
image = mp.Image.create_from_file(image_file_path) | |
roi = vision.InteractiveSegmenterRegionOfInterest( | |
format=vision.InteractiveSegmenterRegionOfInterest.Format.KEYPOINT, | |
keypoint=containers.keypoint.NormalizedKeypoint(x, y) | |
) | |
segmentation_result = segmenter.segment(image, roi) | |
category_mask = segmentation_result.category_mask | |
segmentation_mask = category_mask.numpy_view().astype(np.uint8) | |
return segmentation_mask, image | |
def segment(image_file_name, x, y, prompt): | |
OVERLAY_COLOR = (255, 105, 180) # Rose | |
# Créer le segmenteur | |
with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter: | |
# Créer l'image MediaPipe | |
image = mp.Image.create_from_file(image_file_name) | |
# Récupérer les masques de catégorie pour l'image | |
roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT, | |
keypoint=NormalizedKeypoint(x, y)) | |
segmentation_result = segmenter.segment(image, roi) | |
category_mask = segmentation_result.category_mask | |
# Trouver la boîte englobante de la région segmentée | |
mask = category_mask.numpy_view().astype(np.uint8) | |
# Trouver la boîte englobante de la région segmentée | |
x, y, w, h = cv2.boundingRect(mask) | |
# Convertir l'image BGR en RGB | |
image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB) | |
# Créer une image d'incrustation avec la couleur désirée (par exemple, (255, 0, 0) pour le rouge) | |
overlay_image = np.zeros(image_data.shape, dtype=np.uint8) | |
overlay_image[:] = OVERLAY_COLOR | |
# Créer la condition à partir du tableau category_masks | |
alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) <= 0.1 | |
# Créer un canal alpha à partir de la condition avec l'opacité désirée (par exemple, 0.7 pour 70%) | |
alpha = alpha.astype(float) * 0.5 # Réduire l'opacité à 50% | |
# Fusionner l'image originale et l'image d'incrustation en fonction du canal alpha | |
output_image = image_data * (1 - alpha) + overlay_image * alpha | |
output_image = output_image.astype(np.uint8) | |
# Dessiner un point blanc avec une bordure noire pour indiquer le point d'intérêt | |
thickness, radius = 6, -1 | |
keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height) | |
cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius) | |
cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius) | |
# Convert the mask to binary if it's not already | |
binary_mask = (mask == 255).astype(np.uint8) | |
# Label the regions in the mask | |
labels = label(binary_mask) | |
# Obtain properties of the labeled regions | |
props = regionprops(labels) | |
# Initialize bounding box coordinates | |
minr, minc, maxr, maxc = 0, 0, 0, 0 | |
for prop in props: | |
minr, minc, maxr, maxc = prop.bbox | |
# Add a 30-pixel margin | |
minr = max(0, minr - 300) | |
minc = max(0, minc - 300) | |
maxr = min(binary_mask.shape[0], maxr + 400) | |
maxc = min(binary_mask.shape[1], maxc + 400) | |
# Create a new black image | |
bbox_image = np.zeros_like(binary_mask) | |
# Draw the bounding box in white | |
bbox_image[minr:maxr, minc:maxc] = 255 | |
print(bbox_image) | |
plt.imshow(bbox_image) | |
plt.show() | |
return output_image, bbox_image | |
def generate(image_file_path, x, y, prompt): | |
output_image, bbox_image = segment(image_file_path, x, y, prompt) | |
# Check and process images | |
if image_file_path is None or bbox_image is None: | |
return None | |
# Read image | |
img = Image.open(image_file_path).convert("RGB") | |
# Generate images using images and prompts | |
images = pipe(prompt=prompt, | |
image=img, | |
mask_image=bbox_image, | |
generator=torch.Generator(device="cuda").manual_seed(0), | |
num_images_per_prompt=3, | |
plms=True).images | |
# Create an image grid | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
grid_image = image_grid(images, 1, 3) | |
return output_image, grid_image | |
webapp = gr.Interface(fn=generate, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload an image"), | |
gr.Slider(minimum=0, maximum=1, step=0.01, label="x"), | |
gr.Slider(minimum=0, maximum=1, step=0.01, label="y"), | |
gr.Textbox(label="Prompt")], | |
outputs=[ | |
gr.Image(type="pil", label="Segmented Image"), | |
gr.Image(type="pil", label="Generated Image Grid")]) | |
webapp.launch(debug=True) | |