fffiloni commited on
Commit
8020398
·
verified ·
1 Parent(s): c1fff88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -65
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
  import os
3
  from pathlib import Path
 
4
 
5
  # Add the StableCascade and CSD directories to the Python path
6
  app_dir = Path(__file__).parent
@@ -130,17 +131,12 @@ models_rbm = core.Models(
130
  )
131
  models_rbm.generator.eval().requires_grad_(False)
132
 
133
- def infer(style_description, ref_style_file, caption):
134
-
135
- height=1024
136
- width=1024
137
- batch_size=1
138
- output_file='output.png'
139
 
140
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
141
-
142
- extras.sampling_configs['cfg'] = 4
143
- extras.sampling_configs['shift'] = 2
144
  extras.sampling_configs['timesteps'] = 20
145
  extras.sampling_configs['t_start'] = 1.0
146
 
@@ -149,66 +145,101 @@ def infer(style_description, ref_style_file, caption):
149
  extras_b.sampling_configs['timesteps'] = 10
150
  extras_b.sampling_configs['t_start'] = 1.0
151
 
152
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
153
-
154
- batch = {'captions': [caption] * batch_size}
155
- batch['style'] = ref_style
156
-
157
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
158
-
159
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
160
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
161
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
162
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
163
-
164
  if low_vram:
165
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
166
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
 
 
 
 
 
 
 
 
167
 
168
- # Stage C reverse process.
169
- sampling_c = extras.gdf.sample(
170
- models_rbm.generator, conditions, stage_c_latent_shape,
171
- unconditions, device=device,
172
- **extras.sampling_configs,
173
- x0_style_forward=x0_style_forward,
174
- apply_pushforward=False, tau_pushforward=8,
175
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
176
- extras=extras, models=models_rbm,
177
- lam_style=1, lam_txt_alignment=1.0,
178
- use_ddim_sampler=True,
179
- )
180
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
181
- sampled_c = sampled_c
182
-
183
- # Stage B reverse process.
184
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
185
- conditions_b['effnet'] = sampled_c
186
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
187
 
188
- sampling_b = extras_b.gdf.sample(
189
- models_b.generator, conditions_b, stage_b_latent_shape,
190
- unconditions_b, device=device, **extras_b.sampling_configs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
193
- sampled_b = sampled_b
194
- sampled = models_b.stage_a.decode(sampled_b).float()
195
-
196
- sampled = torch.cat([
197
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
198
- sampled.cpu(),
199
- ], dim=0)
200
-
201
- # Remove the batch dimension and keep only the generated image
202
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
203
-
204
- # Ensure the tensor is in [C, H, W] format
205
- if sampled.dim() == 3 and sampled.shape[0] == 3:
206
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
207
- sampled_image.save(output_file) # Save the image as a PNG
208
- else:
209
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
210
-
211
- return output_file # Return the path to the saved image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  import gradio as gr
214
 
 
1
  import sys
2
  import os
3
  from pathlib import Path
4
+ import gc
5
 
6
  # Add the StableCascade and CSD directories to the Python path
7
  app_dir = Path(__file__).parent
 
131
  )
132
  models_rbm.generator.eval().requires_grad_(False)
133
 
134
+ def reset_inference_state():
135
+ global models_rbm, models_b, extras, extras_b
 
 
 
 
136
 
137
+ # Reset sampling configurations
138
+ extras.sampling_configs['cfg'] = 5
139
+ extras.sampling_configs['shift'] = 1
 
140
  extras.sampling_configs['timesteps'] = 20
141
  extras.sampling_configs['t_start'] = 1.0
142
 
 
145
  extras_b.sampling_configs['timesteps'] = 10
146
  extras_b.sampling_configs['t_start'] = 1.0
147
 
148
+ # Move models back to initial state
 
 
 
 
 
 
 
 
 
 
 
149
  if low_vram:
 
150
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
151
+ models_b.generator.to("cpu")
152
+ else:
153
+ models_to(models_rbm, device="cuda")
154
+ models_b.generator.to("cuda")
155
+
156
+ # Clear CUDA cache
157
+ torch.cuda.empty_cache()
158
+ gc.collect()
159
 
160
+ def infer(style_description, ref_style_file, caption):
161
+ try:
162
+ height=1024
163
+ width=1024
164
+ batch_size=1
165
+ output_file='output.png'
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
168
+
169
+ extras.sampling_configs['cfg'] = 4
170
+ extras.sampling_configs['shift'] = 2
171
+ extras.sampling_configs['timesteps'] = 20
172
+ extras.sampling_configs['t_start'] = 1.0
173
+
174
+ extras_b.sampling_configs['cfg'] = 1.1
175
+ extras_b.sampling_configs['shift'] = 1
176
+ extras_b.sampling_configs['timesteps'] = 10
177
+ extras_b.sampling_configs['t_start'] = 1.0
178
+
179
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
180
+
181
+ batch = {'captions': [caption] * batch_size}
182
+ batch['style'] = ref_style
183
+
184
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
185
+
186
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
187
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
188
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
189
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
190
+
191
+ if low_vram:
192
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
193
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
194
+
195
+ # Stage C reverse process.
196
+ sampling_c = extras.gdf.sample(
197
+ models_rbm.generator, conditions, stage_c_latent_shape,
198
+ unconditions, device=device,
199
+ **extras.sampling_configs,
200
+ x0_style_forward=x0_style_forward,
201
+ apply_pushforward=False, tau_pushforward=8,
202
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
203
+ extras=extras, models=models_rbm,
204
+ lam_style=1, lam_txt_alignment=1.0,
205
+ use_ddim_sampler=True,
206
  )
207
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
208
+ sampled_c = sampled_c
209
+
210
+ # Stage B reverse process.
211
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
212
+ conditions_b['effnet'] = sampled_c
213
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
214
+
215
+ sampling_b = extras_b.gdf.sample(
216
+ models_b.generator, conditions_b, stage_b_latent_shape,
217
+ unconditions_b, device=device, **extras_b.sampling_configs,
218
+ )
219
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
220
+ sampled_b = sampled_b
221
+ sampled = models_b.stage_a.decode(sampled_b).float()
222
+
223
+ sampled = torch.cat([
224
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
225
+ sampled.cpu(),
226
+ ], dim=0)
227
+
228
+ # Remove the batch dimension and keep only the generated image
229
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
230
+
231
+ # Ensure the tensor is in [C, H, W] format
232
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
233
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
234
+ sampled_image.save(output_file) # Save the image as a PNG
235
+ else:
236
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
237
+
238
+ return output_file # Return the path to the saved image
239
+
240
+ finally:
241
+ # Reset the state after inference, regardless of success or failure
242
+ reset_inference_state()
243
 
244
  import gradio as gr
245