Spaces:
Running
Running
File size: 2,635 Bytes
8ce07a1 b71c9c4 a505bd1 797deb4 b71c9c4 797deb4 b71c9c4 2b13f4e b71c9c4 2b13f4e b71c9c4 2b13f4e 797deb4 a505bd1 b71c9c4 797deb4 a505bd1 3145448 797deb4 a505bd1 797deb4 a505bd1 b71c9c4 797deb4 a505bd1 797deb4 a505bd1 797deb4 a505bd1 797deb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
# GPU ์ค์ ์ CPU๋ก ๋ณ๊ฒฝ
# GPU ์ค์ ์ ์ญ์ ํ๊ฑฐ๋ "cuda"๋ฅผ "cpu"๋ก ๋ณ๊ฒฝ
# torch.set_float32_matmul_precision("high")๋ CPU์์ ํ์ ์์.
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu") # GPU -> CPU๋ก ๋ณ๊ฒฝ
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
return (processed_image, origin)
# @spaces.GPU ๋ฐ์ฝ๋ ์ดํฐ ์ ๊ฑฐ
# CPU ํ๊ฒฝ์์ ๋์ํ๋๋ก ์ค์
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU๋ก ๋ณ๊ฒฝ
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)
if __name__ == "__main__":
demo.launch(show_error=True) |