Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -150,15 +150,14 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
150 |
|
151 |
def infer(style_description, ref_style_file, caption):
|
152 |
try:
|
153 |
-
#
|
154 |
-
models_rbm.effnet.to(device)
|
155 |
-
models_rbm.previewer.to(device)
|
156 |
-
models_rbm.generator.to(device)
|
157 |
-
models_rbm.text_model.to(device)
|
158 |
|
159 |
-
|
160 |
-
models_b.
|
161 |
-
models_b.stage_a.to(device)
|
162 |
|
163 |
clear_gpu_cache() # Clear cache before inference
|
164 |
|
@@ -179,13 +178,11 @@ def infer(style_description, ref_style_file, caption):
|
|
179 |
extras_b.sampling_configs['timesteps'] = 10
|
180 |
extras_b.sampling_configs['t_start'] = 1.0
|
181 |
|
182 |
-
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
183 |
|
184 |
batch = {'captions': [caption] * batch_size}
|
185 |
batch['style'] = ref_style
|
186 |
|
187 |
-
# Ensure effnet is on the same device as the input
|
188 |
-
models_rbm.effnet.to(device)
|
189 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
190 |
|
191 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
@@ -198,7 +195,7 @@ def infer(style_description, ref_style_file, caption):
|
|
198 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
199 |
|
200 |
# Stage C reverse process
|
201 |
-
with torch.cuda.amp.autocast(): # Use mixed precision
|
202 |
sampling_c = extras.gdf.sample(
|
203 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
204 |
unconditions, device=device,
|
@@ -216,7 +213,7 @@ def infer(style_description, ref_style_file, caption):
|
|
216 |
clear_gpu_cache() # Clear cache between stages
|
217 |
|
218 |
# Ensure all models are on the right device again
|
219 |
-
models_b.generator.to(device)
|
220 |
|
221 |
# Stage B reverse process
|
222 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
150 |
|
151 |
def infer(style_description, ref_style_file, caption):
|
152 |
try:
|
153 |
+
# Move all model components to the same device and set to the same precision
|
154 |
+
models_rbm.effnet.to(device).bfloat16()
|
155 |
+
models_rbm.previewer.to(device).bfloat16()
|
156 |
+
models_rbm.generator.to(device).bfloat16()
|
157 |
+
models_rbm.text_model.to(device).bfloat16()
|
158 |
|
159 |
+
models_b.generator.to(device).bfloat16()
|
160 |
+
models_b.stage_a.to(device).bfloat16()
|
|
|
161 |
|
162 |
clear_gpu_cache() # Clear cache before inference
|
163 |
|
|
|
178 |
extras_b.sampling_configs['timesteps'] = 10
|
179 |
extras_b.sampling_configs['t_start'] = 1.0
|
180 |
|
181 |
+
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device).bfloat16()
|
182 |
|
183 |
batch = {'captions': [caption] * batch_size}
|
184 |
batch['style'] = ref_style
|
185 |
|
|
|
|
|
186 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
187 |
|
188 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
|
|
195 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
196 |
|
197 |
# Stage C reverse process
|
198 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use mixed precision with bfloat16
|
199 |
sampling_c = extras.gdf.sample(
|
200 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
201 |
unconditions, device=device,
|
|
|
213 |
clear_gpu_cache() # Clear cache between stages
|
214 |
|
215 |
# Ensure all models are on the right device again
|
216 |
+
models_b.generator.to(device).bfloat16()
|
217 |
|
218 |
# Stage B reverse process
|
219 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|