ysr commited on
Commit
925ff9f
1 Parent(s): cf1ad32

add blur object button and multiple object support

Browse files
Files changed (1) hide show
  1. app.py +45 -58
app.py CHANGED
@@ -7,50 +7,13 @@ import random
7
  import gradio as gr
8
  import numpy as np
9
 
 
 
 
10
 
11
  def random_color_gen(n):
12
  return [tuple(random.randint(0,255) for i in range(3)) for i in range(n)]
13
 
14
- import math
15
-
16
- def get_gaussian_kernel(kernel_size=15, sigma=20, channels=3):
17
- # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
18
- x_coord = torch.arange(kernel_size)
19
- x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
20
- y_grid = x_grid.t()
21
- xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
22
-
23
- mean = (kernel_size - 1)/2.
24
- variance = sigma**2.
25
-
26
- # Calculate the 2-dimensional gaussian kernel which is
27
- # the product of two gaussian distributions for two different
28
- # variables (in this case called x and y)
29
- gaussian_kernel = (1./(2.*math.pi*variance)) *\
30
- torch.exp(
31
- -torch.sum((xy_grid - mean)**2., dim=-1) /\
32
- (2*variance)
33
- )
34
-
35
- # Make sure sum of values in gaussian kernel equals 1.
36
- gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
37
-
38
- # Reshape to 2d depthwise convolutional weight
39
- gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
40
- gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
41
-
42
- gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
43
- kernel_size=kernel_size, padding='same', groups=channels, bias=False)
44
-
45
- gaussian_filter.weight.data = gaussian_kernel
46
- gaussian_filter.weight.requires_grad = False
47
-
48
- return gaussian_filter
49
-
50
-
51
- output_dict = {} # this dict is shared between segment and blur_background functions
52
- pred_label_unq = []
53
-
54
  def segment(input_image):
55
 
56
  # prepare image for display
@@ -98,26 +61,54 @@ def segment(input_image):
98
 
99
  return bounding_box_img;
100
 
101
- def blur_background(input_image, label_name):
102
- mask = output_dict[label_name]['mask']
103
- mask = torch.tensor(mask).unsqueeze(0)
 
104
 
105
  input_tensor = T.ToTensor()(input_image).unsqueeze(0)
106
- blur = get_gaussian_kernel()
107
  blurred_tensor = blur(input_tensor)
 
 
 
 
 
 
108
 
109
- final_img = blurred_tensor
110
- final_img[:, :, mask.squeeze(0)] = input_tensor[:, :, mask.squeeze(0)];
111
 
112
  final_img = T.ToPILImage()(final_img.squeeze(0))
113
 
114
  return final_img;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
116
 
 
117
 
118
-
119
 
120
 
 
 
 
 
 
 
121
  with gr.Blocks() as app:
122
 
123
  gr.Markdown("# Blur an objects background with AI")
@@ -126,26 +117,22 @@ with gr.Blocks() as app:
126
  with gr.Column():
127
  input_image = gr.Image(type='pil')
128
  b1 = gr.Button("Segment Image")
129
-
130
-
131
-
132
  with gr.Row():
133
- # masked_image = gr.Image();
134
  bounding_box_image = gr.Image();
135
-
136
-
137
  gr.Markdown("Now choose a label (eg: person1) from the above image of your desired object and input it below")
 
138
  with gr.Column():
139
  label_name = gr.Textbox()
140
- b2 = gr.Button("Blur Backbround")
 
 
141
  result = gr.Image()
142
 
143
  b1.click(segment, inputs=input_image, outputs=bounding_box_image)
144
  b2.click(blur_background, inputs=[input_image, label_name], outputs=result)
 
145
 
146
-
147
-
148
-
149
- # instance_segmentation = gr.Interface(segment, inputs=input_image, outputs=['json', 'image'])
150
 
151
  app.launch(debug=True)
 
7
  import gradio as gr
8
  import numpy as np
9
 
10
+ output_dict = {} # this dict is shared between segment and blur_background functions
11
+ pred_label_unq = []
12
+
13
 
14
  def random_color_gen(n):
15
  return [tuple(random.randint(0,255) for i in range(3)) for i in range(n)]
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def segment(input_image):
18
 
19
  # prepare image for display
 
61
 
62
  return bounding_box_img;
63
 
64
+
65
+ def blur_object(input_image, label_name):
66
+
67
+ label_names = label_name.split(' ')
68
 
69
  input_tensor = T.ToTensor()(input_image).unsqueeze(0)
70
+ blur = T.GaussianBlur(15, 20)
71
  blurred_tensor = blur(input_tensor)
72
+
73
+ final_img = input_tensor
74
+
75
+ for name in label_names:
76
+ mask = output_dict[name.strip()]['mask']
77
+ mask = torch.tensor(mask).unsqueeze(0)
78
 
79
+ final_img[:, :, mask.squeeze(0)] = blurred_tensor[:, :, mask.squeeze(0)];
 
80
 
81
  final_img = T.ToPILImage()(final_img.squeeze(0))
82
 
83
  return final_img;
84
+
85
+ def blur_background(input_image, label_name):
86
+ label_names = label_name.split(' ')
87
+
88
+ input_tensor = T.ToTensor()(input_image).unsqueeze(0)
89
+ blur = T.GaussianBlur(15, 20)
90
+ blurred_tensor = blur(input_tensor)
91
+
92
+ final_img = blurred_tensor
93
+
94
+
95
+ for name in label_names:
96
+ mask = output_dict[name.strip()]['mask']
97
+ mask = torch.tensor(mask).unsqueeze(0)
98
 
99
+ final_img[:, :, mask.squeeze(0)] = input_tensor[:, :, mask.squeeze(0)];
100
 
101
+ final_img = T.ToPILImage()(final_img.squeeze(0))
102
 
103
+ return final_img;
104
 
105
 
106
+
107
+
108
+ ############################
109
+ """ User Interface """
110
+ ############################
111
+
112
  with gr.Blocks() as app:
113
 
114
  gr.Markdown("# Blur an objects background with AI")
 
117
  with gr.Column():
118
  input_image = gr.Image(type='pil')
119
  b1 = gr.Button("Segment Image")
120
+
 
 
121
  with gr.Row():
 
122
  bounding_box_image = gr.Image();
123
+
 
124
  gr.Markdown("Now choose a label (eg: person1) from the above image of your desired object and input it below")
125
+ gr.Markdown("You can also input multiple labels separated by spaces (eg: person1 car1 handbag1)")
126
  with gr.Column():
127
  label_name = gr.Textbox()
128
+ with gr.Row():
129
+ b2 = gr.Button("Blur Backbround")
130
+ b3 = gr.Button("Blur Object")
131
  result = gr.Image()
132
 
133
  b1.click(segment, inputs=input_image, outputs=bounding_box_image)
134
  b2.click(blur_background, inputs=[input_image, label_name], outputs=result)
135
+ b3.click(blur_object, inputs=[input_image, label_name], outputs=result)
136
 
 
 
 
 
137
 
138
  app.launch(debug=True)