fffiloni commited on
Commit
2b29f24
·
verified ·
1 Parent(s): 467adb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -64
app.py CHANGED
@@ -131,21 +131,9 @@ models_rbm = core.Models(
131
 
132
  def unload_models_and_clear_cache():
133
  global models_rbm, models_b, sam_model, extras, extras_b
134
-
135
- # Reset sampling configurations
136
- extras.sampling_configs['cfg'] = 5
137
- extras.sampling_configs['shift'] = 1
138
- extras.sampling_configs['timesteps'] = 20
139
- extras.sampling_configs['t_start'] = 1.0
140
-
141
- extras_b.sampling_configs['cfg'] = 1.1
142
- extras_b.sampling_configs['shift'] = 1
143
- extras_b.sampling_configs['timesteps'] = 10
144
- extras_b.sampling_configs['t_start'] = 1.0
145
 
146
  # Move all models to CPU
147
  models_to(models_rbm, device="cpu")
148
- models_b.generator.to("cpu")
149
 
150
  # Move SAM model components to CPU if they exist
151
  if 'sam_model' in globals():
@@ -168,48 +156,12 @@ def unload_models_and_clear_cache():
168
 
169
  def reset_inference_state():
170
  global models_rbm, models_b, extras, extras_b, device, core, core_b
171
-
172
- # Reset sampling configurations
173
- extras.sampling_configs['cfg'] = 5
174
- extras.sampling_configs['shift'] = 1
175
- extras.sampling_configs['timesteps'] = 20
176
- extras.sampling_configs['t_start'] = 1.0
177
-
178
- extras_b.sampling_configs['cfg'] = 1.1
179
- extras_b.sampling_configs['shift'] = 1
180
- extras_b.sampling_configs['timesteps'] = 10
181
- extras_b.sampling_configs['t_start'] = 1.0
182
-
183
- # Move models to CPU to free up GPU memory
184
- models_to(models_rbm, device="cpu")
185
- models_b.generator.to("cpu")
186
 
187
  # Clear CUDA cache
188
  torch.cuda.empty_cache()
189
  gc.collect()
190
 
191
- # Move necessary models back to the correct device
192
- if low_vram:
193
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
194
- models_rbm.generator.to(device)
195
- models_rbm.previewer.to(device)
196
- else:
197
- models_to(models_rbm, device=device)
198
-
199
- models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
200
-
201
- # Ensure effnet and image_model are on the correct device
202
- models_rbm.effnet.to(device)
203
- if models_rbm.image_model is not None:
204
- models_rbm.image_model.to(device)
205
-
206
- # Reset model states
207
- models_rbm.generator.eval().requires_grad_(False)
208
- models_b.generator.bfloat16().eval().requires_grad_(False)
209
-
210
- # Clear CUDA cache again
211
- torch.cuda.empty_cache()
212
- gc.collect()
213
 
214
  def infer(ref_style_file, style_description, caption):
215
  global models_rbm, models_b
@@ -237,19 +189,6 @@ def infer(ref_style_file, style_description, caption):
237
  batch = {'captions': [caption] * batch_size}
238
  batch['style'] = ref_style
239
 
240
- # Ensure models are on the correct device before inference
241
- if low_vram:
242
- models_to(models_rbm, device=device, excepts=["generator", "previewer"])
243
- else:
244
- models_to(models_rbm, device=device)
245
-
246
- models_b.generator.to(device)
247
-
248
- # Ensure effnet and image_model are on the correct device
249
- models_rbm.effnet.to(device)
250
- if models_rbm.image_model is not None:
251
- models_rbm.image_model.to(device)
252
-
253
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
254
 
255
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
@@ -308,9 +247,9 @@ def infer(ref_style_file, style_description, caption):
308
 
309
  finally:
310
  # Reset the state after inference, regardless of success or failure
311
- # reset_inference_state()
312
  # Unload models and clear cache after inference
313
- unload_models_and_clear_cache()
314
 
315
  def reset_compo_inference_state():
316
  global models_rbm, models_b, extras, extras_b, device, core, core_b, sam_model
 
131
 
132
  def unload_models_and_clear_cache():
133
  global models_rbm, models_b, sam_model, extras, extras_b
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Move all models to CPU
136
  models_to(models_rbm, device="cpu")
 
137
 
138
  # Move SAM model components to CPU if they exist
139
  if 'sam_model' in globals():
 
156
 
157
  def reset_inference_state():
158
  global models_rbm, models_b, extras, extras_b, device, core, core_b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # Clear CUDA cache
161
  torch.cuda.empty_cache()
162
  gc.collect()
163
 
164
+ models_to(models_rbm, device=device, excepts=["generator", "previewer"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  def infer(ref_style_file, style_description, caption):
167
  global models_rbm, models_b
 
189
  batch = {'captions': [caption] * batch_size}
190
  batch['style'] = ref_style
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
193
 
194
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
 
247
 
248
  finally:
249
  # Reset the state after inference, regardless of success or failure
250
+ reset_inference_state()
251
  # Unload models and clear cache after inference
252
+ # unload_models_and_clear_cache()
253
 
254
  def reset_compo_inference_state():
255
  global models_rbm, models_b, extras, extras_b, device, core, core_b, sam_model