Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -37,25 +37,24 @@ print(device)
|
|
37 |
low_vram = True
|
38 |
|
39 |
# Function definition for low VRAM usage
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
if
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
torch.cuda.empty_cache()
|
59 |
|
60 |
# Stage C model configuration
|
61 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
@@ -128,9 +127,10 @@ models_rbm = core.Models(
|
|
128 |
generator_ema=models.generator_ema,
|
129 |
tokenizer=models.tokenizer,
|
130 |
text_model=models.text_model,
|
131 |
-
image_model=models.image_model
|
|
|
|
|
132 |
)
|
133 |
-
models_rbm.generator.eval().requires_grad_(False)
|
134 |
|
135 |
def reset_inference_state():
|
136 |
global models_rbm, models_b, extras, extras_b, device, core, core_b
|
@@ -146,29 +146,32 @@ def reset_inference_state():
|
|
146 |
extras_b.sampling_configs['timesteps'] = 10
|
147 |
extras_b.sampling_configs['t_start'] = 1.0
|
148 |
|
149 |
-
#
|
150 |
-
models_rbm =
|
151 |
-
models_b
|
152 |
-
|
153 |
-
|
154 |
-
)
|
|
|
155 |
|
156 |
-
# Move models to the correct device
|
157 |
if low_vram:
|
158 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
159 |
-
|
|
|
160 |
else:
|
161 |
models_to(models_rbm, device=device)
|
162 |
-
|
|
|
163 |
|
164 |
# Ensure effnet is on the correct device
|
165 |
models_rbm.effnet.to(device)
|
166 |
|
167 |
-
#
|
168 |
models_rbm.generator.eval().requires_grad_(False)
|
169 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
170 |
|
171 |
-
# Clear CUDA cache
|
172 |
torch.cuda.empty_cache()
|
173 |
gc.collect()
|
174 |
|
@@ -197,6 +200,14 @@ def infer(style_description, ref_style_file, caption):
|
|
197 |
batch = {'captions': [caption] * batch_size}
|
198 |
batch['style'] = ref_style
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
# Ensure effnet is on the correct device
|
201 |
models_rbm.effnet.to(device)
|
202 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
|
|
37 |
low_vram = True
|
38 |
|
39 |
# Function definition for low VRAM usage
|
40 |
+
def models_to(model, device="cpu", excepts=None):
|
41 |
+
"""
|
42 |
+
Change the device of nn.Modules within a class, skipping specified attributes.
|
43 |
+
"""
|
44 |
+
for attr_name in dir(model):
|
45 |
+
if attr_name.startswith('__') and attr_name.endswith('__'):
|
46 |
+
continue # skip special attributes
|
47 |
+
|
48 |
+
attr_value = getattr(model, attr_name, None)
|
49 |
+
|
50 |
+
if isinstance(attr_value, torch.nn.Module):
|
51 |
+
if excepts and attr_name in excepts:
|
52 |
+
print(f"Except '{attr_name}'")
|
53 |
+
continue
|
54 |
+
print(f"Change device of '{attr_name}' to {device}")
|
55 |
+
attr_value.to(device)
|
56 |
+
|
57 |
+
torch.cuda.empty_cache()
|
|
|
58 |
|
59 |
# Stage C model configuration
|
60 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
127 |
generator_ema=models.generator_ema,
|
128 |
tokenizer=models.tokenizer,
|
129 |
text_model=models.text_model,
|
130 |
+
image_model=models.image_model,
|
131 |
+
stage_a=models.stage_a,
|
132 |
+
stage_b=models.stage_b,
|
133 |
)
|
|
|
134 |
|
135 |
def reset_inference_state():
|
136 |
global models_rbm, models_b, extras, extras_b, device, core, core_b
|
|
|
146 |
extras_b.sampling_configs['timesteps'] = 10
|
147 |
extras_b.sampling_configs['t_start'] = 1.0
|
148 |
|
149 |
+
# Move models to CPU to free up GPU memory
|
150 |
+
models_to(models_rbm, device="cpu")
|
151 |
+
models_b.generator.to("cpu")
|
152 |
+
|
153 |
+
# Clear CUDA cache
|
154 |
+
torch.cuda.empty_cache()
|
155 |
+
gc.collect()
|
156 |
|
157 |
+
# Move necessary models back to the correct device
|
158 |
if low_vram:
|
159 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
160 |
+
models_rbm.generator.to(device)
|
161 |
+
models_rbm.previewer.to(device)
|
162 |
else:
|
163 |
models_to(models_rbm, device=device)
|
164 |
+
|
165 |
+
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
166 |
|
167 |
# Ensure effnet is on the correct device
|
168 |
models_rbm.effnet.to(device)
|
169 |
|
170 |
+
# Reset model states
|
171 |
models_rbm.generator.eval().requires_grad_(False)
|
172 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
173 |
|
174 |
+
# Clear CUDA cache again
|
175 |
torch.cuda.empty_cache()
|
176 |
gc.collect()
|
177 |
|
|
|
200 |
batch = {'captions': [caption] * batch_size}
|
201 |
batch['style'] = ref_style
|
202 |
|
203 |
+
# Ensure models are on the correct device before inference
|
204 |
+
if low_vram:
|
205 |
+
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
206 |
+
else:
|
207 |
+
models_to(models_rbm, device=device)
|
208 |
+
|
209 |
+
models_b.generator.to(device)
|
210 |
+
|
211 |
# Ensure effnet is on the correct device
|
212 |
models_rbm.effnet.to(device)
|
213 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|