fffiloni commited on
Commit
2c0e7f7
1 Parent(s): 241d1e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -75
app.py CHANGED
@@ -126,89 +126,97 @@ models_rbm.generator.eval().requires_grad_(False)
126
 
127
  sam_model = LangSAM()
128
 
129
- def infer(ref_style_file, style_description, caption):
130
  global models_rbm, models_b, device
131
  if low_vram:
132
  models_to(models_rbm, device=device)
133
  try:
134
- caption = f"{caption} in {style_description}"
135
- height=1024
136
- width=1024
137
- batch_size=1
138
- output_file='output.png'
139
-
140
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
141
-
142
- extras.sampling_configs['cfg'] = 4
143
- extras.sampling_configs['shift'] = 2
144
- extras.sampling_configs['timesteps'] = 20
145
- extras.sampling_configs['t_start'] = 1.0
146
-
147
- extras_b.sampling_configs['cfg'] = 1.1
148
- extras_b.sampling_configs['shift'] = 1
149
- extras_b.sampling_configs['timesteps'] = 10
150
- extras_b.sampling_configs['t_start'] = 1.0
151
-
152
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
153
-
154
- batch = {'captions': [caption] * batch_size}
155
- batch['style'] = ref_style
156
-
157
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
158
-
159
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
160
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
161
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
162
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
163
-
164
- if low_vram:
165
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
166
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
167
-
168
- # Stage C reverse process.
169
- sampling_c = extras.gdf.sample(
170
- models_rbm.generator, conditions, stage_c_latent_shape,
171
- unconditions, device=device,
172
- **extras.sampling_configs,
173
- x0_style_forward=x0_style_forward,
174
- apply_pushforward=False, tau_pushforward=8,
175
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
176
- extras=extras, models=models_rbm,
177
- lam_style=1, lam_txt_alignment=1.0,
178
- use_ddim_sampler=True,
179
- )
180
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
181
- sampled_c = sampled_c
182
-
183
- # Stage B reverse process.
184
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
185
- conditions_b['effnet'] = sampled_c
186
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
187
 
188
- sampling_b = extras_b.gdf.sample(
189
- models_b.generator, conditions_b, stage_b_latent_shape,
190
- unconditions_b, device=device, **extras_b.sampling_configs,
191
- )
192
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
193
- sampled_b = sampled_b
194
- sampled = models_b.stage_a.decode(sampled_b).float()
 
 
 
 
195
 
196
- sampled = torch.cat([
197
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
198
- sampled.cpu(),
199
- ], dim=0)
 
200
 
201
- # Remove the batch dimension and keep only the generated image
202
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
203
 
204
- # Ensure the tensor is in [C, H, W] format
205
- if sampled.dim() == 3 and sampled.shape[0] == 3:
206
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
207
- sampled_image.save(output_file) # Save the image as a PNG
208
- else:
209
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- return output_file # Return the path to the saved image
 
212
 
213
  finally:
214
  # Clear CUDA cache
@@ -324,10 +332,11 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
324
 
325
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
326
  result = None
 
327
  if use_subject_ref is True:
328
  result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference)
329
  else:
330
- result = infer(style_reference_image, style_description, subject_prompt)
331
  return result
332
 
333
  def show_hide_subject_image_component(use_subject_ref):
 
126
 
127
  sam_model = LangSAM()
128
 
129
+ def infer(ref_style_file, style_description, caption, progress):
130
  global models_rbm, models_b, device
131
  if low_vram:
132
  models_to(models_rbm, device=device)
133
  try:
134
+ with progress:
135
+ caption = f"{caption} in {style_description}"
136
+ height=1024
137
+ width=1024
138
+ batch_size=1
139
+ output_file='output.png'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
142
+
143
+ extras.sampling_configs['cfg'] = 4
144
+ extras.sampling_configs['shift'] = 2
145
+ extras.sampling_configs['timesteps'] = 20
146
+ extras.sampling_configs['t_start'] = 1.0
147
+
148
+ extras_b.sampling_configs['cfg'] = 1.1
149
+ extras_b.sampling_configs['shift'] = 1
150
+ extras_b.sampling_configs['timesteps'] = 10
151
+ extras_b.sampling_configs['t_start'] = 1.0
152
 
153
+ progress(0.1, "Loading style reference image")
154
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
155
+
156
+ batch = {'captions': [caption] * batch_size}
157
+ batch['style'] = ref_style
158
 
159
+ progress(0.2, "Processing style reference image")
160
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
161
 
162
+ progress(0.3, "Generating conditions")
163
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
164
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
165
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
166
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
167
+
168
+ if low_vram:
169
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
170
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
171
+
172
+ progress(0.4, "Starting Stage C reverse process")
173
+ # Stage C reverse process.
174
+ sampling_c = extras.gdf.sample(
175
+ models_rbm.generator, conditions, stage_c_latent_shape,
176
+ unconditions, device=device,
177
+ **extras.sampling_configs,
178
+ x0_style_forward=x0_style_forward,
179
+ apply_pushforward=False, tau_pushforward=8,
180
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
181
+ extras=extras, models=models_rbm,
182
+ lam_style=1, lam_txt_alignment=1.0,
183
+ use_ddim_sampler=True,
184
+ )
185
+ for (sampled_c, _, _) in progress.track(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), description="Stage C reverse process"):
186
+ sampled_c = sampled_c
187
+
188
+ progress(0.7, "Starting Stage B reverse process")
189
+ # Stage B reverse process.
190
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
191
+ conditions_b['effnet'] = sampled_c
192
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
193
+
194
+ sampling_b = extras_b.gdf.sample(
195
+ models_b.generator, conditions_b, stage_b_latent_shape,
196
+ unconditions_b, device=device, **extras_b.sampling_configs,
197
+ )
198
+ for (sampled_b, _, _) in progress.track(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), description="Stage B reverse process"):
199
+ sampled_b = sampled_b
200
+ sampled = models_b.stage_a.decode(sampled_b).float()
201
+
202
+ progress(0.9, "Finalizing the output image")
203
+ sampled = torch.cat([
204
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
205
+ sampled.cpu(),
206
+ ], dim=0)
207
+
208
+ # Remove the batch dimension and keep only the generated image
209
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
210
+
211
+ # Ensure the tensor is in [C, H, W] format
212
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
213
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
214
+ sampled_image.save(output_file) # Save the image as a PNG
215
+ else:
216
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
217
 
218
+ progress(1.0, "Inference complete")
219
+ return output_file # Return the path to the saved image
220
 
221
  finally:
222
  # Clear CUDA cache
 
332
 
333
  def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
334
  result = None
335
+ progress = gr.Progress(track_tqdm=True)
336
  if use_subject_ref is True:
337
  result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference)
338
  else:
339
+ result = infer(style_reference_image, style_description, subject_prompt, progress)
340
  return result
341
 
342
  def show_hide_subject_image_component(use_subject_ref):