Spaces:
Runtime error
Runtime error
File size: 5,522 Bytes
05a002a fdc863a 05a002a b7a783f ae4c6a4 05a002a 8a448cc ceb376b 05a002a b377691 bc3cd1b 05a002a bc3cd1b 05a002a ceb376b 05a002a ceb376b 05a002a 8a448cc b377691 05a002a cef3c32 05a002a afbc21e 05a002a bc3cd1b 05a002a fd7dc8f 05a002a c920b87 05a002a e1484f5 05a002a 7f84e59 cf0001f 05a002a b377691 cf0001f 05a002a 2aab4d2 05a002a cf0001f 05a002a cf0001f 05a002a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import requests
from transformers import SamModel, SamProcessor
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import cv2
from typing import List
device = "cuda" if torch.cuda.is_available() else "cpu"
#Load clipseg Model
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
# Load SAM model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
cache_data = None
# Prompts to segment damaged area and car
prompts = ['damaged', 'car']
damage_threshold = 0.3
vehicle_threshold = 0.5
def bbox_normalization(bbox, width, height):
height_coeff = height/352
width_coeff = width/352
normalized_bbox = [[bbox[0]*width_coeff, bbox[1]*height_coeff],
[bbox[2]*width_coeff, bbox[3]*height_coeff]]
print(f'Normalized-bbox:: {normalized_bbox}')
return normalized_bbox
def bbox_area(bbox):
area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
return area
def segment_to_bbox(segment_indexs):
x_points = []
y_points = []
for y, list_val in enumerate(segment_indexs):
for x, val in enumerate(list_val):
if val == 1:
x_points.append(x)
y_points.append(y)
if x_points and y_points:
return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
else:
return [0.0,0.0,0.0,0.0]
def clipseg_prediction(image):
print('Clip-Segmentation-started------->')
img_w, img_h,_ = image.shape
inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = clip_model(**inputs)
preds = outputs.logits.unsqueeze(1)
# Setting threshold and classify the image contains vehicle or not
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
# Initialize a dummy "unlabeled" mask with the threshold
flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle
# Get the top mask index for each pixel
damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
# bbox creation
damage_bbox = segment_to_bbox(damage_inds)
vehicle_bbox = segment_to_bbox(vehicle_inds)
# Vehicle checking
if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
return True, [bbox_normalization(damage_bbox, img_w, img_h)]
else:
return False, [[]]
@torch.no_grad()
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
print('SAM-Segmentation-started------->')
global cache_data
image_input = Image.fromarray(image_input)
inputs = processor(image_input, input_boxes=points, return_tensors="pt").to(device)
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
embedding = model.get_image_embeddings(inputs["pixel_values"])
pixels = inputs["pixel_values"]
cache_data = [pixels, embedding]
del inputs["pixel_values"]
outputs = model.forward(image_embeddings=cache_data[1], **inputs)
# outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].to(device), inputs["reshaped_input_sizes"].to(device)
)
masks = masks[0][0].squeeze(0).numpy()
return masks
def main_func(inputs):
image_input = inputs
classification, points = clipseg_prediction(image_input)
if classification:
masks = foward_pass(image_input, points)
# image_input = Image.fromarray(image_input)
final_mask = masks[0]
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[256, 0, 0]])
return 'Prediction: Vehicle damage prediction is given.',Image.fromarray((mask_colors+ image_input).astype('uint8'), 'RGB')
else:
print('Prediction:: No vehicle found in the image')
return 'Prediction:: No vehicle or damage found in the image',Image.fromarray(image_input)
def reset_data():
global cache_data
cache_data = None
with gr.Blocks() as demo:
gr.Markdown("# Vehicle damage detection")
gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
with gr.Row():
image_input = gr.Image(label='Input Image')
image_output = gr.Image(label='Damage Detection')
with gr.Row():
examples = gr.Examples(examples="./examples", inputs=image_input)
prediction_op = gr.gradio.Textbox(label='Prediction')
image_button = gr.Button("Segment Image", variant='primary')
image_button.click(main_func, inputs=image_input, outputs=[prediction_op, image_output])
image_input.upload(reset_data)
demo.launch()
|