Yarflam commited on
Commit
94a0c52
·
1 Parent(s): 4da6292

Fix net CUDA

Browse files
Files changed (1) hide show
  1. models/networks.py +1 -1
models/networks.py CHANGED
@@ -208,7 +208,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, i
208
  """
209
  if len(gpu_ids) > 0:
210
  assert(torch.cuda.is_available())
211
- net.to(gpu_ids[0])
212
  # if not amp:
213
  # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
214
  if initialize_weights:
 
208
  """
209
  if len(gpu_ids) > 0:
210
  assert(torch.cuda.is_available())
211
+ net.to('cuda:{}'.format(gpu_ids[0]))
212
  # if not amp:
213
  # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
214
  if initialize_weights: