OriLib commited on
Commit
542c815
1 Parent(s): d505a8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ from models import BriaRMBG
5
+ import gradio as gr
6
+ import git # pip install gitpython
7
+
8
+ net=BriaRMBG()
9
+ model_path = "./model.pth"
10
+ git.Git(".").clone("https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4")
11
+
12
+ if torch.cuda.is_available():
13
+ net.load_state_dict(torch.load(model_path))
14
+ net=net.cuda()
15
+ else:
16
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
17
+ net.eval()
18
+
19
+ def image_size_by_min_resolution(
20
+ image: Image.Image,
21
+ resolution: Tuple,
22
+ resample=None,
23
+ ):
24
+ w, h = image.size
25
+
26
+ image_min = min(w, h)
27
+ resolution_min = min(resolution)
28
+
29
+ scale_factor = image_min / resolution_min
30
+
31
+ resize_to: Tuple[int, int] = (
32
+ int(w // scale_factor),
33
+ int(h // scale_factor),
34
+ )
35
+ return resize_to
36
+
37
+
38
+ def resize_image(image):
39
+ image = image.convert('RGB')
40
+ new_image_size = image_size_by_min_resolution(image=image,resolution=(1024, 1024))
41
+ image = image.resize(new_image_size, Image.BILINEAR)
42
+ return image
43
+
44
+
45
+ def process(input_image):
46
+
47
+ # prepare input
48
+ orig_image = Image.open(im_path)
49
+ w,h = orig_im_size = orig_image.size
50
+ image = resize_image(orig_image)
51
+ im_np = np.array(image)
52
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
53
+ im_tensor = torch.unsqueeze(im_tensor,0)
54
+ im_tensor = torch.divide(im_tensor,255.0)
55
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
56
+ if torch.cuda.is_available():
57
+ im_tensor=im_tensor.cuda()
58
+
59
+ #inference
60
+ result=net(im_tensor)
61
+
62
+ # post process
63
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
64
+ ma = torch.max(result)
65
+ mi = torch.min(result)
66
+ result = (result-mi)/(ma-mi)
67
+
68
+ # save result
69
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
70
+ pil_im = Image.fromarray(np.squeeze(im_array))
71
+ # paste the mask on the original image
72
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0))
73
+ new_im.paste(orig_image, mask=pil_im)
74
+
75
+ return new_im
76
+
77
+
78
+ block = gr.Blocks().queue()
79
+
80
+ with block:
81
+ gr.Markdown("## BRIA RMBG 1.4")
82
+ gr.HTML('''
83
+ <p style="margin-bottom: 10px; font-size: 94%">
84
+ This is a demo for BRIA RMBG 1.4 that using
85
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
86
+ </p>
87
+ ''')
88
+ with gr.Row():
89
+ with gr.Column():
90
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
91
+ input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
92
+ run_button = gr.Button(value="Run")
93
+
94
+ with gr.Column():
95
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
96
+ ips = [input_image]
97
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
98
+
99
+ block.launch(debug = True)