Spanicin commited on
Commit
3008a78
·
verified ·
1 Parent(s): c9c856d

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -99,94 +99,94 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
99
  return {'value': kp_transformed}
100
 
101
 
102
- def make_animation(source_image, source_semantics, target_semantics,
103
- generator, kp_detector, he_estimator, mapping,
104
- yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
- use_exp=True):
106
- with torch.no_grad():
107
- predictions = []
108
-
109
- kp_canonical = kp_detector(source_image)
110
- he_source = mapping(source_semantics)
111
- kp_source = keypoint_transformation(kp_canonical, he_source)
112
-
113
-
114
- for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
115
- target_semantics_frame = target_semantics[:, frame_idx]
116
- he_driving = mapping(target_semantics_frame)
117
- if yaw_c_seq is not None:
118
- he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
119
- if pitch_c_seq is not None:
120
- he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
121
- if roll_c_seq is not None:
122
- he_driving['roll_in'] = roll_c_seq[:, frame_idx]
123
-
124
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
125
-
126
- #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
127
- #kp_driving_initial=kp_driving_initial)
128
- kp_norm = kp_driving
129
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
130
- '''
131
- source_image_new = out['prediction'].squeeze(1)
132
- kp_canonical_new = kp_detector(source_image_new)
133
- he_source_new = he_estimator(source_image_new)
134
- kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
135
- kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
136
- out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
137
- '''
138
- predictions.append(out['prediction'])
139
- torch.cuda.empty_cache()
140
- predictions_ts = torch.stack(predictions, dim=1)
141
- return predictions_ts
142
-
143
- # import torch
144
- # from torch.cuda.amp import autocast
145
-
146
  # def make_animation(source_image, source_semantics, target_semantics,
147
- # generator, kp_detector, he_estimator, mapping,
148
- # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
149
- # use_exp=True):
150
-
151
- # # device='cuda'
152
- # # # Move inputs to GPU
153
- # # source_image = source_image.to(device)
154
- # # source_semantics = source_semantics.to(device)
155
- # # target_semantics = target_semantics.to(device)
156
-
157
- # with torch.no_grad(): # No gradients needed
158
  # predictions = []
 
159
  # kp_canonical = kp_detector(source_image)
160
  # he_source = mapping(source_semantics)
161
  # kp_source = keypoint_transformation(kp_canonical, he_source)
162
 
163
- # for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
 
164
  # target_semantics_frame = target_semantics[:, frame_idx]
165
  # he_driving = mapping(target_semantics_frame)
166
-
167
  # if yaw_c_seq is not None:
168
  # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
169
  # if pitch_c_seq is not None:
170
- # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
171
  # if roll_c_seq is not None:
172
- # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
173
-
174
  # kp_driving = keypoint_transformation(kp_canonical, he_driving)
 
 
 
175
  # kp_norm = kp_driving
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # # Use mixed precision for faster computation
178
- # with autocast():
179
- # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
 
 
180
 
181
- # predictions.append(out['prediction'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # # Optional: Explicitly synchronize (use only if necessary)
184
- # torch.cuda.synchronize()
185
 
186
- # # Stack predictions into a single tensor
187
- # predictions_ts = torch.stack(predictions, dim=1)
188
 
189
- # return predictions_ts
190
 
191
 
192
  class AnimateModel(torch.nn.Module):
 
99
  return {'value': kp_transformed}
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  # def make_animation(source_image, source_semantics, target_semantics,
103
+ # generator, kp_detector, he_estimator, mapping,
104
+ # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
+ # use_exp=True):
106
+ # with torch.no_grad():
 
 
 
 
 
 
 
107
  # predictions = []
108
+
109
  # kp_canonical = kp_detector(source_image)
110
  # he_source = mapping(source_semantics)
111
  # kp_source = keypoint_transformation(kp_canonical, he_source)
112
 
113
+
114
+ # for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
115
  # target_semantics_frame = target_semantics[:, frame_idx]
116
  # he_driving = mapping(target_semantics_frame)
 
117
  # if yaw_c_seq is not None:
118
  # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
119
  # if pitch_c_seq is not None:
120
+ # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
121
  # if roll_c_seq is not None:
122
+ # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
123
+
124
  # kp_driving = keypoint_transformation(kp_canonical, he_driving)
125
+
126
+ # #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
127
+ # #kp_driving_initial=kp_driving_initial)
128
  # kp_norm = kp_driving
129
+ # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
130
+ # '''
131
+ # source_image_new = out['prediction'].squeeze(1)
132
+ # kp_canonical_new = kp_detector(source_image_new)
133
+ # he_source_new = he_estimator(source_image_new)
134
+ # kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
135
+ # kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
136
+ # out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
137
+ # '''
138
+ # predictions.append(out['prediction'])
139
+ # torch.cuda.empty_cache()
140
+ # predictions_ts = torch.stack(predictions, dim=1)
141
+ # return predictions_ts
142
+
143
+ import torch
144
+ from torch.cuda.amp import autocast
145
+
146
+ def make_animation(source_image, source_semantics, target_semantics,
147
+ generator, kp_detector, he_estimator, mapping,
148
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
149
+ use_exp=True):
150
 
151
+ # device='cuda'
152
+ # # Move inputs to GPU
153
+ # source_image = source_image.to(device)
154
+ # source_semantics = source_semantics.to(device)
155
+ # target_semantics = target_semantics.to(device)
156
 
157
+ with torch.no_grad(): # No gradients needed
158
+ predictions = []
159
+ kp_canonical = kp_detector(source_image)
160
+ he_source = mapping(source_semantics)
161
+ kp_source = keypoint_transformation(kp_canonical, he_source)
162
+
163
+ for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
164
+ target_semantics_frame = target_semantics[:, frame_idx]
165
+ he_driving = mapping(target_semantics_frame)
166
+
167
+ if yaw_c_seq is not None:
168
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
169
+ if pitch_c_seq is not None:
170
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
171
+ if roll_c_seq is not None:
172
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
173
+
174
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
175
+ kp_norm = kp_driving
176
+
177
+ # Use mixed precision for faster computation
178
+ with autocast():
179
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
180
+
181
+ predictions.append(out['prediction'])
182
 
183
+ # Optional: Explicitly synchronize (use only if necessary)
184
+ torch.cuda.synchronize()
185
 
186
+ # Stack predictions into a single tensor
187
+ predictions_ts = torch.stack(predictions, dim=1)
188
 
189
+ return predictions_ts
190
 
191
 
192
  class AnimateModel(torch.nn.Module):