import gradio as gr from einops import rearrange import segmentation_models_pytorch as smp import numpy as np import cv2 import torch from imutils import perspective def midpoint(ptA, ptB): return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5) # Load in image, convert to gray scale, and Otsu's threshold kernel1 =( np.ones((5,5), dtype=np.float32)) blur_radius=0.5 kernel_sharpening = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])*(1/9) def cca_analysis(image,predicted_mask): image2=np.asarray(image) print(image.shape) image = cv2.resize(predicted_mask, (image2.shape[1],image2.shape[1]), interpolation = cv2.INTER_AREA) image=cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel1,iterations=1 ) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] labels=cv2.connectedComponents(thresh,connectivity=8)[1] a=np.unique(labels) count2=0 for label in a: if label == 0: continue # Create a mask mask = np.zeros(thresh.shape, dtype="uint8") mask[labels == label] = 255 # Find contours and determine contour area cnts,hieararch = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cnts = cnts[0] c_area = cv2.contourArea(cnts) # threshhold for tooth count if c_area>100: count2+=1 rect = cv2.minAreaRect(cnts) box = cv2.boxPoints(rect) box = np.array(box, dtype="int") box = perspective.order_points(box) color1 = (list(np.random.choice(range(150), size=3))) color =[int(color1[0]), int(color1[1]), int(color1[2])] cv2.drawContours(image2,[box.astype("int")],0,color,2) (tl,tr,br,bl)=box (tltrX,tltrY)=midpoint(tl,tr) (blbrX,blbrY)=midpoint(bl,br) # compute the midpoint between the top-left and top-right points, # followed by the midpoint between the top-righ and bottom-right (tlblX,tlblY)=midpoint(tl,bl) (trbrX,trbrY)=midpoint(tr,br) # draw the midpoints on the image cv2.circle(image2, (int(tltrX), int(tltrY)), 5, (255, 0, 0), -1) cv2.circle(image2, (int(blbrX), int(blbrY)), 5, (255, 0, 0), -1) cv2.circle(image2, (int(tlblX), int(tlblY)), 5, (255, 0, 0), -1) cv2.circle(image2, (int(trbrX), int(trbrY)), 5, (255, 0, 0), -1) cv2.line(image2, (int(tltrX), int(tltrY)), (int(blbrX), int(blbrY)),color, 2) cv2.line(image2, (int(tlblX), int(tlblY)), (int(trbrX), int(trbrY)),color, 2) return image2 def to_rgb(img): result_new=np.zeros((img.shape[1],img.shape[0],3)) result_new[:,:,0]=img result_new[:,:,1]=img result_new[:,:,2]=img result_new=np.uint8(result_new*255) return result_new model = smp.Unet( encoder_name="se_resnext50_32x4d", encoder_weights=None, classes=1, activation="sigmoid" ) device='cuda' if torch.cuda.is_available() else 'cpu' map_location=torch.device(device) state_dict = torch.load("srx50-f0.pth",map_location=map_location)["state_dict"] model.load_state_dict(state_dict) model.eval() def predict(image_path,postpro): output_save = "output.png" im = cv2.imread(image_path) image = cv2.resize(im, (768,768), interpolation = cv2.INTER_AREA) im=np.expand_dims(image,0)/255 logits = model(torch.from_numpy(rearrange(im, 'b h w c -> b c h w')).float()) result = logits.sigmoid()[0].detach().numpy().squeeze() if postpro=="Connected Components Labelling": result=to_rgb(result) result=cca_analysis(image,result) else: result=result*255 cv2.imwrite(output_save, result) return image_path, output_save title = "Deprem ML - A U-Net model of Open Cities Challange Winner 1st " markdown=f''' [For the official implementation.](https://github.com/qubvel/segmentation_models.pytorch) It is running on {device} ''' image_list = [ "data/1.png", "data/2.png", "data/3.png", "data/4.png", ] examples = [[image_list[0], "Connected Components Labelling"], [image_list[1], "Connected Components Labelling"], [image_list[2], "Connected Components Labelling"], [image_list[3], "Connected Components Labelling"]] app = gr.Blocks() with app: gr.HTML("

{}

".format(title)) with gr.Row(): gr.Markdown(markdown) with gr.Row(): with gr.Column(): input_video = gr.Image(type='filepath') cca = gr.Dropdown(value="Connected Components Labelling", choices=["Connected Components Labelling","No Post Process"],label="Post Process") input_video_button = gr.Button(value="Predict") with gr.Column(): output_orijinal_image = gr.Image(type='filepath') with gr.Column(): output_mask_image = gr.Image(type='filepath') gr.Examples(examples, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image], fn=predict, cache_examples=True) input_video_button.click(predict, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image]) app.launch(debug=True)