Spaces:
Paused
Paused
from io import BytesIO | |
import pickle | |
import time | |
import torch | |
from tqdm import tqdm | |
from collections import OrderedDict | |
def load_inputs(path, device, is_half=False): | |
parm = torch.load(path, map_location=torch.device("cpu")) | |
for key in parm.keys(): | |
parm[key] = parm[key].to(device) | |
if is_half and parm[key].dtype == torch.float32: | |
parm[key] = parm[key].half() | |
elif not is_half and parm[key].dtype == torch.float16: | |
parm[key] = parm[key].float() | |
return parm | |
def benchmark( | |
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False | |
): | |
parm = load_inputs(inputs_path, device, is_half) | |
total_ts = 0.0 | |
bar = tqdm(range(epoch)) | |
for i in bar: | |
start_time = time.perf_counter() | |
o = model(**parm) | |
total_ts += time.perf_counter() - start_time | |
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") | |
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): | |
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) | |
def to_jit_model( | |
model_path, | |
model_type: str, | |
mode: str = "trace", | |
inputs_path: str = None, | |
device=torch.device("cpu"), | |
is_half=False, | |
): | |
model = None | |
if model_type.lower() == "synthesizer": | |
from .get_synthesizer import get_synthesizer | |
model, _ = get_synthesizer(model_path, device) | |
model.forward = model.infer | |
elif model_type.lower() == "rmvpe": | |
from .get_rmvpe import get_rmvpe | |
model = get_rmvpe(model_path, device) | |
elif model_type.lower() == "hubert": | |
from .get_hubert import get_hubert_model | |
model = get_hubert_model(model_path, device) | |
model.forward = model.infer | |
else: | |
raise ValueError(f"No model type named {model_type}") | |
model = model.eval() | |
model = model.half() if is_half else model.float() | |
if mode == "trace": | |
assert not inputs_path | |
inputs = load_inputs(inputs_path, device, is_half) | |
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) | |
elif mode == "script": | |
model_jit = torch.jit.script(model) | |
model_jit.to(device) | |
model_jit = model_jit.half() if is_half else model_jit.float() | |
# model = model.half() if is_half else model.float() | |
return (model, model_jit) | |
def export( | |
model: torch.nn.Module, | |
mode: str = "trace", | |
inputs: dict = None, | |
device=torch.device("cpu"), | |
is_half: bool = False, | |
) -> dict: | |
model = model.half() if is_half else model.float() | |
model.eval() | |
if mode == "trace": | |
assert inputs is not None | |
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) | |
elif mode == "script": | |
model_jit = torch.jit.script(model) | |
model_jit.to(device) | |
model_jit = model_jit.half() if is_half else model_jit.float() | |
buffer = BytesIO() | |
# model_jit=model_jit.cpu() | |
torch.jit.save(model_jit, buffer) | |
del model_jit | |
cpt = OrderedDict() | |
cpt["model"] = buffer.getvalue() | |
cpt["is_half"] = is_half | |
return cpt | |
def load(path: str): | |
with open(path, "rb") as f: | |
return pickle.load(f) | |
def save(ckpt: dict, save_path: str): | |
with open(save_path, "wb") as f: | |
pickle.dump(ckpt, f) | |
def rmvpe_jit_export( | |
model_path: str, | |
mode: str = "script", | |
inputs_path: str = None, | |
save_path: str = None, | |
device=torch.device("cpu"), | |
is_half=False, | |
): | |
if not save_path: | |
save_path = model_path.rstrip(".pth") | |
save_path += ".half.jit" if is_half else ".jit" | |
if "cuda" in str(device) and ":" not in str(device): | |
device = torch.device("cuda:0") | |
from .get_rmvpe import get_rmvpe | |
model = get_rmvpe(model_path, device) | |
inputs = None | |
if mode == "trace": | |
inputs = load_inputs(inputs_path, device, is_half) | |
ckpt = export(model, mode, inputs, device, is_half) | |
ckpt["device"] = str(device) | |
save(ckpt, save_path) | |
return ckpt | |
def synthesizer_jit_export( | |
model_path: str, | |
mode: str = "script", | |
inputs_path: str = None, | |
save_path: str = None, | |
device=torch.device("cpu"), | |
is_half=False, | |
): | |
if not save_path: | |
save_path = model_path.rstrip(".pth") | |
save_path += ".half.jit" if is_half else ".jit" | |
if "cuda" in str(device) and ":" not in str(device): | |
device = torch.device("cuda:0") | |
from .get_synthesizer import get_synthesizer | |
model, cpt = get_synthesizer(model_path, device) | |
assert isinstance(cpt, dict) | |
model.forward = model.infer | |
inputs = None | |
if mode == "trace": | |
inputs = load_inputs(inputs_path, device, is_half) | |
ckpt = export(model, mode, inputs, device, is_half) | |
cpt.pop("weight") | |
cpt["model"] = ckpt["model"] | |
cpt["device"] = device | |
save(cpt, save_path) | |
return cpt | |