multimodalart HF staff commited on
Commit
c1ee41f
1 Parent(s): 86a6a3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -72
app.py CHANGED
@@ -6,12 +6,9 @@ import random
6
  from io import BytesIO
7
  from utils import *
8
  from constants import *
9
- # from inversion_utils import *
10
- # from inversion_utils_dpmplusplus import *
11
- #from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
12
  from pipeline_semantic_stable_diffusion_img2img_solver import SemanticStableDiffusionImg2ImgPipeline_DPMSolver
13
  from torch import autocast, inference_mode
14
- from diffusers import StableDiffusionPipeline
15
  from diffusers.schedulers import DDIMScheduler
16
  from scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject
17
  from transformers import AutoProcessor, BlipForConditionalGeneration
@@ -20,17 +17,14 @@ from share_btn import community_icon_html, loading_icon_html, share_js
20
  # load pipelines
21
  sd_model_id = "runwayml/stable-diffusion-v1-5"
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
-
24
- pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,torch_dtype=torch.float16).to(device)
25
- # pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
26
  pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
27
  , algorithm_type="sde-dpmsolver++", solver_order=2)
28
 
29
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
30
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
31
 
32
-
33
-
34
  ## IMAGE CPATIONING ##
35
  def caption_image(input_image):
36
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
@@ -40,65 +34,6 @@ def caption_image(input_image):
40
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
41
  return generated_caption, generated_caption
42
 
43
-
44
-
45
- ## DDPM INVERSION AND SAMPLING ##
46
- # def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
47
-
48
- # # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
49
- # # based on the code in https://github.com/inbarhub/DDPM_inversion
50
-
51
- # # returns wt, zs, wts:
52
- # # wt - inverted latent
53
- # # wts - intermediate inverted latents
54
- # # zs - noise maps
55
-
56
- # sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
57
-
58
- # # vae encode image
59
- # with inference_mode():
60
- # w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215)
61
-
62
- # # find Zs and wts - forward process
63
- # wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
64
- # return zs, wts
65
-
66
-
67
- # def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
68
-
69
- # # reverse process (via Zs and wT)
70
- # w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
71
-
72
- # # vae decode image
73
- # with inference_mode():
74
- # x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
75
- # if x0_dec.dim()<4:
76
- # x0_dec = x0_dec[None,:,:,:]
77
- # img = image_grid(x0_dec)
78
- # return img
79
-
80
- # def reconstruct(tar_prompt,
81
- # image_caption,
82
- # tar_cfg_scale,
83
- # skip,
84
- # wts, zs,
85
- # do_reconstruction,
86
- # reconstruction,
87
- # reconstruct_button
88
- # ):
89
-
90
- # if reconstruct_button == "Hide Reconstruction":
91
- # return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=False), do_reconstruction, "Show Reconstruction"
92
-
93
- # else:
94
- # if do_reconstruction:
95
- # if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run actual reconstruction
96
- # tar_prompt = ""
97
- # reconstruction_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
98
- # reconstruction = gr.State(value=reconstruction_img)
99
- # do_reconstruction = False
100
- # return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=True), do_reconstruction, "Hide Reconstruction"
101
-
102
  def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
103
 
104
  latnets = wts.value[-1].expand(1, -1, -1, -1)
@@ -112,8 +47,6 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
112
  zs=zs.value).images[0]
113
  return img
114
 
115
-
116
-
117
  def reconstruct(tar_prompt,
118
  image_caption,
119
  tar_cfg_scale,
@@ -903,5 +836,4 @@ with gr.Blocks(css="style.css") as demo:
903
 
904
 
905
  demo.queue()
906
- demo.launch()
907
- # demo.launch(share=True)
 
6
  from io import BytesIO
7
  from utils import *
8
  from constants import *
 
 
 
9
  from pipeline_semantic_stable_diffusion_img2img_solver import SemanticStableDiffusionImg2ImgPipeline_DPMSolver
10
  from torch import autocast, inference_mode
11
+ from diffusers import StableDiffusionPipeline, AutoencoderKL
12
  from diffusers.schedulers import DDIMScheduler
13
  from scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject
14
  from transformers import AutoProcessor, BlipForConditionalGeneration
 
17
  # load pipelines
18
  sd_model_id = "runwayml/stable-diffusion-v1-5"
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
21
+ pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch.float16).to(device)
 
22
  pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
23
  , algorithm_type="sde-dpmsolver++", solver_order=2)
24
 
25
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
26
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
27
 
 
 
28
  ## IMAGE CPATIONING ##
29
  def caption_image(input_image):
30
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
 
34
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
  return generated_caption, generated_caption
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
38
 
39
  latnets = wts.value[-1].expand(1, -1, -1, -1)
 
47
  zs=zs.value).images[0]
48
  return img
49
 
 
 
50
  def reconstruct(tar_prompt,
51
  image_caption,
52
  tar_cfg_scale,
 
836
 
837
 
838
  demo.queue()
839
+ demo.launch()