CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame contribute delete
No virus
2.12 kB
import argparse
import os
from typing import Tuple
import omegaconf
import torch
from relik.common.utils import from_cache
from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
from relik.reader.relik_reader_core import RelikReaderCoreModel
CKPT_FILE_NAME = "model.ckpt"
CONFIG_FILE_NAME = "cfg.yaml"
def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
print(f"{output_dir} already exists, aborting operation")
exit(1)
relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint(
pl_module_ckpt_path
)
torch.save(
relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}"
)
with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f:
omegaconf.OmegaConf.save(
omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f
)
def load_model_and_conf(
model_dir_path: str,
) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]:
# TODO: quick workaround to load the model from HF hub
model_dir = from_cache(
model_dir_path,
filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME],
cache_dir=None,
force_download=False,
)
ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}"
model = torch.load(ckpt_path, map_location=torch.device("cpu"))
model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}"
model_conf = omegaconf.OmegaConf.load(model_cfg_path)
return model, model_conf
def parse_arg() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
help="Path to the pytorch lightning ckpt you want to convert.",
required=True,
)
parser.add_argument(
"--output-dir",
"-o",
help="The output dir to store the bare models and the config.",
required=True,
)
return parser.parse_args()
def main():
args = parse_arg()
convert_pl_module(args.ckpt, args.output_dir)
if __name__ == "__main__":
main()