HelpingAI-Vision / convert_model.py
Abhaykoul's picture
Upload folder using huggingface_hub
983f690 verified
raw
history blame
3.6 kB
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from transformers import (
AddedToken,
AutoConfig,
AutoTokenizer,
)
from configuration_llava import LlavaConfig
from modeling_llava import LlavaForConditionalGeneration
KEYS_TO_MODIFY_MAPPING = {
"transformer.vision_tower.vision_tower": "vision_model",
"transformer.mm_projector": "multi_modal_projector",
"transformer": "language_model.transformer",
"lm_head": "language_model.lm_head",
"model.model": "language_model.transformer",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
}
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def convert_llava_llama_to_hf(text_model_id, vision_model_id, projector_tokens_num, output_path, old_state_dict_path):
torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
config = LlavaConfig(text_config=text_config, vocab_size=51200, vision_tower_name=vision_model_id, projector_tokens_num=projector_tokens_num)
config.text_config.vocab_size = config.vocab_size
with torch.device("cuda"):
model = LlavaForConditionalGeneration(config)
state_dict = torch.load(old_state_dict_path, map_location="cpu")
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True)
model.config.vocab_size = model.config.vocab_size
model.config.text_config.vocab_size = model.config.text_config.vocab_size
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_path",
help="Location of the converted model",
)
parser.add_argument(
"--old_state_dict_path",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
parser.add_argument(
"--tokens_num",
type=int,
default=1
)
args = parser.parse_args()
convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.tokens_num, args.output_path, args.old_state_dict_path)
if __name__ == "__main__":
main()