vilarin commited on
Commit
47c81d5
·
verified ·
1 Parent(s): 4131205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -94,6 +94,8 @@ class ModelWrapper:
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
 
 
97
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
98
  print(type(eval_images))
99
 
@@ -158,7 +160,7 @@ def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
158
  return pred_original_sample
159
 
160
  class SDXLTextEncoder(torch.nn.Module):
161
- def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
162
  super().__init__()
163
 
164
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
 
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
97
+ print(type(current_timesteps))
98
+ print(type(noise))
99
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
100
  print(type(eval_images))
101
 
 
160
  return pred_original_sample
161
 
162
  class SDXLTextEncoder(torch.nn.Module):
163
+ def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
164
  super().__init__()
165
 
166
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)