fffiloni commited on
Commit
8fab8d2
1 Parent(s): fac5f0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -93
app.py CHANGED
@@ -149,103 +149,112 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
- try:
153
- # Ensure all models are moved back to the correct device
154
- models_rbm.generator.to(device)
155
- models_b.generator.to(device)
156
-
157
- clear_gpu_cache() # Clear cache before inference
158
-
159
- height = 1024
160
- width = 1024
161
- batch_size = 1
162
- output_file = 'output.png'
163
-
164
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
165
-
166
- extras.sampling_configs['cfg'] = 4
167
- extras.sampling_configs['shift'] = 2
168
- extras.sampling_configs['timesteps'] = 20
169
- extras.sampling_configs['t_start'] = 1.0
170
-
171
- extras_b.sampling_configs['cfg'] = 1.1
172
- extras_b.sampling_configs['shift'] = 1
173
- extras_b.sampling_configs['timesteps'] = 10
174
- extras_b.sampling_configs['t_start'] = 1.0
175
-
176
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
177
-
178
- batch = {'captions': [caption] * batch_size}
179
- batch['style'] = ref_style
180
-
181
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
182
-
183
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
184
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
185
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
186
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
187
-
188
- if low_vram:
189
- # Offload non-essential models to CPU for memory savings
190
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
191
-
192
- # Stage C reverse process
193
- with torch.cuda.amp.autocast(): # Use mixed precision
194
- sampling_c = extras.gdf.sample(
195
- models_rbm.generator, conditions, stage_c_latent_shape,
196
- unconditions, device=device,
197
- **extras.sampling_configs,
198
- x0_style_forward=x0_style_forward,
199
- apply_pushforward=False, tau_pushforward=8,
200
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
201
- extras=extras, models=models_rbm,
202
- lam_style=1, lam_txt_alignment=1.0,
203
- use_ddim_sampler=True,
204
- )
205
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
206
- sampled_c = sampled_c
207
-
208
- clear_gpu_cache() # Clear cache between stages
209
-
210
- # Ensure all models are on the right device again
211
- models_b.generator.to(device)
212
-
213
- # Stage B reverse process
214
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
215
- conditions_b['effnet'] = sampled_c
216
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
217
-
218
- sampling_b = extras_b.gdf.sample(
219
- models_b.generator, conditions_b, stage_b_latent_shape,
220
- unconditions_b, device=device, **extras_b.sampling_configs,
221
- )
222
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
223
- sampled_b = sampled_b
224
- sampled = models_b.stage_a.decode(sampled_b).float()
225
-
226
- # Post-process and save the image
227
- sampled = sampled.cpu() # Move to CPU before processing
228
-
229
- # Ensure the tensor is in [C, H, W] format
230
- if sampled.dim() == 4 and sampled.size(0) == 1:
231
- sampled = sampled.squeeze(0)
232
-
233
- if sampled.dim() == 3 and sampled.shape[0] == 3:
234
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
235
- sampled_image.save(output_file) # Save the image as a PNG
236
- else:
237
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
238
-
239
- except Exception as e:
240
- print(f"An error occurred during inference: {str(e)}")
241
- return None
242
-
243
- finally:
244
- clear_gpu_cache() # Always clear cache after inference
 
 
 
 
 
 
 
 
245
 
246
  return output_file # Return the path to the saved image
247
 
248
 
 
249
  import gradio as gr
250
 
251
  gr.Interface(
 
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
+ # Move models to the correct device
153
+ models_rbm.effnet.to(device)
154
+ models_rbm.generator.to(device)
155
+ if low_vram:
156
+ models_rbm.previewer.to(device)
157
+
158
+ # Also, revalidate data types and devices for key tensors
159
+ def check_and_move(tensor):
160
+ if tensor is not None and tensor.device != device:
161
+ return tensor.to(device)
162
+ return tensor
163
+
164
+ clear_gpu_cache() # Clear cache before inference
165
+
166
+ height = 1024
167
+ width = 1024
168
+ batch_size = 1
169
+ output_file = 'output.png'
170
+
171
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
172
+
173
+ extras.sampling_configs['cfg'] = 4
174
+ extras.sampling_configs['shift'] = 2
175
+ extras.sampling_configs['timesteps'] = 20
176
+ extras.sampling_configs['t_start'] = 1.0
177
+
178
+ extras_b.sampling_configs['cfg'] = 1.1
179
+ extras_b.sampling_configs['shift'] = 1
180
+ extras_b.sampling_configs['timesteps'] = 10
181
+ extras_b.sampling_configs['t_start'] = 1.0
182
+
183
+ # Load and preprocess the reference style image
184
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB"))
185
+ ref_style = ref_style.unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
186
+
187
+ batch = {'captions': [caption] * batch_size}
188
+ batch['style'] = ref_style
189
+
190
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
191
+
192
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
193
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
194
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
195
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
196
+
197
+ if low_vram:
198
+ # Offload non-essential models to CPU for memory savings
199
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
200
+
201
+ # Stage C reverse process
202
+ with torch.cuda.amp.autocast():
203
+ sampling_c = extras.gdf.sample(
204
+ models_rbm.generator, conditions, stage_c_latent_shape,
205
+ unconditions, device=device,
206
+ **extras.sampling_configs,
207
+ x0_style_forward=x0_style_forward,
208
+ apply_pushforward=False, tau_pushforward=8,
209
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
210
+ extras=extras, models=models_rbm,
211
+ lam_style=1, lam_txt_alignment=1.0,
212
+ use_ddim_sampler=True,
213
+ )
214
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
215
+ sampled_c = sampled_c
216
+
217
+ clear_gpu_cache() # Clear cache between stages
218
+
219
+ # Ensure all models are on the right device again
220
+ models_b.generator.to(device)
221
+ models_b.stage_a.to(device)
222
+
223
+ # Stage B reverse process
224
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
225
+ conditions_b['effnet'] = sampled_c.to(device)
226
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c).to(device)
227
+
228
+ sampling_b = extras_b.gdf.sample(
229
+ models_b.generator, conditions_b, stage_b_latent_shape,
230
+ unconditions_b, device=device, **extras_b.sampling_configs,
231
+ )
232
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
233
+ sampled_b = sampled_b
234
+ sampled = models_b.stage_a.decode(sampled_b).float().to(device)
235
+
236
+ # Post-process and save the image
237
+ sampled = torch.cat([
238
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
239
+ sampled.cpu(),
240
+ ], dim=0)
241
+
242
+ # Remove the batch dimension and keep only the generated image
243
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
244
+
245
+ # Ensure the tensor is in [C, H, W] format
246
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
247
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
248
+ sampled_image.save(output_file) # Save the image as a PNG
249
+ else:
250
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
251
+
252
+ clear_gpu_cache() # Clear cache after inference
253
 
254
  return output_file # Return the path to the saved image
255
 
256
 
257
+
258
  import gradio as gr
259
 
260
  gr.Interface(