fffiloni commited on
Commit
769e433
1 Parent(s): 74284f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -16
app.py CHANGED
@@ -107,28 +107,32 @@ models_b = WurstCoreB.Models(
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
110
- if low_vram:
111
- # Off-load old generator (which is not used in models_rbm)
112
- models.generator.to("cpu")
113
- torch.cuda.empty_cache()
114
 
115
- generator_rbm = StageCRBM()
116
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
117
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
118
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
119
- generator_rbm = core.load_model(generator_rbm, 'generator')
120
-
121
- models_rbm = core.Models(
122
- effnet=models.effnet, previewer=models.previewer,
123
- generator=generator_rbm, generator_ema=models.generator_ema,
124
- tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
125
- )
126
- models_rbm.generator.eval().requires_grad_(False)
127
 
128
  sam_model = LangSAM()
129
 
130
  def infer(ref_style_file, style_description, caption, progress):
131
  global models_rbm, models_b, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if low_vram:
133
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
134
  try:
@@ -236,6 +240,26 @@ def infer(ref_style_file, style_description, caption, progress):
236
 
237
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
238
  global models_rbm, models_b, device, sam_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  if low_vram:
240
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
241
  models_to(sam_model, device=device)
 
107
  )
108
  models_b.generator.bfloat16().eval().requires_grad_(False)
109
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  sam_model = LangSAM()
113
 
114
  def infer(ref_style_file, style_description, caption, progress):
115
  global models_rbm, models_b, device
116
+
117
+ if low_vram:
118
+ # Off-load old generator (which is not used in models_rbm)
119
+ models.generator.to("cpu")
120
+ torch.cuda.empty_cache()
121
+ gc.collect()
122
+
123
+ generator_rbm = StageCRBM()
124
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
125
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
126
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
127
+ generator_rbm = core.load_model(generator_rbm, 'generator')
128
+
129
+ models_rbm = core.Models(
130
+ effnet=models.effnet, previewer=models.previewer,
131
+ generator=generator_rbm, generator_ema=models.generator_ema,
132
+ tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
133
+ )
134
+ models_rbm.generator.eval().requires_grad_(False)
135
+
136
  if low_vram:
137
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
138
  try:
 
240
 
241
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
242
  global models_rbm, models_b, device, sam_model
243
+
244
+ if low_vram:
245
+ # Off-load old generator (which is not used in models_rbm)
246
+ models.generator.to("cpu")
247
+ torch.cuda.empty_cache()
248
+ gc.collect()
249
+
250
+ generator_rbm = StageCRBM()
251
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
252
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
253
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
254
+ generator_rbm = core.load_model(generator_rbm, 'generator')
255
+
256
+ models_rbm = core.Models(
257
+ effnet=models.effnet, previewer=models.previewer,
258
+ generator=generator_rbm, generator_ema=models.generator_ema,
259
+ tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
260
+ )
261
+ models_rbm.generator.eval().requires_grad_(False)
262
+
263
  if low_vram:
264
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
265
  models_to(sam_model, device=device)