arxivgpt kim
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -4,103 +4,116 @@ import torch.nn.functional as F
|
|
4 |
from torchvision.transforms.functional import normalize
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import gradio as gr
|
7 |
-
from gradio_imageslider import ImageSlider
|
8 |
from briarmbg import BriaRMBG
|
9 |
import PIL
|
10 |
from PIL import Image
|
11 |
-
from typing import Tuple
|
12 |
|
13 |
-
net=BriaRMBG()
|
14 |
-
# model_path = "./model1.pth"
|
15 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
16 |
if torch.cuda.is_available():
|
17 |
net.load_state_dict(torch.load(model_path))
|
18 |
-
net=net.cuda()
|
19 |
else:
|
20 |
-
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
21 |
-
net.eval()
|
22 |
|
23 |
-
|
24 |
def resize_image(image):
|
25 |
image = image.convert('RGB')
|
26 |
model_input_size = (1024, 1024)
|
27 |
image = image.resize(model_input_size, Image.BILINEAR)
|
28 |
return image
|
29 |
|
30 |
-
|
31 |
def process(image):
|
32 |
-
|
33 |
-
# prepare input
|
34 |
orig_image = Image.fromarray(image)
|
35 |
-
w,h =
|
36 |
image = resize_image(orig_image)
|
37 |
im_np = np.array(image)
|
38 |
-
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
|
39 |
-
im_tensor = torch.unsqueeze(im_tensor,0)
|
40 |
-
im_tensor = torch.divide(im_tensor,255.0)
|
41 |
-
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
42 |
if torch.cuda.is_available():
|
43 |
-
im_tensor=im_tensor.cuda()
|
44 |
|
45 |
-
|
46 |
-
result=
|
47 |
-
# post process
|
48 |
-
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
|
49 |
ma = torch.max(result)
|
50 |
mi = torch.min(result)
|
51 |
-
result = (result-mi)/(ma-mi)
|
52 |
-
|
53 |
-
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
|
54 |
pil_im = Image.fromarray(np.squeeze(im_array))
|
55 |
-
|
56 |
-
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
57 |
new_im.paste(orig_image, mask=pil_im)
|
58 |
-
# new_orig_image = orig_image.convert('RGBA')
|
59 |
|
60 |
return new_im
|
61 |
-
# return [new_orig_image, new_im]
|
62 |
-
|
63 |
-
|
64 |
-
# block = gr.Blocks().queue()
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
#
|
76 |
-
|
77 |
-
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
#
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
gr.Markdown("## BRIA RMBG 1.4")
|
90 |
-
gr.HTML('''
|
91 |
-
<p style="margin-bottom: 10px; font-size: 94%">
|
92 |
-
This is a demo for BRIA RMBG 1.4 that using
|
93 |
-
<a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
|
94 |
-
</p>
|
95 |
-
''')
|
96 |
title = "Background Removal"
|
97 |
-
description =
|
98 |
-
|
|
|
|
|
99 |
"""
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
-
demo.launch(share=False)
|
|
|
4 |
from torchvision.transforms.functional import normalize
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import gradio as gr
|
|
|
7 |
from briarmbg import BriaRMBG
|
8 |
import PIL
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
+
net = BriaRMBG()
|
|
|
12 |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
13 |
if torch.cuda.is_available():
|
14 |
net.load_state_dict(torch.load(model_path))
|
15 |
+
net = net.cuda()
|
16 |
else:
|
17 |
+
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
18 |
+
net.eval()
|
19 |
|
|
|
20 |
def resize_image(image):
|
21 |
image = image.convert('RGB')
|
22 |
model_input_size = (1024, 1024)
|
23 |
image = image.resize(model_input_size, Image.BILINEAR)
|
24 |
return image
|
25 |
|
|
|
26 |
def process(image):
|
|
|
|
|
27 |
orig_image = Image.fromarray(image)
|
28 |
+
w, h = orig_image.size
|
29 |
image = resize_image(orig_image)
|
30 |
im_np = np.array(image)
|
31 |
+
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
|
32 |
+
im_tensor = torch.unsqueeze(im_tensor, 0)
|
33 |
+
im_tensor = torch.divide(im_tensor, 255.0)
|
34 |
+
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
35 |
if torch.cuda.is_available():
|
36 |
+
im_tensor = im_tensor.cuda()
|
37 |
|
38 |
+
result = net(im_tensor)
|
39 |
+
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
|
|
|
|
|
40 |
ma = torch.max(result)
|
41 |
mi = torch.min(result)
|
42 |
+
result = (result - mi) / (ma - mi)
|
43 |
+
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
|
|
|
44 |
pil_im = Image.fromarray(np.squeeze(im_array))
|
45 |
+
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
|
|
46 |
new_im.paste(orig_image, mask=pil_im)
|
|
|
47 |
|
48 |
return new_im
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
css = """
|
51 |
+
body {
|
52 |
+
font-family: 'Arial', sans-serif;
|
53 |
+
margin: 0;
|
54 |
+
padding: 0;
|
55 |
+
background-color: #f0f2f5;
|
56 |
+
color: #333;
|
57 |
+
}
|
58 |
+
h1 {
|
59 |
+
color: #0000ff;
|
60 |
+
}
|
61 |
+
p {
|
62 |
+
color: #000000;
|
63 |
+
}
|
64 |
+
.gradio-app, .gradio-content {
|
65 |
+
background-color: #ffffff;
|
66 |
+
border-radius: 8px;
|
67 |
+
border: 1px solid #ccc;
|
68 |
+
box-shadow: 0 10px 25px 0 rgba(0,0,0,0.1);
|
69 |
+
padding: 20px;
|
70 |
+
}
|
71 |
+
button {
|
72 |
+
border: none;
|
73 |
+
color: white;
|
74 |
+
padding: 10px 20px;
|
75 |
+
margin: 10px 0;
|
76 |
+
cursor: pointer;
|
77 |
+
border-radius: 5px;
|
78 |
+
background-image: linear-gradient(to right, #6a11cb 0%, #2575fc 100%);
|
79 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);
|
80 |
+
transition: all 0.2s ease-in-out;
|
81 |
+
}
|
82 |
+
button:hover {
|
83 |
+
box-shadow: 0 6px 8px rgba(0, 0, 0, 0.3);
|
84 |
+
}
|
85 |
+
input, textarea {
|
86 |
+
border: 2px solid #2575fc;
|
87 |
+
border-radius: 4px;
|
88 |
+
padding: 10px;
|
89 |
+
margin: 10px 0;
|
90 |
+
width: 100%;
|
91 |
+
box-sizing: border-box;
|
92 |
+
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1);
|
93 |
+
}
|
94 |
+
.gradio-toolbar {
|
95 |
+
background-color: #f0f2f5;
|
96 |
+
}
|
97 |
+
footer {
|
98 |
+
visibility: hidden;
|
99 |
+
}
|
100 |
+
"""
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
title = "Background Removal"
|
103 |
+
description = """
|
104 |
+
This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as a backbone.<br>
|
105 |
+
Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
|
106 |
+
For a test, upload your image and wait. Read more at the model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
|
107 |
"""
|
108 |
+
|
109 |
+
demo = gr.Interface(
|
110 |
+
fn=process,
|
111 |
+
inputs=gr.Image(type="pil"),
|
112 |
+
outputs=gr.Image(type="pil"),
|
113 |
+
title=title,
|
114 |
+
description=description,
|
115 |
+
css=css
|
116 |
+
)
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
+
demo.launch(share=False)
|