Canvas / app.py
ginipick's picture
Update app.py
eabee32 verified
raw
history blame
16.1 kB
import tempfile
import time
from collections.abc import Sequence
from typing import Any, cast
import gradio as gr
import numpy as np
import pillow_heif
import spaces
import torch
from gradio_image_annotation import image_annotator
from gradio_imageslider import ImageSlider
from PIL import Image
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from refiners.fluxion.utils import no_grad
from refiners.solutions import BoxSegmenter
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
import spaces
import argparse
import os
from os import path
import shutil
from datetime import datetime
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
from diffusers import FluxPipeline
from PIL import Image
from huggingface_hub import login
# HF 토큰 인증 처리
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
try:
login(token=HF_TOKEN)
except Exception as e:
raise ValueError(f"Failed to login to Hugging Face: {str(e)}")
# FLUX νŒŒμ΄ν”„λΌμΈ μ΄ˆκΈ°ν™” μˆ˜μ •
def initialize_pipeline():
try:
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
use_auth_token=HF_TOKEN
)
pipe.load_lora_weights(
hf_hub_download(
"ByteDance/Hyper-SD",
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
use_auth_token=HF_TOKEN
)
)
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
return pipe
except Exception as e:
raise ValueError(f"Failed to initialize pipeline: {str(e)}")
# νŒŒμ΄ν”„λΌμΈ μ΄ˆκΈ°ν™”
try:
pipe = initialize_pipeline()
except Exception as e:
raise RuntimeError(f"Failed to setup the model: {str(e)}")
BoundingBox = tuple[int, int, int, int]
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# weird dance because ZeroGPU
segmenter = BoxSegmenter(device="cpu")
segmenter.device = device
segmenter.model = segmenter.model.to(device=segmenter.device)
gd_model_path = "IDEA-Research/grounding-dino-base"
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
gd_model = gd_model.to(device=device) # type: ignore
assert isinstance(gd_model, GroundingDinoForObjectDetection)
# FLUX νŒŒμ΄ν”„λΌμΈ μ΄ˆκΈ°ν™” μ½”λ“œ μΆ”κ°€
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
def generate_background(prompt: str, width: int, height: int) -> Image.Image:
"""λ°°κ²½ 이미지 생성 ν•¨μˆ˜"""
try:
with timer("Background generation"):
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=8,
guidance_scale=4.0,
).images[0]
return image
except Exception as e:
raise gr.Error(f"Background generation failed: {str(e)}") # κ΄„ν˜Έ λ‹«κΈ° μˆ˜μ •
def combine_with_background(foreground: Image.Image, background: Image.Image) -> Image.Image:
"""μ „κ²½κ³Ό λ°°κ²½ ν•©μ„± ν•¨μˆ˜"""
background = background.resize(foreground.size)
return Image.alpha_composite(background.convert('RGBA'), foreground)
def _process(
img: Image.Image,
prompt: str | BoundingBox | None,
bg_prompt: str | None,
) -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
try:
# κΈ°μ‘΄ 객체 μΆ”μΆœ 둜직
mask, bbox, time_log = _gpu_process(img, prompt)
masked_alpha = apply_mask(img, mask, defringe=True)
# λ°°κ²½ 생성 및 ν•©μ„±
if bg_prompt:
background = generate_background(bg_prompt, img.width, img.height)
combined = combine_with_background(masked_alpha, background)
else:
combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
# μ €μž₯ 둜직
thresholded = mask.point(lambda p: 255 if p > 10 else 0)
bbox = thresholded.getbbox()
to_dl = masked_alpha.crop(bbox)
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
to_dl.save(temp, format="PNG")
temp.close()
return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
except Exception as e:
raise gr.Error(f"Processing failed: {str(e)}")
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
if not bboxes:
return None
for bbox in bboxes:
assert len(bbox) == 4
assert all(isinstance(x, int) for x in bbox)
return (
min(bbox[0] for bbox in bboxes),
min(bbox[1] for bbox in bboxes),
max(bbox[2] for bbox in bboxes),
max(bbox[3] for bbox in bboxes),
)
def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
assert isinstance(gd_processor, GroundingDinoProcessor)
# Grounding Dino expects a dot after each category.
inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
with no_grad():
outputs = gd_model(**inputs)
width, height = img.size
results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
target_sizes=[(height, width)],
)[0]
assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
return bbox_union(bboxes.numpy().tolist())
def apply_mask(
img: Image.Image,
mask_img: Image.Image,
defringe: bool = True,
) -> Image.Image:
assert img.size == mask_img.size
img = img.convert("RGB")
mask_img = mask_img.convert("L")
if defringe:
# Mitigate edge halo effects via color decontamination
rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
img = Image.fromarray((foreground * 255).astype("uint8"))
result = Image.new("RGBA", img.size)
result.paste(img, (0, 0), mask_img)
return result
@spaces.GPU
def _gpu_process(
img: Image.Image,
prompt: str | BoundingBox | None,
) -> tuple[Image.Image, BoundingBox | None, list[str]]:
# Because of ZeroGPU shenanigans, we need a *single* function with the
# `spaces.GPU` decorator that *does not* contain postprocessing.
time_log: list[str] = []
if isinstance(prompt, str):
t0 = time.time()
bbox = gd_detect(img, prompt)
time_log.append(f"detect: {time.time() - t0}")
if not bbox:
print(time_log[0])
raise gr.Error("No object detected")
else:
bbox = prompt
t0 = time.time()
mask = segmenter(img, bbox)
time_log.append(f"segment: {time.time() - t0}")
return mask, bbox, time_log
def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
assert isinstance(img := prompts["image"], Image.Image)
assert isinstance(boxes := prompts["boxes"], list)
if len(boxes) == 1:
assert isinstance(box := boxes[0], dict)
bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
else:
assert len(boxes) == 0
bbox = None
return _process(img, bbox)
def on_change_bbox(prompts: dict[str, Any] | None):
return gr.update(interactive=prompts is not None)
def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
return _process(img, prompt)
def on_change_prompt(img: Image.Image | None, prompt: str | None):
return gr.update(interactive=bool(img and prompt))
css = """
footer {
visibility: hidden;
}
"""
# μŠ€νƒ€μΌ μ •μ˜ μΆ”κ°€
css = """
footer {visibility: hidden}
.container {max-width: 1200px; margin: auto; padding: 20px;}
.main-title {text-align: center; color: #2a2a2a; margin-bottom: 2em;}
.tabs {background: #f7f7f7; border-radius: 15px; padding: 20px;}
.input-column {background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 6px rgba(0,0,0,0.1);}
.output-column {background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 6px rgba(0,0,0,0.1);}
.custom-button {background: #2196F3; color: white; border: none; border-radius: 5px; padding: 10px 20px;}
.custom-button:hover {background: #1976D2;}
.example-region {margin-top: 2em; padding: 20px; background: #f0f0f0; border-radius: 10px;}
"""
def process_prompt(img: Image.Image, prompt: str, bg_prompt: str = None) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
return _process(img, prompt, bg_prompt)
def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
return gr.update(interactive=bool(img and prompt))
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.HTML("""
<div class="main-title">
<h1>🎨 Advanced Image Object Extractor</h1>
<p>Extract objects from images using text prompts or bounding boxes</p>
</div>
""")
with gr.Tabs() as tabs:
with gr.Tab("✨ Extract by Text", id="tab_prompt"):
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=400):
gr.HTML("<h3>πŸ“₯ Input Section</h3>")
iimg = gr.Image(
type="pil",
label="Upload Image"
)
with gr.Group():
prompt = gr.Textbox(
label="🎯 Object to Extract",
placeholder="Enter what you want to extract..."
)
bg_prompt = gr.Textbox(
label="πŸ–ΌοΈ Background Generation Prompt (optional)",
placeholder="Describe the background you want..."
)
btn = gr.Button(
"πŸš€ Process Image",
variant="primary",
interactive=False
)
with gr.Column(scale=1, min_width=400):
gr.HTML("<h3>πŸ“€ Output Section</h3>")
oimg = ImageSlider(
label="Results Preview",
show_download_button=False
)
dlbt = gr.DownloadButton(
"πŸ’Ύ Download Result",
interactive=False
)
with gr.Accordion("πŸ“š Examples", open=False):
examples = [
["examples/text.jpg", "text"],
["examples/potted-plant.jpg", "potted plant"],
["examples/chair.jpg", "chair"],
["examples/black-lamp.jpg", "black lamp"],
]
ex = gr.Examples(
examples=examples,
inputs=[iimg, prompt],
outputs=[oimg, dlbt],
fn=process_prompt,
cache_examples=True
)
with gr.Tab("πŸ“ Extract by Box", id="tab_bb"):
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=400):
gr.HTML("<h3>πŸ“₯ Input Section</h3>")
annotator = image_annotator(
image_type="pil",
disable_edit_boxes=True,
show_download_button=False,
show_share_button=False,
single_box=True,
label="Draw Box Around Object"
)
btn_bb = gr.Button(
"βœ‚οΈ Extract Selection",
variant="primary",
interactive=False
)
with gr.Column(scale=1, min_width=400):
gr.HTML("<h3>πŸ“€ Output Section</h3>")
oimg_bb = ImageSlider(
label="Results Preview",
show_download_button=False
)
dlbt_bb = gr.DownloadButton(
"πŸ’Ύ Download Result",
interactive=False
)
with gr.Accordion("πŸ“š Examples", open=False):
examples_bb = [
{
"image": "examples/text.jpg",
"boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
"white background" # bg_prompt 예제 μΆ”κ°€
},
{
"image": "examples/potted-plant.jpg",
"boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
"natural garden background"
},
{
"image": "examples/chair.jpg",
"boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
"modern living room"
},
{
"image": "examples/black-lamp.jpg",
"boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
"minimalist interior"
},
]
ex_bb = gr.Examples(
examples=examples_bb,
inputs=[iimg, prompt, bg_prompt], # bg_prompt μΆ”κ°€
outputs=[oimg_bb, dlbt_bb],
fn=process_bbox,
cache_examples=True
)
# Event handlers
btn.add(oimg)
for inp in [iimg, prompt]:
inp.change(
fn=on_change_prompt,
inputs=[iimg, prompt, bg_prompt], # bg_prompt μΆ”κ°€
outputs=[btn],
)
btn.click(
fn=process_prompt,
inputs=[iimg, prompt, bg_prompt], # bg_prompt μΆ”κ°€
outputs=[oimg, dlbt],
api_name=False,
)
btn_bb.add(oimg_bb)
annotator.change(
fn=on_change_bbox,
inputs=[annotator],
outputs=[btn_bb],
)
btn_bb.click(
fn=process_bbox,
inputs=[annotator],
outputs=[oimg_bb, dlbt_bb],
api_name=False,
)
# CSS μŠ€νƒ€μΌ μ •μ˜
css = """
footer {display: none}
.main-title {
text-align: center;
margin: 2em 0;
}
.main-title h1 {
color: #2196F3;
font-size: 2.5em;
}
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
"""
# Launch settings
demo.queue(max_size=30, api_open=False)
demo.launch(
show_api=False,
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
# Launch settings
demo.queue(max_size=30, api_open=False)
demo.launch(
show_api=False,
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)