sam_hq_vit_huge / convert_sam_hq_to_hf.py
ductai199x's picture
minor
c2e7f35
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert SAM checkpoints from the original repository.
URL: https://github.com/facebookresearch/segment-anything.
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
"""
import sys
sys.path.append("../")
import argparse
import re
import torch
from safetensors.torch import save_model
from huggingface_hub import hf_hub_download
from transformers import SamVisionConfig
from sam_hq_vit_huge.modeling_sam_hq import SamHQModel
from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig
def get_config(model_name):
if "sam_hq_vit_b" in model_name:
vision_config = SamVisionConfig()
elif "sam_hq_vit_l" in model_name:
vision_config = SamVisionConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
global_attn_indexes=[5, 11, 17, 23],
)
elif "sam_hq_vit_h" in model_name:
vision_config = SamVisionConfig(
hidden_size=1280,
num_hidden_layers=32,
num_attention_heads=16,
global_attn_indexes=[7, 15, 23, 31],
)
config = SamHQConfig(
vision_config=vision_config,
)
return config
KEYS_TO_MODIFY_MAPPING = {
# Vision Encoder
"image_encoder": "vision_encoder",
"patch_embed.proj": "patch_embed.projection",
"blocks.": "layers.",
"neck.0": "neck.conv1",
"neck.1": "neck.layer_norm1",
"neck.2": "neck.conv2",
"neck.3": "neck.layer_norm2",
# Prompt Encoder
"mask_downscaling.0": "mask_embed.conv1",
"mask_downscaling.1": "mask_embed.layer_norm1",
"mask_downscaling.3": "mask_embed.conv2",
"mask_downscaling.4": "mask_embed.layer_norm2",
"mask_downscaling.6": "mask_embed.conv3",
"point_embeddings": "point_embed",
"pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
# Mask Decoder
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
"iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
"mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
"mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
"mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
".norm": ".layer_norm",
# SAM HQ Extra (in Mask Decoder)
"hf_mlp.layers.0": "hf_mlp.proj_in",
"hf_mlp.layers.1": "hf_mlp.layers.0",
"hf_mlp.layers.2": "hf_mlp.proj_out",
}
def replace_keys(state_dict):
model_state_dict = {}
state_dict.pop("pixel_mean", None)
state_dict.pop("pixel_std", None)
output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if re.match(output_hypernetworks_mlps_pattern, key):
layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
if layer_nb == 0:
key = key.replace("layers.0", "proj_in")
elif layer_nb == 1:
key = key.replace("layers.1", "layers.0")
elif layer_nb == 2:
key = key.replace("layers.2", "proj_out")
break
model_state_dict[key] = value.cpu()
model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
"prompt_encoder.shared_embedding.positional_embedding"
].cpu().clone()
return model_state_dict
def convert_sam_checkpoint(model_name, checkpoint_path, output_dir):
config = get_config(model_name)
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = replace_keys(state_dict)
hf_model = SamHQModel(config)
hf_model.eval()
hf_model.load_state_dict(state_dict)
if output_dir is not None:
save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"]
parser.add_argument(
"--model_name",
default="sam_hq_vit_h",
choices=choices,
type=str,
help="Name of the original model to convert",
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=False,
help="Path to the original checkpoint",
)
parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
if args.checkpoint_path is not None:
checkpoint_path = args.checkpoint_path
else:
checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth")
convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir)