Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
from patchify import patchify, unpatchify | |
import numpy as np | |
from skimage.io import imshow, imsave | |
import tensorflow | |
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
def jacard(y_true, y_pred): | |
y_true_c = K.flatten(y_true) | |
y_pred_c = K.flatten(y_pred) | |
intersection = K.sum(y_true_c * y_pred_c) | |
return (intersection + 1.0) / (K.sum(y_true_c) + K.sum(y_pred_c) - intersection + 1.0) | |
def bce_dice(y_true, y_pred): | |
bce = tf.keras.losses.BinaryCrossentropy() | |
return bce(y_true, y_pred) - K.log(jacard(y_true, y_pred)) | |
size = 1024 | |
pach_size = 256 | |
def predict_2(image): | |
image = Image.fromarray(image).resize((size,size)) | |
image = np.array(image) | |
stride = 2 | |
steps = int(pach_size/stride) | |
patches_img = patchify(image, (pach_size, pach_size, 3), step=steps) #Step=256 for 256 patches means no overlap | |
patches_img = patches_img[:,:,0,:,:,:] | |
patched_prediction = [] | |
for i in range(patches_img.shape[0]): | |
for j in range(patches_img.shape[1]): | |
single_patch_img = patches_img[i,j,:,:,:] | |
single_patch_img = single_patch_img/255 | |
single_patch_img = np.expand_dims(single_patch_img, axis=0) | |
pred = model.predict(single_patch_img) | |
# Postprocess the mask | |
pred = np.argmax(pred, axis=3) | |
#print(pred.shape) | |
pred = pred[0, :,:] | |
patched_prediction.append(pred) | |
patched_prediction = np.reshape(patched_prediction, [patches_img.shape[0], patches_img.shape[1], | |
patches_img.shape[2], patches_img.shape[3]]) | |
unpatched_prediction = unpatchify(patched_prediction, (image.shape[0], image.shape[1])) | |
unpatched_prediction = targets_classes_colors[unpatched_prediction] | |
return 'Predicted Masked Image', unpatched_prediction | |
targets_classes_colors = np.array([[ 0, 0, 0], | |
[128, 64, 128], | |
[130, 76, 0], | |
[ 0, 102, 0], | |
[112, 103, 87], | |
[ 28, 42, 168], | |
[ 48, 41, 30], | |
[ 0, 50, 89], | |
[107, 142, 35], | |
[ 70, 70, 70], | |
[102, 102, 156], | |
[254, 228, 12], | |
[254, 148, 12], | |
[190, 153, 153], | |
[153, 153, 153], | |
[255, 22, 96], | |
[102, 51, 0], | |
[ 9, 143, 150], | |
[119, 11, 32], | |
[ 51, 51, 0], | |
[190, 250, 190], | |
[112, 150, 146], | |
[ 2, 135, 115], | |
[255, 0, 0]]) | |
class_weights = {0: 1.0, | |
1: 1.0, | |
2: 2.171655596616696, | |
3: 1.0, | |
4: 1.0, | |
5: 2.2101197049812593, | |
6: 11.601519937899578, | |
7: 7.99072122367673, | |
8: 1.0, | |
9: 1.0, | |
10: 2.5426918173402457, | |
11: 11.187574445057574, | |
12: 241.57620214903147, | |
13: 9.234779790464515, | |
14: 1077.2745952165694, | |
15: 7.396021659003857, | |
16: 855.6730643687165, | |
17: 6.410869993189135, | |
18: 42.0186736125025, | |
19: 2.5648760196752947, | |
20: 4.089194047656931, | |
21: 27.984593442818955, | |
22: 2.0509251319694712} | |
weight_list = list(class_weights.values()) | |
def weighted_categorical_crossentropy(weights): | |
weights = weight_list | |
def wcce(y_true, y_pred): | |
Kweights = K.constant(weights) | |
if not tf.is_tensor(y_pred): y_pred = K.constant(y_pred) | |
y_true = K.cast(y_true, y_pred.dtype) | |
return bce_dice(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1) | |
return wcce | |
# Load the model | |
model = tf.keras.models.load_model("model.h5", custom_objects={"jacard":jacard, "wcce":weighted_categorical_crossentropy}) | |
# Create a user interface for the model | |
my_app = gr.Blocks() | |
with my_app: | |
gr.Markdown("Statellite Image Segmentation Application UI with Gradio") | |
with gr.Tabs(): | |
with gr.TabItem("Select your image"): | |
with gr.Row(): | |
with gr.Column(): | |
img_source = gr.Image(label="Please select source Image") | |
source_image_loader = gr.Button("Load above Image") | |
with gr.Column(): | |
output_label = gr.Label(label="Image Info") | |
img_output = gr.Image(label="Image Output") | |
source_image_loader.click( | |
predict_2, | |
[ | |
img_source | |
], | |
[ | |
output_label, | |
img_output | |
] | |
) | |
my_app.launch(debug=True, share=True) | |
my_app.close() |