Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -1121,15 +1121,24 @@ class StreamMultiDiffusion(nn.Module):
|
|
1121 |
else:
|
1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1123 |
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1133 |
print('222222222222222', model_pred.dtype)
|
1134 |
|
1135 |
if self.bootstrap_steps[0] > 0:
|
|
|
1121 |
else:
|
1122 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1123 |
|
1124 |
+
ns = []
|
1125 |
+
c1, c2, c3 = 0, 0, 0
|
1126 |
+
for n, p in self.unet.named_parameters():
|
1127 |
+
if p.data.dtype == torch.float:
|
1128 |
+
c1 += 1
|
1129 |
+
ns.append(n)
|
1130 |
+
elif p.data.dtype == torch.half:
|
1131 |
+
c2 += 1
|
1132 |
+
else:
|
1133 |
+
c3 += 1
|
1134 |
+
print(c1, c2, c3)
|
1135 |
+
print(ns)
|
1136 |
+
model_pred = self.unet(
|
1137 |
+
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
1138 |
+
t_list, # (B,)
|
1139 |
+
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1140 |
+
return_dict=False,
|
1141 |
+
)[0] # (B, 4, h, w)
|
1142 |
print('222222222222222', model_pred.dtype)
|
1143 |
|
1144 |
if self.bootstrap_steps[0] > 0:
|