fffiloni commited on
Commit
eb756c3
1 Parent(s): 769e433

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -38
app.py CHANGED
@@ -107,31 +107,29 @@ models_b = WurstCoreB.Models(
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  sam_model = LangSAM()
113
 
114
  def infer(ref_style_file, style_description, caption, progress):
115
  global models_rbm, models_b, device
116
-
117
- if low_vram:
118
- # Off-load old generator (which is not used in models_rbm)
119
- models.generator.to("cpu")
120
- torch.cuda.empty_cache()
121
- gc.collect()
122
-
123
- generator_rbm = StageCRBM()
124
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
125
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
126
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
127
- generator_rbm = core.load_model(generator_rbm, 'generator')
128
-
129
- models_rbm = core.Models(
130
- effnet=models.effnet, previewer=models.previewer,
131
- generator=generator_rbm, generator_ema=models.generator_ema,
132
- tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
133
- )
134
- models_rbm.generator.eval().requires_grad_(False)
135
 
136
  if low_vram:
137
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
@@ -240,25 +238,6 @@ def infer(ref_style_file, style_description, caption, progress):
240
 
241
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
242
  global models_rbm, models_b, device, sam_model
243
-
244
- if low_vram:
245
- # Off-load old generator (which is not used in models_rbm)
246
- models.generator.to("cpu")
247
- torch.cuda.empty_cache()
248
- gc.collect()
249
-
250
- generator_rbm = StageCRBM()
251
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
252
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
253
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
254
- generator_rbm = core.load_model(generator_rbm, 'generator')
255
-
256
- models_rbm = core.Models(
257
- effnet=models.effnet, previewer=models.previewer,
258
- generator=generator_rbm, generator_ema=models.generator_ema,
259
- tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
260
- )
261
- models_rbm.generator.eval().requires_grad_(False)
262
 
263
  if low_vram:
264
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
 
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
110
+ if low_vram:
111
+ # Off-load old generator (which is not used in models_rbm)
112
+ models.generator.to("cpu")
113
+ torch.cuda.empty_cache()
114
+ gc.collect()
115
 
116
+ generator_rbm = StageCRBM()
117
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
118
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
119
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
120
+ generator_rbm = core.load_model(generator_rbm, 'generator')
121
+
122
+ models_rbm = core.Models(
123
+ effnet=models.effnet, previewer=models.previewer,
124
+ generator=generator_rbm, generator_ema=models.generator_ema,
125
+ tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
126
+ )
127
+ models_rbm.generator.eval().requires_grad_(False)
128
 
129
  sam_model = LangSAM()
130
 
131
  def infer(ref_style_file, style_description, caption, progress):
132
  global models_rbm, models_b, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if low_vram:
135
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
 
238
 
239
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
240
  global models_rbm, models_b, device, sam_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  if low_vram:
243
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])