dawood HF staff commited on
Commit
9328b97
1 Parent(s): 9bd1184

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torchvision import transforms
6
+
7
+ description = "Automatically remove the image background from a profile photo. Based on a [Space by eugenesiow](https://huggingface.co/spaces/eugenesiow/remove-bg)."
8
+
9
+
10
+ def make_transparent_foreground(pic, mask):
11
+ # split the image into channels
12
+ b, g, r = cv2.split(np.array(pic).astype('uint8'))
13
+ # add an alpha channel with and fill all with transparent pixels (max 255)
14
+ a = np.ones(mask.shape, dtype='uint8') * 255
15
+ # merge the alpha channel back
16
+ alpha_im = cv2.merge([b, g, r, a], 4)
17
+ # create a transparent background
18
+ bg = np.zeros(alpha_im.shape)
19
+ # setup the new mask
20
+ new_mask = np.stack([mask, mask, mask, mask], axis=2)
21
+ # copy only the foreground color pixels from the original image where mask is set
22
+ foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
23
+
24
+ return foreground
25
+
26
+
27
+ def remove_background(input_image):
28
+ preprocess = transforms.Compose([
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
+ ])
32
+
33
+ input_tensor = preprocess(input_image)
34
+ input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
35
+
36
+ # move the input and model to GPU for speed if available
37
+ if torch.cuda.is_available():
38
+ input_batch = input_batch.to('cuda')
39
+ model.to('cuda')
40
+
41
+ with torch.no_grad():
42
+ output = model(input_batch)['out'][0]
43
+ output_predictions = output.argmax(0)
44
+
45
+ # create a binary (black and white) mask of the profile foreground
46
+ mask = output_predictions.byte().cpu().numpy()
47
+ background = np.zeros(mask.shape)
48
+ bin_mask = np.where(mask, 255, background).astype(np.uint8)
49
+
50
+ foreground = make_transparent_foreground(input_image, bin_mask)
51
+
52
+ return foreground, bin_mask
53
+
54
+
55
+ def inference(img):
56
+ foreground, _ = remove_background(img)
57
+ return foreground
58
+
59
+
60
+ torch.hub.download_url_to_file('https://pbs.twimg.com/profile_images/691700243809718272/z7XZUARB_400x400.jpg',
61
+ 'demis.jpg')
62
+ torch.hub.download_url_to_file('https://hai.stanford.edu/sites/default/files/styles/person_medium/public/2020-03/hai_1512feifei.png?itok=INFuLABp',
63
+ 'lifeifei.png')
64
+ model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
65
+ model.eval()
66
+
67
+ gr.Interface(
68
+ inference,
69
+ gr.Image(type="pil", label="Input"),
70
+ gr.Image(type="pil", label="Output"),
71
+ description=description,
72
+ examples=[['demis.jpg'], ['lifeifei.png']],
73
+ enable_queue=True,
74
+ css=".footer{display:none !important}"
75
+ ).launch(debug=False)