Spaces:
Paused
Paused
Fix net CUDA
Browse files- models/networks.py +1 -1
models/networks.py
CHANGED
@@ -208,7 +208,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, i
|
|
208 |
"""
|
209 |
if len(gpu_ids) > 0:
|
210 |
assert(torch.cuda.is_available())
|
211 |
-
net.to(gpu_ids[0])
|
212 |
# if not amp:
|
213 |
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
214 |
if initialize_weights:
|
|
|
208 |
"""
|
209 |
if len(gpu_ids) > 0:
|
210 |
assert(torch.cuda.is_available())
|
211 |
+
net.to('cuda:{}'.format(gpu_ids[0]))
|
212 |
# if not amp:
|
213 |
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
214 |
if initialize_weights:
|