Spanicin commited on
Commit
687b854
·
verified ·
1 Parent(s): 1e1159f

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -148,10 +148,15 @@ def make_animation(source_image, source_semantics, target_semantics,
148
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
149
  use_exp=True):
150
 
151
- device='cuda'
152
- generator = torch.nn.DataParallel(generator).to(device)
153
- kp_detector = torch.nn.DataParallel(kp_detector).to(device)
154
- mapping = torch.nn.DataParallel(mapping).to(device)
 
 
 
 
 
155
 
156
  source_image = source_image.to(device)
157
  source_semantics = source_semantics.to(device)
 
148
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
149
  use_exp=True):
150
 
151
+ device = 'cuda:0'
152
+ generator = generator.to(device)
153
+ kp_detector = kp_detector.to(device)
154
+ mapping = mapping.to(device)
155
+
156
+ # Wrap the models in DataParallel to use all available GPUs
157
+ generator = torch.nn.DataParallel(generator)
158
+ kp_detector = torch.nn.DataParallel(kp_detector)
159
+ mapping = torch.nn.DataParallel(mapping)
160
 
161
  source_image = source_image.to(device)
162
  source_semantics = source_semantics.to(device)