|
import gradio as gr |
|
from einops import rearrange |
|
import segmentation_models_pytorch as smp |
|
|
|
import numpy as np |
|
import cv2 |
|
import torch |
|
from imutils import perspective |
|
|
|
|
|
def midpoint(ptA, ptB): |
|
return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5) |
|
|
|
kernel1 =( np.ones((5,5), dtype=np.float32)) |
|
blur_radius=0.5 |
|
kernel_sharpening = np.array([[-1,-1,-1], |
|
[-1,9,-1], |
|
[-1,-1,-1]])*(1/9) |
|
|
|
|
|
def cca_analysis(image,predicted_mask): |
|
|
|
image2=np.asarray(image) |
|
print(image.shape) |
|
image = cv2.resize(predicted_mask, (image2.shape[1],image2.shape[1]), interpolation = cv2.INTER_AREA) |
|
|
|
image=cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel1,iterations=1 ) |
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] |
|
|
|
labels=cv2.connectedComponents(thresh,connectivity=8)[1] |
|
a=np.unique(labels) |
|
count2=0 |
|
for label in a: |
|
if label == 0: |
|
continue |
|
|
|
|
|
mask = np.zeros(thresh.shape, dtype="uint8") |
|
mask[labels == label] = 255 |
|
|
|
|
|
cnts,hieararch = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
cnts = cnts[0] |
|
c_area = cv2.contourArea(cnts) |
|
|
|
if c_area>100: |
|
count2+=1 |
|
|
|
rect = cv2.minAreaRect(cnts) |
|
box = cv2.boxPoints(rect) |
|
box = np.array(box, dtype="int") |
|
box = perspective.order_points(box) |
|
color1 = (list(np.random.choice(range(150), size=3))) |
|
color =[int(color1[0]), int(color1[1]), int(color1[2])] |
|
cv2.drawContours(image2,[box.astype("int")],0,color,2) |
|
(tl,tr,br,bl)=box |
|
|
|
(tltrX,tltrY)=midpoint(tl,tr) |
|
(blbrX,blbrY)=midpoint(bl,br) |
|
|
|
|
|
(tlblX,tlblY)=midpoint(tl,bl) |
|
(trbrX,trbrY)=midpoint(tr,br) |
|
|
|
cv2.circle(image2, (int(tltrX), int(tltrY)), 5, (255, 0, 0), -1) |
|
cv2.circle(image2, (int(blbrX), int(blbrY)), 5, (255, 0, 0), -1) |
|
cv2.circle(image2, (int(tlblX), int(tlblY)), 5, (255, 0, 0), -1) |
|
cv2.circle(image2, (int(trbrX), int(trbrY)), 5, (255, 0, 0), -1) |
|
cv2.line(image2, (int(tltrX), int(tltrY)), (int(blbrX), int(blbrY)),color, 2) |
|
cv2.line(image2, (int(tlblX), int(tlblY)), (int(trbrX), int(trbrY)),color, 2) |
|
return image2 |
|
|
|
|
|
def to_rgb(img): |
|
result_new=np.zeros((img.shape[1],img.shape[0],3)) |
|
result_new[:,:,0]=img |
|
result_new[:,:,1]=img |
|
result_new[:,:,2]=img |
|
result_new=np.uint8(result_new*255) |
|
return result_new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = smp.Unet( |
|
encoder_name="se_resnext50_32x4d", |
|
encoder_weights=None, |
|
classes=1, |
|
activation="sigmoid" |
|
) |
|
device='cuda' if torch.cuda.is_available() else 'cpu' |
|
map_location=torch.device(device) |
|
|
|
state_dict = torch.load("srx50-f0.pth",map_location=map_location)["state_dict"] |
|
|
|
|
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
def predict(image_path,postpro): |
|
output_save = "output.png" |
|
im = cv2.imread(image_path) |
|
image = cv2.resize(im, (768,768), interpolation = cv2.INTER_AREA) |
|
im=np.expand_dims(image,0)/255 |
|
|
|
logits = model(torch.from_numpy(rearrange(im, 'b h w c -> b c h w')).float()) |
|
result = logits.sigmoid()[0].detach().numpy().squeeze() |
|
|
|
if postpro=="Connected Components Labelling": |
|
result=to_rgb(result) |
|
result=cca_analysis(image,result) |
|
else: |
|
result=result*255 |
|
|
|
cv2.imwrite(output_save, result) |
|
return image_path, output_save |
|
|
|
|
|
title = "Deprem ML - A U-Net model of Open Cities Challange Winner 1st " |
|
|
|
markdown=f''' |
|
|
|
[For the official implementation.](https://github.com/qubvel/segmentation_models.pytorch) |
|
|
|
|
|
It is running on {device} |
|
|
|
''' |
|
|
|
|
|
image_list = [ |
|
"data/1.png", |
|
"data/2.png", |
|
"data/3.png", |
|
"data/4.png", |
|
] |
|
examples = [[image_list[0], "Connected Components Labelling"], |
|
[image_list[1], "Connected Components Labelling"], |
|
[image_list[2], "Connected Components Labelling"], |
|
[image_list[3], "Connected Components Labelling"]] |
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title)) |
|
|
|
with gr.Row(): |
|
gr.Markdown(markdown) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_video = gr.Image(type='filepath') |
|
cca = gr.Dropdown(value="Connected Components Labelling", choices=["Connected Components Labelling","No Post Process"],label="Post Process") |
|
|
|
input_video_button = gr.Button(value="Predict") |
|
|
|
with gr.Column(): |
|
output_orijinal_image = gr.Image(type='filepath') |
|
|
|
with gr.Column(): |
|
output_mask_image = gr.Image(type='filepath') |
|
|
|
|
|
gr.Examples(examples, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image], fn=predict, cache_examples=True) |
|
input_video_button.click(predict, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image]) |
|
|
|
app.launch(debug=True) |
|
|