reibs commited on
Commit
e96ef91
1 Parent(s): 1903533

Update demo.ipynb

Browse files
Files changed (1) hide show
  1. demo.ipynb +644 -0
demo.ipynb CHANGED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Textual-inversion fine-tuning for Stable Diffusion using d🧨ffusers
2
+
3
+ This notebook shows how to "teach" Stable Diffusion a new concept via textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
4
+
5
+ ![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)
6
+ _By using just 3-5 images you can teach new concepts to Stable Diffusion and personalize the model on your own images_
7
+
8
+ For a general introduction to the Stable Diffusion model please refer to this [colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb).
9
+
10
+
11
+ ## Initial setup
12
+ #@title Install the required libs
13
+ !pip install -U -qq git+https://github.com/huggingface/diffusers.git
14
+ !pip install -qq accelerate transformers ftfy
15
+ #@title [Optional] Install xformers for faster and memory efficient training
16
+ #@markdown Acknowledgement: The xformers wheel are taken from [TheLastBen/fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion). Thanks a lot for building these wheels!
17
+ %%time
18
+
19
+ !pip install -U --pre triton
20
+
21
+ from subprocess import getoutput
22
+ from IPython.display import HTML
23
+ from IPython.display import clear_output
24
+ import time
25
+
26
+ s = getoutput('nvidia-smi')
27
+ if 'T4' in s:
28
+ gpu = 'T4'
29
+ elif 'P100' in s:
30
+ gpu = 'P100'
31
+ elif 'V100' in s:
32
+ gpu = 'V100'
33
+ elif 'A100' in s:
34
+ gpu = 'A100'
35
+
36
+ while True:
37
+ try:
38
+ gpu=='T4'or gpu=='P100'or gpu=='V100'or gpu=='A100'
39
+ break
40
+ except:
41
+ pass
42
+ print('[1;31mit seems that your GPU is not supported at the moment')
43
+ time.sleep(5)
44
+
45
+ if (gpu=='T4'):
46
+ %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl
47
+
48
+ elif (gpu=='P100'):
49
+ %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/P100/xformers-0.0.13.dev0-py3-none-any.whl
50
+
51
+ elif (gpu=='V100'):
52
+ %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/V100/xformers-0.0.13.dev0-py3-none-any.whl
53
+
54
+ elif (gpu=='A100'):
55
+ %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/A100/xformers-0.0.13.dev0-py3-none-any.whl
56
+ #@title [Optional] Login to the Hugging Face Hub
57
+ #@markdown Add a token with the "Write Access" role to be able to add your trained concept to the [Library of Concepts](https://huggingface.co/sd-concepts-library)
58
+ from huggingface_hub import notebook_login
59
+
60
+ notebook_login()
61
+ #@title Import required libraries
62
+ import argparse
63
+ import itertools
64
+ import math
65
+ import os
66
+ import random
67
+
68
+ import numpy as np
69
+ import torch
70
+ import torch.nn.functional as F
71
+ import torch.utils.checkpoint
72
+ from torch.utils.data import Dataset
73
+
74
+ import PIL
75
+ from accelerate import Accelerator
76
+ from accelerate.logging import get_logger
77
+ from accelerate.utils import set_seed
78
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
79
+ from diffusers.optimization import get_scheduler
80
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
81
+ from PIL import Image
82
+ from torchvision import transforms
83
+ from tqdm.auto import tqdm
84
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
85
+
86
+ def image_grid(imgs, rows, cols):
87
+ assert len(imgs) == rows*cols
88
+
89
+ w, h = imgs[0].size
90
+ grid = Image.new('RGB', size=(cols*w, rows*h))
91
+ grid_w, grid_h = grid.size
92
+
93
+ for i, img in enumerate(imgs):
94
+ grid.paste(img, box=(i%cols*w, i//cols*h))
95
+ return grid
96
+ ## Settings for teaching your new concept
97
+ #@markdown `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use
98
+ pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true}
99
+ ### Get the training images:
100
+ #### Download the images from the internet and save them locally.
101
+
102
+ You can also upload the images to colab or load from google drive, please check the next section if you want to use that.
103
+ #@markdown Add here the URLs to the images of the concept you are adding. 3-5 should be fine
104
+ urls = [
105
+ "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
106
+ "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
107
+ "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
108
+ "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
109
+ ## You can add additional images here
110
+ ]
111
+ #@title Download
112
+ import requests
113
+ import glob
114
+ from io import BytesIO
115
+
116
+ def download_image(url):
117
+ try:
118
+ response = requests.get(url)
119
+ except:
120
+ return None
121
+ return Image.open(BytesIO(response.content)).convert("RGB")
122
+
123
+ images = list(filter(None,[download_image(url) for url in urls]))
124
+ save_path = "./my_concept"
125
+ if not os.path.exists(save_path):
126
+ os.mkdir(save_path)
127
+ [image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
128
+ #### Load images from local folder or google drive
129
+
130
+ You can also load your own training images from google drive or upload them to colab usingthe files taband then provide the path to the directory containing images.
131
+
132
+ *Make sure that the directory only contains images as the following cells will read all the files from the provided directory.*
133
+ from google.colab import drive
134
+ drive.mount('/content/gdrive')
135
+ #@markdown `images_path` is a path to directory containing the training images. It could
136
+ images_path = "/content/sample_data/food" #@param {type:"string"}
137
+ while not os.path.exists(str(images_path)):
138
+ print('The images_path specified does not exist, use the colab file explorer to copy the path :')
139
+ images_path=input("")
140
+ save_path = images_path
141
+ #### Setup and check the images you have just added
142
+ images = []
143
+ for file_path in os.listdir(save_path):
144
+ try:
145
+ image_path = os.path.join(save_path, file_path)
146
+ images.append(Image.open(image_path).resize((512, 512)))
147
+ except:
148
+ print(f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail.")
149
+ image_grid(images, 1, len(images))
150
+ #@title Settings for your newly created concept
151
+ #@markdown `what_to_teach`: what is it that you are teaching? `object` enables you to teach the model a new object to be used, `style` allows you to teach the model a new style one can use.
152
+ what_to_teach = "object" #@param ["object", "style"]
153
+ #@markdown `placeholder_token` is the token you are going to use to represent your new concept (so when you prompt the model, you will say "A `<my-placeholder-token>` in an amusement park"). We use angle brackets to differentiate a token from other words/tokens, to avoid collision.
154
+ placeholder_token = "\u003Cjapanese-oysters>" #@param {type:"string"}
155
+ #@markdown `initializer_token` is a word that can summarise what your new concept is, to be used as a starting point
156
+ initializer_token = "food" #@param {type:"string"}
157
+ ## Teach the model a new concept (fine-tuning with textual inversion)
158
+ Execute this this sequence of cells to run the training process. The whole process may take from 1-4 hours. (Open this block if you are interested in how this process works under the hood or if you want to change advanced training settings or hyperparameters)
159
+ ### Create Dataset
160
+ #@title Setup the prompt templates for training
161
+ imagenet_templates_small = [
162
+ "a photo of a {}",
163
+ "a rendering of a {}",
164
+ "a cropped photo of the {}",
165
+ "the photo of a {}",
166
+ "a photo of a clean {}",
167
+ "a photo of a dirty {}",
168
+ "a dark photo of the {}",
169
+ "a photo of my {}",
170
+ "a photo of the cool {}",
171
+ "a close-up photo of a {}",
172
+ "a bright photo of the {}",
173
+ "a cropped photo of a {}",
174
+ "a photo of the {}",
175
+ "a good photo of the {}",
176
+ "a photo of one {}",
177
+ "a close-up photo of the {}",
178
+ "a rendition of the {}",
179
+ "a photo of the clean {}",
180
+ "a rendition of a {}",
181
+ "a photo of a nice {}",
182
+ "a good photo of a {}",
183
+ "a photo of the nice {}",
184
+ "a photo of the small {}",
185
+ "a photo of the weird {}",
186
+ "a photo of the large {}",
187
+ "a photo of a cool {}",
188
+ "a photo of a small {}",
189
+ ]
190
+
191
+ imagenet_style_templates_small = [
192
+ "a painting in the style of {}",
193
+ "a rendering in the style of {}",
194
+ "a cropped painting in the style of {}",
195
+ "the painting in the style of {}",
196
+ "a clean painting in the style of {}",
197
+ "a dirty painting in the style of {}",
198
+ "a dark painting in the style of {}",
199
+ "a picture in the style of {}",
200
+ "a cool painting in the style of {}",
201
+ "a close-up painting in the style of {}",
202
+ "a bright painting in the style of {}",
203
+ "a cropped painting in the style of {}",
204
+ "a good painting in the style of {}",
205
+ "a close-up painting in the style of {}",
206
+ "a rendition in the style of {}",
207
+ "a nice painting in the style of {}",
208
+ "a small painting in the style of {}",
209
+ "a weird painting in the style of {}",
210
+ "a large painting in the style of {}",
211
+ ]
212
+ #@title Setup the dataset
213
+ class TextualInversionDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ data_root,
217
+ tokenizer,
218
+ learnable_property="object", # [object, style]
219
+ size=512,
220
+ repeats=100,
221
+ interpolation="bicubic",
222
+ flip_p=0.5,
223
+ set="train",
224
+ placeholder_token="*",
225
+ center_crop=False,
226
+ ):
227
+
228
+ self.data_root = data_root
229
+ self.tokenizer = tokenizer
230
+ self.learnable_property = learnable_property
231
+ self.size = size
232
+ self.placeholder_token = placeholder_token
233
+ self.center_crop = center_crop
234
+ self.flip_p = flip_p
235
+
236
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
237
+
238
+ self.num_images = len(self.image_paths)
239
+ self._length = self.num_images
240
+
241
+ if set == "train":
242
+ self._length = self.num_images * repeats
243
+
244
+ self.interpolation = {
245
+ "linear": PIL.Image.LINEAR,
246
+ "bilinear": PIL.Image.BILINEAR,
247
+ "bicubic": PIL.Image.BICUBIC,
248
+ "lanczos": PIL.Image.LANCZOS,
249
+ }[interpolation]
250
+
251
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
252
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
253
+
254
+ def __len__(self):
255
+ return self._length
256
+
257
+ def __getitem__(self, i):
258
+ example = {}
259
+ image = Image.open(self.image_paths[i % self.num_images])
260
+
261
+ if not image.mode == "RGB":
262
+ image = image.convert("RGB")
263
+
264
+ placeholder_string = self.placeholder_token
265
+ text = random.choice(self.templates).format(placeholder_string)
266
+
267
+ example["input_ids"] = self.tokenizer(
268
+ text,
269
+ padding="max_length",
270
+ truncation=True,
271
+ max_length=self.tokenizer.model_max_length,
272
+ return_tensors="pt",
273
+ ).input_ids[0]
274
+
275
+ # default to score-sde preprocessing
276
+ img = np.array(image).astype(np.uint8)
277
+
278
+ if self.center_crop:
279
+ crop = min(img.shape[0], img.shape[1])
280
+ h, w, = (
281
+ img.shape[0],
282
+ img.shape[1],
283
+ )
284
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
285
+
286
+ image = Image.fromarray(img)
287
+ image = image.resize((self.size, self.size), resample=self.interpolation)
288
+
289
+ image = self.flip_transform(image)
290
+ image = np.array(image).astype(np.uint8)
291
+ image = (image / 127.5 - 1.0).astype(np.float32)
292
+
293
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
294
+ return example
295
+ ### Setting up the model
296
+ #@title Load the tokenizer and add the placeholder token as a additional special token.
297
+ tokenizer = CLIPTokenizer.from_pretrained(
298
+ pretrained_model_name_or_path,
299
+ subfolder="tokenizer",
300
+ )
301
+
302
+ # Add the placeholder token in tokenizer
303
+ num_added_tokens = tokenizer.add_tokens(placeholder_token)
304
+ if num_added_tokens == 0:
305
+ raise ValueError(
306
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
307
+ " `placeholder_token` that is not already in the tokenizer."
308
+ )
309
+ #@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token
310
+ # Convert the initializer_token, placeholder_token to ids
311
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
312
+ # Check if initializer_token is a single token or a sequence of tokens
313
+ if len(token_ids) > 1:
314
+ raise ValueError("The initializer token must be a single token.")
315
+
316
+ initializer_token_id = token_ids[0]
317
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
318
+ #@title Load the Stable Diffusion model
319
+ # Load models and create wrapper for stable diffusion
320
+ # pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
321
+ # del pipeline
322
+ text_encoder = CLIPTextModel.from_pretrained(
323
+ pretrained_model_name_or_path, subfolder="text_encoder"
324
+ )
325
+ vae = AutoencoderKL.from_pretrained(
326
+ pretrained_model_name_or_path, subfolder="vae"
327
+ )
328
+ unet = UNet2DConditionModel.from_pretrained(
329
+ pretrained_model_name_or_path, subfolder="unet"
330
+ )
331
+ We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here, this will a new embedding vector in the token embeddings for our `placeholder_token`
332
+ text_encoder.resize_token_embeddings(len(tokenizer))
333
+ Initialise the newly added placeholder token with the embeddings of the initializer token
334
+ token_embeds = text_encoder.get_input_embeddings().weight.data
335
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
336
+ In Textual-Inversion we only train the newly added embedding vector, so lets freeze rest of the model parameters here
337
+ def freeze_params(params):
338
+ for param in params:
339
+ param.requires_grad = False
340
+
341
+ # Freeze vae and unet
342
+ freeze_params(vae.parameters())
343
+ freeze_params(unet.parameters())
344
+ # Freeze all parameters except for the token embeddings in text encoder
345
+ params_to_freeze = itertools.chain(
346
+ text_encoder.text_model.encoder.parameters(),
347
+ text_encoder.text_model.final_layer_norm.parameters(),
348
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
349
+ )
350
+ freeze_params(params_to_freeze)
351
+ ### Creating our training data
352
+ Let's create the Dataset and Dataloader
353
+ train_dataset = TextualInversionDataset(
354
+ data_root=save_path,
355
+ tokenizer=tokenizer,
356
+ size=vae.sample_size,
357
+ placeholder_token=placeholder_token,
358
+ repeats=100,
359
+ learnable_property=what_to_teach, #Option selected above between object and style
360
+ center_crop=False,
361
+ set="train",
362
+ )
363
+ def create_dataloader(train_batch_size=1):
364
+ return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
365
+ Create noise_scheduler for training
366
+ noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")
367
+ ### Training
368
+ Define hyperparameters for our training
369
+ If you are not happy with your results, you can tune the `learning_rate` and the `max_train_steps`
370
+ #@title Setting up all training args
371
+ hyperparameters = {
372
+ "learning_rate": 5e-04,
373
+ "scale_lr": True,
374
+ "max_train_steps": 2000,
375
+ "save_steps": 250,
376
+ "train_batch_size": 4,
377
+ "gradient_accumulation_steps": 1,
378
+ "gradient_checkpointing": True,
379
+ "mixed_precision": "fp16",
380
+ "seed": 42,
381
+ "output_dir": "sd-concept-output"
382
+ }
383
+ !mkdir -p sd-concept-output
384
+ Train!
385
+ #@title Training function
386
+ logger = get_logger(__name__)
387
+
388
+ def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
389
+ logger.info("Saving embeddings")
390
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
391
+ learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
392
+ torch.save(learned_embeds_dict, save_path)
393
+
394
+ def training_function(text_encoder, vae, unet):
395
+ train_batch_size = hyperparameters["train_batch_size"]
396
+ gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
397
+ learning_rate = hyperparameters["learning_rate"]
398
+ max_train_steps = hyperparameters["max_train_steps"]
399
+ output_dir = hyperparameters["output_dir"]
400
+ gradient_checkpointing = hyperparameters["gradient_checkpointing"]
401
+
402
+ accelerator = Accelerator(
403
+ gradient_accumulation_steps=gradient_accumulation_steps,
404
+ mixed_precision=hyperparameters["mixed_precision"]
405
+ )
406
+
407
+ if gradient_checkpointing:
408
+ text_encoder.gradient_checkpointing_enable()
409
+ unet.enable_gradient_checkpointing()
410
+
411
+ train_dataloader = create_dataloader(train_batch_size)
412
+
413
+ if hyperparameters["scale_lr"]:
414
+ learning_rate = (
415
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
416
+ )
417
+
418
+ # Initialize the optimizer
419
+ optimizer = torch.optim.AdamW(
420
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
421
+ lr=learning_rate,
422
+ )
423
+
424
+ text_encoder, optimizer, train_dataloader = accelerator.prepare(
425
+ text_encoder, optimizer, train_dataloader
426
+ )
427
+
428
+ weight_dtype = torch.float32
429
+ if accelerator.mixed_precision == "fp16":
430
+ weight_dtype = torch.float16
431
+ elif accelerator.mixed_precision == "bf16":
432
+ weight_dtype = torch.bfloat16
433
+
434
+ # Move vae and unet to device
435
+ vae.to(accelerator.device, dtype=weight_dtype)
436
+ unet.to(accelerator.device, dtype=weight_dtype)
437
+
438
+ # Keep vae in eval mode as we don't train it
439
+ vae.eval()
440
+ # Keep unet in train mode to enable gradient checkpointing
441
+ unet.train()
442
+
443
+
444
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
445
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
446
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
447
+
448
+ # Train!
449
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
450
+
451
+ logger.info("***** Running training *****")
452
+ logger.info(f" Num examples = {len(train_dataset)}")
453
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
454
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
455
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
456
+ logger.info(f" Total optimization steps = {max_train_steps}")
457
+ # Only show the progress bar once on each machine.
458
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
459
+ progress_bar.set_description("Steps")
460
+ global_step = 0
461
+
462
+ for epoch in range(num_train_epochs):
463
+ text_encoder.train()
464
+ for step, batch in enumerate(train_dataloader):
465
+ with accelerator.accumulate(text_encoder):
466
+ # Convert images to latent space
467
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
468
+ latents = latents * 0.18215
469
+
470
+ # Sample noise that we'll add to the latents
471
+ noise = torch.randn_like(latents)
472
+ bsz = latents.shape[0]
473
+ # Sample a random timestep for each image
474
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
475
+
476
+ # Add noise to the latents according to the noise magnitude at each timestep
477
+ # (this is the forward diffusion process)
478
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
479
+
480
+ # Get the text embedding for conditioning
481
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
482
+
483
+ # Predict the noise residual
484
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states.to(weight_dtype)).sample
485
+
486
+ # Get the target for loss depending on the prediction type
487
+ if noise_scheduler.config.prediction_type == "epsilon":
488
+ target = noise
489
+ elif noise_scheduler.config.prediction_type == "v_prediction":
490
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
491
+ else:
492
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
493
+
494
+ loss = F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
495
+ accelerator.backward(loss)
496
+
497
+ # Zero out the gradients for all token embeddings except the newly added
498
+ # embeddings for the concept, as we only want to optimize the concept embeddings
499
+ if accelerator.num_processes > 1:
500
+ grads = text_encoder.module.get_input_embeddings().weight.grad
501
+ else:
502
+ grads = text_encoder.get_input_embeddings().weight.grad
503
+ # Get the index for tokens that we want to zero the grads for
504
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
505
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
506
+
507
+ optimizer.step()
508
+ optimizer.zero_grad()
509
+
510
+ # Checks if the accelerator has performed an optimization step behind the scenes
511
+ if accelerator.sync_gradients:
512
+ progress_bar.update(1)
513
+ global_step += 1
514
+ if global_step % hyperparameters["save_steps"] == 0:
515
+ save_path = os.path.join(output_dir, f"learned_embeds-step-{global_step}.bin")
516
+ save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
517
+
518
+ logs = {"loss": loss.detach().item()}
519
+ progress_bar.set_postfix(**logs)
520
+
521
+ if global_step >= max_train_steps:
522
+ break
523
+
524
+ accelerator.wait_for_everyone()
525
+
526
+
527
+ # Create the pipeline using using the trained modules and save it.
528
+ if accelerator.is_main_process:
529
+ pipeline = StableDiffusionPipeline.from_pretrained(
530
+ pretrained_model_name_or_path,
531
+ text_encoder=accelerator.unwrap_model(text_encoder),
532
+ tokenizer=tokenizer,
533
+ vae=vae,
534
+ unet=unet,
535
+ )
536
+ pipeline.save_pretrained(output_dir)
537
+ # Also save the newly trained embeddings
538
+ save_path = os.path.join(output_dir, f"learned_embeds.bin")
539
+ save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
540
+ import accelerate
541
+ accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
542
+
543
+ for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
544
+ if param.grad is not None:
545
+ del param.grad # free some memory
546
+ torch.cuda.empty_cache()
547
+ ## Run the code with your newly trained model
548
+ If you have just trained your model with the code above, use the block below to run it
549
+
550
+ To save this concept for re-using, download the `learned_embeds.bin` file or save it on the library of concepts.
551
+
552
+ Use the [Stable Conceptualizer notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) for inference with persistently saved pre-trained concepts
553
+ #@title Save your newly created concept to the [library of concepts](https://huggingface.co/sd-concepts-library)?
554
+
555
+ save_concept_to_public_library = True #@param {type:"boolean"}
556
+ name_of_your_concept = "Japanese oysters" #@param {type:"string"}
557
+ #@markdown `hf_token_write`: leave blank if you logged in with a token with `write access` in the [Initial Setup](#scrollTo=KbzZ9xe6dWwf). If not, [go to your tokens settings and create a write access token](https://huggingface.co/settings/tokens)
558
+ hf_token_write = "" #@param {type:"string"}
559
+
560
+ if(save_concept_to_public_library):
561
+ from slugify import slugify
562
+ from huggingface_hub import HfApi, HfFolder, CommitOperationAdd
563
+ from huggingface_hub import create_repo
564
+ repo_id = f"sd-concepts-library/{slugify(name_of_your_concept)}"
565
+ output_dir = hyperparameters["output_dir"]
566
+ if(not hf_token_write):
567
+ with open(HfFolder.path_token, 'r') as fin: hf_token = fin.read();
568
+ else:
569
+ hf_token = hf_token_write
570
+ #Join the Concepts Library organization if you aren't part of it already
571
+ !curl -X POST -H 'Authorization: Bearer '$hf_token -H 'Content-Type: application/json' https://huggingface.co/organizations/sd-concepts-library/share/VcLXJtzwwxnHYCkNMLpSJCdnNFZHQwWywv
572
+ images_upload = os.listdir("my_concept")
573
+ image_string = ""
574
+ repo_id = f"sd-concepts-library/{slugify(name_of_your_concept)}"
575
+ for i, image in enumerate(images_upload):
576
+ image_string = f'''{image_string}![{placeholder_token} {i}](https://huggingface.co/{repo_id}/resolve/main/concept_images/{image})
577
+ '''
578
+ if(what_to_teach == "style"):
579
+ what_to_teach_article = f"a `{what_to_teach}`"
580
+ else:
581
+ what_to_teach_article = f"an `{what_to_teach}`"
582
+ readme_text = f'''---
583
+ license: mit
584
+ ---
585
+ ### {name_of_your_concept} on Stable Diffusion
586
+ This is the `{placeholder_token}` concept taught to Stable Diffusion via Textual Inversion. You can load this concept into the [Stable Conceptualizer](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) notebook. You can also train your own concepts and load them into the concept libraries using [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb).
587
+
588
+ Here is the new concept you will be able to use as {what_to_teach_article}:
589
+ {image_string}
590
+ '''
591
+ #Save the readme to a file
592
+ readme_file = open("README.md", "w")
593
+ readme_file.write(readme_text)
594
+ readme_file.close()
595
+ #Save the token identifier to a file
596
+ text_file = open("token_identifier.txt", "w")
597
+ text_file.write(placeholder_token)
598
+ text_file.close()
599
+ #Save the type of teached thing to a file
600
+ type_file = open("type_of_concept.txt","w")
601
+ type_file.write(what_to_teach)
602
+ type_file.close()
603
+ operations = [
604
+ CommitOperationAdd(path_in_repo="learned_embeds.bin", path_or_fileobj=f"{output_dir}/learned_embeds.bin"),
605
+ CommitOperationAdd(path_in_repo="token_identifier.txt", path_or_fileobj="token_identifier.txt"),
606
+ CommitOperationAdd(path_in_repo="type_of_concept.txt", path_or_fileobj="type_of_concept.txt"),
607
+ CommitOperationAdd(path_in_repo="README.md", path_or_fileobj="README.md"),
608
+ ]
609
+ create_repo(repo_id,private=True, token=hf_token)
610
+ api = HfApi()
611
+ api.create_commit(
612
+ repo_id=repo_id,
613
+ operations=operations,
614
+ commit_message=f"Upload the concept {name_of_your_concept} embeds and token",
615
+ token=hf_token
616
+ )
617
+ api.upload_folder(
618
+ folder_path=save_path,
619
+ path_in_repo="concept_images",
620
+ repo_id=repo_id,
621
+ token=hf_token
622
+ )
623
+ #@title Set up the pipeline
624
+ from diffusers import DPMSolverMultistepScheduler
625
+ pipe = StableDiffusionPipeline.from_pretrained(
626
+ hyperparameters["output_dir"],
627
+ scheduler=DPMSolverMultistepScheduler.from_pretrained(hyperparameters["output_dir"], subfolder="scheduler"),
628
+ torch_dtype=torch.float16,
629
+ ).to("cuda")
630
+ #@title Run the Stable Diffusion pipeline
631
+ #@markdown Don't forget to use the placeholder token in your prompt
632
+
633
+ prompt = "a \u003Cjapanese-oysters> inside ramen-bowl" #@param {type:"string"}
634
+
635
+ num_samples = 2 #@param {type:"number"}
636
+ num_rows = 1 #@param {type:"number"}
637
+
638
+ all_images = []
639
+ for _ in range(num_rows):
640
+ images = pipe([prompt] * num_samples, num_inference_steps=30, guidance_scale=7.5).images
641
+ all_images.extend(images)
642
+
643
+ grid = image_grid(all_images, num_rows, num_samples)
644
+ grid