Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from huggingface_hub import from_pretrained_keras | |
import os | |
import sys | |
print('Loading model...') | |
model = from_pretrained_keras("mostafapasha/ribs-segmentation-model", compile=False) | |
print('Successfully loaded model...') | |
examples = [['examples/VinDr_RibCXR_train_056.png', 0.2], ['examples/VinDr_RibCXR_train_179.png', 0.8]] | |
def infer(img, threshold): | |
if np.ndim(img) != 2: | |
img = img[:, :, 1] | |
img = img.reshape(1, img.shape[0], img.shape[1], 1) | |
logits = model(img, training=False) | |
prob = tf.sigmoid(logits) | |
pred = tf.cast(prob > threshold, dtype=tf.float32) | |
pred = np.array(pred.numpy())[0,:,:,0] | |
return pred | |
gr_input = [gr.inputs.Image(label="Image", type="numpy", shape=(512, 512)), gr.inputs.Slider(minimum=0, maximum=1, step=0.05, default=0.5, label="Segmentation Threshold") | |
] | |
gr_output = [gr.outputs.Image(type="pil",label="Segmentation Mask"), | |
] | |
iface = gr.Interface(fn=infer, title='ribs segmentation model', description='Keras implementation of ResUNET++ for xray ribs segmentation', inputs=gr_input, outputs=gr_output, examples=examples, flagging_dir="flagged").launch() | |