fffiloni commited on
Commit
52cb438
1 Parent(s): b6f23ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -32
app.py CHANGED
@@ -37,25 +37,24 @@ print(device)
37
  low_vram = True
38
 
39
  # Function definition for low VRAM usage
40
- if low_vram:
41
- def models_to(model, device="cpu", excepts=None):
42
- """
43
- Change the device of nn.Modules within a class, skipping specified attributes.
44
- """
45
- for attr_name in dir(model):
46
- if attr_name.startswith('__') and attr_name.endswith('__'):
47
- continue # skip special attributes
48
-
49
- attr_value = getattr(model, attr_name, None)
50
-
51
- if isinstance(attr_value, torch.nn.Module):
52
- if excepts and attr_name in excepts:
53
- print(f"Except '{attr_name}'")
54
- continue
55
- print(f"Change device of '{attr_name}' to {device}")
56
- attr_value.to(device)
57
-
58
- torch.cuda.empty_cache()
59
 
60
  # Stage C model configuration
61
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
@@ -128,9 +127,10 @@ models_rbm = core.Models(
128
  generator_ema=models.generator_ema,
129
  tokenizer=models.tokenizer,
130
  text_model=models.text_model,
131
- image_model=models.image_model
 
 
132
  )
133
- models_rbm.generator.eval().requires_grad_(False)
134
 
135
  def reset_inference_state():
136
  global models_rbm, models_b, extras, extras_b, device, core, core_b
@@ -146,29 +146,32 @@ def reset_inference_state():
146
  extras_b.sampling_configs['timesteps'] = 10
147
  extras_b.sampling_configs['t_start'] = 1.0
148
 
149
- # Reset models
150
- models_rbm = core.setup_models(extras)
151
- models_b = core_b.setup_models(extras_b, skip_clip=True)
152
- models_b = WurstCoreB.Models(
153
- **{**models_b.to_dict(), 'tokenizer': models_rbm.tokenizer, 'text_model': models_rbm.text_model}
154
- )
 
155
 
156
- # Move models to the correct device
157
  if low_vram:
158
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
159
- models_b.generator.to("cpu")
 
160
  else:
161
  models_to(models_rbm, device=device)
162
- models_b.generator.to(device)
 
163
 
164
  # Ensure effnet is on the correct device
165
  models_rbm.effnet.to(device)
166
 
167
- # Set models to eval mode and disable gradients
168
  models_rbm.generator.eval().requires_grad_(False)
169
  models_b.generator.bfloat16().eval().requires_grad_(False)
170
 
171
- # Clear CUDA cache
172
  torch.cuda.empty_cache()
173
  gc.collect()
174
 
@@ -197,6 +200,14 @@ def infer(style_description, ref_style_file, caption):
197
  batch = {'captions': [caption] * batch_size}
198
  batch['style'] = ref_style
199
 
 
 
 
 
 
 
 
 
200
  # Ensure effnet is on the correct device
201
  models_rbm.effnet.to(device)
202
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
 
37
  low_vram = True
38
 
39
  # Function definition for low VRAM usage
40
+ def models_to(model, device="cpu", excepts=None):
41
+ """
42
+ Change the device of nn.Modules within a class, skipping specified attributes.
43
+ """
44
+ for attr_name in dir(model):
45
+ if attr_name.startswith('__') and attr_name.endswith('__'):
46
+ continue # skip special attributes
47
+
48
+ attr_value = getattr(model, attr_name, None)
49
+
50
+ if isinstance(attr_value, torch.nn.Module):
51
+ if excepts and attr_name in excepts:
52
+ print(f"Except '{attr_name}'")
53
+ continue
54
+ print(f"Change device of '{attr_name}' to {device}")
55
+ attr_value.to(device)
56
+
57
+ torch.cuda.empty_cache()
 
58
 
59
  # Stage C model configuration
60
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
 
127
  generator_ema=models.generator_ema,
128
  tokenizer=models.tokenizer,
129
  text_model=models.text_model,
130
+ image_model=models.image_model,
131
+ stage_a=models.stage_a,
132
+ stage_b=models.stage_b,
133
  )
 
134
 
135
  def reset_inference_state():
136
  global models_rbm, models_b, extras, extras_b, device, core, core_b
 
146
  extras_b.sampling_configs['timesteps'] = 10
147
  extras_b.sampling_configs['t_start'] = 1.0
148
 
149
+ # Move models to CPU to free up GPU memory
150
+ models_to(models_rbm, device="cpu")
151
+ models_b.generator.to("cpu")
152
+
153
+ # Clear CUDA cache
154
+ torch.cuda.empty_cache()
155
+ gc.collect()
156
 
157
+ # Move necessary models back to the correct device
158
  if low_vram:
159
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
160
+ models_rbm.generator.to(device)
161
+ models_rbm.previewer.to(device)
162
  else:
163
  models_to(models_rbm, device=device)
164
+
165
+ models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
166
 
167
  # Ensure effnet is on the correct device
168
  models_rbm.effnet.to(device)
169
 
170
+ # Reset model states
171
  models_rbm.generator.eval().requires_grad_(False)
172
  models_b.generator.bfloat16().eval().requires_grad_(False)
173
 
174
+ # Clear CUDA cache again
175
  torch.cuda.empty_cache()
176
  gc.collect()
177
 
 
200
  batch = {'captions': [caption] * batch_size}
201
  batch['style'] = ref_style
202
 
203
+ # Ensure models are on the correct device before inference
204
+ if low_vram:
205
+ models_to(models_rbm, device=device, excepts=["generator", "previewer"])
206
+ else:
207
+ models_to(models_rbm, device=device)
208
+
209
+ models_b.generator.to(device)
210
+
211
  # Ensure effnet is on the correct device
212
  models_rbm.effnet.to(device)
213
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))