import gradio as gr import numpy as np import torch from PIL import Image, ImageDraw import requests from transformers import SamModel, SamProcessor from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import cv2 from typing import List device = "cuda" if torch.cuda.is_available() else "cpu" #Load clipseg Model clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) # Load SAM model and processor model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-base") cache_data = None # Prompts to segment damaged area and car prompts = ['damaged', 'car'] damage_threshold = 0.3 vehicle_threshold = 0.5 def bbox_normalization(bbox, width, height): height_coeff = height/352 width_coeff = width/352 normalized_bbox = [[bbox[0]*width_coeff, bbox[1]*height_coeff], [bbox[2]*width_coeff, bbox[3]*height_coeff]] print(f'Normalized-bbox:: {normalized_bbox}') return normalized_bbox def bbox_area(bbox): area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) return area def segment_to_bbox(segment_indexs): x_points = [] y_points = [] for y, list_val in enumerate(segment_indexs): for x, val in enumerate(list_val): if val == 1: x_points.append(x) y_points.append(y) if x_points and y_points: return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)] else: return [0.0,0.0,0.0,0.0] def clipseg_prediction(image): print('Clip-Segmentation-started------->') img_w, img_h,_ = image.shape inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt") # predict with torch.no_grad(): outputs = clip_model(**inputs) preds = outputs.logits.unsqueeze(1) # Setting threshold and classify the image contains vehicle or not flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1)) # Initialize a dummy "unlabeled" mask with the threshold flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold) flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold) flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle # Get the top mask index for each pixel damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1])) vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1])) # bbox creation damage_bbox = segment_to_bbox(damage_inds) vehicle_bbox = segment_to_bbox(vehicle_inds) # Vehicle checking if bbox_area(vehicle_bbox) > bbox_area(damage_bbox): return True, [bbox_normalization(damage_bbox, img_w, img_h)] else: return False, [[]] @torch.no_grad() def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray: print('SAM-Segmentation-started------->') global cache_data image_input = Image.fromarray(image_input) inputs = processor(image_input, input_boxes=points, return_tensors="pt").to(device) if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]): embedding = model.get_image_embeddings(inputs["pixel_values"]) pixels = inputs["pixel_values"] cache_data = [pixels, embedding] del inputs["pixel_values"] outputs = model.forward(image_embeddings=cache_data[1], **inputs) # outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].to(device), inputs["reshaped_input_sizes"].to(device) ) masks = masks[0][0].squeeze(0).numpy() return masks def main_func(inputs): image_input = inputs classification, points = clipseg_prediction(image_input) if classification: masks = foward_pass(image_input, points) # image_input = Image.fromarray(image_input) final_mask = masks[0] mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) mask_colors[final_mask, :] = np.array([[256, 0, 0]]) return Image.fromarray((mask_colors+ image_input).astype('uint8'), 'RGB') else: print('Prediction:: No vehicle found in the image') return Image.fromarray(image_input) def reset_data(): global cache_data cache_data = None with gr.Blocks() as demo: gr.Markdown("# Vehicle damage detection") gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""") with gr.Row(): image_input = gr.Image() image_output = gr.Image() image_button = gr.Button("Segment Image", variant='primary') image_button.click(main_func, inputs=image_input, outputs=image_output) image_input.upload(reset_data) demo.launch()