vilarin commited on
Commit
b87d3cd
·
verified ·
1 Parent(s): b20a23b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -89,12 +89,12 @@ class ModelWrapper:
89
  else:
90
  raise NotImplementedError()
91
 
92
- prompt_embed = prompt_embed.to(dtype=torch.float16)
93
  DTYPE = prompt_embed.dtype
94
  print(DTYPE)
95
 
96
  for constant in all_timesteps:
97
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
 
98
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
99
  print(type(eval_images))
100
 
@@ -122,7 +122,7 @@ class ModelWrapper:
122
 
123
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
124
 
125
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float16)
126
 
127
  prompt_inputs = self._encode_prompt(prompt)
128
 
 
89
  else:
90
  raise NotImplementedError()
91
 
 
92
  DTYPE = prompt_embed.dtype
93
  print(DTYPE)
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
+ print(current_timesteps.dtype)
98
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
99
  print(type(eval_images))
100
 
 
122
 
123
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
124
 
125
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float32)
126
 
127
  prompt_inputs = self._encode_prompt(prompt)
128