🎨 Advanced Image Object Extractor
Extract objects from images using text prompts or bounding boxes
import tempfile import time from collections.abc import Sequence from typing import Any, cast import gradio as gr import numpy as np import pillow_heif import spaces import torch from gradio_image_annotation import image_annotator from gradio_imageslider import ImageSlider from PIL import Image from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from refiners.fluxion.utils import no_grad from refiners.solutions import BoxSegmenter from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor import spaces import argparse import os from os import path import shutil from datetime import datetime from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gradio as gr from diffusers import FluxPipeline from PIL import Image from huggingface_hub import login # HF 토큰 인증 처리 HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: raise ValueError("Please set the HF_TOKEN environment variable") try: login(token=HF_TOKEN) except Exception as e: raise ValueError(f"Failed to login to Hugging Face: {str(e)}") # FLUX 파이프라인 초기화 수정 def initialize_pipeline(): try: pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, use_auth_token=HF_TOKEN ) pipe.load_lora_weights( hf_hub_download( "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors", use_auth_token=HF_TOKEN ) ) pipe.fuse_lora(lora_scale=0.125) pipe.to(device="cuda", dtype=torch.bfloat16) return pipe except Exception as e: raise ValueError(f"Failed to initialize pipeline: {str(e)}") # 파이프라인 초기화 try: pipe = initialize_pipeline() except Exception as e: raise RuntimeError(f"Failed to setup the model: {str(e)}") BoundingBox = tuple[int, int, int, int] pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # weird dance because ZeroGPU segmenter = BoxSegmenter(device="cpu") segmenter.device = device segmenter.model = segmenter.model.to(device=segmenter.device) gd_model_path = "IDEA-Research/grounding-dino-base" gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path) gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32) gd_model = gd_model.to(device=device) # type: ignore assert isinstance(gd_model, GroundingDinoForObjectDetection) # FLUX 파이프라인 초기화 코드 추가 pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to(device="cuda", dtype=torch.bfloat16) def generate_background(prompt: str, width: int, height: int) -> Image.Image: """배경 이미지 생성 함수""" try: with timer("Background generation"): image = pipe( prompt=prompt, width=width, height=height, num_inference_steps=8, guidance_scale=4.0, ).images[0] return image except Exception as e: raise gr.Error(f"Background generation failed: {str(e)}") # 괄호 닫기 수정 def combine_with_background(foreground: Image.Image, background: Image.Image) -> Image.Image: """전경과 배경 합성 함수""" background = background.resize(foreground.size) return Image.alpha_composite(background.convert('RGBA'), foreground) def _process( img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None, ) -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]: try: # 기존 객체 추출 로직 mask, bbox, time_log = _gpu_process(img, prompt) masked_alpha = apply_mask(img, mask, defringe=True) # 배경 생성 및 합성 if bg_prompt: background = generate_background(bg_prompt, img.width, img.height) combined = combine_with_background(masked_alpha, background) else: combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha) # 저장 로직 thresholded = mask.point(lambda p: 255 if p > 10 else 0) bbox = thresholded.getbbox() to_dl = masked_alpha.crop(bbox) temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") to_dl.save(temp, format="PNG") temp.close() return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True) except Exception as e: raise gr.Error(f"Processing failed: {str(e)}") def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None: if not bboxes: return None for bbox in bboxes: assert len(bbox) == 4 assert all(isinstance(x, int) for x in bbox) return ( min(bbox[0] for bbox in bboxes), min(bbox[1] for bbox in bboxes), max(bbox[2] for bbox in bboxes), max(bbox[3] for bbox in bboxes), ) def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor: x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1) return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1) def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None: assert isinstance(gd_processor, GroundingDinoProcessor) # Grounding Dino expects a dot after each category. inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device) with no_grad(): outputs = gd_model(**inputs) width, height = img.size results: dict[str, Any] = gd_processor.post_process_grounded_object_detection( outputs, inputs["input_ids"], target_sizes=[(height, width)], )[0] assert "boxes" in results and isinstance(results["boxes"], torch.Tensor) bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height) return bbox_union(bboxes.numpy().tolist()) def apply_mask( img: Image.Image, mask_img: Image.Image, defringe: bool = True, ) -> Image.Image: assert img.size == mask_img.size img = img.convert("RGB") mask_img = mask_img.convert("L") if defringe: # Mitigate edge halo effects via color decontamination rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0 foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha)) img = Image.fromarray((foreground * 255).astype("uint8")) result = Image.new("RGBA", img.size) result.paste(img, (0, 0), mask_img) return result @spaces.GPU def _gpu_process( img: Image.Image, prompt: str | BoundingBox | None, ) -> tuple[Image.Image, BoundingBox | None, list[str]]: # Because of ZeroGPU shenanigans, we need a *single* function with the # `spaces.GPU` decorator that *does not* contain postprocessing. time_log: list[str] = [] if isinstance(prompt, str): t0 = time.time() bbox = gd_detect(img, prompt) time_log.append(f"detect: {time.time() - t0}") if not bbox: print(time_log[0]) raise gr.Error("No object detected") else: bbox = prompt t0 = time.time() mask = segmenter(img, bbox) time_log.append(f"segment: {time.time() - t0}") return mask, bbox, time_log def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]: assert isinstance(img := prompts["image"], Image.Image) assert isinstance(boxes := prompts["boxes"], list) if len(boxes) == 1: assert isinstance(box := boxes[0], dict) bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"]) else: assert len(boxes) == 0 bbox = None return _process(img, bbox) def on_change_bbox(prompts: dict[str, Any] | None): return gr.update(interactive=prompts is not None) def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]: return _process(img, prompt) def on_change_prompt(img: Image.Image | None, prompt: str | None): return gr.update(interactive=bool(img and prompt)) css = """ footer { visibility: hidden; } """ # 스타일 정의 추가 css = """ footer {visibility: hidden} .container {max-width: 1200px; margin: auto; padding: 20px;} .main-title {text-align: center; color: #2a2a2a; margin-bottom: 2em;} .tabs {background: #f7f7f7; border-radius: 15px; padding: 20px;} .input-column {background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 6px rgba(0,0,0,0.1);} .output-column {background: white; padding: 20px; border-radius: 10px; box-shadow: 0 2px 6px rgba(0,0,0,0.1);} .custom-button {background: #2196F3; color: white; border: none; border-radius: 5px; padding: 10px 20px;} .custom-button:hover {background: #1976D2;} .example-region {margin-top: 2em; padding: 20px; background: #f0f0f0; border-radius: 10px;} """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.HTML("""
Extract objects from images using text prompts or bounding boxes