fffiloni commited on
Commit
e76ae74
·
verified ·
1 Parent(s): 1034c70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -150,15 +150,14 @@ models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
  try:
153
- # Instead of trying to move the entire models_rbm object, move individual components
154
- models_rbm.effnet.to(device)
155
- models_rbm.previewer.to(device)
156
- models_rbm.generator.to(device)
157
- models_rbm.text_model.to(device)
158
 
159
- # For models_b, we need to move its components as well
160
- models_b.generator.to(device)
161
- models_b.stage_a.to(device)
162
 
163
  clear_gpu_cache() # Clear cache before inference
164
 
@@ -179,13 +178,11 @@ def infer(style_description, ref_style_file, caption):
179
  extras_b.sampling_configs['timesteps'] = 10
180
  extras_b.sampling_configs['t_start'] = 1.0
181
 
182
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
183
 
184
  batch = {'captions': [caption] * batch_size}
185
  batch['style'] = ref_style
186
 
187
- # Ensure effnet is on the same device as the input
188
- models_rbm.effnet.to(device)
189
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
190
 
191
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
@@ -198,7 +195,7 @@ def infer(style_description, ref_style_file, caption):
198
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
199
 
200
  # Stage C reverse process
201
- with torch.cuda.amp.autocast(): # Use mixed precision
202
  sampling_c = extras.gdf.sample(
203
  models_rbm.generator, conditions, stage_c_latent_shape,
204
  unconditions, device=device,
@@ -216,7 +213,7 @@ def infer(style_description, ref_style_file, caption):
216
  clear_gpu_cache() # Clear cache between stages
217
 
218
  # Ensure all models are on the right device again
219
- models_b.generator.to(device)
220
 
221
  # Stage B reverse process
222
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
150
 
151
  def infer(style_description, ref_style_file, caption):
152
  try:
153
+ # Move all model components to the same device and set to the same precision
154
+ models_rbm.effnet.to(device).bfloat16()
155
+ models_rbm.previewer.to(device).bfloat16()
156
+ models_rbm.generator.to(device).bfloat16()
157
+ models_rbm.text_model.to(device).bfloat16()
158
 
159
+ models_b.generator.to(device).bfloat16()
160
+ models_b.stage_a.to(device).bfloat16()
 
161
 
162
  clear_gpu_cache() # Clear cache before inference
163
 
 
178
  extras_b.sampling_configs['timesteps'] = 10
179
  extras_b.sampling_configs['t_start'] = 1.0
180
 
181
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device).bfloat16()
182
 
183
  batch = {'captions': [caption] * batch_size}
184
  batch['style'] = ref_style
185
 
 
 
186
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
187
 
188
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
 
195
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
196
 
197
  # Stage C reverse process
198
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use mixed precision with bfloat16
199
  sampling_c = extras.gdf.sample(
200
  models_rbm.generator, conditions, stage_c_latent_shape,
201
  unconditions, device=device,
 
213
  clear_gpu_cache() # Clear cache between stages
214
 
215
  # Ensure all models are on the right device again
216
+ models_b.generator.to(device).bfloat16()
217
 
218
  # Stage B reverse process
219
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):