Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,273 Bytes
69e8a46 28c720a 69e8a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import shutil
from copy import deepcopy
from pathlib import Path
import click
import hydra
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from fish_speech.models.text2semantic.llama import BaseTransformer
from fish_speech.models.text2semantic.lora import get_merged_state_dict
@click.command()
@click.option("--lora-config", type=str, default="r_8_alpha_16")
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
@click.option("--lora-weight", type=str, required=True)
@click.option("--output", type=str, required=True)
def merge(lora_config, base_weight, lora_weight, output):
output = Path(output)
logger.info(
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
)
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
cfg = compose(config_name=lora_config)
lora_config = instantiate(cfg)
logger.info(f"Loaded lora model with config {lora_config}")
llama_model = BaseTransformer.from_pretrained(
path=base_weight,
load_weights=True,
lora_config=lora_config,
)
logger.info(f"Loaded llama model")
llama_state_dict = llama_model.state_dict()
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
llama_state_dict_copy = deepcopy(llama_state_dict)
lora_state_dict = torch.load(lora_weight, map_location="cpu")
if "state_dict" in llama_state_dict:
llama_state_dict = llama_state_dict["state_dict"]
if "state_dict" in lora_state_dict:
lora_state_dict = lora_state_dict["state_dict"]
# remove prefix model.
if any(k.startswith("model.") for k in llama_state_dict.keys()):
llama_state_dict = {
k.replace("model.", ""): v
for k, v in llama_state_dict.items()
if k.startswith("model.")
}
if any(k.startswith("model.") for k in lora_state_dict.keys()):
lora_state_dict = {
k.replace("model.", ""): v
for k, v in lora_state_dict.items()
if k.startswith("model.")
}
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
merged_state_dict = llama_state_dict | lora_state_dict
llama_model.load_state_dict(merged_state_dict, strict=True)
logger.info(f"Merged model loaded")
# Trigger eval mode to merge lora
llama_model.eval()
llama_model.save_pretrained(output, drop_lora=True)
logger.info(f"Saved merged model to {output}, validating")
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
original_keys = set(llama_state_dict_copy.keys())
merged_keys = set(new_state_dict.keys())
assert original_keys == merged_keys, "Keys should be same"
for key in original_keys:
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
if diff_l1 != 0:
break
else:
logger.error("Merged model is same as the original model")
exit(1)
logger.info("Merged model is different from the original model, check passed")
if __name__ == "__main__":
merge()
|