|
import pickle |
|
from io import BytesIO |
|
from collections import OrderedDict |
|
import os |
|
|
|
import torch |
|
|
|
|
|
def load_pickle(path: str): |
|
with open(path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
|
|
def save_pickle(ckpt: dict, save_path: str): |
|
with open(save_path, "wb") as f: |
|
pickle.dump(ckpt, f) |
|
|
|
|
|
def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False): |
|
parm = torch.load(path, map_location=torch.device("cpu")) |
|
for key in parm.keys(): |
|
parm[key] = parm[key].to(device) |
|
if is_half and parm[key].dtype == torch.float32: |
|
parm[key] = parm[key].half() |
|
elif not is_half and parm[key].dtype == torch.float16: |
|
parm[key] = parm[key].float() |
|
return parm |
|
|
|
|
|
def export_jit_model( |
|
model: torch.nn.Module, |
|
mode: str = "trace", |
|
inputs: dict = None, |
|
device=torch.device("cpu"), |
|
is_half: bool = False, |
|
) -> dict: |
|
model = model.half() if is_half else model.float() |
|
model.eval() |
|
if mode == "trace": |
|
assert inputs is not None |
|
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) |
|
elif mode == "script": |
|
model_jit = torch.jit.script(model) |
|
model_jit.to(device) |
|
model_jit = model_jit.half() if is_half else model_jit.float() |
|
buffer = BytesIO() |
|
|
|
torch.jit.save(model_jit, buffer) |
|
del model_jit |
|
cpt = OrderedDict() |
|
cpt["model"] = buffer.getvalue() |
|
cpt["is_half"] = is_half |
|
return cpt |
|
|
|
|
|
def get_jit_model(model_path: str, is_half: bool, device: str, exporter): |
|
jit_model_path = model_path.rstrip(".pth") |
|
jit_model_path += ".half.jit" if is_half else ".jit" |
|
ckpt = None |
|
|
|
if os.path.exists(jit_model_path): |
|
ckpt = load_pickle(jit_model_path) |
|
model_device = ckpt["device"] |
|
if model_device != str(device): |
|
del ckpt |
|
ckpt = None |
|
|
|
if ckpt is None: |
|
ckpt = exporter( |
|
model_path=model_path, |
|
mode="script", |
|
inputs_path=None, |
|
save_path=jit_model_path, |
|
device=device, |
|
is_half=is_half, |
|
) |
|
|
|
return ckpt |
|
|