from collections import OrderedDict import torch from .layers.synthesizers import SynthesizerTrnMsNSFsid from .jit import load_inputs, export_jit_model, save_pickle def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")): cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") if version == "v1": encoder_dim = 256 elif version == "v2": encoder_dim = 768 net_g = SynthesizerTrnMsNSFsid( *cpt["config"], encoder_dim=encoder_dim, use_f0=if_f0 == 1, ) del net_g.enc_q net_g.load_state_dict(cpt["weight"], strict=False) net_g = net_g.float() net_g.eval().to(device) net_g.remove_weight_norm() return net_g, cpt def load_synthesizer( pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu") ): return get_synthesizer( torch.load(pth_path, map_location=torch.device("cpu")), device, ) def synthesizer_jit_export( model_path: str, mode: str = "script", inputs_path: str = None, save_path: str = None, device=torch.device("cpu"), is_half=False, ): if not save_path: save_path = model_path.rstrip(".pth") save_path += ".half.jit" if is_half else ".jit" if "cuda" in str(device) and ":" not in str(device): device = torch.device("cuda:0") from rvc.synthesizer import load_synthesizer model, cpt = load_synthesizer(model_path, device) assert isinstance(cpt, dict) model.forward = model.infer inputs = None if mode == "trace": inputs = load_inputs(inputs_path, device, is_half) ckpt = export_jit_model(model, mode, inputs, device, is_half) cpt.pop("weight") cpt["model"] = ckpt["model"] cpt["device"] = device save_pickle(cpt, save_path) return cpt