fffiloni commited on
Commit
e429857
·
verified ·
1 Parent(s): 2c0e7f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -82
app.py CHANGED
@@ -131,92 +131,92 @@ def infer(ref_style_file, style_description, caption, progress):
131
  if low_vram:
132
  models_to(models_rbm, device=device)
133
  try:
134
- with progress:
135
- caption = f"{caption} in {style_description}"
136
- height=1024
137
- width=1024
138
- batch_size=1
139
- output_file='output.png'
140
-
141
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
142
-
143
- extras.sampling_configs['cfg'] = 4
144
- extras.sampling_configs['shift'] = 2
145
- extras.sampling_configs['timesteps'] = 20
146
- extras.sampling_configs['t_start'] = 1.0
147
-
148
- extras_b.sampling_configs['cfg'] = 1.1
149
- extras_b.sampling_configs['shift'] = 1
150
- extras_b.sampling_configs['timesteps'] = 10
151
- extras_b.sampling_configs['t_start'] = 1.0
152
 
153
- progress(0.1, "Loading style reference image")
154
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
155
-
156
- batch = {'captions': [caption] * batch_size}
157
- batch['style'] = ref_style
158
 
159
- progress(0.2, "Processing style reference image")
160
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
 
 
161
 
162
- progress(0.3, "Generating conditions")
163
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
164
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
165
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
166
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
167
-
168
- if low_vram:
169
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
170
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
171
-
172
- progress(0.4, "Starting Stage C reverse process")
173
- # Stage C reverse process.
174
- sampling_c = extras.gdf.sample(
175
- models_rbm.generator, conditions, stage_c_latent_shape,
176
- unconditions, device=device,
177
- **extras.sampling_configs,
178
- x0_style_forward=x0_style_forward,
179
- apply_pushforward=False, tau_pushforward=8,
180
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
181
- extras=extras, models=models_rbm,
182
- lam_style=1, lam_txt_alignment=1.0,
183
- use_ddim_sampler=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
- for (sampled_c, _, _) in progress.track(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), description="Stage C reverse process"):
186
- sampled_c = sampled_c
187
-
188
- progress(0.7, "Starting Stage B reverse process")
189
- # Stage B reverse process.
190
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
191
- conditions_b['effnet'] = sampled_c
192
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
193
-
194
- sampling_b = extras_b.gdf.sample(
195
- models_b.generator, conditions_b, stage_b_latent_shape,
196
- unconditions_b, device=device, **extras_b.sampling_configs,
197
- )
198
- for (sampled_b, _, _) in progress.track(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), description="Stage B reverse process"):
199
- sampled_b = sampled_b
200
- sampled = models_b.stage_a.decode(sampled_b).float()
201
-
202
- progress(0.9, "Finalizing the output image")
203
- sampled = torch.cat([
204
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
205
- sampled.cpu(),
206
- ], dim=0)
207
-
208
- # Remove the batch dimension and keep only the generated image
209
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
210
-
211
- # Ensure the tensor is in [C, H, W] format
212
- if sampled.dim() == 3 and sampled.shape[0] == 3:
213
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
214
- sampled_image.save(output_file) # Save the image as a PNG
215
- else:
216
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
217
-
218
- progress(1.0, "Inference complete")
219
- return output_file # Return the path to the saved image
220
 
221
  finally:
222
  # Clear CUDA cache
 
131
  if low_vram:
132
  models_to(models_rbm, device=device)
133
  try:
134
+
135
+ caption = f"{caption} in {style_description}"
136
+ height=1024
137
+ width=1024
138
+ batch_size=1
139
+ output_file='output.png'
140
+
141
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
 
 
 
 
 
 
 
 
 
 
142
 
143
+ extras.sampling_configs['cfg'] = 4
144
+ extras.sampling_configs['shift'] = 2
145
+ extras.sampling_configs['timesteps'] = 20
146
+ extras.sampling_configs['t_start'] = 1.0
 
147
 
148
+ extras_b.sampling_configs['cfg'] = 1.1
149
+ extras_b.sampling_configs['shift'] = 1
150
+ extras_b.sampling_configs['timesteps'] = 10
151
+ extras_b.sampling_configs['t_start'] = 1.0
152
 
153
+ progress(0.1, "Loading style reference image")
154
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
155
+
156
+ batch = {'captions': [caption] * batch_size}
157
+ batch['style'] = ref_style
158
+
159
+ progress(0.2, "Processing style reference image")
160
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
161
+
162
+ progress(0.3, "Generating conditions")
163
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
164
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
165
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
166
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
167
+
168
+ if low_vram:
169
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
170
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
171
+
172
+ progress(0.4, "Starting Stage C reverse process")
173
+ # Stage C reverse process.
174
+ sampling_c = extras.gdf.sample(
175
+ models_rbm.generator, conditions, stage_c_latent_shape,
176
+ unconditions, device=device,
177
+ **extras.sampling_configs,
178
+ x0_style_forward=x0_style_forward,
179
+ apply_pushforward=False, tau_pushforward=8,
180
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
181
+ extras=extras, models=models_rbm,
182
+ lam_style=1, lam_txt_alignment=1.0,
183
+ use_ddim_sampler=True,
184
+ )
185
+ for (sampled_c, _, _) in progress.track(tqdm(sampling_c, total=extras.sampling_configs['timesteps']), description="Stage C reverse process"):
186
+ sampled_c = sampled_c
187
+
188
+ progress(0.7, "Starting Stage B reverse process")
189
+ # Stage B reverse process.
190
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
191
+ conditions_b['effnet'] = sampled_c
192
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
193
+
194
+ sampling_b = extras_b.gdf.sample(
195
+ models_b.generator, conditions_b, stage_b_latent_shape,
196
+ unconditions_b, device=device, **extras_b.sampling_configs,
197
  )
198
+ for (sampled_b, _, _) in progress.track(tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']), description="Stage B reverse process"):
199
+ sampled_b = sampled_b
200
+ sampled = models_b.stage_a.decode(sampled_b).float()
201
+
202
+ progress(0.9, "Finalizing the output image")
203
+ sampled = torch.cat([
204
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
205
+ sampled.cpu(),
206
+ ], dim=0)
207
+
208
+ # Remove the batch dimension and keep only the generated image
209
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
210
+
211
+ # Ensure the tensor is in [C, H, W] format
212
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
213
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
214
+ sampled_image.save(output_file) # Save the image as a PNG
215
+ else:
216
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
217
+
218
+ progress(1.0, "Inference complete")
219
+ return output_file # Return the path to the saved image
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  finally:
222
  # Clear CUDA cache