Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,552 Bytes
5de8b22 a3f48ee 8939ec2 a3f48ee 9281027 a3f48ee e4a0378 c2df784 8939ec2 e4a0378 8939ec2 a3f48ee 8939ec2 e4a0378 8939ec2 a3f48ee 8939ec2 2ccc78e 8939ec2 2ccc78e e4a0378 8939ec2 e4a0378 8939ec2 a973e8e a3f48ee 8939ec2 a3f48ee 8939ec2 e4a0378 a3f48ee 8939ec2 a3f48ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
output_folder = 'output_images'
if not os.path.exists(output_folder):
os.makedirs(output_folder)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
image_path = os.path.join(output_folder, "no_bg_image.png")
image.save(image_path)
return (image, origin), image_path
@spaces.GPU
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".",1)[0]+".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")
chameleon = load_img("giraffe.jpg", output_type="pil")
url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2], ["input image", "input url"], title="RMBG-2.0 for background removal"
)
if __name__ == "__main__":
demo.launch(show_error=True) |