|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Convert SAM checkpoints from the original repository. |
|
|
|
URL: https://github.com/facebookresearch/segment-anything. |
|
|
|
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master. |
|
""" |
|
import sys |
|
sys.path.append("../") |
|
|
|
import argparse |
|
import re |
|
import torch |
|
from safetensors.torch import save_model |
|
from huggingface_hub import hf_hub_download |
|
from transformers import SamVisionConfig |
|
from sam_hq_vit_huge.modeling_sam_hq import SamHQModel |
|
from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig |
|
|
|
|
|
def get_config(model_name): |
|
if "sam_hq_vit_b" in model_name: |
|
vision_config = SamVisionConfig() |
|
elif "sam_hq_vit_l" in model_name: |
|
vision_config = SamVisionConfig( |
|
hidden_size=1024, |
|
num_hidden_layers=24, |
|
num_attention_heads=16, |
|
global_attn_indexes=[5, 11, 17, 23], |
|
) |
|
elif "sam_hq_vit_h" in model_name: |
|
vision_config = SamVisionConfig( |
|
hidden_size=1280, |
|
num_hidden_layers=32, |
|
num_attention_heads=16, |
|
global_attn_indexes=[7, 15, 23, 31], |
|
) |
|
|
|
config = SamHQConfig( |
|
vision_config=vision_config, |
|
) |
|
|
|
return config |
|
|
|
|
|
KEYS_TO_MODIFY_MAPPING = { |
|
|
|
"image_encoder": "vision_encoder", |
|
"patch_embed.proj": "patch_embed.projection", |
|
"blocks.": "layers.", |
|
"neck.0": "neck.conv1", |
|
"neck.1": "neck.layer_norm1", |
|
"neck.2": "neck.conv2", |
|
"neck.3": "neck.layer_norm2", |
|
|
|
|
|
"mask_downscaling.0": "mask_embed.conv1", |
|
"mask_downscaling.1": "mask_embed.layer_norm1", |
|
"mask_downscaling.3": "mask_embed.conv2", |
|
"mask_downscaling.4": "mask_embed.layer_norm2", |
|
"mask_downscaling.6": "mask_embed.conv3", |
|
"point_embeddings": "point_embed", |
|
"pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", |
|
|
|
|
|
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in", |
|
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0", |
|
"iou_prediction_head.layers.2": "iou_prediction_head.proj_out", |
|
"mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", |
|
"mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", |
|
"mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", |
|
".norm": ".layer_norm", |
|
|
|
|
|
"hf_mlp.layers.0": "hf_mlp.proj_in", |
|
"hf_mlp.layers.1": "hf_mlp.layers.0", |
|
"hf_mlp.layers.2": "hf_mlp.proj_out", |
|
} |
|
|
|
|
|
def replace_keys(state_dict): |
|
model_state_dict = {} |
|
state_dict.pop("pixel_mean", None) |
|
state_dict.pop("pixel_std", None) |
|
|
|
output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" |
|
|
|
for key, value in state_dict.items(): |
|
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |
|
if key_to_modify in key: |
|
key = key.replace(key_to_modify, new_key) |
|
|
|
if re.match(output_hypernetworks_mlps_pattern, key): |
|
layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) |
|
if layer_nb == 0: |
|
key = key.replace("layers.0", "proj_in") |
|
elif layer_nb == 1: |
|
key = key.replace("layers.1", "layers.0") |
|
elif layer_nb == 2: |
|
key = key.replace("layers.2", "proj_out") |
|
break |
|
|
|
model_state_dict[key] = value.cpu() |
|
|
|
model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ |
|
"prompt_encoder.shared_embedding.positional_embedding" |
|
].cpu().clone() |
|
|
|
return model_state_dict |
|
|
|
|
|
def convert_sam_checkpoint(model_name, checkpoint_path, output_dir): |
|
config = get_config(model_name) |
|
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
state_dict = replace_keys(state_dict) |
|
|
|
hf_model = SamHQModel(config) |
|
hf_model.eval() |
|
|
|
hf_model.load_state_dict(state_dict) |
|
|
|
if output_dir is not None: |
|
save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"}) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"] |
|
parser.add_argument( |
|
"--model_name", |
|
default="sam_hq_vit_h", |
|
choices=choices, |
|
type=str, |
|
help="Name of the original model to convert", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", |
|
type=str, |
|
required=False, |
|
help="Path to the original checkpoint", |
|
) |
|
parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.checkpoint_path is not None: |
|
checkpoint_path = args.checkpoint_path |
|
else: |
|
checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth") |
|
|
|
convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir) |
|
|