fffiloni commited on
Commit
cb4eb0e
1 Parent(s): 8cbd2c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -125,7 +125,8 @@ models_rbm = core.Models(
125
  text_model=models.text_model,
126
  tokenizer=models.tokenizer,
127
  generator=generator_rbm,
128
- previewer=models.previewer
 
129
  )
130
 
131
  def reset_inference_state():
@@ -160,8 +161,10 @@ def reset_inference_state():
160
 
161
  models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
162
 
163
- # Ensure effnet is on the correct device
164
  models_rbm.effnet.to(device)
 
 
165
 
166
  # Reset model states
167
  models_rbm.generator.eval().requires_grad_(False)
@@ -204,8 +207,11 @@ def infer(style_description, ref_style_file, caption):
204
 
205
  models_b.generator.to(device)
206
 
207
- # Ensure effnet is on the correct device
208
  models_rbm.effnet.to(device)
 
 
 
209
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
210
 
211
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
 
125
  text_model=models.text_model,
126
  tokenizer=models.tokenizer,
127
  generator=generator_rbm,
128
+ previewer=models.previewer,
129
+ image_model=models.image_model # Add this line
130
  )
131
 
132
  def reset_inference_state():
 
161
 
162
  models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
163
 
164
+ # Ensure effnet and image_model are on the correct device
165
  models_rbm.effnet.to(device)
166
+ if models_rbm.image_model is not None:
167
+ models_rbm.image_model.to(device)
168
 
169
  # Reset model states
170
  models_rbm.generator.eval().requires_grad_(False)
 
207
 
208
  models_b.generator.to(device)
209
 
210
+ # Ensure effnet and image_model are on the correct device
211
  models_rbm.effnet.to(device)
212
+ if models_rbm.image_model is not None:
213
+ models_rbm.image_model.to(device)
214
+
215
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
216
 
217
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)