vilarin commited on
Commit
d290faa
·
verified ·
1 Parent(s): 9e4e479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -90,12 +90,12 @@ class ModelWrapper:
90
  raise NotImplementedError()
91
 
92
  DTYPE = prompt_embed.dtype
93
- print(DTYPE)
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
- current_timesteps = current_timesteps.to(torch.float32)
98
- print(current_timesteps.dtype)
99
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
@@ -123,7 +123,7 @@ class ModelWrapper:
123
 
124
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
125
 
126
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float32)
127
 
128
  prompt_inputs = self._encode_prompt(prompt)
129
 
@@ -142,9 +142,10 @@ class ModelWrapper:
142
  }
143
 
144
 
145
- print(noise.dtype)
146
- print(batch_prompt_embeds.dtype)
147
-
 
148
 
149
  eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)
150
 
@@ -165,7 +166,7 @@ def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
165
  return pred_original_sample
166
 
167
  class SDXLTextEncoder(torch.nn.Module):
168
- def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
169
  super().__init__()
170
 
171
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
 
90
  raise NotImplementedError()
91
 
92
  DTYPE = prompt_embed.dtype
93
+ print(f'prompt_embed: {DTYPE}')
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
+ current_timesteps = current_timesteps.to(torch.float16)
98
+ print(f'current_timestpes: {current_timesteps.dtype}')
99
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
 
123
 
124
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
125
 
126
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float16)
127
 
128
  prompt_inputs = self._encode_prompt(prompt)
129
 
 
142
  }
143
 
144
 
145
+ print(f'noise: {noise.dtype}')
146
+ print(f'prompt: {batch_prompt_embeds.dtype}')
147
+ print(unet_added_conditions['time_ids'].dtype)
148
+ print(unet_added_conditions['text_embeds'].dtype)
149
 
150
  eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)
151
 
 
166
  return pred_original_sample
167
 
168
  class SDXLTextEncoder(torch.nn.Module):
169
+ def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
170
  super().__init__()
171
 
172
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)