vilarin commited on
Commit
4c6dd33
·
verified ·
1 Parent(s): 75fcca2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -97,7 +97,7 @@ class ModelWrapper:
97
 
98
  for constant in all_timesteps:
99
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
100
- current_timesteps = current_timesteps.to(torch.float16)
101
  print(f'current_timestpes: {current_timesteps.dtype}')
102
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
103
  print(type(eval_images))
@@ -135,8 +135,8 @@ class ModelWrapper:
135
  prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)
136
 
137
  batch_prompt_embeds, batch_pooled_prompt_embeds = (
138
- prompt_embeds.repeat(num_images, 1, 1).to(torch.float16),
139
- pooled_prompt_embeds.repeat(num_images, 1, 1).to(torch.float16)
140
  )
141
 
142
  unet_added_conditions = {
 
97
 
98
  for constant in all_timesteps:
99
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
100
+ #current_timesteps = current_timesteps.to(torch.float16)
101
  print(f'current_timestpes: {current_timesteps.dtype}')
102
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
103
  print(type(eval_images))
 
135
  prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)
136
 
137
  batch_prompt_embeds, batch_pooled_prompt_embeds = (
138
+ prompt_embeds.repeat(num_images, 1, 1),
139
+ pooled_prompt_embeds.repeat(num_images, 1, 1)
140
  )
141
 
142
  unet_added_conditions = {