Spaces:
Build error
Build error
from collections import OrderedDict | |
from text.symbols import symbols | |
import torch | |
from tools.log import logger | |
import utils | |
from models import SynthesizerTrn | |
import os | |
def copyStateDict(state_dict): | |
if list(state_dict.keys())[0].startswith("module"): | |
start_idx = 1 | |
else: | |
start_idx = 0 | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = ",".join(k.split(".")[start_idx:]) | |
new_state_dict[name] = v | |
return new_state_dict | |
def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): | |
hps = utils.get_hparams_from_file(config) | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
n_speakers=hps.data.n_speakers, | |
**hps.model, | |
) | |
optim_g = torch.optim.AdamW( | |
net_g.parameters(), | |
hps.train.learning_rate, | |
betas=hps.train.betas, | |
eps=hps.train.eps, | |
) | |
state_dict_g = torch.load(input_model, map_location="cpu") | |
new_dict_g = copyStateDict(state_dict_g) | |
keys = [] | |
for k, v in new_dict_g["model"].items(): | |
if "enc_q" in k: | |
continue # noqa: E701 | |
keys.append(k) | |
new_dict_g = ( | |
{k: new_dict_g["model"][k].half() for k in keys} | |
if ishalf | |
else {k: new_dict_g["model"][k] for k in keys} | |
) | |
torch.save( | |
{ | |
"model": new_dict_g, | |
"iteration": 0, | |
"optimizer": optim_g.state_dict(), | |
"learning_rate": 0.0001, | |
}, | |
output_model, | |
) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-c", "--config", type=str, default="configs/config.json") | |
parser.add_argument("-i", "--input", type=str) | |
parser.add_argument("-o", "--output", type=str, default=None) | |
parser.add_argument( | |
"-hf", "--half", action="store_true", default=False, help="Save as FP16" | |
) | |
args = parser.parse_args() | |
output = args.output | |
if output is None: | |
import os.path | |
filename, ext = os.path.splitext(args.input) | |
half = "_half" if args.half else "" | |
output = filename + "_release" + half + ext | |
removeOptimizer(args.config, args.input, args.half, output) | |
logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}") | |