ysr commited on
Commit
6a9128d
1 Parent(s): e0661a9

inital commit

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py CHANGED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as T
5
+ from torchvision.utils import draw_segmentation_masks, draw_bounding_boxes
6
+ import random
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+
11
+ def random_color_gen(n):
12
+ return [tuple(random.randint(0,255) for i in range(3)) for i in range(n)]
13
+
14
+ import math
15
+
16
+ def get_gaussian_kernel(kernel_size=15, sigma=20, channels=3):
17
+ # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
18
+ x_coord = torch.arange(kernel_size)
19
+ x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
20
+ y_grid = x_grid.t()
21
+ xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
22
+
23
+ mean = (kernel_size - 1)/2.
24
+ variance = sigma**2.
25
+
26
+ # Calculate the 2-dimensional gaussian kernel which is
27
+ # the product of two gaussian distributions for two different
28
+ # variables (in this case called x and y)
29
+ gaussian_kernel = (1./(2.*math.pi*variance)) *\
30
+ torch.exp(
31
+ -torch.sum((xy_grid - mean)**2., dim=-1) /\
32
+ (2*variance)
33
+ )
34
+
35
+ # Make sure sum of values in gaussian kernel equals 1.
36
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
37
+
38
+ # Reshape to 2d depthwise convolutional weight
39
+ gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
40
+ gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
41
+
42
+ gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
43
+ kernel_size=kernel_size, padding='same', groups=channels, bias=False)
44
+
45
+ gaussian_filter.weight.data = gaussian_kernel
46
+ gaussian_filter.weight.requires_grad = False
47
+
48
+ return gaussian_filter
49
+
50
+
51
+ output_dict = {} # this dict is shared between segment and blur_background functions
52
+ pred_label_unq = []
53
+
54
+ def segment(input_image):
55
+
56
+ # prepare image for display
57
+ display_img = torch.tensor(np.asarray(input_image)).unsqueeze(0)
58
+ display_img = display_img.permute(0, 3, 1, 2).squeeze(0)
59
+
60
+ # Prepare the RCNN model
61
+ weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
62
+ transforms = weights.transforms()
63
+ model = maskrcnn_resnet50_fpn_v2(weights=weights)
64
+ model = model.eval();
65
+
66
+ # Prepare the input image
67
+ input_tensor = transforms(input_image).unsqueeze(0)
68
+
69
+ # Get the predictions
70
+ output = model(input_tensor)[0] # idx 0 to get the first dictionary of the returned list
71
+
72
+
73
+ # Filter by threshold
74
+ score_threshold = 0.75
75
+ mask_threshold = 0.5
76
+ masks = output['masks'][output['scores'] > score_threshold] > mask_threshold;
77
+ boxes = output['boxes'][output['scores'] > score_threshold]
78
+ masks = masks.squeeze(1)
79
+ boxes = boxes.squeeze(1)
80
+
81
+ pred_labels = [weights.meta["categories"][label] for label in output['labels'][output['scores'] > score_threshold]]
82
+ n_pred = len(pred_labels)
83
+
84
+ # give unique id to all the predicitons
85
+ pred_label_unq = [pred_labels[i] + str(pred_labels[:i].count(pred_labels[i]) + 1) for i in range(n_pred)]
86
+
87
+ colors = random_color_gen(n_pred)
88
+
89
+ # Prepare output_dict
90
+ for i in range(n_pred):
91
+ output_dict[pred_label_unq[i]] = {'mask': masks[i].tolist(), 'color': colors[i]}
92
+
93
+
94
+ masked_img = draw_segmentation_masks(display_img, masks, alpha=0.9, colors=colors)
95
+ bounding_box_img = draw_bounding_boxes(masked_img, boxes, labels=pred_label_unq, colors='white')
96
+ masked_img = T.ToPILImage()(masked_img)
97
+ bounding_box_img = T.ToPILImage()(bounding_box_img)
98
+
99
+ return bounding_box_img;
100
+
101
+ def blur_background(input_image, label_name):
102
+ mask = output_dict[label_name]['mask']
103
+ mask = torch.tensor(mask).unsqueeze(0)
104
+
105
+ input_tensor = T.ToTensor()(input_image).unsqueeze(0)
106
+ blur = get_gaussian_kernel()
107
+ blurred_tensor = blur(input_tensor)
108
+
109
+ final_img = blurred_tensor
110
+ final_img[:, :, mask.squeeze(0)] = input_tensor[:, :, mask.squeeze(0)];
111
+
112
+ final_img = T.ToPILImage()(final_img.squeeze(0))
113
+
114
+ return final_img;
115
+
116
+
117
+
118
+
119
+
120
+
121
+ with gr.Blocks() as app:
122
+
123
+ gr.Markdown("# Blur an objects background with AI")
124
+
125
+ gr.Markdown("First segment the image and create bounding boxes")
126
+ with gr.Column():
127
+ input_image = gr.Image(type='pil')
128
+ b1 = gr.Button("Segment Image")
129
+
130
+
131
+
132
+ with gr.Row():
133
+ # masked_image = gr.Image();
134
+ bounding_box_image = gr.Image();
135
+
136
+
137
+ gr.Markdown("Now choose a label (eg: person1) from the above image of your desired object and input it below")
138
+ with gr.Column():
139
+ label_name = gr.Textbox()
140
+ b2 = gr.Button("Blur Backbround")
141
+ result = gr.Image()
142
+
143
+ b1.click(segment, inputs=input_image, outputs=bounding_box_image)
144
+ b2.click(blur_background, inputs=[input_image, label_name], outputs=result)
145
+
146
+
147
+
148
+
149
+ # instance_segmentation = gr.Interface(segment, inputs=input_image, outputs=['json', 'image'])
150
+
151
+ app.launch(debug=True)