Spaces:
Runtime error
Runtime error
from typing import Tuple, Dict | |
import requests | |
import random | |
import numpy as np | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from diffusers import StableDiffusionInpaintPipeline | |
INFO = """ | |
# FLUX-Based Inpainting π¨ | |
This interface utilizes a FLUX model variant for precise inpainting. Special thanks to the [Black Forest Labs](https://huggingface.co/black-forest-labs) team | |
and [Gothos](https://github.com/Gothos) for contributing to this advanced solution. | |
""" | |
# Constants | |
MAX_SEED_VALUE = np.iinfo(np.int32).max | |
TARGET_DIM = 1024 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Function to clear background | |
def clear_background(image: Image.Image, threshold: int = 50) -> Image.Image: | |
image = image.convert("RGBA") | |
pixels = image.getdata() | |
processed_data = [ | |
(0, 0, 0, 0) if sum(pixel[:3]) / 3 < threshold else pixel for pixel in pixels | |
] | |
image.putdata(processed_data) | |
return image | |
# Sample data examples | |
EXAMPLES = [ | |
[ | |
{ | |
"background": Image.open(requests.get("https://example.com/doge-1.png", stream=True).raw), | |
"layers": [clear_background(Image.open(requests.get("https://example.com/mask-1.png", stream=True).raw))], | |
"composite": Image.open(requests.get("https://example.com/composite-1.png", stream=True).raw), | |
}, | |
"desert mirage", | |
42, | |
False, | |
0.75, | |
25 | |
], | |
[ | |
{ | |
"background": Image.open(requests.get("https://example.com/doge-2.png", stream=True).raw), | |
"layers": [clear_background(Image.open(requests.get("https://example.com/mask-2.png", stream=True).raw))], | |
"composite": Image.open(requests.get("https://example.com/composite-2.png", stream=True).raw), | |
}, | |
"neon city", | |
100, | |
True, | |
0.9, | |
35 | |
] | |
] | |
# Load model | |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE) | |
# Utility to adjust image size | |
def get_scaled_dimensions( | |
original_size: Tuple[int, int], max_dim: int = TARGET_DIM | |
) -> Tuple[int, int]: | |
width, height = original_size | |
scaling_factor = max_dim / max(width, height) | |
return (int(width * scaling_factor) // 32 * 32, int(height * scaling_factor) // 32 * 32) | |
def generate_inpainting( | |
input_data: Dict, | |
prompt_text: str, | |
chosen_seed: int, | |
use_random_seed: bool, | |
inpainting_strength: float, | |
steps: int, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if not prompt_text: | |
return gr.Info("Provide a prompt to proceed."), None | |
background = input_data.get("background") | |
mask_layer = input_data.get("layers")[0] | |
if not background: | |
return gr.Info("Background image is missing."), None | |
if not mask_layer: | |
return gr.Info("Mask layer is missing."), None | |
new_width, new_height = get_scaled_dimensions(background.size) | |
resized_background = background.resize((new_width, new_height), Image.LANCZOS) | |
resized_mask = mask_layer.resize((new_width, new_height), Image.LANCZOS) | |
if use_random_seed: | |
chosen_seed = random.randint(0, MAX_SEED_VALUE) | |
torch.manual_seed(chosen_seed) | |
generated_image = inpainting_pipeline( | |
prompt=prompt_text, | |
image=resized_background, | |
mask_image=resized_mask, | |
strength=inpainting_strength, | |
num_inference_steps=steps, | |
).images[0] | |
return generated_image, resized_mask | |
# Build the Gradio interface | |
with gr.Blocks() as flux_app: | |
gr.Markdown(INFO) | |
with gr.Row(): | |
with gr.Column(): | |
image_editor = gr.ImageEditor( | |
label="Edit Image", | |
type="pil", | |
sources=["upload", "webcam"], | |
brush=gr.Brush(colors=["#FFF"], color_mode="fixed") | |
) | |
prompt_box = gr.Text( | |
label="Inpainting Prompt", placeholder="Describe the change you'd like." | |
) | |
run_button = gr.Button(value="Run Inpainting") | |
with gr.Accordion("Settings"): | |
seed_slider = gr.Slider(0, MAX_SEED_VALUE, step=1, value=42, label="Seed") | |
random_seed_toggle = gr.Checkbox(label="Randomize Seed", value=True) | |
inpainting_strength_slider = gr.Slider(0.0, 1.0, step=0.01, value=0.85, label="Inpainting Strength") | |
steps_slider = gr.Slider(1, 50, step=1, value=25, label="Inference Steps") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
output_mask = gr.Image(label="Processed Mask") | |
run_button.click( | |
generate_inpainting, | |
inputs=[image_editor, prompt_box, seed_slider, random_seed_toggle, inpainting_strength_slider, steps_slider], | |
outputs=[output_image, output_mask] | |
) | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=generate_inpainting, | |
inputs=[image_editor, prompt_box, seed_slider, random_seed_toggle, inpainting_strength_slider, steps_slider], | |
outputs=[output_image, output_mask], | |
run_on_click=True, | |
) | |
flux_app.launch(debug=False, show_error=True) | |