Spaces:
Running
Running
srgan: enable model to run with gpu
Browse files- models/SRGAN/srgan.py +3 -2
models/SRGAN/srgan.py
CHANGED
@@ -75,12 +75,13 @@ class GeneratorResnet(nn.Module):
|
|
75 |
|
76 |
if __name__ == '__main__':
|
77 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
78 |
-
|
79 |
model = GeneratorResnet()
|
80 |
-
model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu'))
|
81 |
model.eval()
|
82 |
with torch.no_grad():
|
83 |
input_image = Image.open('images/demo.png')
|
84 |
input_image = ToTensor()(input_image).unsqueeze(0)
|
|
|
85 |
output_image = model.test(input_image)
|
86 |
print(output_image.max())
|
|
|
75 |
|
76 |
if __name__ == '__main__':
|
77 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
78 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
79 |
model = GeneratorResnet()
|
80 |
+
model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu')).to(DEVICE)
|
81 |
model.eval()
|
82 |
with torch.no_grad():
|
83 |
input_image = Image.open('images/demo.png')
|
84 |
input_image = ToTensor()(input_image).unsqueeze(0)
|
85 |
+
input_image = input_image.to(DEVICE)
|
86 |
output_image = model.test(input_image)
|
87 |
print(output_image.max())
|