File size: 4,470 Bytes
6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d 925ff9f 6a9128d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import draw_segmentation_masks, draw_bounding_boxes
import random
import gradio as gr
import numpy as np
output_dict = {} # this dict is shared between segment and blur_background functions
pred_label_unq = []
def random_color_gen(n):
return [tuple(random.randint(0,255) for i in range(3)) for i in range(n)]
def segment(input_image):
# prepare image for display
display_img = torch.tensor(np.asarray(input_image)).unsqueeze(0)
display_img = display_img.permute(0, 3, 1, 2).squeeze(0)
# Prepare the RCNN model
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
transforms = weights.transforms()
model = maskrcnn_resnet50_fpn_v2(weights=weights)
model = model.eval();
# Prepare the input image
input_tensor = transforms(input_image).unsqueeze(0)
# Get the predictions
output = model(input_tensor)[0] # idx 0 to get the first dictionary of the returned list
# Filter by threshold
score_threshold = 0.75
mask_threshold = 0.5
masks = output['masks'][output['scores'] > score_threshold] > mask_threshold;
boxes = output['boxes'][output['scores'] > score_threshold]
masks = masks.squeeze(1)
boxes = boxes.squeeze(1)
pred_labels = [weights.meta["categories"][label] for label in output['labels'][output['scores'] > score_threshold]]
n_pred = len(pred_labels)
# give unique id to all the predicitons
pred_label_unq = [pred_labels[i] + str(pred_labels[:i].count(pred_labels[i]) + 1) for i in range(n_pred)]
colors = random_color_gen(n_pred)
# Prepare output_dict
for i in range(n_pred):
output_dict[pred_label_unq[i]] = {'mask': masks[i].tolist(), 'color': colors[i]}
masked_img = draw_segmentation_masks(display_img, masks, alpha=0.9, colors=colors)
bounding_box_img = draw_bounding_boxes(masked_img, boxes, labels=pred_label_unq, colors='white')
masked_img = T.ToPILImage()(masked_img)
bounding_box_img = T.ToPILImage()(bounding_box_img)
return bounding_box_img;
def blur_object(input_image, label_name):
label_names = label_name.split(' ')
input_tensor = T.ToTensor()(input_image).unsqueeze(0)
blur = T.GaussianBlur(15, 20)
blurred_tensor = blur(input_tensor)
final_img = input_tensor
for name in label_names:
mask = output_dict[name.strip()]['mask']
mask = torch.tensor(mask).unsqueeze(0)
final_img[:, :, mask.squeeze(0)] = blurred_tensor[:, :, mask.squeeze(0)];
final_img = T.ToPILImage()(final_img.squeeze(0))
return final_img;
def blur_background(input_image, label_name):
label_names = label_name.split(' ')
input_tensor = T.ToTensor()(input_image).unsqueeze(0)
blur = T.GaussianBlur(15, 20)
blurred_tensor = blur(input_tensor)
final_img = blurred_tensor
for name in label_names:
mask = output_dict[name.strip()]['mask']
mask = torch.tensor(mask).unsqueeze(0)
final_img[:, :, mask.squeeze(0)] = input_tensor[:, :, mask.squeeze(0)];
final_img = T.ToPILImage()(final_img.squeeze(0))
return final_img;
############################
""" User Interface """
############################
with gr.Blocks() as app:
gr.Markdown("# Blur an objects background with AI")
gr.Markdown("First segment the image and create bounding boxes")
with gr.Column():
input_image = gr.Image(type='pil')
b1 = gr.Button("Segment Image")
with gr.Row():
bounding_box_image = gr.Image();
gr.Markdown("Now choose a label (eg: person1) from the above image of your desired object and input it below")
gr.Markdown("You can also input multiple labels separated by spaces (eg: person1 car1 handbag1)")
with gr.Column():
label_name = gr.Textbox()
with gr.Row():
b2 = gr.Button("Blur Backbround")
b3 = gr.Button("Blur Object")
result = gr.Image()
b1.click(segment, inputs=input_image, outputs=bounding_box_image)
b2.click(blur_background, inputs=[input_image, label_name], outputs=result)
b3.click(blur_object, inputs=[input_image, label_name], outputs=result)
app.launch(debug=True) |