fffiloni commited on
Commit
6ec8160
1 Parent(s): 8fab8d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -38
app.py CHANGED
@@ -149,25 +149,13 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
- # Move models to the correct device
153
- models_rbm.effnet.to(device)
154
- models_rbm.generator.to(device)
155
- if low_vram:
156
- models_rbm.previewer.to(device)
157
-
158
- # Also, revalidate data types and devices for key tensors
159
- def check_and_move(tensor):
160
- if tensor is not None and tensor.device != device:
161
- return tensor.to(device)
162
- return tensor
163
-
164
  clear_gpu_cache() # Clear cache before inference
165
 
166
- height = 1024
167
- width = 1024
168
- batch_size = 1
169
- output_file = 'output.png'
170
-
171
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
172
 
173
  extras.sampling_configs['cfg'] = 4
@@ -180,26 +168,24 @@ def infer(style_description, ref_style_file, caption):
180
  extras_b.sampling_configs['timesteps'] = 10
181
  extras_b.sampling_configs['t_start'] = 1.0
182
 
183
- # Load and preprocess the reference style image
184
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB"))
185
- ref_style = ref_style.unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
186
 
187
  batch = {'captions': [caption] * batch_size}
188
  batch['style'] = ref_style
189
 
190
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
191
 
192
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
193
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
194
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
195
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
196
 
197
  if low_vram:
198
- # Offload non-essential models to CPU for memory savings
199
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
200
 
201
- # Stage C reverse process
202
- with torch.cuda.amp.autocast():
203
  sampling_c = extras.gdf.sample(
204
  models_rbm.generator, conditions, stage_c_latent_shape,
205
  unconditions, device=device,
@@ -216,24 +202,19 @@ def infer(style_description, ref_style_file, caption):
216
 
217
  clear_gpu_cache() # Clear cache between stages
218
 
219
- # Ensure all models are on the right device again
220
- models_b.generator.to(device)
221
- models_b.stage_a.to(device)
222
-
223
- # Stage B reverse process
224
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
225
- conditions_b['effnet'] = sampled_c.to(device)
226
- unconditions_b['effnet'] = torch.zeros_like(sampled_c).to(device)
227
-
228
  sampling_b = extras_b.gdf.sample(
229
  models_b.generator, conditions_b, stage_b_latent_shape,
230
  unconditions_b, device=device, **extras_b.sampling_configs,
231
  )
232
  for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
233
  sampled_b = sampled_b
234
- sampled = models_b.stage_a.decode(sampled_b).float().to(device)
235
 
236
- # Post-process and save the image
237
  sampled = torch.cat([
238
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
239
  sampled.cpu(),
@@ -253,8 +234,6 @@ def infer(style_description, ref_style_file, caption):
253
 
254
  return output_file # Return the path to the saved image
255
 
256
-
257
-
258
  import gradio as gr
259
 
260
  gr.Interface(
 
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
 
 
 
 
 
 
 
 
 
 
 
 
152
  clear_gpu_cache() # Clear cache before inference
153
 
154
+ height=1024
155
+ width=1024
156
+ batch_size=1
157
+ output_file='output.png'
158
+
159
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
160
 
161
  extras.sampling_configs['cfg'] = 4
 
168
  extras_b.sampling_configs['timesteps'] = 10
169
  extras_b.sampling_configs['t_start'] = 1.0
170
 
171
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
 
 
172
 
173
  batch = {'captions': [caption] * batch_size}
174
  batch['style'] = ref_style
175
 
176
  x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
177
 
178
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
179
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
180
  conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
181
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
182
 
183
  if low_vram:
184
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
185
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
186
 
187
+ # Stage C reverse process.
188
+ with torch.cuda.amp.autocast(): # Use mixed precision
189
  sampling_c = extras.gdf.sample(
190
  models_rbm.generator, conditions, stage_c_latent_shape,
191
  unconditions, device=device,
 
202
 
203
  clear_gpu_cache() # Clear cache between stages
204
 
205
+ # Stage B reverse process.
206
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
207
+ conditions_b['effnet'] = sampled_c
208
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
209
+
 
 
 
 
210
  sampling_b = extras_b.gdf.sample(
211
  models_b.generator, conditions_b, stage_b_latent_shape,
212
  unconditions_b, device=device, **extras_b.sampling_configs,
213
  )
214
  for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
215
  sampled_b = sampled_b
216
+ sampled = models_b.stage_a.decode(sampled_b).float()
217
 
 
218
  sampled = torch.cat([
219
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
  sampled.cpu(),
 
234
 
235
  return output_file # Return the path to the saved image
236
 
 
 
237
  import gradio as gr
238
 
239
  gr.Interface(