fffiloni commited on
Commit
b6f23ea
·
verified ·
1 Parent(s): bd18e87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -133,7 +133,7 @@ models_rbm = core.Models(
133
  models_rbm.generator.eval().requires_grad_(False)
134
 
135
  def reset_inference_state():
136
- global models_rbm, models_b, extras, extras_b, device
137
 
138
  # Reset sampling configurations
139
  extras.sampling_configs['cfg'] = 5
@@ -146,6 +146,13 @@ def reset_inference_state():
146
  extras_b.sampling_configs['timesteps'] = 10
147
  extras_b.sampling_configs['t_start'] = 1.0
148
 
 
 
 
 
 
 
 
149
  # Move models to the correct device
150
  if low_vram:
151
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
@@ -157,11 +164,16 @@ def reset_inference_state():
157
  # Ensure effnet is on the correct device
158
  models_rbm.effnet.to(device)
159
 
 
 
 
 
160
  # Clear CUDA cache
161
  torch.cuda.empty_cache()
162
  gc.collect()
163
 
164
  def infer(style_description, ref_style_file, caption):
 
165
  try:
166
  height=1024
167
  width=1024
 
133
  models_rbm.generator.eval().requires_grad_(False)
134
 
135
  def reset_inference_state():
136
+ global models_rbm, models_b, extras, extras_b, device, core, core_b
137
 
138
  # Reset sampling configurations
139
  extras.sampling_configs['cfg'] = 5
 
146
  extras_b.sampling_configs['timesteps'] = 10
147
  extras_b.sampling_configs['t_start'] = 1.0
148
 
149
+ # Reset models
150
+ models_rbm = core.setup_models(extras)
151
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
152
+ models_b = WurstCoreB.Models(
153
+ **{**models_b.to_dict(), 'tokenizer': models_rbm.tokenizer, 'text_model': models_rbm.text_model}
154
+ )
155
+
156
  # Move models to the correct device
157
  if low_vram:
158
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
 
164
  # Ensure effnet is on the correct device
165
  models_rbm.effnet.to(device)
166
 
167
+ # Set models to eval mode and disable gradients
168
+ models_rbm.generator.eval().requires_grad_(False)
169
+ models_b.generator.bfloat16().eval().requires_grad_(False)
170
+
171
  # Clear CUDA cache
172
  torch.cuda.empty_cache()
173
  gc.collect()
174
 
175
  def infer(style_description, ref_style_file, caption):
176
+ global models_rbm, models_b
177
  try:
178
  height=1024
179
  width=1024