vilarin commited on
Commit
d79f1a1
·
verified ·
1 Parent(s): d290faa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -96,13 +96,13 @@ class ModelWrapper:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
  current_timesteps = current_timesteps.to(torch.float16)
98
  print(f'current_timestpes: {current_timesteps.dtype}')
99
- eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
102
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
103
 
104
  next_timestep = current_timesteps - step_interval
105
- noise = self.scheduler.add_noise(eval_images, torch.randn_like(eval_images), next_timestep).to(DTYPE)
106
 
107
  if fast_vae_decode:
108
  eval_images = self.tiny_vae.decode(eval_images.to(self.tiny_vae_dtype) / self.tiny_vae.config.scaling_factor, return_dict=False)[0]
@@ -138,7 +138,7 @@ class ModelWrapper:
138
 
139
  unet_added_conditions = {
140
  "time_ids": add_time_ids,
141
- "text_embeds": batch_pooled_prompt_embeds.squeeze(1)
142
  }
143
 
144
 
 
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
  current_timesteps = current_timesteps.to(torch.float16)
98
  print(f'current_timestpes: {current_timesteps.dtype}')
99
+ eval_images = self.model(noise.to(torch.float16), current_timesteps, prompt_embed.to(torch.float16), added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
102
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
103
 
104
  next_timestep = current_timesteps - step_interval
105
+ noise = self.scheduler.add_noise(eval_images, torch.randn_like(eval_images), next_timestep).to(torch.float16)
106
 
107
  if fast_vae_decode:
108
  eval_images = self.tiny_vae.decode(eval_images.to(self.tiny_vae_dtype) / self.tiny_vae.config.scaling_factor, return_dict=False)[0]
 
138
 
139
  unet_added_conditions = {
140
  "time_ids": add_time_ids,
141
+ "text_embeds": batch_pooled_prompt_embeds.squeeze(1).to(torch.float16)
142
  }
143
 
144