Spaces:
Runtime error
Runtime error
File size: 1,742 Bytes
eba1c6b ad62063 7f268fe eba1c6b ad62063 eba1c6b 459c031 eba1c6b 66432b9 eba1c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.modelNetA import Generator as GA
from models.modelNetB import Generator as GB
from models.modelNetC import Generator as GC
# DEVICE='cpu'
DEVICE='cuda'
model_type = 'model_c'
modeltype2path = {
'model_a': 'DTM_exp_train10%_model_a/g-best.pth',
'model_b': 'DTM_exp_train10%_model_b/g-best.pth',
'model_c': 'DTM_exp_train10%_model_c/g-best.pth',
}
if model_type == 'model_a':
generator = GA()
if model_type == 'model_b':
generator = GB()
if model_type == 'model_c':
generator = GC()
generator = torch.nn.DataParallel(generator)
state_dict_Gen = torch.load(modeltype2path[model_type], map_location=torch.device('cpu'))
generator.load_state_dict(state_dict_Gen)
generator = generator.module.to(DEVICE)
# generator.to(DEVICE)
generator.eval()
preprocess = transforms.Compose([
transforms.Grayscale(),
# transforms.Resize((128, 128)),
transforms.ToTensor()
])
input_img = Image.open('demo_imgs/fake.jpg')
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
with torch.no_grad():
output = generator(torch_img)
sr, sr_dem_selected = output[0], output[1]
sr = sr.squeeze(0).cpu()
print(sr.shape)
torchvision.utils.save_image(sr, 'sr.png')
# sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L')
# sr.save('sr2.png')
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
print(sr_dem_selected.shape)
plt.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
plt.colorbar()
plt.savefig('test.png') |