from collections import OrderedDict import torch from models.model import GLPDepth from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import numpy as np DEVICE='cpu' def load_mde_model(path): model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) model_weight = torch.load(path, map_location=torch.device('cpu')) model_weight = model_weight['model_state_dict'] if 'module' in next(iter(model_weight.items()))[0]: model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) model.load_state_dict(model_weight) model.eval() return model model = load_mde_model('best_model.ckpt') preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) input_img = Image.open('demo_imgs/fake.jpg') torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0) with torch.no_grad(): output_patch = model(torch_img) output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy() print(output_patch.shape) plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch)) plt.colorbar() plt.savefig('test.png')