import os from pathlib import Path import hydra import torch import yaml from omegaconf import OmegaConf from torch import nn from saicinpainting.training.trainers import load_checkpoint from saicinpainting.utils import register_debug_signal_handlers class JITWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, image, mask): batch = { "image": image, "mask": mask } out = self.model(batch) return out["inpainted"] @hydra.main(config_path="../configs/prediction", config_name="default.yaml") def main(predict_config: OmegaConf): if sys.platform != 'win32': register_debug_signal_handlers() # kill -10 will result in traceback dumped into log train_config_path = os.path.join(predict_config.model.path, "config.yaml") with open(train_config_path, "r") as f: train_config = OmegaConf.create(yaml.safe_load(f)) train_config.training_model.predict_only = True train_config.visualizer.kind = "noop" checkpoint_path = os.path.join( predict_config.model.path, "models", predict_config.model.checkpoint ) model = load_checkpoint( train_config, checkpoint_path, strict=False, map_location="cpu" ) model.eval() jit_model_wrapper = JITWrapper(model) image = torch.rand(1, 3, 120, 120) mask = torch.rand(1, 1, 120, 120) output = jit_model_wrapper(image, mask) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") image = image.to(device) mask = mask.to(device) traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device) save_path = Path(predict_config.save_path) save_path.parent.mkdir(parents=True, exist_ok=True) print(f"Saving big-lama.pt model to {save_path}") traced_model.save(save_path) print(f"Checking jit model output...") jit_model = torch.jit.load(str(save_path)) jit_output = jit_model(image, mask) diff = (output - jit_output).abs().sum() print(f"diff: {diff}") if __name__ == "__main__": main()