vilarin commited on
Commit
9e4e479
·
verified ·
1 Parent(s): b87d3cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -94,6 +94,7 @@ class ModelWrapper:
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
 
97
  print(current_timesteps.dtype)
98
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
99
  print(type(eval_images))
 
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
+ current_timesteps = current_timesteps.to(torch.float32)
98
  print(current_timesteps.dtype)
99
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))