File size: 5,522 Bytes
05a002a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc863a
05a002a
 
 
 
 
b7a783f
 
ae4c6a4
05a002a
 
 
 
 
 
 
 
 
 
 
 
 
 
8a448cc
 
 
ceb376b
05a002a
 
b377691
 
bc3cd1b
05a002a
 
bc3cd1b
05a002a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceb376b
05a002a
ceb376b
05a002a
 
 
8a448cc
b377691
05a002a
 
cef3c32
05a002a
 
 
 
 
 
afbc21e
 
05a002a
bc3cd1b
05a002a
fd7dc8f
05a002a
 
 
 
 
c920b87
05a002a
 
 
 
e1484f5
05a002a
 
 
7f84e59
cf0001f
05a002a
b377691
cf0001f
05a002a
 
 
 
 
 
2aab4d2
05a002a
 
cf0001f
 
 
 
 
05a002a
 
 
cf0001f
05a002a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import requests
from transformers import SamModel, SamProcessor
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import cv2
from typing import List

device = "cuda" if torch.cuda.is_available() else "cpu"

#Load clipseg Model
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)

# Load SAM model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

cache_data = None

# Prompts to segment damaged area and car
prompts = ['damaged', 'car']
damage_threshold = 0.3
vehicle_threshold = 0.5

def bbox_normalization(bbox, width, height):
    height_coeff = height/352
    width_coeff = width/352
    normalized_bbox = [[bbox[0]*width_coeff, bbox[1]*height_coeff],
                       [bbox[2]*width_coeff, bbox[3]*height_coeff]]
    print(f'Normalized-bbox:: {normalized_bbox}')
    return normalized_bbox

def bbox_area(bbox):
    area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
    return area

def segment_to_bbox(segment_indexs):
    x_points = []
    y_points = []
    for y, list_val in enumerate(segment_indexs):
        for x, val in enumerate(list_val):
            if val == 1:
                x_points.append(x)
                y_points.append(y)
    if x_points and y_points:
        return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
    else:
        return [0.0,0.0,0.0,0.0]

def clipseg_prediction(image):
    print('Clip-Segmentation-started------->')
    img_w, img_h,_ = image.shape
    inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
    # predict
    with torch.no_grad():
        outputs = clip_model(**inputs)
    preds = outputs.logits.unsqueeze(1)
    # Setting threshold and classify the image contains vehicle or not
    flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))

    # Initialize a dummy "unlabeled" mask with the threshold
    flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
    flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
    flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
    flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle

    # Get the top mask index for each pixel
    damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
    vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))

    # bbox creation
    damage_bbox = segment_to_bbox(damage_inds)
    vehicle_bbox = segment_to_bbox(vehicle_inds)

    # Vehicle checking
    if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
        return True, [bbox_normalization(damage_bbox, img_w, img_h)]
    else:
        return False, [[]]
     

@torch.no_grad()
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
    print('SAM-Segmentation-started------->')
    global cache_data
    image_input = Image.fromarray(image_input)
    inputs = processor(image_input, input_boxes=points, return_tensors="pt").to(device)
    if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
        embedding = model.get_image_embeddings(inputs["pixel_values"])
        pixels = inputs["pixel_values"]
        cache_data = [pixels, embedding]
    del inputs["pixel_values"]

    outputs = model.forward(image_embeddings=cache_data[1], **inputs)
    # outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].to(device), inputs["reshaped_input_sizes"].to(device)
    )
    masks = masks[0][0].squeeze(0).numpy()

    return masks

def main_func(inputs):
    
    image_input = inputs
    classification, points = clipseg_prediction(image_input)
    if classification:
        masks = foward_pass(image_input, points)
    
        # image_input = Image.fromarray(image_input)
        
        final_mask = masks[0]
        mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
        mask_colors[final_mask, :] = np.array([[256, 0, 0]])
        return 'Prediction: Vehicle damage prediction is given.',Image.fromarray((mask_colors+ image_input).astype('uint8'), 'RGB')
    else:
        print('Prediction:: No vehicle found in the image')
        return 'Prediction:: No vehicle or damage found in the image',Image.fromarray(image_input)

def reset_data():
    global cache_data
    cache_data = None

with gr.Blocks() as demo:
    gr.Markdown("# Vehicle damage detection")
    gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
    with gr.Row():
        image_input = gr.Image(label='Input Image')
        image_output = gr.Image(label='Damage Detection')
    with gr.Row():
        examples = gr.Examples(examples="./examples", inputs=image_input)
        prediction_op = gr.gradio.Textbox(label='Prediction')
    
    image_button = gr.Button("Segment Image", variant='primary')

    image_button.click(main_func, inputs=image_input, outputs=[prediction_op, image_output])
    image_input.upload(reset_data)

demo.launch()