Spaces:
Running
Running
File size: 2,267 Bytes
564565f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
from pathlib import Path
import hydra
import torch
import yaml
from omegaconf import OmegaConf
from torch import nn
from saicinpainting.training.trainers import load_checkpoint
from saicinpainting.utils import register_debug_signal_handlers
class JITWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, image, mask):
batch = {
"image": image,
"mask": mask
}
out = self.model(batch)
return out["inpainted"]
@hydra.main(config_path="../configs/prediction", config_name="default.yaml")
def main(predict_config: OmegaConf):
if sys.platform != 'win32':
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
with open(train_config_path, "r") as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = "noop"
checkpoint_path = os.path.join(
predict_config.model.path, "models", predict_config.model.checkpoint
)
model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location="cpu"
)
model.eval()
jit_model_wrapper = JITWrapper(model)
image = torch.rand(1, 3, 120, 120)
mask = torch.rand(1, 1, 120, 120)
output = jit_model_wrapper(image, mask)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
image = image.to(device)
mask = mask.to(device)
traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device)
save_path = Path(predict_config.save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Saving big-lama.pt model to {save_path}")
traced_model.save(save_path)
print(f"Checking jit model output...")
jit_model = torch.jit.load(str(save_path))
jit_output = jit_model(image, mask)
diff = (output - jit_output).abs().sum()
print(f"diff: {diff}")
if __name__ == "__main__":
main()
|