SkalskiP's picture
debug
b83cf94
raw
history blame
7.3 kB
import random
from typing import Tuple
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image, ImageFilter
from diffusers import FluxInpaintPipeline
from gradio_client import Client, handle_file
MARKDOWN = """
# FLUX.1 Inpainting 🔥
Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
for taking it to the next level by enabling inpainting with the FLUX.
"""
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# HF_TOKEN = os.environ.get("HF_TOKEN", None)
# client = Client("SkalskiP/florence-sam-masking", hf_token=HF_TOKEN)
pipe = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_resolution_wh
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
def is_image_empty(image: Image.Image) -> bool:
gray_img = image.convert("L")
pixels = list(gray_img.getdata())
return all(pixel == 0 for pixel in pixels)
def set_client_for_session(request: gr.Request):
x_ip_token = request.headers['x-ip-token']
return Client("SkalskiP/florence-sam-masking", headers={"X-IP-Token": x_ip_token})
@spaces.GPU(duration=100)
def process(
client,
input_image_editor: dict,
inpainting_prompt_text: str,
masking_prompt_text: str,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
progress=gr.Progress(track_tqdm=True)
):
if not inpainting_prompt_text:
gr.Info("Please enter a text prompt.")
return None, None
print(input_image_editor)
image_path = input_image_editor['background']
mask_path = input_image_editor['layers'][0]
image = Image.open(image_path)
mask = Image.open(mask_path)
if not image:
gr.Info("Please upload an image.")
return None, None
if is_image_empty(mask) and not masking_prompt_text:
gr.Info("Please draw a mask or enter a masking prompt.")
return None, None
if not is_image_empty(mask) and masking_prompt_text:
gr.Info("Both mask and masking prompt are provided. Please provide only one.")
return None, None
if is_image_empty(mask):
mask = client.predict(
image_input=handle_file(image_path),
text_input=masking_prompt_text,
api_name="/process_image")
mask = Image.open(mask)
mask = mask.filter(ImageFilter.GaussianBlur(radius=5))
width, height = resize_image_dimensions(original_resolution_wh=image.size)
resized_image = image.resize((width, height), Image.LANCZOS)
resized_mask = mask.resize((width, height), Image.LANCZOS)
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
result = pipe(
prompt=inpainting_prompt_text,
image=resized_image,
mask_image=resized_mask,
width=width,
height=height,
strength=strength_slider,
generator=generator,
num_inference_steps=num_inference_steps_slider
).images[0]
print('INFERENCE DONE')
return result, resized_mask
with gr.Blocks() as demo:
client_component = gr.State()
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image_editor_component = gr.ImageEditor(
label='Image',
type='filepath',
sources=["upload", "webcam"],
image_mode='RGB',
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
with gr.Row():
inpainting_prompt_text_component = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter text to generate inpainting",
container=False,
)
submit_button_component = gr.Button(
value='Submit', variant='primary', scale=0)
with gr.Accordion("Advanced Settings", open=False):
masking_prompt_text_component = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter text to generate masking",
container=False,
)
seed_slicer_component = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed_checkbox_component = gr.Checkbox(
label="Randomize seed", value=True)
with gr.Row():
strength_slider_component = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`. "
"Must be between 0 and 1. `image` is used as a starting "
"point and more noise is added the higher the `strength`.",
minimum=0,
maximum=1,
step=0.01,
value=0.85,
)
num_inference_steps_slider_component = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps. More denoising steps "
"usually lead to a higher quality image at the",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Generated image', format="png")
with gr.Accordion("Debug", open=False):
output_mask_component = gr.Image(
type='pil', image_mode='RGB', label='Input mask', format="png")
submit_button_component.click(
fn=process,
inputs=[
input_image_editor_component,
inpainting_prompt_text_component,
masking_prompt_text_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component
],
outputs=[
output_image_component,
output_mask_component
]
)
demo.load(set_client_for_session, None, client_component)
demo.launch(debug=False, show_error=True)