satellite-test / app.py
Jayem-11's picture
Update app.py
286c232
raw
history blame
2.12 kB
import os
import cv2
from PIL import Image
import numpy as np
from patchify import patchify
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from matplotlib import pyplot as plt
import random
from keras.utils import to_categorical
from keras import backend as K
import gradio as gr
def jaccard_coef(y_true, y_pred):
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten*y_pred_flatten) + 1.0
union = K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0
iou = intersection / union
return iou
weights = [0.166,0.166,0.166,0.166,0.166,0.166]
import segmentation_models as sm
dice_loss = sm.losses.DiceLoss(class_weights = weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)
from keras.models import load_model
saved_model = load_model('model/satellite_segmentation_full.h5',
custom_objects=({'dice_loss_plus_1focal_loss': total_loss,
'jaccard_coef': jaccard_coef}))
def process_input_image(image_source):
image = np.expand_dims(image_source, 0)
prediction = saved_model.predict(image)
predicted_image = np.argmax(prediction, axis=3)
predicted_image = predicted_image[0,:,:]
predicted_image = predicted_image * 50
return 'Predicted Masked Image', predicted_image
my_app = gr.Blocks()
with my_app:
gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Row():
with gr.Column():
img_source = gr.Image(label="Please select source Image", shape=(256, 256))
source_image_loader = gr.Button("Load above Image")
with gr.Column():
output_label = gr.Label(label="Image Info")
img_output = gr.Image(label="Image Output")
source_image_loader.click(
process_input_image,
[
img_source
],
[
output_label,
img_output
]
)
my_app.launch(debug=True)