Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
a722e19
1
Parent(s):
f34dfad
Update app.py
Browse files
app.py
CHANGED
@@ -119,8 +119,8 @@ def load_and_invert(
|
|
119 |
skip=skip,
|
120 |
eta=1.0,
|
121 |
)
|
122 |
-
wts =
|
123 |
-
zs =
|
124 |
do_inversion = False
|
125 |
|
126 |
return wts, zs, do_inversion, gr.update(visible=False)
|
@@ -173,8 +173,8 @@ def edit(input_image,
|
|
173 |
skip = skip,
|
174 |
eta = 1.0,
|
175 |
)
|
176 |
-
wts =
|
177 |
-
zs =
|
178 |
do_inversion = False
|
179 |
|
180 |
if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
|
@@ -194,7 +194,7 @@ def edit(input_image,
|
|
194 |
use_intersect_mask=use_intersect_mask
|
195 |
)
|
196 |
|
197 |
-
latnets = wts
|
198 |
sega_out = pipe(prompt=tar_prompt,
|
199 |
init_latents=latnets,
|
200 |
guidance_scale = tar_cfg_scale,
|
@@ -202,7 +202,7 @@ def edit(input_image,
|
|
202 |
# num_inference_steps=steps,
|
203 |
# use_ddpm=True,
|
204 |
# wts=wts.value,
|
205 |
-
zs=zs
|
206 |
|
207 |
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
208 |
|
@@ -210,12 +210,12 @@ def edit(input_image,
|
|
210 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
211 |
|
212 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
213 |
-
pure_ddpm_img = sample(zs
|
214 |
-
reconstruction =
|
215 |
do_reconstruction = False
|
216 |
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
217 |
|
218 |
-
return reconstruction
|
219 |
|
220 |
|
221 |
def randomize_seed_fn(seed, is_random):
|
@@ -872,6 +872,5 @@ with gr.Blocks(css="style.css") as demo:
|
|
872 |
cache_examples=True
|
873 |
)
|
874 |
|
875 |
-
|
876 |
demo.queue()
|
877 |
demo.launch()
|
|
|
119 |
skip=skip,
|
120 |
eta=1.0,
|
121 |
)
|
122 |
+
wts = wts_tensor
|
123 |
+
zs = zs_tensor
|
124 |
do_inversion = False
|
125 |
|
126 |
return wts, zs, do_inversion, gr.update(visible=False)
|
|
|
173 |
skip = skip,
|
174 |
eta = 1.0,
|
175 |
)
|
176 |
+
wts = wts_tensor
|
177 |
+
zs = zs_tensor
|
178 |
do_inversion = False
|
179 |
|
180 |
if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
|
|
|
194 |
use_intersect_mask=use_intersect_mask
|
195 |
)
|
196 |
|
197 |
+
latnets = wts[-1].expand(1, -1, -1, -1)
|
198 |
sega_out = pipe(prompt=tar_prompt,
|
199 |
init_latents=latnets,
|
200 |
guidance_scale = tar_cfg_scale,
|
|
|
202 |
# num_inference_steps=steps,
|
203 |
# use_ddpm=True,
|
204 |
# wts=wts.value,
|
205 |
+
zs=zs, **editing_args)
|
206 |
|
207 |
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
208 |
|
|
|
210 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
211 |
|
212 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
213 |
+
pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
214 |
+
reconstruction = pure_ddpm_img
|
215 |
do_reconstruction = False
|
216 |
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
217 |
|
218 |
+
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
219 |
|
220 |
|
221 |
def randomize_seed_fn(seed, is_random):
|
|
|
872 |
cache_examples=True
|
873 |
)
|
874 |
|
|
|
875 |
demo.queue()
|
876 |
demo.launch()
|