fffiloni commited on
Commit
95a7614
·
verified ·
1 Parent(s): e76ae74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -81
app.py CHANGED
@@ -2,6 +2,7 @@ import sys
2
  import os
3
  from pathlib import Path
4
  import gc
 
5
 
6
  # Add the StableCascade and CSD directories to the Python path
7
  app_dir = Path(__file__).parent
@@ -27,6 +28,7 @@ from utils import WurstCoreCRBM
27
  from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
 
30
 
31
  # Enable mixed precision
32
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -73,94 +75,69 @@ if low_vram:
73
 
74
  clear_gpu_cache()
75
 
76
- # Stage C model configuration
77
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
78
  with open(config_file, "r", encoding="utf-8") as file:
79
  loaded_config = yaml.safe_load(file)
80
 
81
- core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
82
-
83
- # Stage B model configuration
84
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
85
  with open(config_file_b, "r", encoding="utf-8") as file:
86
  config_file_b = yaml.safe_load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
89
-
90
- # Setup extras and models for Stage C
91
- extras = core.setup_extras_pre()
92
-
93
- gdf_rbm = RBM(
94
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
95
- input_scaler=VPScaler(), target=EpsilonTarget(),
96
- noise_cond=CosineTNoiseCond(),
97
- loss_weight=AdaptiveLossWeight(),
98
- )
99
-
100
- sampling_configs = {
101
- "cfg": 5,
102
- "sampler": DDPMSampler(gdf_rbm),
103
- "shift": 1,
104
- "timesteps": 20
105
- }
106
-
107
- extras = core.Extras(
108
- gdf=gdf_rbm,
109
- sampling_configs=sampling_configs,
110
- transforms=extras.transforms,
111
- effnet_preprocess=extras.effnet_preprocess,
112
- clip_preprocess=extras.clip_preprocess
113
- )
114
-
115
- models = core.setup_models(extras)
116
- models.generator.eval().requires_grad_(False)
117
-
118
- # Setup extras and models for Stage B
119
- extras_b = core_b.setup_extras_pre()
120
- models_b = core_b.setup_models(extras_b, skip_clip=True)
121
- models_b = WurstCoreB.Models(
122
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
123
- )
124
- models_b.generator.bfloat16().eval().requires_grad_(False)
125
-
126
- # Off-load old generator (low VRAM mode)
127
- if low_vram:
128
- models.generator.to("cpu")
129
  clear_gpu_cache()
130
-
131
- # Load and configure new generator
132
- generator_rbm = StageCRBM()
133
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
134
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
135
-
136
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
137
- generator_rbm = core.load_model(generator_rbm, 'generator')
138
-
139
- # Create models_rbm instance
140
- models_rbm = core.Models(
141
- effnet=models.effnet,
142
- previewer=models.previewer,
143
- generator=generator_rbm,
144
- generator_ema=models.generator_ema,
145
- tokenizer=models.tokenizer,
146
- text_model=models.text_model,
147
- image_model=models.image_model
148
- )
149
- models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
  try:
153
- # Move all model components to the same device and set to the same precision
154
- models_rbm.effnet.to(device).bfloat16()
155
- models_rbm.previewer.to(device).bfloat16()
156
- models_rbm.generator.to(device).bfloat16()
157
- models_rbm.text_model.to(device).bfloat16()
158
 
159
- models_b.generator.to(device).bfloat16()
160
- models_b.stage_a.to(device).bfloat16()
161
-
162
- clear_gpu_cache() # Clear cache before inference
163
-
164
  height = 1024
165
  width = 1024
166
  batch_size = 1
@@ -178,7 +155,7 @@ def infer(style_description, ref_style_file, caption):
178
  extras_b.sampling_configs['timesteps'] = 10
179
  extras_b.sampling_configs['t_start'] = 1.0
180
 
181
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device).bfloat16()
182
 
183
  batch = {'captions': [caption] * batch_size}
184
  batch['style'] = ref_style
@@ -195,7 +172,7 @@ def infer(style_description, ref_style_file, caption):
195
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
196
 
197
  # Stage C reverse process
198
- with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use mixed precision with bfloat16
199
  sampling_c = extras.gdf.sample(
200
  models_rbm.generator, conditions, stage_c_latent_shape,
201
  unconditions, device=device,
@@ -212,9 +189,6 @@ def infer(style_description, ref_style_file, caption):
212
 
213
  clear_gpu_cache() # Clear cache between stages
214
 
215
- # Ensure all models are on the right device again
216
- models_b.generator.to(device).bfloat16()
217
-
218
  # Stage B reverse process
219
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
220
  conditions_b['effnet'] = sampled_c
@@ -243,6 +217,7 @@ def infer(style_description, ref_style_file, caption):
243
 
244
  except Exception as e:
245
  print(f"An error occurred during inference: {str(e)}")
 
246
  return None
247
 
248
  finally:
@@ -252,8 +227,11 @@ def infer(style_description, ref_style_file, caption):
252
 
253
  import gradio as gr
254
 
 
 
 
255
  gr.Interface(
256
- fn = infer,
257
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
258
  outputs=[gr.Image()]
259
  ).launch()
 
2
  import os
3
  from pathlib import Path
4
  import gc
5
+ import traceback
6
 
7
  # Add the StableCascade and CSD directories to the Python path
8
  app_dir = Path(__file__).parent
 
28
  from gdf.schedulers import CosineSchedule
29
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
30
  from gdf.targets import EpsilonTarget
31
+ import PIL
32
 
33
  # Enable mixed precision
34
  torch.backends.cuda.matmul.allow_tf32 = True
 
75
 
76
  clear_gpu_cache()
77
 
78
+ # Load configurations
79
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
80
  with open(config_file, "r", encoding="utf-8") as file:
81
  loaded_config = yaml.safe_load(file)
82
 
 
 
 
83
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
84
  with open(config_file_b, "r", encoding="utf-8") as file:
85
  config_file_b = yaml.safe_load(file)
86
+
87
+ def initialize_models():
88
+ global models_rbm, models_b, extras, extras_b, core, core_b
89
+
90
+ # Clear any existing models from memory
91
+ models_rbm = None
92
+ models_b = None
93
+ extras = None
94
+ extras_b = None
95
+
96
+ # Clear GPU cache
97
+ clear_gpu_cache()
98
+
99
+ # Initialize models
100
+ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
101
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
102
+
103
+ extras = core.setup_extras_pre()
104
+ models = core.setup_models(extras)
105
+
106
+ extras_b = core_b.setup_extras_pre()
107
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
108
+ models_b = WurstCoreB.Models(
109
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
110
+ )
111
+
112
+ # Initialize models_rbm
113
+ generator_rbm = StageCRBM()
114
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
115
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
116
+
117
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
118
+ generator_rbm = core.load_model(generator_rbm, 'generator')
119
+
120
+ models_rbm = core.Models(
121
+ effnet=models.effnet,
122
+ previewer=models.previewer,
123
+ generator=generator_rbm,
124
+ generator_ema=models.generator_ema,
125
+ tokenizer=models.tokenizer,
126
+ text_model=models.text_model,
127
+ image_model=models.image_model
128
+ )
129
+
130
+ # Move models to appropriate devices
131
+ models_rbm.generator.to(device).eval().requires_grad_(False)
132
+ models_b.generator.to(device).eval().requires_grad_(False)
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  clear_gpu_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def infer(style_description, ref_style_file, caption):
137
  try:
138
+ # Initialize (or reinitialize) models before each inference
139
+ initialize_models()
 
 
 
140
 
 
 
 
 
 
141
  height = 1024
142
  width = 1024
143
  batch_size = 1
 
155
  extras_b.sampling_configs['timesteps'] = 10
156
  extras_b.sampling_configs['t_start'] = 1.0
157
 
158
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
159
 
160
  batch = {'captions': [caption] * batch_size}
161
  batch['style'] = ref_style
 
172
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
173
 
174
  # Stage C reverse process
175
+ with torch.cuda.amp.autocast():
176
  sampling_c = extras.gdf.sample(
177
  models_rbm.generator, conditions, stage_c_latent_shape,
178
  unconditions, device=device,
 
189
 
190
  clear_gpu_cache() # Clear cache between stages
191
 
 
 
 
192
  # Stage B reverse process
193
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
194
  conditions_b['effnet'] = sampled_c
 
217
 
218
  except Exception as e:
219
  print(f"An error occurred during inference: {str(e)}")
220
+ traceback.print_exc() # This will print the full traceback
221
  return None
222
 
223
  finally:
 
227
 
228
  import gradio as gr
229
 
230
+ def gradio_interface(style_description, ref_style_file, caption):
231
+ return infer(style_description, ref_style_file, caption)
232
+
233
  gr.Interface(
234
+ fn=gradio_interface,
235
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
236
  outputs=[gr.Image()]
237
  ).launch()