|
import gradio as gr |
|
import spaces |
|
import torch |
|
from loadimg import load_img |
|
from torchvision import transforms |
|
from transformers import AutoModelForImageSegmentation |
|
from diffusers import FluxFillPipeline |
|
from PIL import Image, ImageOps |
|
|
|
torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
"ZhengPeng7/BiRefNet", trust_remote_code=True |
|
) |
|
birefnet.to("cuda") |
|
|
|
transform_image = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
pipe = FluxFillPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
def prepare_image_and_mask( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
): |
|
image = load_img(image).convert("RGB") |
|
|
|
background = ImageOps.expand( |
|
image, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
mask = Image.new("RGB", image.size, "black") |
|
mask = ImageOps.expand( |
|
mask, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
return background, mask |
|
|
|
|
|
def inpaint( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
prompt="", |
|
): |
|
background, mask = prepare_image_and_mask( |
|
image, padding_top, padding_bottom, padding_left, padding_right |
|
) |
|
|
|
result = pipe( |
|
prompt=prompt, |
|
height=background.height, |
|
width=background.width, |
|
image=background, |
|
mask_image=mask, |
|
num_inference_steps=28, |
|
guidance_scale=30, |
|
).images[0] |
|
|
|
result = result.convert("RGBA") |
|
|
|
return result |
|
|
|
|
|
def rmbg(image, url): |
|
if image is None: |
|
image = url |
|
image = load_img(image).convert("RGB") |
|
image_size = image.size |
|
input_images = transform_image(image).unsqueeze(0).to("cuda") |
|
|
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image_size) |
|
image.putalpha(mask) |
|
return image |
|
|
|
|
|
@spaces.GPU |
|
def main(*args, progress=gr.Progress(track_tqdm=True)): |
|
if len(args) == 2: |
|
return rmbg(*args) |
|
else: |
|
return inpaint(*args) |
|
|
|
|
|
rmbg_tab = gr.Interface( |
|
fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg" |
|
) |
|
|
|
outpaint_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
"image", |
|
gr.Number(label="padding top"), |
|
gr.Number(label="padding bottom"), |
|
gr.Number(label="padding left"), |
|
gr.Number(label="padding right"), |
|
gr.Text(label="prompt"), |
|
], |
|
outputs=["image"], |
|
api_name="outpainting", |
|
) |
|
|
|
demo = gr.TabbedInterface( |
|
[rmbg_tab, outpaint_tab], |
|
["remove background", "outpainting"], |
|
title="Utilities that require GPU", |
|
) |
|
|
|
|
|
demo.launch() |
|
|