vilarin commited on
Commit
f03332f
·
verified ·
1 Parent(s): 0cffd40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -46,10 +46,11 @@ class ModelWrapper:
46
  self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
-
 
50
  def create_generator(self, model_id, checkpoint_path):
51
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
- state_dict = torch.load(checkpoint_path, map_location="cpu")
53
  generator.load_state_dict(state_dict, strict=True)
54
  generator.requires_grad_(False)
55
  return generator
@@ -108,7 +109,7 @@ class ModelWrapper:
108
  eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
109
  return eval_images
110
 
111
- @spaces.GPU(enable_queue=True)
112
  @torch.no_grad()
113
  def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
114
  print("Running model inference...")
@@ -196,9 +197,6 @@ def create_demo():
196
  num_step = 4
197
  revision = None
198
 
199
- torch.backends.cuda.matmul.allow_tf32 = True
200
- torch.backends.cudnn.allow_tf32 = True
201
-
202
  accelerator = Accelerator()
203
 
204
  model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
@@ -211,10 +209,10 @@ def create_demo():
211
  run_button = gr.Button("Run")
212
  with gr.Accordion(label="Advanced options", open=True):
213
  seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
214
- num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=16)
215
  fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
216
- height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=1024)
217
- width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=1024)
218
  with gr.Column():
219
  result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
220
  error_message = gr.Text(label="Job Status")
 
46
  self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
+
50
+ @spaces.GPU(enable_queue=True)
51
  def create_generator(self, model_id, checkpoint_path):
52
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
53
+ state_dict = torch.load(checkpoint_path, map_location="cuda")
54
  generator.load_state_dict(state_dict, strict=True)
55
  generator.requires_grad_(False)
56
  return generator
 
109
  eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
110
  return eval_images
111
 
112
+
113
  @torch.no_grad()
114
  def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
115
  print("Running model inference...")
 
197
  num_step = 4
198
  revision = None
199
 
 
 
 
200
  accelerator = Accelerator()
201
 
202
  model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
 
209
  run_button = gr.Button("Run")
210
  with gr.Accordion(label="Advanced options", open=True):
211
  seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
212
+ num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=1)
213
  fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
214
+ height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=512)
215
+ width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=512)
216
  with gr.Column():
217
  result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
218
  error_message = gr.Text(label="Job Status")