sam-quant / app.py
Witold Wydmański
init
a06c206
import gradio as gr
import onnxruntime as rt
import numpy as np
from transforms import ResizeLongestSide
from torch.nn import functional as F
import torch
import onnxruntime
IMAGE_SIZE = 1024
def preprocess_image(image):
transform = ResizeLongestSide(IMAGE_SIZE)
input_image = transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device="cpu")
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
x = (input_image_torch - pixel_mean) / pixel_std
h, w = x.shape[-2:]
padh = IMAGE_SIZE - h
padw = IMAGE_SIZE - w
x = F.pad(x, (0, padw, 0, padh))
x = x.numpy()
return x
def prepare_inputs(image_embedding, input_point, image_shape):
transform = ResizeLongestSide(IMAGE_SIZE)
input_label = np.array([1])
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image_shape, dtype=np.float32)
}
return decoder_inputs
enc_session = onnxruntime.InferenceSession("encoder-quant.onnx")
dec_session = onnxruntime.InferenceSession("decoder-quant.onnx")
def predict_image(img):
x = preprocess_image(img)
encoder_inputs = {
"x": x,
}
output = enc_session.run(None, encoder_inputs)
image_embedding = output[0]
middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]])
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2])
masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
# normalize the results between -1 and 1
masks = masks[0][0]
masks[masks<0] = 0
masks = masks / np.max(masks)
return masks, image_embedding, img.shape[:2]
def segment_image(image_embedding, shape, evt: gr.SelectData):
image_embedding = np.array(image_embedding)
middle_of_photo = np.array([evt.index])
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape)
masks, _, low_res_logits = dec_session.run(None, decoder_inputs)
# normalize the results between -1 and 1
masks = masks[0][0]
masks[masks<0] = 0
masks = masks / np.max(masks)
return masks
with gr.Blocks() as demo:
gr.Markdown("# SAM quantized (Segment Anything Model)")
markdown = """
This is a demo of the SAM model, which is a model for segmenting anything in an image.
It returns segmentation mask of the image that's overlapping with the clicked point.
The model is quantized using ONNX Runtime
"""
gr.Markdown(markdown)
embedding = gr.State()
shape = gr.State()
with gr.Row():
with gr.Column():
inputs = gr.Image()
start_segmentation = gr.Button("Segment")
with gr.Column():
outputs = gr.Image(label="Segmentation Mask")
start_segmentation.click(
predict_image,
inputs,
[outputs, embedding, shape],
)
outputs.select(
segment_image,
[embedding, shape],
outputs,
)
demo.launch()