vilarin commited on
Commit
f0841ba
·
verified ·
1 Parent(s): c15af9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -77,7 +77,7 @@ class ModelWrapper:
77
  def _get_time():
78
  return time.time()
79
 
80
- @spaces.GPU()
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
83
 
@@ -89,13 +89,15 @@ class ModelWrapper:
89
  step_interval = 250
90
  else:
91
  raise NotImplementedError()
92
-
 
 
93
  DTYPE = prompt_embed.dtype
94
  print(f'prompt_embed: {DTYPE}')
95
 
96
  for constant in all_timesteps:
97
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
98
- current_timesteps = current_timesteps.to(torch.float16)
99
  print(f'current_timestpes: {current_timesteps.dtype}')
100
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
101
  print(type(eval_images))
@@ -124,7 +126,7 @@ class ModelWrapper:
124
 
125
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
126
 
127
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=torch.float16)
128
 
129
  prompt_inputs = self._encode_prompt(prompt)
130
 
@@ -143,7 +145,7 @@ class ModelWrapper:
143
  }
144
 
145
 
146
- print(f'noise: {noise.dtype}')
147
  print(f'prompt: {batch_prompt_embeds.dtype}')
148
  print(unet_added_conditions['time_ids'].dtype)
149
  print(unet_added_conditions['text_embeds'].dtype)
 
77
  def _get_time():
78
  return time.time()
79
 
80
+ @spaces.GPU(duration=100)
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
83
 
 
89
  step_interval = 250
90
  else:
91
  raise NotImplementedError()
92
+
93
+ noise = noise.to(device="cuda", dtype=torch.float16)
94
+ print(f'noise: {noise.dtype}')
95
  DTYPE = prompt_embed.dtype
96
  print(f'prompt_embed: {DTYPE}')
97
 
98
  for constant in all_timesteps:
99
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
100
+ #current_timesteps = current_timesteps.to(torch.float16)
101
  print(f'current_timestpes: {current_timesteps.dtype}')
102
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
103
  print(type(eval_images))
 
126
 
127
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
128
 
129
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator)
130
 
131
  prompt_inputs = self._encode_prompt(prompt)
132
 
 
145
  }
146
 
147
 
148
+
149
  print(f'prompt: {batch_prompt_embeds.dtype}')
150
  print(unet_added_conditions['time_ids'].dtype)
151
  print(unet_added_conditions['text_embeds'].dtype)