sayanbanerjee32 commited on
Commit
e3ba844
·
verified ·
1 Parent(s): de07311

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +58 -0
  2. hue_loss.py +41 -0
  3. requirements.txt +6 -0
  4. sd.py +238 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sd import stl_list, img_size_opt_dict, generate_images
3
+
4
+ with gr.Blocks() as demo:
5
+ gr.HTML("<h1 align = 'center'> Stable Diffusion - Text Inversion and additional guidence</h1>")
6
+ gr.HTML("<h4 align = 'center'> Generates imgaes based on the prompt and 5 different styles and then with additional guidence of hue loss</h4>")
7
+ gr.HTML("<h6 align = 'center'> !!The image generation may take 5 to 10 minutes on CPU!!</h4>")
8
+
9
+ with gr.Row():
10
+ content = gr.Textbox(label = "Enter prompt text here")
11
+ gr.Examples([
12
+ "A mouse",
13
+ "A puppy"
14
+ ],
15
+ inputs = content)
16
+ num_steps = gr.Slider(1, 50, step = 1, value=30, label="Number of inference steps", info="Choose between 1 and 50")
17
+ # gr.Number(value = 10, label = "Number of inference steps")
18
+
19
+
20
+ with gr.Row():
21
+ stl_dropdown = gr.Dropdown(
22
+ stl_list,
23
+ value=stl_list[:1], multiselect=True, label="Style",
24
+ info="Styles to be applied on images"
25
+ )
26
+ size_dropdown = gr.Dropdown(
27
+ [*img_size_opt_dict],
28
+ value = [*img_size_opt_dict][-1],
29
+ label="Image size", info="Target size for generated images"
30
+ )
31
+
32
+ inputs = [
33
+ content,
34
+ num_steps,
35
+ stl_dropdown,
36
+ size_dropdown
37
+ ]
38
+
39
+ generate_btn = gr.Button(value = 'Generate')
40
+
41
+ with gr.Row():
42
+ with gr.Column(scale=2):
43
+ wo_add_guide = gr.Gallery(
44
+ label="Without additional guidence", show_label=True, elem_id="gallery"
45
+ , columns=[3], rows=[2], object_fit="contain", height="auto")
46
+
47
+ with gr.Column(scale=2):
48
+ add_guide = gr.Gallery(
49
+ label="With hue loss guidence", show_label=True, elem_id="gallery"
50
+ , columns=[3], rows=[2], object_fit="contain", height="auto")
51
+ outputs = [wo_add_guide, add_guide ]
52
+ generate_btn.click(fn = generate_images, inputs= inputs, outputs = outputs)
53
+
54
+ # for collab
55
+ # demo.launch(debug=True)
56
+
57
+ if __name__ == '__main__':
58
+ demo.launch()
hue_loss.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # hue loss
4
+ def rgb_to_hsv(image):
5
+ r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]
6
+ maxc = torch.max(image, dim=1)[0]
7
+ minc = torch.min(image, dim=1)[0]
8
+
9
+ v = maxc
10
+ s = (maxc - minc) / (maxc + 1e-10)
11
+ deltac = maxc - minc
12
+
13
+ # Initialize hue
14
+ h = torch.zeros_like(maxc)
15
+
16
+ mask = maxc == r
17
+ h[mask] = ((g - b) / deltac)[mask] % 6
18
+
19
+ mask = maxc == g
20
+ h[mask] = ((b - r) / deltac)[mask] + 2
21
+
22
+ mask = maxc == b
23
+ h[mask] = ((r - g) / deltac)[mask] + 4
24
+
25
+ h = h / 6 # Normalize to [0, 1]
26
+ h[deltac == 0] = 0 # If no color difference, set hue to 0
27
+
28
+ return torch.stack([h, s, v], dim=1)
29
+
30
+
31
+ def hue_loss(images, target_hue=0.5):
32
+ # Convert the images to HSV color space
33
+ hsv_images = rgb_to_hsv(images)
34
+
35
+ # Extract the hue channel
36
+ hue = hsv_images[:, 0, :, :]
37
+
38
+ # Calculate the error as the mean absolute deviation from the target hue
39
+ error = torch.abs(hue - target_hue).mean()
40
+
41
+ return error
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pillow
2
+ torch
3
+ transformers
4
+ diffusers
5
+ huggingface_hub
6
+ numpy
sd.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
4
+
5
+ from PIL import Image
6
+ from tqdm.auto import tqdm
7
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
8
+ import os
9
+
10
+ from hue_loss import hue_loss
11
+
12
+
13
+ torch.manual_seed(1)
14
+ # if not (Path.home()/'.cache/huggingface'/'token').exists(): notebook_login()
15
+
16
+ # Supress some unnecessary warnings when loading the CLIPTextModel
17
+ logging.set_verbosity_error()
18
+
19
+ # Set device
20
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ from huggingface_hub import hf_hub_download
23
+
24
+
25
+ stl_list = [
26
+ 'birb-style',
27
+ 'cute-game-style',
28
+ 'depthmap',
29
+ 'line-art',
30
+ 'low-poly-hd-logos-icons'
31
+ ]
32
+
33
+ for stl in stl_list:
34
+ if not os.path.exists(stl):
35
+ os.mkdir(stl)
36
+ hf_hub_download(repo_id=f"sd-concepts-library/{stl}", filename="learned_embeds.bin", local_dir=f"./{stl}")
37
+
38
+ img_size_opt_dict = {
39
+ "512x512 - best quality but very slow": (512,512),
40
+ "256x256 - not good quality but still slow" : (256,256),
41
+ "128x128 - poor quality but faster" : (128,128),
42
+ }
43
+
44
+ # Load the autoencoder model which will be used to decode the latents into image space.
45
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
46
+
47
+ # Load the tokenizer and text encoder to tokenize and encode the text.
48
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
49
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
50
+
51
+ # The UNet model for generating the latents.
52
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
53
+
54
+ # The noise scheduler
55
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
56
+
57
+ # To the GPU we go!
58
+ vae = vae.to(torch_device)
59
+ text_encoder = text_encoder.to(torch_device)
60
+ unet = unet.to(torch_device);
61
+
62
+ # Convert latents to images
63
+
64
+ def latents_to_pil(latents):
65
+ # bath of latents -> list of images
66
+ latents = (1 / 0.18215) * latents
67
+ with torch.no_grad():
68
+ image = vae.decode(latents).sample
69
+ image = (image / 2 + 0.5).clamp(0, 1)
70
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
71
+ images = (image * 255).round().astype("uint8")
72
+ pil_images = [Image.fromarray(image) for image in images]
73
+ return pil_images
74
+
75
+ # Prep Scheduler
76
+ def set_timesteps(scheduler, num_inference_steps):
77
+ scheduler.set_timesteps(num_inference_steps)
78
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
79
+
80
+ #Generating an image with these modified embeddings
81
+ def generate_with_embs(text_embeddings, text_input, loss_fn = None, loss_scale = 200, guidance_scale = 7.5,
82
+ seed_value = 1, num_inference_steps = 50, additional_guidence = False, hight_width = (512, 512)):
83
+ height, width = hight_width # default height of Stable Diffusion
84
+ # width = 512 # default width of Stable Diffusion
85
+ # num_inference_steps = 50 # Number of denoising steps
86
+ # Scale for classifier-free guidance
87
+ generator = torch.manual_seed(seed_value) # Seed generator to create the inital latent noise
88
+ batch_size = 1
89
+
90
+
91
+ max_length = text_input.input_ids.shape[-1]
92
+ uncond_input = tokenizer(
93
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
94
+ )
95
+ with torch.no_grad():
96
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
97
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
98
+
99
+ # Prep Scheduler
100
+ set_timesteps(scheduler, num_inference_steps)
101
+
102
+ # Prep latents
103
+ latents = torch.randn(
104
+ (batch_size, unet.in_channels, height // 8, width // 8),
105
+ generator=generator,
106
+ )
107
+ latents = latents.to(torch_device)
108
+ latents = latents * scheduler.init_noise_sigma
109
+
110
+ # Loop
111
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
112
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
113
+ latent_model_input = torch.cat([latents] * 2)
114
+ sigma = scheduler.sigmas[i]
115
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
116
+
117
+ # predict the noise residual
118
+ with torch.no_grad():
119
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
120
+
121
+ # perform guidance
122
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
123
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
124
+
125
+ #### ADDITIONAL GUIDANCE ###
126
+ if i%5 == 0 and additional_guidence:
127
+ # Requires grad on the latents
128
+ latents = latents.detach().requires_grad_()
129
+
130
+ # Get the predicted x0:
131
+ latents_x0 = latents - sigma * noise_pred
132
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
133
+
134
+ # Decode to image space
135
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
136
+
137
+ # Calculate loss
138
+ loss = loss_fn(denoised_images) * loss_scale
139
+
140
+ # Occasionally print it out
141
+ if i%10==0:
142
+ print(i, 'loss:', loss.item())
143
+
144
+ # Get gradient
145
+ cond_grad = torch.autograd.grad(loss, latents)[0]
146
+
147
+ # Modify the latents based on this gradient
148
+ # latents = latents.detach() - cond_grad * sigma**2
149
+ latents = latents.detach() - cond_grad * sigma**2
150
+
151
+ # compute the previous noisy sample x_t -> x_t-1
152
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
153
+
154
+ # Ensure the latents do not lose the grad tracking
155
+ # latents.requires_grad_()
156
+
157
+ return latents_to_pil(latents)[0]
158
+
159
+ def get_output_embeds(input_embeddings):
160
+ # CLIP's text model uses causal mask, so we prepare it here:
161
+ bsz, seq_len = input_embeddings.shape[:2]
162
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
163
+
164
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
165
+ # so that it doesn't just return the pooled final predictions:
166
+ encoder_outputs = text_encoder.text_model.encoder(
167
+ inputs_embeds=input_embeddings,
168
+ attention_mask=None, # We aren't using an attention mask so that can be None
169
+ causal_attention_mask=causal_attention_mask.to(torch_device),
170
+ output_attentions=None,
171
+ output_hidden_states=True, # We want the output embs not the final output
172
+ return_dict=None,
173
+ )
174
+
175
+ # We're interested in the output hidden state only
176
+ output = encoder_outputs[0]
177
+
178
+ # There is a final layer norm we need to pass these through
179
+ output = text_encoder.text_model.final_layer_norm(output)
180
+
181
+ # And now they're ready!
182
+ return output
183
+
184
+ # Access the embedding layer
185
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
186
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
187
+
188
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
189
+ position_embeddings = pos_emb_layer(position_ids)
190
+
191
+ def generate_images(prompt, num_inference_steps, stl_list, img_size):
192
+ ### add a statis text that will contain the style
193
+ prompt = prompt + ' in the style of puppy'
194
+ height_width = img_size_opt_dict[img_size]
195
+ # Tokenize
196
+ text_input = tokenizer(prompt, padding="max_length",
197
+ max_length=tokenizer.model_max_length,
198
+ truncation=True, return_tensors="pt")
199
+ input_ids = text_input.input_ids.to(torch_device)
200
+
201
+ # Get token embeddings
202
+ token_embeddings = token_emb_layer(input_ids)
203
+
204
+ wo_guide_lst = []
205
+ guide_lst = []
206
+ for i, stl in enumerate(stl_list):
207
+ stl_embed = torch.load(f'{stl}/learned_embeds.bin')
208
+
209
+ # The new embedding - our special birb word
210
+ replacement_token_embedding = stl_embed[f'<{stl}>'].to(torch_device)
211
+
212
+ # Insert this into the token embeddings
213
+ token_embeddings[0, min(torch.where(input_ids[0]==tokenizer.eos_token_id)[0]) - 1] = replacement_token_embedding.to(torch_device)
214
+
215
+ # Combine with pos embs
216
+ input_embeddings = token_embeddings + position_embeddings
217
+
218
+ # Feed through to get final output embs
219
+ modified_output_embeddings = get_output_embeds(input_embeddings)
220
+
221
+ # # And generate an image with this:
222
+ pil_im = generate_with_embs(modified_output_embeddings,
223
+ num_inference_steps = num_inference_steps,
224
+ text_input = text_input,
225
+ seed_value = i,additional_guidence = False,
226
+ hight_width = height_width)
227
+ wo_guide_lst.append((pil_im,stl))
228
+
229
+ pil_im = generate_with_embs(modified_output_embeddings,
230
+ num_inference_steps = num_inference_steps,
231
+ text_input = text_input,
232
+ loss_fn = hue_loss,
233
+ additional_guidence = True,
234
+ hight_width = height_width,
235
+ seed_value = i)
236
+ guide_lst.append((pil_im,stl))
237
+
238
+ return wo_guide_lst, guide_lst