File size: 2,635 Bytes
8ce07a1
 
 
 
 
 
 
 
b71c9c4
 
 
 
 
 
 
 
 
a505bd1
 
 
 
 
 
 
 
797deb4
 
 
 
b71c9c4
797deb4
 
b71c9c4
 
 
 
2b13f4e
b71c9c4
 
2b13f4e
b71c9c4
2b13f4e
 
 
 
 
 
797deb4
 
 
a505bd1
b71c9c4
797deb4
 
a505bd1
3145448
797deb4
a505bd1
797deb4
 
 
a505bd1
b71c9c4
797deb4
 
a505bd1
797deb4
 
 
a505bd1
797deb4
 
a505bd1
 
 
797deb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

# GPU ์„ค์ •์„ CPU๋กœ ๋ณ€๊ฒฝ
# GPU ์„ค์ •์„ ์‚ญ์ œํ•˜๊ฑฐ๋‚˜ "cuda"๋ฅผ "cpu"๋กœ ๋ณ€๊ฒฝ
# torch.set_float32_matmul_precision("high")๋Š” CPU์—์„  ํ•„์š” ์—†์Œ.

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")  # GPU -> CPU๋กœ ๋ณ€๊ฒฝ

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    processed_image = process(im)
    return (processed_image, origin)

# @spaces.GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์ œ๊ฑฐ
# CPU ํ™˜๊ฒฝ์—์„œ ๋™์ž‘ํ•˜๋„๋ก ์„ค์ •

def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cpu")  # GPU -> CPU๋กœ ๋ณ€๊ฒฝ
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

def process_file(f):
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")

# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")

demo = gr.TabbedInterface(
    [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)

if __name__ == "__main__":
    demo.launch(show_error=True)