vilarin commited on
Commit
36d0a3b
·
verified ·
1 Parent(s): c2b3f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -21,7 +21,7 @@ class ModelWrapper:
21
  torch.set_grad_enabled(False)
22
 
23
  self.DTYPE = torch.float16
24
- self.device = accelerator.device
25
 
26
  self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
27
  self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
@@ -60,7 +60,7 @@ class ModelWrapper:
60
  crop_top_left = (0, 0)
61
 
62
  add_time_ids = list(original_size + crop_top_left + target_size)
63
- add_time_ids = torch.tensor([add_time_ids], device=self.device, dtype=self.DTYPE)
64
  return add_time_ids
65
 
66
  def _encode_prompt(self, prompt):
@@ -92,7 +92,7 @@ class ModelWrapper:
92
  DTYPE = prompt_embed.dtype
93
 
94
  for constant in all_timesteps:
95
- current_timesteps = torch.ones(len(prompt_embed), device=self.device, dtype=torch.long) * constant
96
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
97
 
98
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
@@ -120,7 +120,7 @@ class ModelWrapper:
120
 
121
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
122
 
123
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device=self.device, dtype=self.DTYPE)
124
 
125
  prompt_inputs = self._encode_prompt(prompt)
126
 
@@ -148,7 +148,6 @@ class ModelWrapper:
148
 
149
  return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
150
 
151
- @spaces.GPU()
152
  def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
153
  alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
154
  beta_prod_t = 1 - alpha_prod_t
 
21
  torch.set_grad_enabled(False)
22
 
23
  self.DTYPE = torch.float16
24
+ self.device = 0
25
 
26
  self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
27
  self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
 
60
  crop_top_left = (0, 0)
61
 
62
  add_time_ids = list(original_size + crop_top_left + target_size)
63
+ add_time_ids = torch.tensor([add_time_ids], device="cuda", dtype=self.DTYPE)
64
  return add_time_ids
65
 
66
  def _encode_prompt(self, prompt):
 
92
  DTYPE = prompt_embed.dtype
93
 
94
  for constant in all_timesteps:
95
+ current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
96
  eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
97
 
98
  eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)
 
120
 
121
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
122
 
123
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=self.DTYPE)
124
 
125
  prompt_inputs = self._encode_prompt(prompt)
126
 
 
148
 
149
  return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
150
 
 
151
  def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
152
  alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
153
  beta_prod_t = 1 - alpha_prod_t