Zai commited on
Commit
15e78ee
1 Parent(s): 447cd0b

project setup

Browse files
Files changed (5) hide show
  1. app.py +40 -0
  2. inpainting.py +227 -0
  3. requirements.txt +9 -0
  4. test.py +12 -0
  5. utils.py +13 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from io import BytesIO
3
+
4
+ from torch import autocast
5
+ import requests
6
+ import PIL
7
+ import torch
8
+ from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline
9
+
10
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
11
+ "CompVis/stable-diffusion-v1-4",
12
+ revision="fp16",
13
+ torch_dtype=torch.float16,
14
+ use_auth_token=True,
15
+ )
16
+
17
+
18
+ def process_image(dict, prompt):
19
+ init_img = dict["image"].convert("RGB").resize((512, 512))
20
+ mask_img = dict["mask"].convert("RGB").resize((512, 512))
21
+ images = pipe(
22
+ prompt=prompt, init_image=init_img, mask_image=mask_img, strength=0.75
23
+ )["sample"]
24
+ return images[0]
25
+
26
+
27
+ iface = gr.Interface(
28
+ fn=process_image,
29
+ title="Stable Diffusion In-Painting Tool on Colab with Gradio",
30
+ inputs=[
31
+ gr.Image(source="upload", tool="sketch", type="pil"),
32
+ gr.Textbox(label="prompt"),
33
+ ],
34
+ outputs=[gr.Image()],
35
+ description="Choose a feature and upload an image to see the processed result.",
36
+ article="<p style='text-align: center;'>Built with Gradio</p>",
37
+ )
38
+
39
+
40
+ iface.launch()
inpainting.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # credit : Hugging Face Team
2
+ import inspect
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ import PIL
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ DDIMScheduler,
12
+ DiffusionPipeline,
13
+ PNDMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
17
+ from tqdm.auto import tqdm
18
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
19
+
20
+
21
+ def preprocess_image(image):
22
+ w, h = image.size
23
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
24
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
25
+ image = np.array(image).astype(np.float32) / 255.0
26
+ image = image[None].transpose(0, 3, 1, 2)
27
+ image = torch.from_numpy(image)
28
+ return 2.0 * image - 1.0
29
+
30
+
31
+ def preprocess_mask(mask):
32
+ mask = mask.convert("L")
33
+ w, h = mask.size
34
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
35
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
36
+ mask = np.array(mask).astype(np.float32) / 255.0
37
+ mask = np.tile(mask, (4, 1, 1))
38
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
39
+ mask = 1 - mask # repaint white, keep black
40
+ mask = torch.from_numpy(mask)
41
+ return mask
42
+
43
+
44
+ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
45
+ def __init__(
46
+ self,
47
+ vae: AutoencoderKL,
48
+ text_encoder: CLIPTextModel,
49
+ tokenizer: CLIPTokenizer,
50
+ unet: UNet2DConditionModel,
51
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
52
+ safety_checker: StableDiffusionSafetyChecker,
53
+ feature_extractor: CLIPFeatureExtractor,
54
+ ):
55
+ super().__init__()
56
+ scheduler = scheduler.set_format("pt")
57
+ self.register_modules(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ unet=unet,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ feature_extractor=feature_extractor,
65
+ )
66
+
67
+ @torch.no_grad()
68
+ def __call__(
69
+ self,
70
+ prompt: Union[str, List[str]],
71
+ init_image: torch.FloatTensor,
72
+ mask_image: torch.FloatTensor,
73
+ strength: float = 0.8,
74
+ num_inference_steps: Optional[int] = 50,
75
+ guidance_scale: Optional[float] = 7.5,
76
+ eta: Optional[float] = 0.0,
77
+ generator: Optional[torch.Generator] = None,
78
+ output_type: Optional[str] = "pil",
79
+ ):
80
+ if isinstance(prompt, str):
81
+ batch_size = 1
82
+ elif isinstance(prompt, list):
83
+ batch_size = len(prompt)
84
+ else:
85
+ raise ValueError(
86
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
87
+ )
88
+
89
+ if strength < 0 or strength > 1:
90
+ raise ValueError(
91
+ f"The value of strength should in [0.0, 1.0] but is {strength}"
92
+ )
93
+
94
+ # set timesteps
95
+ accepts_offset = "offset" in set(
96
+ inspect.signature(self.scheduler.set_timesteps).parameters.keys()
97
+ )
98
+ extra_set_kwargs = {}
99
+ offset = 0
100
+ if accepts_offset:
101
+ offset = 1
102
+ extra_set_kwargs["offset"] = 1
103
+
104
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
105
+
106
+ # preprocess image
107
+ init_image = preprocess_image(init_image).to(self.device)
108
+
109
+ # encode the init image into latents and scale the latents
110
+ init_latents = self.vae.encode(init_image).sample()
111
+ init_latents = 0.18215 * init_latents
112
+
113
+ # prepare init_latents noise to latents
114
+ init_latents = torch.cat([init_latents] * batch_size)
115
+ init_latents_orig = init_latents
116
+
117
+ # preprocess mask
118
+ mask = preprocess_mask(mask_image).to(self.device)
119
+ mask = torch.cat([mask] * batch_size)
120
+
121
+ # check sizes
122
+ if not mask.shape == init_latents.shape:
123
+ raise ValueError(f"The mask and init_image should be the same size!")
124
+
125
+ # get the original timestep using init_timestep
126
+ init_timestep = int(num_inference_steps * strength) + offset
127
+ init_timestep = min(init_timestep, num_inference_steps)
128
+ timesteps = self.scheduler.timesteps[-init_timestep]
129
+ timesteps = torch.tensor(
130
+ [timesteps] * batch_size, dtype=torch.long, device=self.device
131
+ )
132
+
133
+ # add noise to latents using the timesteps
134
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
135
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
136
+
137
+ # get prompt text embeddings
138
+ text_input = self.tokenizer(
139
+ prompt,
140
+ padding="max_length",
141
+ max_length=self.tokenizer.model_max_length,
142
+ truncation=True,
143
+ return_tensors="pt",
144
+ )
145
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
146
+
147
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
148
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
149
+ # corresponds to doing no classifier free guidance.
150
+ do_classifier_free_guidance = guidance_scale > 1.0
151
+ # get unconditional embeddings for classifier free guidance
152
+ if do_classifier_free_guidance:
153
+ max_length = text_input.input_ids.shape[-1]
154
+ uncond_input = self.tokenizer(
155
+ [""] * batch_size,
156
+ padding="max_length",
157
+ max_length=max_length,
158
+ return_tensors="pt",
159
+ )
160
+ uncond_embeddings = self.text_encoder(
161
+ uncond_input.input_ids.to(self.device)
162
+ )[0]
163
+
164
+ # For classifier free guidance, we need to do two forward passes.
165
+ # Here we concatenate the unconditional and text embeddings into a single batch
166
+ # to avoid doing two forward passes
167
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
168
+
169
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
170
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
171
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
172
+ # and should be between [0, 1]
173
+ accepts_eta = "eta" in set(
174
+ inspect.signature(self.scheduler.step).parameters.keys()
175
+ )
176
+ extra_step_kwargs = {}
177
+ if accepts_eta:
178
+ extra_step_kwargs["eta"] = eta
179
+
180
+ latents = init_latents
181
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
182
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
183
+ # expand the latents if we are doing classifier free guidance
184
+ latent_model_input = (
185
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
186
+ )
187
+
188
+ # predict the noise residual
189
+ noise_pred = self.unet(
190
+ latent_model_input, t, encoder_hidden_states=text_embeddings
191
+ )["sample"]
192
+
193
+ # perform guidance
194
+ if do_classifier_free_guidance:
195
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
196
+ noise_pred = noise_pred_uncond + guidance_scale * (
197
+ noise_pred_text - noise_pred_uncond
198
+ )
199
+
200
+ # compute the previous noisy sample x_t -> x_t-1
201
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
202
+ "prev_sample"
203
+ ]
204
+
205
+ # masking
206
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
207
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
208
+
209
+ # scale and decode the image latents with vae
210
+ latents = 1 / 0.18215 * latents
211
+ image = self.vae.decode(latents)
212
+
213
+ image = (image / 2 + 0.5).clamp(0, 1)
214
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
215
+
216
+ # run safety checker
217
+ safety_cheker_input = self.feature_extractor(
218
+ self.numpy_to_pil(image), return_tensors="pt"
219
+ ).to(self.device)
220
+ image, has_nsfw_concept = self.safety_checker(
221
+ images=image, clip_input=safety_cheker_input.pixel_values
222
+ )
223
+
224
+ if output_type == "pil":
225
+ image = self.numpy_to_pil(image)
226
+
227
+ return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ requests
3
+ PIL
4
+ diffusers
5
+ gradio
6
+ numpy
7
+ tqdm
8
+ typing
9
+ inspect
test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name, intensity):
4
+ return "Hello " * intensity + name + "!"
5
+
6
+ demo = gr.Interface(
7
+ fn=greet,
8
+ inputs=["text", "slider"],
9
+ outputs=["text"],
10
+ )
11
+
12
+ demo.launch()
utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def add_feature(image):
2
+ # inpainting features
3
+ pass
4
+
5
+
6
+ def remove_feature(image):
7
+ # inpainting features
8
+ pass
9
+
10
+
11
+ def enhance_feature(image):
12
+ # inpainting features
13
+ pass