|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
from transformers import ( |
|
AddedToken, |
|
AutoConfig, |
|
AutoTokenizer, |
|
) |
|
from configuration_llava import LlavaConfig |
|
from modeling_llava import LlavaForConditionalGeneration |
|
|
|
|
|
KEYS_TO_MODIFY_MAPPING = { |
|
"transformer.vision_tower.vision_tower": "vision_model", |
|
"transformer.mm_projector": "multi_modal_projector", |
|
"transformer": "language_model.transformer", |
|
"lm_head": "language_model.lm_head", |
|
"model.model": "language_model.transformer", |
|
"multi_modal_projector.0": "multi_modal_projector.linear_1", |
|
"multi_modal_projector.2": "multi_modal_projector.linear_2", |
|
} |
|
|
|
|
|
def convert_state_dict_to_hf(state_dict): |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |
|
if key_to_modify in key: |
|
key = key.replace(key_to_modify, new_key) |
|
|
|
new_state_dict[key] = value |
|
return new_state_dict |
|
|
|
|
|
def convert_llava_llama_to_hf(text_model_id, vision_model_id, projector_tokens_num, output_path, old_state_dict_path): |
|
torch.set_default_dtype(torch.float16) |
|
text_config = AutoConfig.from_pretrained(text_model_id, trust_remote_code=True) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(text_model_id) |
|
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True) |
|
tokenizer.add_special_tokens({"pad_token": "<pad>"}) |
|
|
|
config = LlavaConfig(text_config=text_config, vocab_size=51200, vision_tower_name=vision_model_id, projector_tokens_num=projector_tokens_num) |
|
config.text_config.vocab_size = config.vocab_size |
|
|
|
with torch.device("cuda"): |
|
model = LlavaForConditionalGeneration(config) |
|
|
|
state_dict = torch.load(old_state_dict_path, map_location="cpu") |
|
state_dict = convert_state_dict_to_hf(state_dict) |
|
model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
|
model.config.vocab_size = model.config.vocab_size |
|
model.config.text_config.vocab_size = model.config.text_config.vocab_size |
|
|
|
model.save_pretrained(output_path) |
|
tokenizer.save_pretrained(output_path) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--text_model_id", |
|
help="Hub location of the text model", |
|
) |
|
parser.add_argument( |
|
"--vision_model_id", |
|
help="Hub location of the vision model", |
|
) |
|
parser.add_argument( |
|
"--output_path", |
|
help="Location of the converted model", |
|
) |
|
parser.add_argument( |
|
"--old_state_dict_path", |
|
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", |
|
) |
|
parser.add_argument( |
|
"--tokens_num", |
|
type=int, |
|
default=1 |
|
) |
|
args = parser.parse_args() |
|
convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.tokens_num, args.output_path, args.old_state_dict_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |