import numpy as np import gradio as gr import segment_anything import imutils import numpy as np import base64 import torch import typing import os import subprocess def image_to_sam_image_embedding( image_url: str, model_size: typing.Literal["base", "large", "huge"] = "base", ) -> str: """Generate an image embedding.""" # Load image image = imutils.url_to_image(image_url) # Select model size if model_size == "base": predictor = base_predictor elif model_size == "large": predictor = large_predictor elif model_size == "huge": predictor = huge_predictor # Run model predictor.set_image(image) # Output shape is (1, 256, 64, 64) image_embedding = predictor.get_image_embedding().cpu().numpy() # Flatten the array to a 1D array flat_arr = image_embedding.flatten() # Convert the 1D array to bytes bytes_arr = flat_arr.astype(np.float32).tobytes() # Encode the bytes to base64 base64_str = base64.b64encode(bytes_arr).decode("utf-8") return base64_str if __name__ == "__main__": # Load the model into memory to make running multiple predictions efficient device = "cuda" if torch.cuda.is_available() else "cpu" base_sam_checkpoint = "sam_vit_b_01ec64.pth" # 375 MB large_sam_checkpoint = "sam_vit_l_0b3195.pth" # 1.25 GB huge_sam_checkpoint = "sam_vit_h_4b8939.pth" # 2.56 GB # Download the model checkpoints for model in [base_sam_checkpoint, large_sam_checkpoint, huge_sam_checkpoint]: if not os.path.exists(f"./{model}"): result = subprocess.run( ["wget", f"https://dl.fbaipublicfiles.com/segment_anything/{model}"], check=True, ) print(f"wget {model} result = {result}") base_sam = segment_anything.sam_model_registry["vit_b"]( checkpoint=base_sam_checkpoint ) large_sam = segment_anything.sam_model_registry["vit_l"]( checkpoint=large_sam_checkpoint ) huge_sam = segment_anything.sam_model_registry["vit_h"]( checkpoint=huge_sam_checkpoint ) base_sam.to(device=device) large_sam.to(device=device) huge_sam.to(device=device) base_predictor = segment_anything.SamPredictor(base_sam) large_predictor = segment_anything.SamPredictor(large_sam) huge_predictor = segment_anything.SamPredictor(huge_sam) # Gradio app app = gr.Interface( fn=image_to_sam_image_embedding, inputs="text", outputs="text", ) app.launch()