pqt commited on
Commit
26d60e1
1 Parent(s): 5bdf4bb

srgan: enable model to run with gpu

Browse files
Files changed (1) hide show
  1. models/SRGAN/srgan.py +3 -2
models/SRGAN/srgan.py CHANGED
@@ -75,12 +75,13 @@ class GeneratorResnet(nn.Module):
75
 
76
  if __name__ == '__main__':
77
  current_dir = os.path.dirname(os.path.realpath(__file__))
78
-
79
  model = GeneratorResnet()
80
- model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu'))
81
  model.eval()
82
  with torch.no_grad():
83
  input_image = Image.open('images/demo.png')
84
  input_image = ToTensor()(input_image).unsqueeze(0)
 
85
  output_image = model.test(input_image)
86
  print(output_image.max())
 
75
 
76
  if __name__ == '__main__':
77
  current_dir = os.path.dirname(os.path.realpath(__file__))
78
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
  model = GeneratorResnet()
80
+ model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu')).to(DEVICE)
81
  model.eval()
82
  with torch.no_grad():
83
  input_image = Image.open('images/demo.png')
84
  input_image = ToTensor()(input_image).unsqueeze(0)
85
+ input_image = input_image.to(DEVICE)
86
  output_image = model.test(input_image)
87
  print(output_image.max())