ViTMatte / app.py
hysts's picture
hysts HF Staff
Fix type annotation
5bdb8db
#!/usr/bin/env python
import os
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import VitMatteForImageMatting, VitMatteImageProcessor
DESCRIPTION = """\
# [ViTMatte](https://github.com/hustvl/ViTMatte)
This is a demo of [ViTMatte](https://github.com/hustvl/ViTMatte), an image matting method that uses Vision Transformers (ViT) to accurately extract the foreground from an image.
It predicts a soft alpha matte to help separate the subject from the background — even tricky areas like hair and fur!
You've got two ways to get started:
### 🖼️ Option 1: Upload Image & Trimap
- Upload your original image.
- Upload a **trimap**: a helper image that labels regions as **foreground (white)**, **background (black)**, and **unknown (gray)**.
- The trimap must be a **grayscale image** containing only three pixel values:
- `0` for **background**
- `128` for **unknown**
- `255` for **foreground**
- The model will use this trimap to generate the alpha matte and extract the foreground.
### ✏️ Option 2: Draw Your Own Trimap
- Upload just your image.
- Go to the **"Draw Trimap"** tab to start drawing masks.
- Use the tools to mark:
- **Foreground** (e.g. the subject),
- **Unknown** (areas where the boundary is unclear).
- Once you're done, click the **"Generate Trimap"** button to generate the trimap from your drawing.
### ✨ Optional: Replace Background
Want to swap the background? Just check the **"Replace Background"** option and choose a new background image.
The app will blend your extracted subject with the new background seamlessly!
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500"))
MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646")
processor = VitMatteImageProcessor.from_pretrained(MODEL_ID)
model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device)
def resize_input_image(image: PIL.Image.Image | None) -> PIL.Image.Image:
if image is None:
return None
if max(image.size) > MAX_IMAGE_SIZE:
w, h = image.size
scale = MAX_IMAGE_SIZE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
gr.Info(
f"The uploaded image exceeded the maximum resolution limit of {MAX_IMAGE_SIZE}px. It has been resized to {new_w}x{new_h}."
)
return image.resize((new_w, new_h))
return image
def binarize_mask(mask: np.ndarray) -> np.ndarray:
mask[mask > 0] = 1
return mask
def update_trimap(foreground_mask_editor: dict, unknown_mask_editor: dict) -> np.ndarray:
foreground = foreground_mask_editor["layers"][0]
foreground = binarize_mask(foreground)
unknown = unknown_mask_editor["layers"][0]
unknown = binarize_mask(unknown)
trimap = np.zeros_like(foreground)
trimap[unknown > 0] = 128
trimap[foreground > 0] = 255
return trimap
def adjust_background_image(background_image: PIL.Image.Image, target_size: tuple[int, int]) -> PIL.Image.Image:
target_w, target_h = target_size
bg_w, bg_h = background_image.size
scale = max(target_w / bg_w, target_h / bg_h)
new_bg_w = int(bg_w * scale)
new_bg_h = int(bg_h * scale)
background_image = background_image.resize((new_bg_w, new_bg_h))
left = (new_bg_w - target_w) // 2
top = (new_bg_h - target_h) // 2
right = left + target_w
bottom = top + target_h
return background_image.crop((left, top, right, bottom))
def replace_background(
image: PIL.Image.Image, alpha: np.ndarray, background_image: PIL.Image.Image | None
) -> PIL.Image.Image | None:
if background_image is None:
return None
if image.mode != "RGB":
raise gr.Error("Image must be RGB.")
background_image = background_image.convert("RGB")
background_image = adjust_background_image(background_image, image.size)
image = np.array(image).astype(float) / 255
background_image = np.array(background_image).astype(float) / 255
result = image * alpha[:, :, None] + background_image * (1 - alpha[:, :, None])
return (result * 255).astype(np.uint8)
@spaces.GPU
@torch.inference_mode()
def run(
image: PIL.Image.Image,
trimap: PIL.Image.Image,
apply_background_replacement: bool,
background_image: PIL.Image.Image | None,
) -> tuple[np.ndarray, PIL.Image.Image, PIL.Image.Image | None]:
if image.size != trimap.size:
raise gr.Error("Image and trimap must have the same size.")
if max(image.size) > MAX_IMAGE_SIZE:
error_message = f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels."
raise gr.Error(error_message)
if image.mode != "RGB":
raise gr.Error("Image must be RGB.")
if trimap.mode != "L":
raise gr.Error("Trimap must be grayscale.")
pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values
out = model(pixel_values=pixel_values)
alpha = out.alphas[0, 0].to("cpu").numpy()
w, h = image.size
alpha = alpha[:h, :w]
foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None])
foreground = (foreground * 255).astype(np.uint8)
foreground = PIL.Image.fromarray(foreground)
res_bg_replacement = replace_background(image, alpha, background_image) if apply_background_replacement else None
return alpha, foreground, res_bg_replacement
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
with gr.Group():
image = gr.Image(label="Input image", type="pil")
with gr.Tabs():
with gr.Tab(label="Trimap"):
trimap = gr.Image(label="Trimap", type="pil", image_mode="L")
with gr.Tab(label="Draw trimap"):
foreground_mask = gr.ImageEditor(
label="Foreground",
type="numpy",
sources=("upload",),
transforms=(),
image_mode="L",
height=500,
brush=gr.Brush(default_color=("#00ff00", 0.6)),
layers=gr.LayerOptions(allow_additional_layers=False, layers=["Foreground mask"]),
)
unknown_mask = gr.ImageEditor(
label="Unknown",
type="numpy",
sources=("upload",),
transforms=(),
image_mode="L",
height=500,
brush=gr.Brush(default_color=("#00ff00", 0.6)),
layers=gr.LayerOptions(allow_additional_layers=False, layers=["Unknown mask"]),
)
generate_trimap_button = gr.Button("Generate trimap")
apply_background_replacement = gr.Checkbox(label="Replace background", value=False)
background_image = gr.Image(label="Background image", type="pil", visible=False)
run_button = gr.Button("Run")
with gr.Column():
with gr.Group():
out_alpha = gr.Image(label="Alpha")
out_foreground = gr.Image(label="Foreground")
out_background_replacement = gr.Image(label="Background replacement", visible=False)
inputs = [
image,
trimap,
apply_background_replacement,
background_image,
]
outputs = [
out_alpha,
out_foreground,
out_background_replacement,
]
gr.Examples(
examples=[
["assets/retriever_rgb.png", "assets/retriever_trimap.png", False, None],
["assets/bulb_rgb.png", "assets/bulb_trimap.png", True, "assets/new_bg.jpg"],
],
inputs=inputs,
outputs=outputs,
fn=run,
cache_examples=False,
)
image.input(
fn=resize_input_image,
inputs=image,
outputs=image,
api_name=False,
).then(
fn=lambda image: (image, image),
inputs=image,
outputs=[foreground_mask, unknown_mask],
api_name=False,
)
generate_trimap_button.click(
fn=update_trimap,
inputs=[foreground_mask, unknown_mask],
outputs=trimap,
api_name=False,
)
apply_background_replacement.change(
fn=lambda checked: (gr.Image(visible=checked), gr.Image(visible=checked)),
inputs=apply_background_replacement,
outputs=[background_image, out_background_replacement],
api_name=False,
)
run_button.click(
fn=run,
inputs=inputs,
outputs=outputs,
)
if __name__ == "__main__":
demo.launch()