fffiloni commited on
Commit
cc91cd8
1 Parent(s): 2b29f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -123
app.py CHANGED
@@ -106,65 +106,10 @@ models_b = WurstCoreB.Models(
106
  )
107
  models_b.generator.bfloat16().eval().requires_grad_(False)
108
 
109
- # Off-load old generator (low VRAM mode)
110
- if low_vram:
111
- models.generator.to("cpu")
112
- torch.cuda.empty_cache()
113
-
114
- # Load and configure new generator
115
- generator_rbm = StageCRBM()
116
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
117
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
118
-
119
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
120
- generator_rbm = core.load_model(generator_rbm, 'generator')
121
-
122
- # Create models_rbm instance
123
- models_rbm = core.Models(
124
- effnet=models.effnet,
125
- text_model=models.text_model,
126
- tokenizer=models.tokenizer,
127
- generator=generator_rbm,
128
- previewer=models.previewer,
129
- image_model=models.image_model # Add this line
130
- )
131
-
132
- def unload_models_and_clear_cache():
133
- global models_rbm, models_b, sam_model, extras, extras_b
134
-
135
- # Move all models to CPU
136
- models_to(models_rbm, device="cpu")
137
-
138
- # Move SAM model components to CPU if they exist
139
- if 'sam_model' in globals():
140
- models_to(sam_model, device="cpu")
141
- models_to(sam_model.sam, device="cpu")
142
-
143
- # Clear CUDA cache
144
- torch.cuda.empty_cache()
145
- gc.collect()
146
-
147
- # Ensure all models are in eval mode and don't require gradients
148
- for model in [models_rbm.generator, models_b.generator]:
149
- model.eval()
150
- for param in model.parameters():
151
- param.requires_grad = False
152
-
153
- # Clear CUDA cache again
154
- torch.cuda.empty_cache()
155
- gc.collect()
156
-
157
- def reset_inference_state():
158
- global models_rbm, models_b, extras, extras_b, device, core, core_b
159
-
160
- # Clear CUDA cache
161
- torch.cuda.empty_cache()
162
- gc.collect()
163
-
164
- models_to(models_rbm, device=device, excepts=["generator", "previewer"])
165
-
166
  def infer(ref_style_file, style_description, caption):
167
- global models_rbm, models_b
 
 
168
  try:
169
  caption = f"{caption} in {style_description}"
170
  height=1024
@@ -189,7 +134,7 @@ def infer(ref_style_file, style_description, caption):
189
  batch = {'captions': [caption] * batch_size}
190
  batch['style'] = ref_style
191
 
192
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
193
 
194
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
195
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
@@ -246,85 +191,48 @@ def infer(ref_style_file, style_description, caption):
246
  return output_file # Return the path to the saved image
247
 
248
  finally:
249
- # Reset the state after inference, regardless of success or failure
250
- reset_inference_state()
251
- # Unload models and clear cache after inference
252
- # unload_models_and_clear_cache()
253
-
254
- def reset_compo_inference_state():
255
- global models_rbm, models_b, extras, extras_b, device, core, core_b, sam_model
256
-
257
- # Reset sampling configurations
258
- extras.sampling_configs['cfg'] = 4
259
- extras.sampling_configs['shift'] = 2
260
- extras.sampling_configs['timesteps'] = 20
261
- extras.sampling_configs['t_start'] = 1.0
262
-
263
- extras_b.sampling_configs['cfg'] = 1.1
264
- extras_b.sampling_configs['shift'] = 1
265
- extras_b.sampling_configs['timesteps'] = 10
266
- extras_b.sampling_configs['t_start'] = 1.0
267
-
268
- # Move models to CPU to free up GPU memory
269
- models_to(models_rbm, device="cpu")
270
- models_b.generator.to("cpu")
271
-
272
- # Clear CUDA cache
273
- torch.cuda.empty_cache()
274
- gc.collect()
275
-
276
- # Move SAM model components to CPU if they exist
277
- models_to(sam_model, device="cpu")
278
- models_to(sam_model.sam, device="cpu")
279
-
280
- # Clear CUDA cache
281
- torch.cuda.empty_cache()
282
- gc.collect()
283
-
284
- # Ensure all models are in eval mode and don't require gradients
285
- for model in [models_rbm.generator, models_b.generator]:
286
- model.eval()
287
- for param in model.parameters():
288
- param.requires_grad = False
289
-
290
- # Clear CUDA cache again
291
- torch.cuda.empty_cache()
292
- gc.collect()
293
 
294
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
295
  global models_rbm, models_b, device, sam_model
 
 
 
 
296
  try:
297
  caption = f"{caption} in {style_description}"
298
  sam_prompt = f"{caption}"
299
  use_sam_mask = False
300
-
301
- # Ensure all models are on the correct device
302
- models_to(models_rbm, device)
303
- models_b.generator.to(device)
304
 
305
  batch_size = 1
306
  height, width = 1024, 1024
307
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
308
-
 
 
 
 
 
 
 
 
 
309
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
310
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
311
 
312
- batch = {'captions': [caption] * batch_size, 'style': ref_style, 'images': ref_images}
 
 
313
 
314
- x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images))
315
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
316
 
317
  ## SAM Mask for sub
318
  use_sam_mask = False
319
  x0_preview = models_rbm.previewer(x0_forward)
320
  sam_model = LangSAM()
321
-
322
- # Move SAM model components to the correct device
323
- models_to(sam_model, device)
324
- models_to(sam_model.sam, device)
325
-
326
- x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
327
- sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
328
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
329
 
330
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
@@ -389,11 +297,8 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
389
  return output_file # Return the path to the saved image
390
 
391
  finally:
392
- # Reset the state after inference, regardless of success or failure
393
- # reset_compo_inference_state()
394
- # reset_inference_state()
395
- # Unload models and clear cache after inference
396
- unload_models_and_clear_cache()
397
 
398
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
399
  result = None
 
106
  )
107
  models_b.generator.bfloat16().eval().requires_grad_(False)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def infer(ref_style_file, style_description, caption):
110
+ global models_rbm, models_b, device
111
+ if low_vram:
112
+ models_to(models_rbm, device=device)
113
  try:
114
  caption = f"{caption} in {style_description}"
115
  height=1024
 
134
  batch = {'captions': [caption] * batch_size}
135
  batch['style'] = ref_style
136
 
137
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
138
 
139
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
140
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
 
191
  return output_file # Return the path to the saved image
192
 
193
  finally:
194
+ # Clear CUDA cache
195
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
198
  global models_rbm, models_b, device, sam_model
199
+ if low_vram:
200
+ models_to(models_rbm, device=device)
201
+ models_to(sam_model, device=device)
202
+ models_to(sam_model.sam, device=device)
203
  try:
204
  caption = f"{caption} in {style_description}"
205
  sam_prompt = f"{caption}"
206
  use_sam_mask = False
 
 
 
 
207
 
208
  batch_size = 1
209
  height, width = 1024, 1024
210
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
211
+
212
+ extras.sampling_configs['cfg'] = 4
213
+ extras.sampling_configs['shift'] = 2
214
+ extras.sampling_configs['timesteps'] = 20
215
+ extras.sampling_configs['t_start'] = 1.0
216
+ extras_b.sampling_configs['cfg'] = 1.1
217
+ extras_b.sampling_configs['shift'] = 1
218
+ extras_b.sampling_configs['timesteps'] = 10
219
+ extras_b.sampling_configs['t_start'] = 1.0
220
+
221
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
222
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
223
 
224
+ batch = {'captions': [caption] * batch_size}
225
+ batch['style'] = ref_style
226
+ batch['images'] = ref_images
227
 
228
+ x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
229
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
230
 
231
  ## SAM Mask for sub
232
  use_sam_mask = False
233
  x0_preview = models_rbm.previewer(x0_forward)
234
  sam_model = LangSAM()
235
+ sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
 
 
 
 
 
 
236
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
237
 
238
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_subject_style=True, eval_csd=False)
 
297
  return output_file # Return the path to the saved image
298
 
299
  finally:
300
+ # Clear CUDA cache
301
+ torch.cuda.empty_cache()
 
 
 
302
 
303
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
304
  result = None