File size: 2,552 Bytes
5de8b22
a3f48ee
8939ec2
a3f48ee
 
 
 
 
 
 
 
 
9281027
a3f48ee
 
 
 
 
 
 
 
 
 
e4a0378
 
 
c2df784
8939ec2
 
 
 
 
e4a0378
8939ec2
 
 
a3f48ee
 
 
 
 
 
 
 
 
 
8939ec2
 
 
 
e4a0378
8939ec2
a3f48ee
8939ec2
 
 
2ccc78e
8939ec2
 
2ccc78e
e4a0378
8939ec2
 
 
e4a0378
8939ec2
 
 
a973e8e
a3f48ee
8939ec2
a3f48ee
 
8939ec2
 
 
e4a0378
a3f48ee
8939ec2
a3f48ee
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)    
    image_path = os.path.join(output_folder, "no_bg_image.png")
    image.save(image_path)
    return (image, origin), image_path

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image
  
def process_file(f):
    name_path = f.rsplit(".",1)[0]+".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")


chameleon = load_img("giraffe.jpg", output_type="pil")

url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"

tab1 = gr.Interface(
    fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)

tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")


demo = gr.TabbedInterface(
    [tab1, tab2], ["input image", "input url"], title="RMBG-2.0 for background removal"
)

if __name__ == "__main__":
    demo.launch(show_error=True)