Spanicin commited on
Commit
a60f0f9
·
verified ·
1 Parent(s): f6d22ad

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -145,50 +145,25 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
145
  # return predictions_ts
146
 
147
  import torch
148
- from torch.cuda import CUDAGraph
149
 
150
  def make_animation(source_image, source_semantics, target_semantics,
151
  generator, kp_detector, he_estimator, mapping,
152
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
- use_exp=True, device='cuda'):
 
 
 
154
  source_image = source_image.to(device)
155
  source_semantics = source_semantics.to(device)
156
  target_semantics = target_semantics.to(device)
157
 
158
- # Prepare for CUDA Graph capture
159
- with torch.no_grad():
160
  predictions = []
161
  kp_canonical = kp_detector(source_image)
162
  he_source = mapping(source_semantics)
163
  kp_source = keypoint_transformation(kp_canonical, he_source)
164
 
165
- # Use a non-default CUDA stream for graph capture
166
- capture_stream = torch.cuda.Stream()
167
- graph = CUDAGraph()
168
-
169
- # Warm-up to ensure proper graph capturing
170
- torch.cuda.synchronize()
171
-
172
- with torch.cuda.stream(capture_stream):
173
- target_semantics_frame = target_semantics[:, 0]
174
- he_driving = mapping(target_semantics_frame)
175
-
176
- if yaw_c_seq is not None:
177
- he_driving['yaw_in'] = yaw_c_seq[:, 0]
178
- if pitch_c_seq is not None:
179
- he_driving['pitch_in'] = pitch_c_seq[:, 0]
180
- if roll_c_seq is not None:
181
- he_driving['roll_in'] = roll_c_seq[:, 0]
182
-
183
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
184
- kp_norm = kp_driving
185
-
186
- # Begin capturing the graph
187
- graph.capture_begin()
188
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
189
- graph.capture_end()
190
-
191
- # Execute the graph on the default stream
192
  for frame_idx in range(target_semantics.shape[1]):
193
  target_semantics_frame = target_semantics[:, frame_idx]
194
  he_driving = mapping(target_semantics_frame)
@@ -203,22 +178,22 @@ def make_animation(source_image, source_semantics, target_semantics,
203
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
204
  kp_norm = kp_driving
205
 
206
- # Replay the captured graph
207
- with torch.cuda.stream(torch.cuda.current_stream()):
208
  out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
209
 
210
  predictions.append(out['prediction'])
211
-
212
- # Optional: Explicitly synchronize if needed
213
  torch.cuda.synchronize()
214
 
 
215
  predictions_ts = torch.stack(predictions, dim=1)
216
 
217
  return predictions_ts
218
 
219
 
220
 
221
-
222
  class AnimateModel(torch.nn.Module):
223
  """
224
  Merge all generator related updates into single model for better multi-gpu usage
 
145
  # return predictions_ts
146
 
147
  import torch
148
+ from torch.cuda.amp import autocast
149
 
150
  def make_animation(source_image, source_semantics, target_semantics,
151
  generator, kp_detector, he_estimator, mapping,
152
  yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
153
+ use_exp=True):
154
+
155
+ device='cuda'
156
+ # Move inputs to GPU
157
  source_image = source_image.to(device)
158
  source_semantics = source_semantics.to(device)
159
  target_semantics = target_semantics.to(device)
160
 
161
+ with torch.no_grad(): # No gradients needed
 
162
  predictions = []
163
  kp_canonical = kp_detector(source_image)
164
  he_source = mapping(source_semantics)
165
  kp_source = keypoint_transformation(kp_canonical, he_source)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  for frame_idx in range(target_semantics.shape[1]):
168
  target_semantics_frame = target_semantics[:, frame_idx]
169
  he_driving = mapping(target_semantics_frame)
 
178
  kp_driving = keypoint_transformation(kp_canonical, he_driving)
179
  kp_norm = kp_driving
180
 
181
+ # Use mixed precision for faster computation
182
+ with autocast():
183
  out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
184
 
185
  predictions.append(out['prediction'])
186
+
187
+ # Optional: Explicitly synchronize (use only if necessary)
188
  torch.cuda.synchronize()
189
 
190
+ # Stack predictions into a single tensor
191
  predictions_ts = torch.stack(predictions, dim=1)
192
 
193
  return predictions_ts
194
 
195
 
196
 
 
197
  class AnimateModel(torch.nn.Module):
198
  """
199
  Merge all generator related updates into single model for better multi-gpu usage