fffiloni commited on
Commit
024ee6a
1 Parent(s): 04793a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -1
app.py CHANGED
@@ -73,7 +73,80 @@ if low_vram:
73
 
74
  clear_gpu_cache()
75
 
76
- # ... (rest of your setup code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def infer(style_description, ref_style_file, caption):
79
  clear_gpu_cache() # Clear cache before inference
 
73
 
74
  clear_gpu_cache()
75
 
76
+ # Stage C model configuration
77
+ config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
78
+ with open(config_file, "r", encoding="utf-8") as file:
79
+ loaded_config = yaml.safe_load(file)
80
+
81
+ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
82
+
83
+ # Stage B model configuration
84
+ config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
85
+ with open(config_file_b, "r", encoding="utf-8") as file:
86
+ config_file_b = yaml.safe_load(file)
87
+
88
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
89
+
90
+ # Setup extras and models for Stage C
91
+ extras = core.setup_extras_pre()
92
+
93
+ gdf_rbm = RBM(
94
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
95
+ input_scaler=VPScaler(), target=EpsilonTarget(),
96
+ noise_cond=CosineTNoiseCond(),
97
+ loss_weight=AdaptiveLossWeight(),
98
+ )
99
+
100
+ sampling_configs = {
101
+ "cfg": 5,
102
+ "sampler": DDPMSampler(gdf_rbm),
103
+ "shift": 1,
104
+ "timesteps": 20
105
+ }
106
+
107
+ extras = core.Extras(
108
+ gdf=gdf_rbm,
109
+ sampling_configs=sampling_configs,
110
+ transforms=extras.transforms,
111
+ effnet_preprocess=extras.effnet_preprocess,
112
+ clip_preprocess=extras.clip_preprocess
113
+ )
114
+
115
+ models = core.setup_models(extras)
116
+ models.generator.eval().requires_grad_(False)
117
+
118
+ # Setup extras and models for Stage B
119
+ extras_b = core_b.setup_extras_pre()
120
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
121
+ models_b = WurstCoreB.Models(
122
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
123
+ )
124
+ models_b.generator.bfloat16().eval().requires_grad_(False)
125
+
126
+ # Off-load old generator (low VRAM mode)
127
+ if low_vram:
128
+ models.generator.to("cpu")
129
+ clear_gpu_cache()
130
+
131
+ # Load and configure new generator
132
+ generator_rbm = StageCRBM()
133
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
134
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
135
+
136
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
137
+ generator_rbm = core.load_model(generator_rbm, 'generator')
138
+
139
+ # Create models_rbm instance
140
+ models_rbm = core.Models(
141
+ effnet=models.effnet,
142
+ previewer=models.previewer,
143
+ generator=generator_rbm,
144
+ generator_ema=models.generator_ema,
145
+ tokenizer=models.tokenizer,
146
+ text_model=models.text_model,
147
+ image_model=models.image_model
148
+ )
149
+ models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
  clear_gpu_cache() # Clear cache before inference