fffiloni commited on
Commit
c24728c
·
verified ·
1 Parent(s): 95a7614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -135,9 +135,13 @@ def initialize_models():
135
 
136
  def infer(style_description, ref_style_file, caption):
137
  try:
138
- # Initialize (or reinitialize) models before each inference
139
- initialize_models()
140
-
 
 
 
 
141
  height = 1024
142
  width = 1024
143
  batch_size = 1
@@ -145,16 +149,6 @@ def infer(style_description, ref_style_file, caption):
145
 
146
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
147
 
148
- extras.sampling_configs['cfg'] = 4
149
- extras.sampling_configs['shift'] = 2
150
- extras.sampling_configs['timesteps'] = 20
151
- extras.sampling_configs['t_start'] = 1.0
152
-
153
- extras_b.sampling_configs['cfg'] = 1.1
154
- extras_b.sampling_configs['shift'] = 1
155
- extras_b.sampling_configs['timesteps'] = 10
156
- extras_b.sampling_configs['t_start'] = 1.0
157
-
158
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
159
 
160
  batch = {'captions': [caption] * batch_size}
@@ -189,6 +183,9 @@ def infer(style_description, ref_style_file, caption):
189
 
190
  clear_gpu_cache() # Clear cache between stages
191
 
 
 
 
192
  # Stage B reverse process
193
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
194
  conditions_b['effnet'] = sampled_c
 
135
 
136
  def infer(style_description, ref_style_file, caption):
137
  try:
138
+ # Clear GPU cache before inference
139
+ clear_gpu_cache()
140
+
141
+ # Ensure models are on the correct device
142
+ models_rbm.to(device)
143
+ models_b.to(device)
144
+
145
  height = 1024
146
  width = 1024
147
  batch_size = 1
 
149
 
150
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
151
 
 
 
 
 
 
 
 
 
 
 
152
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
153
 
154
  batch = {'captions': [caption] * batch_size}
 
183
 
184
  clear_gpu_cache() # Clear cache between stages
185
 
186
+ # Ensure models_b is on the correct device
187
+ models_b.to(device)
188
+
189
  # Stage B reverse process
190
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
191
  conditions_b['effnet'] = sampled_c