semantic-segmentation / pipeline.py
merve's picture
merve HF staff
Update pipeline.py
176687d
raw
history blame
No virus
1.88 kB
import json
from typing import Any, Dict, List
import tensorflow as tf
from tensorflow import keras
import base64
import io
import os
import numpy as np
from PIL import Image
class PreTrainedPipeline():
def __init__(self, path: str):
self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
with Image.open(inputs) as img:
img = np.array(img)
im = tf.image.resize(img, (128, 128))
im = tf.cast(im, tf.float32) / 255.0
pred_mask = model.predict(im[tf.newaxis, ...])
pred_mask_arg = tf.argmax(pred_mask, axis=-1)
labels = []
binary_masks = {}
mask_codes = {}
for cls in range(pred_mask.shape[-1]):
binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2]))
for row in range(pred_mask_arg[0][1].get_shape().as_list()[0]):
for col in range(pred_mask_arg[0][2].get_shape().as_list()[0]):
if pred_mask_arg[0][row][col] == cls:
binary_masks[f"mask_{cls}"][row][col] = 1
else:
binary_masks[f"mask_{cls}"][row][col] = 0
mask = binary_masks[f"mask_{cls}"]
mask *= 255
img = Image.fromarray(mask.astype(np.int8), mode="L")
with io.BytesIO() as out:
img.save(out, format="PNG")
png_string = out.getvalue()
mask = base64.b64encode(png_string).decode("utf-8")
mask_codes[f"mask_{cls}"] = mask
labels.append({
"label": f"LABEL_{cls}",
"mask": mask_codes[f"mask_{cls}"],
"score": 1.0,
})
return labels