Spaces:
Build error
Build error
import gradio as gr | |
import onnxruntime as rt | |
import numpy as np | |
from transforms import ResizeLongestSide | |
from torch.nn import functional as F | |
import torch | |
import onnxruntime | |
IMAGE_SIZE = 1024 | |
def preprocess_image(image): | |
transform = ResizeLongestSide(IMAGE_SIZE) | |
input_image = transform.apply_image(image) | |
input_image_torch = torch.as_tensor(input_image, device="cpu") | |
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] | |
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) | |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) | |
x = (input_image_torch - pixel_mean) / pixel_std | |
h, w = x.shape[-2:] | |
padh = IMAGE_SIZE - h | |
padw = IMAGE_SIZE - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
x = x.numpy() | |
return x | |
def prepare_inputs(image_embedding, input_point, image_shape): | |
transform = ResizeLongestSide(IMAGE_SIZE) | |
input_label = np.array([1]) | |
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | |
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) | |
onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32) | |
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
onnx_has_mask_input = np.zeros(1, dtype=np.float32) | |
decoder_inputs = { | |
"image_embeddings": image_embedding, | |
"point_coords": onnx_coord, | |
"point_labels": onnx_label, | |
"mask_input": onnx_mask_input, | |
"has_mask_input": onnx_has_mask_input, | |
"orig_im_size": np.array(image_shape, dtype=np.float32) | |
} | |
return decoder_inputs | |
enc_session = onnxruntime.InferenceSession("encoder-quant.onnx") | |
dec_session = onnxruntime.InferenceSession("decoder-quant.onnx") | |
def predict_image(img): | |
x = preprocess_image(img) | |
encoder_inputs = { | |
"x": x, | |
} | |
output = enc_session.run(None, encoder_inputs) | |
image_embedding = output[0] | |
middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]]) | |
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2]) | |
masks, _, low_res_logits = dec_session.run(None, decoder_inputs) | |
# normalize the results between -1 and 1 | |
masks = masks[0][0] | |
masks[masks<0] = 0 | |
masks = masks / np.max(masks) | |
return masks, image_embedding, img.shape[:2] | |
def segment_image(image_embedding, shape, evt: gr.SelectData): | |
image_embedding = np.array(image_embedding) | |
middle_of_photo = np.array([evt.index]) | |
decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape) | |
masks, _, low_res_logits = dec_session.run(None, decoder_inputs) | |
# normalize the results between -1 and 1 | |
masks = masks[0][0] | |
masks[masks<0] = 0 | |
masks = masks / np.max(masks) | |
return masks | |
with gr.Blocks() as demo: | |
gr.Markdown("# SAM quantized (Segment Anything Model)") | |
markdown = """ | |
This is a demo of the SAM model, which is a model for segmenting anything in an image. | |
It returns segmentation mask of the image that's overlapping with the clicked point. | |
The model is quantized using ONNX Runtime | |
""" | |
gr.Markdown(markdown) | |
embedding = gr.State() | |
shape = gr.State() | |
with gr.Row(): | |
with gr.Column(): | |
inputs = gr.Image() | |
start_segmentation = gr.Button("Segment") | |
with gr.Column(): | |
outputs = gr.Image(label="Segmentation Mask") | |
start_segmentation.click( | |
predict_image, | |
inputs, | |
[outputs, embedding, shape], | |
) | |
outputs.select( | |
segment_image, | |
[embedding, shape], | |
outputs, | |
) | |
demo.launch() |