ductai199x
commited on
Commit
·
b967cb8
1
Parent(s):
d2752ed
add weight conversion script for other model versions
Browse files- __init__.py +0 -0
- convert_sam_hq_to_hf.py +172 -0
__init__.py
ADDED
File without changes
|
convert_sam_hq_to_hf.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Convert SAM checkpoints from the original repository.
|
17 |
+
|
18 |
+
URL: https://github.com/facebookresearch/segment-anything.
|
19 |
+
|
20 |
+
Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
|
21 |
+
"""
|
22 |
+
import sys
|
23 |
+
sys.path.append("../")
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import re
|
27 |
+
import torch
|
28 |
+
from safetensors.torch import save_model
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
+
from transformers import (
|
31 |
+
SamImageProcessor,
|
32 |
+
SamProcessor,
|
33 |
+
SamVisionConfig,
|
34 |
+
)
|
35 |
+
from sam_hq_vit_huge.modeling_sam_hq import SamHQModel
|
36 |
+
from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig
|
37 |
+
|
38 |
+
|
39 |
+
def get_config(model_name):
|
40 |
+
if "sam_hq_vit_b" in model_name:
|
41 |
+
vision_config = SamVisionConfig()
|
42 |
+
elif "sam_hq_vit_l" in model_name:
|
43 |
+
vision_config = SamVisionConfig(
|
44 |
+
hidden_size=1024,
|
45 |
+
num_hidden_layers=24,
|
46 |
+
num_attention_heads=16,
|
47 |
+
global_attn_indexes=[5, 11, 17, 23],
|
48 |
+
)
|
49 |
+
elif "sam_hq_vit_h" in model_name:
|
50 |
+
vision_config = SamVisionConfig(
|
51 |
+
hidden_size=1280,
|
52 |
+
num_hidden_layers=32,
|
53 |
+
num_attention_heads=16,
|
54 |
+
global_attn_indexes=[7, 15, 23, 31],
|
55 |
+
)
|
56 |
+
|
57 |
+
config = SamHQConfig(
|
58 |
+
vision_config=vision_config,
|
59 |
+
)
|
60 |
+
|
61 |
+
return config
|
62 |
+
|
63 |
+
|
64 |
+
KEYS_TO_MODIFY_MAPPING = {
|
65 |
+
# Vision Encoder
|
66 |
+
"image_encoder": "vision_encoder",
|
67 |
+
"patch_embed.proj": "patch_embed.projection",
|
68 |
+
"blocks.": "layers.",
|
69 |
+
"neck.0": "neck.conv1",
|
70 |
+
"neck.1": "neck.layer_norm1",
|
71 |
+
"neck.2": "neck.conv2",
|
72 |
+
"neck.3": "neck.layer_norm2",
|
73 |
+
|
74 |
+
# Prompt Encoder
|
75 |
+
"mask_downscaling.0": "mask_embed.conv1",
|
76 |
+
"mask_downscaling.1": "mask_embed.layer_norm1",
|
77 |
+
"mask_downscaling.3": "mask_embed.conv2",
|
78 |
+
"mask_downscaling.4": "mask_embed.layer_norm2",
|
79 |
+
"mask_downscaling.6": "mask_embed.conv3",
|
80 |
+
"point_embeddings": "point_embed",
|
81 |
+
"pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
|
82 |
+
|
83 |
+
# Mask Decoder
|
84 |
+
"iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
|
85 |
+
"iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
|
86 |
+
"iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
|
87 |
+
"mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
|
88 |
+
"mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
|
89 |
+
"mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
|
90 |
+
".norm": ".layer_norm",
|
91 |
+
|
92 |
+
# SAM HQ Extra (in Mask Decoder)
|
93 |
+
"hf_mlp.layers.0": "hf_mlp.proj_in",
|
94 |
+
"hf_mlp.layers.1": "hf_mlp.layers.0",
|
95 |
+
"hf_mlp.layers.2": "hf_mlp.proj_out",
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
def replace_keys(state_dict):
|
100 |
+
model_state_dict = {}
|
101 |
+
state_dict.pop("pixel_mean", None)
|
102 |
+
state_dict.pop("pixel_std", None)
|
103 |
+
|
104 |
+
output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
|
105 |
+
|
106 |
+
for key, value in state_dict.items():
|
107 |
+
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
108 |
+
if key_to_modify in key:
|
109 |
+
key = key.replace(key_to_modify, new_key)
|
110 |
+
|
111 |
+
if re.match(output_hypernetworks_mlps_pattern, key):
|
112 |
+
layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
|
113 |
+
if layer_nb == 0:
|
114 |
+
key = key.replace("layers.0", "proj_in")
|
115 |
+
elif layer_nb == 1:
|
116 |
+
key = key.replace("layers.1", "layers.0")
|
117 |
+
elif layer_nb == 2:
|
118 |
+
key = key.replace("layers.2", "proj_out")
|
119 |
+
break
|
120 |
+
|
121 |
+
model_state_dict[key] = value.cpu()
|
122 |
+
|
123 |
+
model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
|
124 |
+
"prompt_encoder.shared_embedding.positional_embedding"
|
125 |
+
].cpu().clone()
|
126 |
+
|
127 |
+
return model_state_dict
|
128 |
+
|
129 |
+
|
130 |
+
def convert_sam_checkpoint(model_name, checkpoint_path, output_dir):
|
131 |
+
config = get_config(model_name)
|
132 |
+
|
133 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
134 |
+
state_dict = replace_keys(state_dict)
|
135 |
+
# print(state_dict.keys())
|
136 |
+
|
137 |
+
hf_model = SamHQModel(config)
|
138 |
+
hf_model.eval()
|
139 |
+
|
140 |
+
hf_model.load_state_dict(state_dict)
|
141 |
+
|
142 |
+
if output_dir is not None:
|
143 |
+
save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"})
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"]
|
149 |
+
parser.add_argument(
|
150 |
+
"--model_name",
|
151 |
+
default="sam_hq_vit_h",
|
152 |
+
choices=choices,
|
153 |
+
type=str,
|
154 |
+
help="Name of the original model to convert",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--checkpoint_path",
|
158 |
+
type=str,
|
159 |
+
required=False,
|
160 |
+
help="Path to the original checkpoint",
|
161 |
+
)
|
162 |
+
parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.")
|
163 |
+
|
164 |
+
args = parser.parse_args()
|
165 |
+
|
166 |
+
if args.checkpoint_path is not None:
|
167 |
+
checkpoint_path = args.checkpoint_path
|
168 |
+
else:
|
169 |
+
checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth")
|
170 |
+
print(checkpoint_path)
|
171 |
+
|
172 |
+
convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir)
|