|
import argparse |
|
import os |
|
from typing import Tuple |
|
|
|
import omegaconf |
|
import torch |
|
|
|
from relik.common.utils import from_cache |
|
from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule |
|
from relik.reader.relik_reader_core import RelikReaderCoreModel |
|
|
|
CKPT_FILE_NAME = "model.ckpt" |
|
CONFIG_FILE_NAME = "cfg.yaml" |
|
|
|
|
|
def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
else: |
|
print(f"{output_dir} already exists, aborting operation") |
|
exit(1) |
|
|
|
relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( |
|
pl_module_ckpt_path |
|
) |
|
torch.save( |
|
relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" |
|
) |
|
with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: |
|
omegaconf.OmegaConf.save( |
|
omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f |
|
) |
|
|
|
|
|
def load_model_and_conf( |
|
model_dir_path: str, |
|
) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]: |
|
|
|
model_dir = from_cache( |
|
model_dir_path, |
|
filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], |
|
cache_dir=None, |
|
force_download=False, |
|
) |
|
|
|
ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" |
|
model = torch.load(ckpt_path, map_location=torch.device("cpu")) |
|
|
|
model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" |
|
model_conf = omegaconf.OmegaConf.load(model_cfg_path) |
|
return model, model_conf |
|
|
|
|
|
def parse_arg() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--ckpt", |
|
help="Path to the pytorch lightning ckpt you want to convert.", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
"-o", |
|
help="The output dir to store the bare models and the config.", |
|
required=True, |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_arg() |
|
convert_pl_module(args.ckpt, args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|