Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -135,9 +135,13 @@ def initialize_models():
|
|
135 |
|
136 |
def infer(style_description, ref_style_file, caption):
|
137 |
try:
|
138 |
-
#
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
141 |
height = 1024
|
142 |
width = 1024
|
143 |
batch_size = 1
|
@@ -145,16 +149,6 @@ def infer(style_description, ref_style_file, caption):
|
|
145 |
|
146 |
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
|
147 |
|
148 |
-
extras.sampling_configs['cfg'] = 4
|
149 |
-
extras.sampling_configs['shift'] = 2
|
150 |
-
extras.sampling_configs['timesteps'] = 20
|
151 |
-
extras.sampling_configs['t_start'] = 1.0
|
152 |
-
|
153 |
-
extras_b.sampling_configs['cfg'] = 1.1
|
154 |
-
extras_b.sampling_configs['shift'] = 1
|
155 |
-
extras_b.sampling_configs['timesteps'] = 10
|
156 |
-
extras_b.sampling_configs['t_start'] = 1.0
|
157 |
-
|
158 |
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
159 |
|
160 |
batch = {'captions': [caption] * batch_size}
|
@@ -189,6 +183,9 @@ def infer(style_description, ref_style_file, caption):
|
|
189 |
|
190 |
clear_gpu_cache() # Clear cache between stages
|
191 |
|
|
|
|
|
|
|
192 |
# Stage B reverse process
|
193 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
194 |
conditions_b['effnet'] = sampled_c
|
|
|
135 |
|
136 |
def infer(style_description, ref_style_file, caption):
|
137 |
try:
|
138 |
+
# Clear GPU cache before inference
|
139 |
+
clear_gpu_cache()
|
140 |
+
|
141 |
+
# Ensure models are on the correct device
|
142 |
+
models_rbm.to(device)
|
143 |
+
models_b.to(device)
|
144 |
+
|
145 |
height = 1024
|
146 |
width = 1024
|
147 |
batch_size = 1
|
|
|
149 |
|
150 |
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
|
153 |
|
154 |
batch = {'captions': [caption] * batch_size}
|
|
|
183 |
|
184 |
clear_gpu_cache() # Clear cache between stages
|
185 |
|
186 |
+
# Ensure models_b is on the correct device
|
187 |
+
models_b.to(device)
|
188 |
+
|
189 |
# Stage B reverse process
|
190 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
191 |
conditions_b['effnet'] = sampled_c
|