SerdarHelli's picture
Update app.py
1ecf700
import gradio as gr
from einops import rearrange
import segmentation_models_pytorch as smp
import numpy as np
import cv2
import torch
from imutils import perspective
def midpoint(ptA, ptB):
return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5)
# Load in image, convert to gray scale, and Otsu's threshold
kernel1 =( np.ones((5,5), dtype=np.float32))
blur_radius=0.5
kernel_sharpening = np.array([[-1,-1,-1],
[-1,9,-1],
[-1,-1,-1]])*(1/9)
def cca_analysis(image,predicted_mask):
image2=np.asarray(image)
print(image.shape)
image = cv2.resize(predicted_mask, (image2.shape[1],image2.shape[1]), interpolation = cv2.INTER_AREA)
image=cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel1,iterations=1 )
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
labels=cv2.connectedComponents(thresh,connectivity=8)[1]
a=np.unique(labels)
count2=0
for label in a:
if label == 0:
continue
# Create a mask
mask = np.zeros(thresh.shape, dtype="uint8")
mask[labels == label] = 255
# Find contours and determine contour area
cnts,hieararch = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0]
c_area = cv2.contourArea(cnts)
# threshhold for tooth count
if c_area>100:
count2+=1
rect = cv2.minAreaRect(cnts)
box = cv2.boxPoints(rect)
box = np.array(box, dtype="int")
box = perspective.order_points(box)
color1 = (list(np.random.choice(range(150), size=3)))
color =[int(color1[0]), int(color1[1]), int(color1[2])]
cv2.drawContours(image2,[box.astype("int")],0,color,2)
(tl,tr,br,bl)=box
(tltrX,tltrY)=midpoint(tl,tr)
(blbrX,blbrY)=midpoint(bl,br)
# compute the midpoint between the top-left and top-right points,
# followed by the midpoint between the top-righ and bottom-right
(tlblX,tlblY)=midpoint(tl,bl)
(trbrX,trbrY)=midpoint(tr,br)
# draw the midpoints on the image
cv2.circle(image2, (int(tltrX), int(tltrY)), 5, (255, 0, 0), -1)
cv2.circle(image2, (int(blbrX), int(blbrY)), 5, (255, 0, 0), -1)
cv2.circle(image2, (int(tlblX), int(tlblY)), 5, (255, 0, 0), -1)
cv2.circle(image2, (int(trbrX), int(trbrY)), 5, (255, 0, 0), -1)
cv2.line(image2, (int(tltrX), int(tltrY)), (int(blbrX), int(blbrY)),color, 2)
cv2.line(image2, (int(tlblX), int(tlblY)), (int(trbrX), int(trbrY)),color, 2)
return image2
def to_rgb(img):
result_new=np.zeros((img.shape[1],img.shape[0],3))
result_new[:,:,0]=img
result_new[:,:,1]=img
result_new[:,:,2]=img
result_new=np.uint8(result_new*255)
return result_new
model = smp.Unet(
encoder_name="se_resnext50_32x4d",
encoder_weights=None,
classes=1,
activation="sigmoid"
)
device='cuda' if torch.cuda.is_available() else 'cpu'
map_location=torch.device(device)
state_dict = torch.load("srx50-f0.pth",map_location=map_location)["state_dict"]
model.load_state_dict(state_dict)
model.eval()
def predict(image_path,postpro):
output_save = "output.png"
im = cv2.imread(image_path)
image = cv2.resize(im, (768,768), interpolation = cv2.INTER_AREA)
im=np.expand_dims(image,0)/255
logits = model(torch.from_numpy(rearrange(im, 'b h w c -> b c h w')).float())
result = logits.sigmoid()[0].detach().numpy().squeeze()
if postpro=="Connected Components Labelling":
result=to_rgb(result)
result=cca_analysis(image,result)
else:
result=result*255
cv2.imwrite(output_save, result)
return image_path, output_save
title = "Deprem ML - A U-Net model of Open Cities Challange Winner 1st "
markdown=f'''
[For the official implementation.](https://github.com/qubvel/segmentation_models.pytorch)
It is running on {device}
'''
image_list = [
"data/1.png",
"data/2.png",
"data/3.png",
"data/4.png",
]
examples = [[image_list[0], "Connected Components Labelling"],
[image_list[1], "Connected Components Labelling"],
[image_list[2], "Connected Components Labelling"],
[image_list[3], "Connected Components Labelling"]]
app = gr.Blocks()
with app:
gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title))
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
with gr.Column():
input_video = gr.Image(type='filepath')
cca = gr.Dropdown(value="Connected Components Labelling", choices=["Connected Components Labelling","No Post Process"],label="Post Process")
input_video_button = gr.Button(value="Predict")
with gr.Column():
output_orijinal_image = gr.Image(type='filepath')
with gr.Column():
output_mask_image = gr.Image(type='filepath')
gr.Examples(examples, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image], fn=predict, cache_examples=True)
input_video_button.click(predict, inputs=[input_video,cca], outputs=[output_orijinal_image, output_mask_image])
app.launch(debug=True)