File size: 1,879 Bytes
a7ce59e
 
 
 
 
 
d299b84
9471efd
c6b5997
 
a7ce59e
 
 
c6b5997
9471efd
a7ce59e
9471efd
c6b5997
 
 
 
 
 
 
 
 
 
a7ce59e
6cb57f7
c6b5997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d299b84
 
 
 
 
 
 
176687d
c6b5997
d299b84
c6b5997
 
7aad423
6cb57f7
 
7aad423
 
6cb57f7
 
d299b84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import json
from typing import Any, Dict, List

import tensorflow as tf
from tensorflow import keras
import base64
import io
import os
import numpy as np
from PIL import Image



class PreTrainedPipeline():
    def __init__(self, path: str):

        self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))

    def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:

        with Image.open(inputs) as img:
            img = np.array(img)

        im = tf.image.resize(img, (128, 128))
        im = tf.cast(im, tf.float32) / 255.0
        pred_mask = model.predict(im[tf.newaxis, ...])
        pred_mask_arg = tf.argmax(pred_mask, axis=-1)

        labels = []

        binary_masks = {}
        mask_codes = {}


        for cls in range(pred_mask.shape[-1]):

            binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2]))
            for row in range(pred_mask_arg[0][1].get_shape().as_list()[0]):

                for col in range(pred_mask_arg[0][2].get_shape().as_list()[0]):

                    if pred_mask_arg[0][row][col] == cls:
                        
                        binary_masks[f"mask_{cls}"][row][col] = 1
                    else:
                        binary_masks[f"mask_{cls}"][row][col] = 0

            mask = binary_masks[f"mask_{cls}"]
            mask *= 255
            img = Image.fromarray(mask.astype(np.int8), mode="L")

            with io.BytesIO() as out:
                img.save(out, format="PNG")
                png_string = out.getvalue()
                mask = base64.b64encode(png_string).decode("utf-8")

            mask_codes[f"mask_{cls}"] = mask
    
    

            
            labels.append({
                "label": f"LABEL_{cls}",
                "mask": mask_codes[f"mask_{cls}"],
                "score": 1.0,
            })
        return labels