File size: 2,616 Bytes
fd15625
d0bbc42
caa5920
62df097
 
 
 
 
 
 
b18464f
4e95af2
d0bbc42
62df097
 
4240815
62df097
 
 
 
52248eb
 
78ebf61
 
 
52248eb
 
78ebf61
 
52248eb
78ebf61
 
 
52248eb
78ebf61
 
52248eb
78ebf61
52248eb
78ebf61
52248eb
78ebf61
52248eb
78ebf61
52248eb
78ebf61
52248eb
78ebf61
 
 
52248eb
78ebf61
 
52248eb
78ebf61
399fb9d
 
 
6fc2a9e
78ebf61
62df097
 
e256fb4
fd15625
 
 
 
 
 
 
 
 
 
62df097
843c579
62df097
cf9d3c1
 
62df097
 
 
 
 
 
 
 
 
fd15625
62df097
8b41375
62df097
 
fd15625
 
09e7a06
62df097
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from turtle import title
import gradio as gr
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import base64


model = tf.keras.models.load_model("./tf_model.h5")


def predict(image):
    img = np.array(image)
    original_shape = img.shape[:2]

    im = tf.image.resize(img, (128, 128))
    im = tf.cast(im, tf.float32) / 255.0
    pred_mask = model.predict(im[tf.newaxis, ...])


    # take the best performing class for each pixel
    # the output of argmax looks like this [[1, 2, 0], ...]
    pred_mask_arg = tf.argmax(pred_mask, axis=-1)


    # convert the prediction mask into binary masks for each class
    binary_masks = {}

    # when we take tf.argmax() over pred_mask, it becomes a tensor object
    # the shape becomes TensorShape object, looking like this TensorShape([128]) 
    # we need to take get shape, convert to list and take the best one

    rows = pred_mask_arg[0][1].get_shape().as_list()[0]
    cols = pred_mask_arg[0][2].get_shape().as_list()[0]

    for cls in range(pred_mask.shape[-1]):

        binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
        
        for row in range(rows):

            for col in range(cols):

                if pred_mask_arg[0][row][col] == cls:
                    
                    binary_masks[f"mask_{cls}"][row][col] = 1
                else:
                    binary_masks[f"mask_{cls}"][row][col] = 0

        mask = binary_masks[f"mask_{cls}"]
        mask *= 255

    mask = np.array(Image.fromarray(mask).convert("L"))
    mask = tf.image.resize(mask[..., tf.newaxis], original_shape)
    mask = tf.cast(mask, tf.uint8)
    mask = mask.numpy().squeeze()

    return mask
    

title = '<h1 style="text-align: center;">Segment Pets</h1>'

description = """
## About
This space demonstrates the use of a semantic segmentation model to segment pets and classify them 
according to the pixels.


## 🚀 To run
Upload a pet image and hit submit or select one from the given examples
"""

inputs = gr.inputs.Image(label="Upload a pet image", type = 'pil', optional=False)
outputs = [
    gr.outputs.Image(label="Segmentation")
    # , gr.outputs.Textbox(type="auto",label="Pet Prediction")
]

examples = [
    "./examples/cat_1.jpg",
    "./examples/cat_2.jpg",
    "./examples/dog_1.jpg",
    "./examples/dog_2.jpg",
]



interface = gr.Interface(fn=predict, 
    inputs=inputs,
    outputs=outputs,
    title = title, 
    description=description,
    examples=examples
    )
interface.launch()