adaface-neurips commited on
Commit
8ee7393
·
1 Parent(s): 76ccb95

update code

Browse files
.gitignore CHANGED
@@ -6,3 +6,5 @@ gradio_cached_examples/
6
  samples/*
7
  samples/
8
  .gradio/certificate.pem
 
 
 
6
  samples/*
7
  samples/
8
  .gradio/certificate.pem
9
+ models/*
10
+ models
ConsistentID/app.py CHANGED
@@ -26,8 +26,8 @@ pipe = ConsistentIDPipeline.from_pretrained(
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
29
- consistentID_weight_path="./models/ConsistentID-v1.bin",
30
- bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
  pipe = pipe.to(device, torch.float16)
 
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
29
+ consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
30
+ bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
  pipe = pipe.to(device, torch.float16)
ConsistentID/requirements.txt CHANGED
@@ -7,7 +7,7 @@ peft
7
  opencv-python
8
  insightface
9
  diffusers
10
- torch
11
  torchvision
12
  transformers
13
  spaces
 
7
  opencv-python
8
  insightface
9
  diffusers
10
+ torch==2.4.1
11
  torchvision
12
  transformers
13
  spaces
README2.md CHANGED
@@ -187,9 +187,9 @@ To exclude the effects of AdaFace, we generate a subset of videos with AdaFace-A
187
  ## Installation
188
 
189
  ### Manually Download Model Checkpoints
190
- - Download Stable Diffusion V1.5 into ``animatediff/sd``:
191
 
192
- ``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 animatediff/sd``
193
  - Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
194
  - Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
195
  - Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
 
187
  ## Installation
188
 
189
  ### Manually Download Model Checkpoints
190
+ - Download Stable Diffusion V1.5 into ``models/animatediff/sd``:
191
 
192
+ ``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/animatediff/sd``
193
  - Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
194
  - Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
195
  - Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
adaface/__init__.py ADDED
File without changes
adaface/adaface_infer.py CHANGED
@@ -41,42 +41,36 @@ def seed_everything(seed):
41
  def parse_args():
42
  parser = argparse.ArgumentParser()
43
  parser.add_argument("--pipeline", type=str, default="text2img",
44
- choices=["text2img", "img2img", "text2img3", "flux"],
45
  help="Type of pipeline to use (default: txt2img)")
46
  parser.add_argument("--base_model_path", type=str, default=None,
47
  help="Type of checkpoints to use (default: None, using the official model)")
48
- parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
49
- default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
50
- parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
51
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
 
 
 
52
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
53
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
54
  help="CFG scales of output embeddings of the ID2Ada prompt encoders")
55
  parser.add_argument("--main_unet_filepath", type=str, default=None,
56
  help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
57
  parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
58
- default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'],
59
  help="Extra paths to the checkpoints of the UNet models")
60
- parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1],
61
  help="Weights for the UNet models")
62
  parser.add_argument("--subject", type=str)
63
  parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
64
  parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
65
  parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
66
- parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
67
  parser.add_argument("--randface", action="store_true")
68
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
69
  help="Guidance scale for the diffusion model")
70
- parser.add_argument("--id_cfg_scale", type=float, default=6,
71
- help="CFG scale when generating the identity embeddings")
72
-
73
- parser.add_argument("--subject_string",
74
- type=str, default="z",
75
- help="Subject placeholder string used in prompts to denote the concept.")
76
  parser.add_argument("--num_images_per_row", type=int, default=4,
77
  help="Number of images to display in a row in the output grid image.")
78
- parser.add_argument("--num_inference_steps", type=int, default=50,
79
- help="Number of DDIM inference steps")
80
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
81
  parser.add_argument("--seed", type=int, default=42,
82
  help="the seed (for reproducible sampling). Set to -1 to disable.")
@@ -95,16 +89,15 @@ if __name__ == "__main__":
95
 
96
  if args.pipeline not in ["text2img", "img2img"]:
97
  args.extra_unet_dirpaths = None
98
- args.unet_weights = None
99
 
100
  adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
101
- args.adaface_encoder_types, args.adaface_ckpt_paths,
102
- args.adaface_encoder_cfg_scales,
103
- args.subject_string, args.num_inference_steps,
104
  unet_types=None,
105
  main_unet_filepath=args.main_unet_filepath,
106
  extra_unet_dirpaths=args.extra_unet_dirpaths,
107
- unet_weights=args.unet_weights, device=args.device)
108
 
109
  if not args.randface:
110
  image_folder = args.subject
@@ -143,7 +136,7 @@ if __name__ == "__main__":
143
  rand_init_id_embs = torch.randn(1, 512)
144
 
145
  init_id_embs = rand_init_id_embs if args.randface else None
146
- noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
147
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
148
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
149
  # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
@@ -151,5 +144,7 @@ if __name__ == "__main__":
151
  adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
152
  perturb_at_stage='img_prompt_emb',
153
  perturb_std=args.perturb_std, update_text_encoder=True)
154
- images = adaface(noise, args.prompt, None, 'append', args.guidance_scale, args.out_image_count, verbose=True)
 
 
155
  save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
 
41
  def parse_args():
42
  parser = argparse.ArgumentParser()
43
  parser.add_argument("--pipeline", type=str, default="text2img",
44
+ choices=["text2img", "text2imgxl", "img2img", "text2img3", "flux"],
45
  help="Type of pipeline to use (default: txt2img)")
46
  parser.add_argument("--base_model_path", type=str, default=None,
47
  help="Type of checkpoints to use (default: None, using the official model)")
48
+ parser.add_argument('--adaface_ckpt_path', type=str, required=True)
49
+ parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
 
50
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
51
+ parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
52
+ choices=["arc2face", "consistentID"],
53
+ help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
54
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
55
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
56
  help="CFG scales of output embeddings of the ID2Ada prompt encoders")
57
  parser.add_argument("--main_unet_filepath", type=str, default=None,
58
  help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
59
  parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
60
+ default=[],
61
  help="Extra paths to the checkpoints of the UNet models")
62
+ parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
63
  help="Weights for the UNet models")
64
  parser.add_argument("--subject", type=str)
65
  parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
66
  parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
67
  parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
68
+ parser.add_argument("--perturb_std", type=float, default=0)
69
  parser.add_argument("--randface", action="store_true")
70
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
71
  help="Guidance scale for the diffusion model")
 
 
 
 
 
 
72
  parser.add_argument("--num_images_per_row", type=int, default=4,
73
  help="Number of images to display in a row in the output grid image.")
 
 
74
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
75
  parser.add_argument("--seed", type=int, default=42,
76
  help="the seed (for reproducible sampling). Set to -1 to disable.")
 
89
 
90
  if args.pipeline not in ["text2img", "img2img"]:
91
  args.extra_unet_dirpaths = None
92
+ args.unet_weights_in_ensemble = None
93
 
94
  adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
95
+ args.adaface_encoder_types, args.adaface_ckpt_path,
96
+ args.adaface_encoder_cfg_scales, args.enabled_encoders,
 
97
  unet_types=None,
98
  main_unet_filepath=args.main_unet_filepath,
99
  extra_unet_dirpaths=args.extra_unet_dirpaths,
100
+ unet_weights_in_ensemble=args.unet_weights_in_ensemble, device=args.device)
101
 
102
  if not args.randface:
103
  image_folder = args.subject
 
136
  rand_init_id_embs = torch.randn(1, 512)
137
 
138
  init_id_embs = rand_init_id_embs if args.randface else None
139
+ init_noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
140
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
141
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
142
  # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
 
144
  adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
145
  perturb_at_stage='img_prompt_emb',
146
  perturb_std=args.perturb_std, update_text_encoder=True)
147
+ images = adaface(init_noise, args.prompt, None, None,
148
+ 'append', args.guidance_scale,
149
+ args.out_image_count, verbose=True)
150
  save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
adaface/adaface_translate.py CHANGED
@@ -25,21 +25,25 @@ def seed_everything(seed):
25
 
26
  def parse_args():
27
  parser = argparse.ArgumentParser()
28
- parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
- help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
- parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
31
- default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
32
- parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
33
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
 
 
 
34
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
35
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
36
  help="CFG scales of output embeddings of the ID2Ada prompt encoders")
37
  parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
38
- default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'],
39
  help="Extra paths to the checkpoints of the UNet models")
40
- parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1],
41
  help="Weights for the UNet models")
42
  parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
 
 
43
  # If True, the input folder contains images of mixed subjects.
44
  # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
45
  parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
@@ -49,19 +53,14 @@ def parse_args():
49
  parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
50
  parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
51
  parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
52
- parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
53
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
54
  help="Guidance scale for the diffusion model")
55
  parser.add_argument("--ref_img_strength", type=float, default=0.8,
56
  help="Strength of the reference image in the output image.")
57
- parser.add_argument("--subject_string",
58
- type=str, default="z",
59
- help="Subject placeholder string used in prompts to denote the concept.")
60
  parser.add_argument("--prompt", type=str, default="a person z")
61
  parser.add_argument("--num_images_per_row", type=int, default=4,
62
  help="Number of images to display in a row in the output grid image.")
63
- parser.add_argument("--num_inference_steps", type=int, default=50,
64
- help="Number of DDIM inference steps")
65
  parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
66
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
67
  parser.add_argument("--seed", type=int, default=42,
@@ -90,15 +89,16 @@ if __name__ == "__main__":
90
  process_index = 0
91
 
92
  adaface = AdaFaceWrapper("img2img", args.base_model_path,
93
- args.adaface_encoder_types, args.adaface_ckpt_paths,
94
- args.adaface_encoder_cfg_scales,
95
- args.subject_string, args.num_inference_steps,
96
  unet_types=None,
97
- extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights,
 
98
  device=args.device)
99
 
100
  in_folder = args.in_folder
101
  if os.path.isfile(in_folder):
 
102
  subject_folders = [ os.path.dirname(in_folder) ]
103
  images_by_subject = [[in_folder]]
104
  else:
@@ -154,6 +154,24 @@ if __name__ == "__main__":
154
  images_by_subject = images_by_subject[process_index::args.num_gpus]
155
  #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
158
  # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
159
  # Otherwise, we use the folder name as the signature of the images.
@@ -173,29 +191,32 @@ if __name__ == "__main__":
173
  os.makedirs(subject_out_folder)
174
  print(f"Output images will be saved to {subject_out_folder}")
175
 
176
- in_images = []
177
- for image_path in image_paths:
178
- image = Image.open(image_path).convert("RGB").resize((512, 512))
179
- # [512, 512, 3] -> [3, 512, 512].
180
- image = np.array(image).transpose(2, 0, 1)
181
- # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
182
- image = torch.tensor(image).unsqueeze(0).float().cuda()
183
- in_images.append(image)
184
-
185
- # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
186
- # NOTE: For simplicity, we do not check overly large batch sizes.
187
- in_images = torch.cat(in_images, dim=0)
188
- # in_images: [5, 3, 512, 512].
189
- # Normalize the pixel values to [0, 1].
190
- in_images = in_images / 255.0
191
- num_out_images = len(in_images) * args.out_count_per_input_image
 
192
 
193
  with torch.no_grad():
194
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
195
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
196
  # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
197
  # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
198
- out_images = adaface(in_images, args.prompt, None, 'append', args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
 
 
199
 
200
  for img_i, img in enumerate(out_images):
201
  # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
@@ -203,9 +224,11 @@ if __name__ == "__main__":
203
  copy_i = img_i // len(in_images)
204
  image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
205
  if copy_i == 0:
206
- img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
207
  else:
208
- img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
 
 
209
 
210
  if args.copy_masks:
211
  mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
 
25
 
26
  def parse_args():
27
  parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/sar/sar.safetensors',
29
+ help="Path to the UNet checkpoint (Default: SAR)")
30
+ parser.add_argument('--adaface_ckpt_path', type=str, required=True)
31
+ parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
 
32
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
33
+ parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
34
+ choices=["arc2face", "consistentID"],
35
+ help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
36
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
37
  parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
38
  help="CFG scales of output embeddings of the ID2Ada prompt encoders")
39
  parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
40
+ default=[],
41
  help="Extra paths to the checkpoints of the UNet models")
42
+ parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
43
  help="Weights for the UNet models")
44
  parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
45
+ parser.add_argument("--restore_image", type=str, default=None,
46
+ help="Path to the image to be restored")
47
  # If True, the input folder contains images of mixed subjects.
48
  # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
49
  parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
 
53
  parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
54
  parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
55
  parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
56
+ parser.add_argument("--perturb_std", type=float, default=0)
57
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
58
  help="Guidance scale for the diffusion model")
59
  parser.add_argument("--ref_img_strength", type=float, default=0.8,
60
  help="Strength of the reference image in the output image.")
 
 
 
61
  parser.add_argument("--prompt", type=str, default="a person z")
62
  parser.add_argument("--num_images_per_row", type=int, default=4,
63
  help="Number of images to display in a row in the output grid image.")
 
 
64
  parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
65
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
66
  parser.add_argument("--seed", type=int, default=42,
 
89
  process_index = 0
90
 
91
  adaface = AdaFaceWrapper("img2img", args.base_model_path,
92
+ args.adaface_encoder_types, args.adaface_ckpt_path,
93
+ args.adaface_encoder_cfg_scales, args.enabled_encoders,
 
94
  unet_types=None,
95
+ extra_unet_dirpaths=args.extra_unet_dirpaths,
96
+ unet_weights_in_ensemble=args.unet_weights_in_ensemble,
97
  device=args.device)
98
 
99
  in_folder = args.in_folder
100
  if os.path.isfile(in_folder):
101
+ args.in_folder = os.path.dirname(args.in_folder)
102
  subject_folders = [ os.path.dirname(in_folder) ]
103
  images_by_subject = [[in_folder]]
104
  else:
 
154
  images_by_subject = images_by_subject[process_index::args.num_gpus]
155
  #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
156
 
157
+ if args.restore_image is not None:
158
+ in_images = []
159
+ for image_path in [args.restore_image]:
160
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
161
+ # [512, 512, 3] -> [3, 512, 512].
162
+ image = np.array(image).transpose(2, 0, 1)
163
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
164
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
165
+ in_images.append(image)
166
+
167
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
168
+ # NOTE: For simplicity, we do not check overly large batch sizes.
169
+ in_images = torch.cat(in_images, dim=0)
170
+ # in_images: [5, 3, 512, 512].
171
+ # Normalize the pixel values to [0, 1].
172
+ in_images = in_images / 255.0
173
+ num_out_images = len(in_images) * args.out_count_per_input_image
174
+
175
  for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
176
  # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
177
  # Otherwise, we use the folder name as the signature of the images.
 
191
  os.makedirs(subject_out_folder)
192
  print(f"Output images will be saved to {subject_out_folder}")
193
 
194
+ if args.restore_image is None:
195
+ in_images = []
196
+ for image_path in image_paths:
197
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
198
+ # [512, 512, 3] -> [3, 512, 512].
199
+ image = np.array(image).transpose(2, 0, 1)
200
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
201
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
202
+ in_images.append(image)
203
+
204
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
205
+ # NOTE: For simplicity, we do not check overly large batch sizes.
206
+ in_images = torch.cat(in_images, dim=0)
207
+ # in_images: [5, 3, 512, 512].
208
+ # Normalize the pixel values to [0, 1].
209
+ in_images = in_images / 255.0
210
+ num_out_images = len(in_images) * args.out_count_per_input_image
211
 
212
  with torch.no_grad():
213
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
214
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
215
  # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
216
  # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
217
+ out_images = adaface(in_images, args.prompt, None, None,
218
+ 'append', args.guidance_scale, num_out_images,
219
+ ref_img_strength=args.ref_img_strength)
220
 
221
  for img_i, img in enumerate(out_images):
222
  # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
 
224
  copy_i = img_i // len(in_images)
225
  image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
226
  if copy_i == 0:
227
+ save_path = os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")
228
  else:
229
+ save_path = os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")
230
+ img.save(save_path)
231
+ print(f"Saved {save_path}")
232
 
233
  if args.copy_masks:
234
  mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
adaface/adaface_wrapper.py CHANGED
@@ -8,22 +8,29 @@ from diffusers import (
8
  StableDiffusion3Pipeline,
9
  #FluxPipeline,
10
  DDIMScheduler,
 
 
11
  AutoencoderKL,
 
12
  )
13
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
14
  from adaface.util import UNetEnsemble
15
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
 
16
  from safetensors.torch import load_file as safetensors_load_file
17
  import re, os
18
  import numpy as np
 
19
 
20
  class AdaFaceWrapper(nn.Module):
21
  def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
22
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
23
- enabled_encoders=None,
24
- subject_string='z', num_inference_steps=50, negative_prompt=None,
25
  use_840k_vae=False, use_ds_text_encoder=False,
26
- main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights=None,
 
 
27
  device='cuda', is_training=False):
28
  '''
29
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
@@ -38,15 +45,23 @@ class AdaFaceWrapper(nn.Module):
38
  self.adaface_ckpt_paths = adaface_ckpt_paths
39
  self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
40
  self.enabled_encoders = enabled_encoders
 
 
 
 
 
 
41
  self.subject_string = subject_string
 
42
 
43
- self.num_inference_steps = num_inference_steps
 
44
  self.use_840k_vae = use_840k_vae
45
  self.use_ds_text_encoder = use_ds_text_encoder
46
  self.main_unet_filepath = main_unet_filepath
47
  self.unet_types = unet_types
48
  self.extra_unet_dirpaths = extra_unet_dirpaths
49
- self.unet_weights = unet_weights
50
  self.device = device
51
  self.is_training = is_training
52
 
@@ -62,7 +77,14 @@ class AdaFaceWrapper(nn.Module):
62
  self.initialize_pipeline()
63
  # During inference, we never use static image suffix embeddings.
64
  # So num_id_vecs is the length of the returned adaface embeddings for each encoder.
65
- self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs
 
 
 
 
 
 
 
66
  self.extend_tokenizer_and_text_encoder()
67
 
68
  def to(self, device):
@@ -76,7 +98,8 @@ class AdaFaceWrapper(nn.Module):
76
  self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
77
  self.adaface_ckpt_paths,
78
  self.adaface_encoder_cfg_scales,
79
- self.enabled_encoders)
 
80
 
81
  self.id2ada_prompt_encoder.to(self.device)
82
  print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
@@ -118,10 +141,10 @@ class AdaFaceWrapper(nn.Module):
118
 
119
  if self.base_model_path is None:
120
  base_model_path_dict = {
121
- 'text2img': 'models/sd15-dste8-vae.safetensors',
122
- 'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
123
- 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
124
- 'flux': 'black-forest-labs/FLUX.1-schnell',
125
  }
126
  self.base_model_path = base_model_path_dict[self.pipeline_name]
127
 
@@ -137,6 +160,20 @@ class AdaFaceWrapper(nn.Module):
137
  safety_checker=None
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if self.main_unet_filepath is not None:
141
  print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
142
  ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
@@ -147,12 +184,19 @@ class AdaFaceWrapper(nn.Module):
147
 
148
  if (self.unet_types is not None and len(self.unet_types) > 0) \
149
  or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
150
- unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights,
151
  device=self.device, torch_dtype=torch.float16)
152
  pipeline.unet = unet_ensemble
153
 
154
  print(f"Loaded pipeline from {self.base_model_path}.")
155
-
 
 
 
 
 
 
 
156
  if self.use_840k_vae:
157
  pipeline.vae = vae
158
  print("Replaced the VAE with the 840k-step VAE.")
@@ -167,19 +211,56 @@ class AdaFaceWrapper(nn.Module):
167
  pipeline.vae = None
168
  print("Removed UNet and VAE from the pipeline.")
169
 
170
- if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"]:
171
- noise_scheduler = DDIMScheduler(
172
- num_train_timesteps=1000,
173
- beta_start=0.00085,
174
- beta_end=0.012,
175
- beta_schedule="scaled_linear",
176
- clip_sample=False,
177
- set_alpha_to_one=False,
178
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  pipeline.scheduler = noise_scheduler
180
- # Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
 
181
  self.pipeline = pipeline.to(self.device)
182
 
 
 
 
 
183
  def load_unet_from_file(self, unet_path, device=None):
184
  if os.path.isfile(unet_path):
185
  if unet_path.endswith(".safetensors"):
@@ -208,7 +289,109 @@ class AdaFaceWrapper(nn.Module):
208
  else:
209
  raise ValueError(f"UNet path {unet_path} is not a file.")
210
  return unet_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def extend_tokenizer_and_text_encoder(self):
213
  if np.sum(self.encoders_num_id_vecs) < 1:
214
  raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
@@ -218,6 +401,7 @@ class AdaFaceWrapper(nn.Module):
218
  # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
219
  self.all_placeholder_tokens = []
220
  self.placeholder_tokens_strs = []
 
221
  for i in range(len(self.adaface_encoder_types)):
222
  placeholder_tokens = []
223
  for j in range(self.encoders_num_id_vecs[i]):
@@ -225,9 +409,11 @@ class AdaFaceWrapper(nn.Module):
225
  placeholder_tokens_str = " ".join(placeholder_tokens)
226
 
227
  self.all_placeholder_tokens.extend(placeholder_tokens)
 
228
  self.placeholder_tokens_strs.append(placeholder_tokens_str)
229
 
230
  self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
 
231
  # all_null_placeholder_tokens_str: ", , , , ..." (20 times).
232
  # It just contains the commas and spaces with the same length, but no actual tokens.
233
  self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
@@ -241,7 +427,7 @@ class AdaFaceWrapper(nn.Module):
241
 
242
  print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
243
 
244
- # placeholder_token_ids: [49408, ..., 49423].
245
  self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
246
  #print("New tokens:", self.placeholder_token_ids)
247
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
@@ -252,24 +438,49 @@ class AdaFaceWrapper(nn.Module):
252
 
253
  # Extend pipeline.text_encoder with the adaface subject emeddings.
254
  # subj_embs: [16, 768].
255
- def update_text_encoder_subj_embeddings(self, subj_embs):
256
  # Initialise the newly added placeholder token with the embeddings of the initializer token
257
  # token_embeds: [49412, 768]
258
  token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
 
 
 
 
259
  with torch.no_grad():
260
- for i, token_id in enumerate(self.placeholder_token_ids):
261
- token_embeds[token_id] = subj_embs[i]
262
- print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.all_placeholder_tokens_str}) in the text encoder.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def update_prompt(self, prompt, placeholder_tokens_pos='append',
 
265
  use_null_placeholders=False):
266
  if prompt is None:
267
  prompt = ""
268
 
269
  if use_null_placeholders:
270
  all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
 
 
 
271
  else:
272
- all_placeholder_tokens_str = self.all_placeholder_tokens_str
273
 
274
  # Delete the subject_string from the prompt.
275
  prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
@@ -279,15 +490,29 @@ class AdaFaceWrapper(nn.Module):
279
  # When we do joint training, seems both work better if they are appended to the prompt.
280
  # Therefore we simply appended all placeholder_tokens_str's to the prompt.
281
  # NOTE: Prepending them hurts compositional prompts.
282
- if placeholder_tokens_pos == 'prepend':
283
- prompt = all_placeholder_tokens_str + " " + prompt
284
- elif placeholder_tokens_pos == 'append':
285
- prompt = prompt + " " + all_placeholder_tokens_str
 
 
 
 
 
 
 
286
  else:
287
- breakpoint()
 
 
 
 
 
288
 
289
  return prompt
290
 
 
 
291
  # If face_id_embs is None, then it extracts face_id_embs from the images,
292
  # then map them to ada prompt embeddings.
293
  # avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
@@ -298,27 +523,29 @@ class AdaFaceWrapper(nn.Module):
298
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
299
  perturb_std=0, update_text_encoder=True):
300
 
301
- all_adaface_subj_embs = \
302
  self.id2ada_prompt_encoder.generate_adaface_embeddings(\
303
  image_paths, face_id_embs=face_id_embs,
304
  img_prompt_embs=None,
305
  avg_at_stage=avg_at_stage,
306
  perturb_at_stage=perturb_at_stage,
307
  perturb_std=perturb_std,
308
- enable_static_img_suffix_embs=False)
309
 
310
  if all_adaface_subj_embs is None:
311
  return None
312
 
 
 
313
  if all_adaface_subj_embs.ndim == 4:
314
- # [1, 1, 16, 768] -> [16, 768]
315
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
316
  elif all_adaface_subj_embs.ndim == 3:
317
- # [1, 16, 768] -> [16, 768]
318
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
319
 
320
  if update_text_encoder:
321
- self.update_text_encoder_subj_embeddings(all_adaface_subj_embs)
322
  return all_adaface_subj_embs
323
 
324
  def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
@@ -368,6 +595,7 @@ class AdaFaceWrapper(nn.Module):
368
  else:
369
  breakpoint()
370
  else:
 
371
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
372
  prompt_embeds_, negative_prompt_embeds_ = \
373
  self.pipeline.encode_prompt(prompt, device=device,
@@ -378,9 +606,53 @@ class AdaFaceWrapper(nn.Module):
378
  return prompt_embeds_, negative_prompt_embeds_, \
379
  pooled_prompt_embeds_, negative_pooled_prompt_embeds_
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  def encode_prompt(self, prompt, negative_prompt=None,
382
  placeholder_tokens_pos='append',
383
- do_neg_id_prompt_weight=0,
 
 
 
 
384
  device=None, verbose=False):
385
  if negative_prompt is None:
386
  negative_prompt = self.negative_prompt
@@ -389,59 +661,81 @@ class AdaFaceWrapper(nn.Module):
389
  device = self.device
390
 
391
  plain_prompt = prompt
392
- prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos)
 
 
 
 
 
 
393
  if verbose:
394
  print(f"Subject prompt:\n{prompt}")
395
 
396
- if do_neg_id_prompt_weight > 0:
397
- # Use 'prepend' for the negative prompt, since it's long and we want to make sure
398
- # the placeholder tokens are not cut off.
399
- negative_prompt0 = negative_prompt
400
- negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend')
401
- null_negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend',
402
- use_null_placeholders=True)
403
- ''' if verbose:
404
- print(f"Negative prompt:\n{negative_prompt}")
405
- print(f"Null negative prompt:\n{null_negative_prompt}")
406
-
407
- '''
408
- else:
409
- null_negative_prompt = None
410
-
411
  # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
412
  # So we manually move it to GPU here.
413
  self.pipeline.text_encoder.to(device)
414
 
415
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
416
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
417
-
418
- if 0 < do_neg_id_prompt_weight < 1:
419
- _, negative_prompt_embeds_null, _, _ = \
420
- self.diffusers_encode_prompts(prompt, plain_prompt, null_negative_prompt, device)
421
- negative_prompt_embeds_ = negative_prompt_embeds_ * do_neg_id_prompt_weight + \
422
- negative_prompt_embeds_null * (1 - do_neg_id_prompt_weight)
423
-
 
 
 
 
 
 
 
424
  return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
425
 
426
  # ref_img_strength is used only in the img2img pipeline.
427
- def forward(self, noise, prompt, negative_prompt=None,
428
  placeholder_tokens_pos='append',
429
- do_neg_id_prompt_weight=0,
430
  guidance_scale=6.0, out_image_count=4,
431
- ref_img_strength=0.8, generator=None, verbose=False):
 
 
 
 
 
 
432
  noise = noise.to(device=self.device, dtype=torch.float16)
 
 
433
 
434
  if negative_prompt is None:
435
  negative_prompt = self.negative_prompt
436
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
437
- prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
438
- negative_pooled_prompt_embeds_ = \
439
- self.encode_prompt(prompt, negative_prompt,
440
- placeholder_tokens_pos=placeholder_tokens_pos,
441
- do_neg_id_prompt_weight=do_neg_id_prompt_weight,
442
- device=self.device, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  # Repeat the prompt embeddings for all images in the batch.
444
  prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
 
445
  if negative_prompt_embeds_ is not None:
446
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
447
 
 
8
  StableDiffusion3Pipeline,
9
  #FluxPipeline,
10
  DDIMScheduler,
11
+ PNDMScheduler,
12
+ DPMSolverSinglestepScheduler,
13
  AutoencoderKL,
14
+ LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
  from adaface.util import UNetEnsemble
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
+ from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
21
  import re, os
22
  import numpy as np
23
+ from peft.utils.constants import DUMMY_TARGET_MODULES
24
 
25
  class AdaFaceWrapper(nn.Module):
26
  def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
+ enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
+ num_inference_steps=50, subject_string='z', negative_prompt=None,
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
+ main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
+ enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
+ attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
34
  device='cuda', is_training=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
 
45
  self.adaface_ckpt_paths = adaface_ckpt_paths
46
  self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
47
  self.enabled_encoders = enabled_encoders
48
+ # None, or a list of two bools for two encoders. If None, both are disabled.
49
+ self.enable_static_img_suffix_embs = enable_static_img_suffix_embs
50
+ self.unet_uses_attn_lora = unet_uses_attn_lora
51
+ self.attn_lora_layer_names = attn_lora_layer_names
52
+ self.q_lora_updates_query = q_lora_updates_query
53
+ self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
+ self.shrink_cross_attn = shrink_cross_attn
56
 
57
+ self.default_scheduler_name = default_scheduler_name
58
+ self.num_inference_steps = num_inference_steps if not use_lcm else 4
59
  self.use_840k_vae = use_840k_vae
60
  self.use_ds_text_encoder = use_ds_text_encoder
61
  self.main_unet_filepath = main_unet_filepath
62
  self.unet_types = unet_types
63
  self.extra_unet_dirpaths = extra_unet_dirpaths
64
+ self.unet_weights_in_ensemble = unet_weights_in_ensemble
65
  self.device = device
66
  self.is_training = is_training
67
 
 
77
  self.initialize_pipeline()
78
  # During inference, we never use static image suffix embeddings.
79
  # So num_id_vecs is the length of the returned adaface embeddings for each encoder.
80
+ self.encoders_num_id_vecs = np.array(self.id2ada_prompt_encoder.encoders_num_id_vecs)
81
+ self.encoders_num_static_img_suffix_embs = np.array(self.id2ada_prompt_encoder.encoders_num_static_img_suffix_embs)
82
+ if self.enable_static_img_suffix_embs is not None:
83
+ assert len(self.enable_static_img_suffix_embs) == len(self.encoders_num_id_vecs)
84
+ self.encoders_num_static_img_suffix_embs *= np.array(self.enable_static_img_suffix_embs)
85
+ self.encoders_num_id_vecs += self.encoders_num_static_img_suffix_embs
86
+
87
+ self.img_prompt_embs = None
88
  self.extend_tokenizer_and_text_encoder()
89
 
90
  def to(self, device):
 
98
  self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
99
  self.adaface_ckpt_paths,
100
  self.adaface_encoder_cfg_scales,
101
+ self.enabled_encoders,
102
+ num_static_img_suffix_embs=4)
103
 
104
  self.id2ada_prompt_encoder.to(self.device)
105
  print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
 
141
 
142
  if self.base_model_path is None:
143
  base_model_path_dict = {
144
+ 'text2img': 'models/sd15-dste8-vae.safetensors',
145
+ 'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
146
+ 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
147
+ 'flux': 'black-forest-labs/FLUX.1-schnell',
148
  }
149
  self.base_model_path = base_model_path_dict[self.pipeline_name]
150
 
 
160
  safety_checker=None
161
  )
162
 
163
+ if self.use_lcm:
164
+ lcm_path_dict = {
165
+ 'text2img': 'latent-consistency/lcm-lora-sdv1-5',
166
+ 'text2imgxl': 'latent-consistency/lcm-lora-sdxl',
167
+ }
168
+ if self.pipeline_name not in lcm_path_dict:
169
+ raise ValueError(f"Pipeline {self.pipeline_name} does not support LCM.")
170
+
171
+ lcm_path = lcm_path_dict[self.pipeline_name]
172
+ pipeline.load_lora_weights(lcm_path)
173
+ pipeline.fuse_lora()
174
+ print(f"Loaded LCM weights from {lcm_path}.")
175
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
176
+
177
  if self.main_unet_filepath is not None:
178
  print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
179
  ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
 
184
 
185
  if (self.unet_types is not None and len(self.unet_types) > 0) \
186
  or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
187
+ unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights_in_ensemble,
188
  device=self.device, torch_dtype=torch.float16)
189
  pipeline.unet = unet_ensemble
190
 
191
  print(f"Loaded pipeline from {self.base_model_path}.")
192
+ if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
193
+ unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
194
+ attn_lora_layer_names=self.attn_lora_layer_names,
195
+ shrink_cross_attn=self.shrink_cross_attn,
196
+ q_lora_updates_query=self.q_lora_updates_query)
197
+
198
+ pipeline.unet = unet2
199
+
200
  if self.use_840k_vae:
201
  pipeline.vae = vae
202
  print("Replaced the VAE with the 840k-step VAE.")
 
211
  pipeline.vae = None
212
  print("Removed UNet and VAE from the pipeline.")
213
 
214
+ if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"] and not self.use_lcm:
215
+ if self.default_scheduler_name == 'ddim':
216
+ noise_scheduler = DDIMScheduler(
217
+ num_train_timesteps=1000,
218
+ beta_start=0.00085,
219
+ beta_end=0.012,
220
+ beta_schedule="scaled_linear",
221
+ clip_sample=False,
222
+ set_alpha_to_one=False,
223
+ steps_offset=1,
224
+ timestep_spacing="leading",
225
+ rescale_betas_zero_snr=False,
226
+ )
227
+ elif self.default_scheduler_name == 'pndm':
228
+ noise_scheduler = PNDMScheduler(
229
+ num_train_timesteps=1000,
230
+ beta_start=0.00085,
231
+ beta_end=0.012,
232
+ beta_schedule="scaled_linear",
233
+ set_alpha_to_one=False,
234
+ steps_offset=1,
235
+ timestep_spacing="leading",
236
+ skip_prk_steps=True,
237
+ )
238
+ elif self.default_scheduler_name == 'dpm++':
239
+ noise_scheduler = DPMSolverSinglestepScheduler(
240
+ beta_start=0.00085,
241
+ beta_end=0.012,
242
+ beta_schedule="scaled_linear",
243
+ prediction_type="epsilon",
244
+ num_train_timesteps=1000,
245
+ trained_betas=None,
246
+ thresholding=False,
247
+ algorithm_type="dpmsolver++",
248
+ solver_type="midpoint",
249
+ lower_order_final=True,
250
+ use_karras_sigmas=True,
251
+ )
252
+ else:
253
+ breakpoint()
254
+
255
  pipeline.scheduler = noise_scheduler
256
+ # Otherwise, if not use_lcm, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
257
+ # if use_lcm, pipeline.scheduler == LCMScheduler
258
  self.pipeline = pipeline.to(self.device)
259
 
260
+ def set_adaface_encoder_cfg_scales(self, adaface_encoder_cfg_scales):
261
+ self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
262
+ self.id2ada_prompt_encoder.set_out_id_embs_cfg_scale(adaface_encoder_cfg_scales)
263
+
264
  def load_unet_from_file(self, unet_path, device=None):
265
  if os.path.isfile(unet_path):
266
  if unet_path.endswith(".safetensors"):
 
289
  else:
290
  raise ValueError(f"UNet path {unet_path} is not a file.")
291
  return unet_state_dict
292
+
293
+ # Adapted from ConsistentIDPipeline:set_ip_adapter().
294
+ def load_unet_loras(self, unet, unet_lora_modules_state_dict,
295
+ use_attn_lora=True, use_ffn_lora=False,
296
+ attn_lora_layer_names=['q', 'k', 'v', 'out'],
297
+ shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
298
+ q_lora_updates_query=False):
299
+ attn_capture_procs, attn_opt_modules = \
300
+ set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
301
+ lora_rank=192, lora_scale_down=8,
302
+ cross_attn_shrink_factor=cross_attn_shrink_factor,
303
+ q_lora_updates_query=q_lora_updates_query)
304
+ # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
305
+ if use_ffn_lora:
306
+ target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
307
+ else:
308
+ # A special pattern, "dummy-target-modules" tells PEFT to add loras on NONE of the layers.
309
+ # We couldn't simply skip PEFT initialization (converting unet to a PEFT model),
310
+ # otherwise the attn lora layers will cause nan quickly during a fp16 training.
311
+ target_modules_pat = DUMMY_TARGET_MODULES
312
+
313
+ unet, ffn_lora_layers, ffn_opt_modules = \
314
+ set_up_ffn_loras(unet, target_modules_pat=target_modules_pat, lora_uses_dora=True)
315
+
316
+ # self.attn_capture_procs and ffn_lora_layers will be used in set_lora_and_capture_flags().
317
+ self.attn_capture_procs = list(attn_capture_procs.values())
318
+ self.ffn_lora_layers = list(ffn_lora_layers.values())
319
+ # Combine attn_opt_modules and ffn_opt_modules into unet_lora_modules.
320
+ # unet_lora_modules is for optimization and loading/saving.
321
+ unet_lora_modules = {}
322
+ # attn_opt_modules and ffn_opt_modules have different depths of keys.
323
+ # attn_opt_modules:
324
+ # up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_std_shrink_factor,
325
+ # up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_to_q_lora_lora_A, ...
326
+ # ffn_opt_modules:
327
+ # base_model_model_up_blocks_3_resnets_1_conv1_lora_A, ...
328
+ # with the prefix 'base_model_model_'. Because ffn_opt_modules are extracted from the peft-wrapped model,
329
+ # and attn_opt_modules are extracted from the original unet model.
330
+ # To be compatible with old param keys, we append 'base_model_model_' to the keys of attn_opt_modules.
331
+ unet_lora_modules.update({ f'base_model_model_{k}': v for k, v in attn_opt_modules.items() })
332
+ unet_lora_modules.update(ffn_opt_modules)
333
+ # ParameterDict can contain both Parameter and nn.Module.
334
+ # TODO: maybe in the future, we couldn't put nn.Module in nn.ParameterDict.
335
+ self.unet_lora_modules = torch.nn.ParameterDict(unet_lora_modules)
336
+
337
+ missing, unexpected = self.unet_lora_modules.load_state_dict(unet_lora_modules_state_dict, strict=False)
338
+ if len(missing) > 0:
339
+ print(f"Missing Keys: {missing}")
340
+ if len(unexpected) > 0:
341
+ print(f"Unexpected Keys: {unexpected}")
342
+
343
+ print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
344
+ self.outfeat_capture_blocks.append(unet.up_blocks[3])
345
+
346
+ # If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
347
+ # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
348
+ set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
349
+ use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
350
+ shrink_cross_attn=shrink_cross_attn)
351
+
352
+ return unet
353
+
354
+ def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
355
+ shrink_cross_attn=False, q_lora_updates_query=False):
356
+ unet_lora_weight_found = False
357
+ if isinstance(self.adaface_ckpt_paths, str):
358
+ adaface_ckpt_paths = [self.adaface_ckpt_paths]
359
+ else:
360
+ adaface_ckpt_paths = self.adaface_ckpt_paths
361
+
362
+ for adaface_ckpt_path in adaface_ckpt_paths:
363
+ ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
364
+ if 'unet_lora_modules' in ckpt_dict:
365
+ unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
366
+ print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
367
+ unet_lora_weight_found = True
368
+ break
369
+
370
+ # Since unet lora weights are not found in the adaface ckpt, we give up on loading unet attn processors.
371
+ if not unet_lora_weight_found:
372
+ print(f"LoRA weights not found in {self.adaface_ckpt_paths}.")
373
+ return unet
374
 
375
+ self.outfeat_capture_blocks = []
376
+
377
+ if isinstance(unet, UNetEnsemble):
378
+ for i, unet_ in enumerate(unet.unets):
379
+ unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
380
+ use_attn_lora=use_attn_lora,
381
+ attn_lora_layer_names=attn_lora_layer_names,
382
+ shrink_cross_attn=shrink_cross_attn,
383
+ q_lora_updates_query=q_lora_updates_query)
384
+ unet.unets[i] = unet_
385
+ print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
386
+ else:
387
+ unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
388
+ use_attn_lora=use_attn_lora,
389
+ attn_lora_layer_names=attn_lora_layer_names,
390
+ shrink_cross_attn=shrink_cross_attn,
391
+ q_lora_updates_query=q_lora_updates_query)
392
+
393
+ return unet
394
+
395
  def extend_tokenizer_and_text_encoder(self):
396
  if np.sum(self.encoders_num_id_vecs) < 1:
397
  raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
 
401
  # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
402
  self.all_placeholder_tokens = []
403
  self.placeholder_tokens_strs = []
404
+ self.encoder_placeholder_tokens = []
405
  for i in range(len(self.adaface_encoder_types)):
406
  placeholder_tokens = []
407
  for j in range(self.encoders_num_id_vecs[i]):
 
409
  placeholder_tokens_str = " ".join(placeholder_tokens)
410
 
411
  self.all_placeholder_tokens.extend(placeholder_tokens)
412
+ self.encoder_placeholder_tokens.append(placeholder_tokens)
413
  self.placeholder_tokens_strs.append(placeholder_tokens_str)
414
 
415
  self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
416
+ self.updated_tokens_str = self.all_placeholder_tokens_str
417
  # all_null_placeholder_tokens_str: ", , , , ..." (20 times).
418
  # It just contains the commas and spaces with the same length, but no actual tokens.
419
  self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
 
427
 
428
  print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
429
 
430
+ # placeholder_token_ids: [49408, ..., 49427].
431
  self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
432
  #print("New tokens:", self.placeholder_token_ids)
433
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
 
438
 
439
  # Extend pipeline.text_encoder with the adaface subject emeddings.
440
  # subj_embs: [16, 768].
441
+ def update_text_encoder_subj_embeddings(self, subj_embs, lens_subj_emb_segments):
442
  # Initialise the newly added placeholder token with the embeddings of the initializer token
443
  # token_embeds: [49412, 768]
444
  token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
445
+ all_encoders_updated_tokens = []
446
+ all_encoders_updated_token_strs = []
447
+ idx = 0
448
+
449
  with torch.no_grad():
450
+ # sum of lens_subj_emb_segments are probably shorter than self.placeholder_token_ids,
451
+ # when some static_img_suffix_embs are disabled.
452
+ for i, encoder_type in enumerate(self.adaface_encoder_types):
453
+ encoder_updated_tokens = []
454
+ if (self.enabled_encoders is not None) and (encoder_type not in self.enabled_encoders):
455
+ idx += lens_subj_emb_segments[i]
456
+ continue
457
+ for j in range(lens_subj_emb_segments[i]):
458
+ placeholder_token = f"{self.subject_string}_{i}_{j}"
459
+ token_id = self.pipeline.tokenizer.convert_tokens_to_ids(placeholder_token)
460
+ token_embeds[token_id] = subj_embs[idx]
461
+ encoder_updated_tokens.append(placeholder_token)
462
+ idx += 1
463
+
464
+ all_encoders_updated_tokens.extend(encoder_updated_tokens)
465
+ all_encoders_updated_token_strs.append(" ".join(encoder_updated_tokens))
466
+
467
+ self.updated_tokens_str = " ".join(all_encoders_updated_token_strs)
468
+ self.all_encoders_updated_token_strs = all_encoders_updated_token_strs
469
+ print(f"Updated {len(all_encoders_updated_tokens)} tokens ({self.updated_tokens_str}) in the text encoder.")
470
 
471
  def update_prompt(self, prompt, placeholder_tokens_pos='append',
472
+ repeat_prompt_for_each_encoder=True,
473
  use_null_placeholders=False):
474
  if prompt is None:
475
  prompt = ""
476
 
477
  if use_null_placeholders:
478
  all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
479
+ if not re.search(r"\b(man|woman|person|child|girl|boy)\b", prompt.lower()):
480
+ all_placeholder_tokens_str = "person " + all_placeholder_tokens_str
481
+ repeat_prompt_for_each_encoder = False
482
  else:
483
+ all_placeholder_tokens_str = self.updated_tokens_str
484
 
485
  # Delete the subject_string from the prompt.
486
  prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
 
490
  # When we do joint training, seems both work better if they are appended to the prompt.
491
  # Therefore we simply appended all placeholder_tokens_str's to the prompt.
492
  # NOTE: Prepending them hurts compositional prompts.
493
+ if repeat_prompt_for_each_encoder:
494
+ encoder_prompts = []
495
+ for encoder_updated_token_strs in self.all_encoders_updated_token_strs:
496
+ if placeholder_tokens_pos == 'prepend':
497
+ encoder_prompt = encoder_updated_token_strs + " " + prompt
498
+ elif placeholder_tokens_pos == 'append':
499
+ encoder_prompt = prompt + " " + encoder_updated_token_strs
500
+ else:
501
+ breakpoint()
502
+ encoder_prompts.append(encoder_prompt)
503
+ prompt = ", ".join(encoder_prompts)
504
  else:
505
+ if placeholder_tokens_pos == 'prepend':
506
+ prompt = all_placeholder_tokens_str + " " + prompt
507
+ elif placeholder_tokens_pos == 'append':
508
+ prompt = prompt + " " + all_placeholder_tokens_str
509
+ else:
510
+ breakpoint()
511
 
512
  return prompt
513
 
514
+ # NOTE: all_adaface_subj_embs is the input to the CLIP text encoder.
515
+ # ** DO NOT use it as prompt_embeds in the forward() method.
516
  # If face_id_embs is None, then it extracts face_id_embs from the images,
517
  # then map them to ada prompt embeddings.
518
  # avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
 
523
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
524
  perturb_std=0, update_text_encoder=True):
525
 
526
+ all_adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments = \
527
  self.id2ada_prompt_encoder.generate_adaface_embeddings(\
528
  image_paths, face_id_embs=face_id_embs,
529
  img_prompt_embs=None,
530
  avg_at_stage=avg_at_stage,
531
  perturb_at_stage=perturb_at_stage,
532
  perturb_std=perturb_std,
533
+ enable_static_img_suffix_embs=self.enable_static_img_suffix_embs)
534
 
535
  if all_adaface_subj_embs is None:
536
  return None
537
 
538
+ self.img_prompt_embs = img_prompt_embs
539
+
540
  if all_adaface_subj_embs.ndim == 4:
541
+ # [1, 1, 20, 768] -> [20, 768]
542
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
543
  elif all_adaface_subj_embs.ndim == 3:
544
+ # [1, 20, 768] -> [20, 768]
545
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
546
 
547
  if update_text_encoder:
548
+ self.update_text_encoder_subj_embeddings(all_adaface_subj_embs, lens_subj_emb_segments)
549
  return all_adaface_subj_embs
550
 
551
  def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
 
595
  else:
596
  breakpoint()
597
  else:
598
+ # "text2img" and "img2img" pipelines.
599
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
600
  prompt_embeds_, negative_prompt_embeds_ = \
601
  self.pipeline.encode_prompt(prompt, device=device,
 
606
  return prompt_embeds_, negative_prompt_embeds_, \
607
  pooled_prompt_embeds_, negative_pooled_prompt_embeds_
608
 
609
+ # alt_prompt_embed_type: 'ada-nonmix', 'img'
610
+ def mix_ada_embs_with_other_embs(self, prompt, prompt_embeds,
611
+ alt_prompt_embed_type, alt_prompt_emb_weights):
612
+ # Scan prompt and replace tokens in self.placeholder_token_ids
613
+ # with the corresponding image embeddings.
614
+ prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
615
+ prompt_embeds2 = prompt_embeds.clone()
616
+ if alt_prompt_embed_type == 'img':
617
+ if self.img_prompt_embs is None:
618
+ print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
619
+ return prompt_embeds
620
+ # self.img_prompt_embs: [1, 20, 768]
621
+ repl_embeddings = self.img_prompt_embs
622
+ elif alt_prompt_embed_type == 'ada-nonmix':
623
+ repl_embeddings_, _, _, _ = self.encode_prompt(prompt, ablate_prompt_only_placeholders=True,
624
+ verbose=True)
625
+ # repl_embeddings_: [1, 77, 768] -> [1, 20, 768]
626
+ repl_embeddings = repl_embeddings_[:, 1:len(self.all_placeholder_tokens)+1]
627
+ else:
628
+ breakpoint()
629
+
630
+ repl_tokens = {}
631
+ for i in range(len(prompt_tokens)):
632
+ if prompt_tokens[i] in self.all_placeholder_tokens:
633
+ encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
634
+ if prompt_tokens[i] in sublist), 0)
635
+ alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
636
+ prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
637
+ + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
638
+ repl_tokens[prompt_tokens[i]] = 1
639
+
640
+ repl_token_count = len(repl_tokens)
641
+ if np.all(np.array(alt_prompt_emb_weights) == 1):
642
+ print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
643
+ else:
644
+ print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
645
+
646
+ return prompt_embeds2
647
+
648
+
649
  def encode_prompt(self, prompt, negative_prompt=None,
650
  placeholder_tokens_pos='append',
651
+ ablate_prompt_only_placeholders=False,
652
+ ablate_prompt_no_placeholders=False,
653
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
654
+ nonmix_prompt_emb_weight=0,
655
+ repeat_prompt_for_each_encoder=True,
656
  device=None, verbose=False):
657
  if negative_prompt is None:
658
  negative_prompt = self.negative_prompt
 
661
  device = self.device
662
 
663
  plain_prompt = prompt
664
+ if ablate_prompt_only_placeholders:
665
+ prompt = self.updated_tokens_str
666
+ else:
667
+ prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos,
668
+ repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
669
+ use_null_placeholders=ablate_prompt_no_placeholders)
670
+
671
  if verbose:
672
  print(f"Subject prompt:\n{prompt}")
673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
675
  # So we manually move it to GPU here.
676
  self.pipeline.text_encoder.to(device)
677
 
678
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
679
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
680
+
681
+ if ablate_prompt_embed_type != 'ada':
682
+ alt_prompt_embed_type = ablate_prompt_embed_type
683
+ alt_prompt_emb_weights = (1, 1)
684
+ elif nonmix_prompt_emb_weight > 0:
685
+ alt_prompt_embed_type = 'ada-nonmix'
686
+ alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
687
+ else:
688
+ alt_prompt_emb_weights = (0, 0)
689
+
690
+ if sum(alt_prompt_emb_weights) > 0:
691
+ prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
692
+ alt_prompt_embed_type, alt_prompt_emb_weights)
693
+
694
  return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
695
 
696
  # ref_img_strength is used only in the img2img pipeline.
697
+ def forward(self, noise, prompt, prompt_embeds=None, negative_prompt=None,
698
  placeholder_tokens_pos='append',
 
699
  guidance_scale=6.0, out_image_count=4,
700
+ ref_img_strength=0.8, generator=None,
701
+ ablate_prompt_only_placeholders=False,
702
+ ablate_prompt_no_placeholders=False,
703
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
704
+ nonmix_prompt_emb_weight=0,
705
+ repeat_prompt_for_each_encoder=True,
706
+ verbose=False):
707
  noise = noise.to(device=self.device, dtype=torch.float16)
708
+ if self.use_lcm:
709
+ guidance_scale = 0
710
 
711
  if negative_prompt is None:
712
  negative_prompt = self.negative_prompt
713
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
714
+ if prompt_embeds is None:
715
+ prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
716
+ negative_pooled_prompt_embeds_ = \
717
+ self.encode_prompt(prompt, negative_prompt,
718
+ placeholder_tokens_pos=placeholder_tokens_pos,
719
+ ablate_prompt_only_placeholders=ablate_prompt_only_placeholders,
720
+ ablate_prompt_no_placeholders=ablate_prompt_no_placeholders,
721
+ ablate_prompt_embed_type=ablate_prompt_embed_type,
722
+ nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
723
+ repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
724
+ device=self.device,
725
+ verbose=verbose)
726
+ else:
727
+ if len(prompt_embeds) == 2:
728
+ prompt_embeds_, negative_prompt_embeds_ = prompt_embeds
729
+ pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None
730
+ elif len(prompt_embeds) == 4:
731
+ prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
732
+ negative_pooled_prompt_embeds_ = prompt_embeds
733
+ else:
734
+ breakpoint()
735
+
736
  # Repeat the prompt embeddings for all images in the batch.
737
  prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
738
+
739
  if negative_prompt_embeds_ is not None:
740
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
741
 
adaface/diffusers_attn_lora_capture.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, Dict, Any
5
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
+ from diffusers.utils import logging, is_torch_version, deprecate
7
+ from diffusers.utils.torch_utils import fourier_filter
8
+ # UNet is a diffusers PeftAdapterMixin instance.
9
+ from diffusers.loaders.peft import PeftAdapterMixin
10
+ from peft import LoraConfig, get_peft_model
11
+ import peft.tuners.lora as peft_lora
12
+ from peft.tuners.lora.dora import DoraLinearLayer
13
+ from einops import rearrange
14
+ import math, re
15
+ import numpy as np
16
+ from peft.tuners.tuners_utils import BaseTunerLayer
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ def dummy_func(*args, **kwargs):
22
+ pass
23
+
24
+ # Revised from RevGrad, by removing the grad negation.
25
+ class ScaleGrad(torch.autograd.Function):
26
+ @staticmethod
27
+ def forward(ctx, input_, alpha_, debug=False):
28
+ ctx.save_for_backward(alpha_, debug)
29
+ output = input_
30
+ if debug:
31
+ print(f"input: {input_.abs().mean().item()}")
32
+ return output
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_output): # pragma: no cover
36
+ # saved_tensors returns a tuple of tensors.
37
+ alpha_, debug = ctx.saved_tensors
38
+ if ctx.needs_input_grad[0]:
39
+ grad_output2 = grad_output * alpha_
40
+ if debug:
41
+ print(f"grad_output2: {grad_output2.abs().mean().item()}")
42
+ else:
43
+ grad_output2 = None
44
+ return grad_output2, None, None
45
+
46
+ class GradientScaler(nn.Module):
47
+ def __init__(self, alpha=1., debug=False, *args, **kwargs):
48
+ """
49
+ A gradient scaling layer.
50
+ This layer has no parameters, and simply scales the gradient in the backward pass.
51
+ """
52
+ super().__init__(*args, **kwargs)
53
+
54
+ self._alpha = torch.tensor(alpha, requires_grad=False)
55
+ self._debug = torch.tensor(debug, requires_grad=False)
56
+
57
+ def forward(self, input_):
58
+ _debug = self._debug if hasattr(self, '_debug') else False
59
+ return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
60
+
61
+ def gen_gradient_scaler(alpha, debug=False):
62
+ if alpha == 1:
63
+ return nn.Identity()
64
+ if alpha > 0:
65
+ return GradientScaler(alpha, debug=debug)
66
+ else:
67
+ assert alpha == 0
68
+ # Don't use lambda function here, otherwise the object can't be pickled.
69
+ return torch.detach
70
+
71
+ def split_indices_by_instance(indices, as_dict=False):
72
+ indices_B, indices_N = indices
73
+ unique_indices_B = torch.unique(indices_B)
74
+ if not as_dict:
75
+ indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ]
76
+ else:
77
+ indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
78
+ return indices_by_instance
79
+
80
+ # If do_sum, returned emb_attns is 3D. Otherwise 4D.
81
+ # indices are applied on the first 2 dims of attn_mat.
82
+ def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
83
+ indices_by_instance = split_indices_by_instance(indices)
84
+
85
+ # emb_attns[0]: [1, 9, 8, 64]
86
+ # 8: 8 attention heads. Last dim 64: number of image tokens.
87
+ emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
88
+ if all_token_weights is not None:
89
+ # all_token_weights: [4, 77].
90
+ # token_weights_by_instance[0]: [1, 9, 1, 1].
91
+ token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
92
+ else:
93
+ token_weights = [ 1 ] * len(indices_by_instance)
94
+
95
+ # Apply token weights.
96
+ emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
97
+
98
+ # sum among K_subj_i subj embeddings -> [1, 8, 64]
99
+ if do_sum:
100
+ emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
101
+ elif do_mean:
102
+ emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
103
+
104
+ emb_attns = torch.cat(emb_attns, dim=0)
105
+ return emb_attns
106
+
107
+ # Slow implementation equivalent to F.scaled_dot_product_attention.
108
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
109
+ shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
110
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
111
+ B, L, S = query.size(0), query.size(-2), key.size(-2)
112
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
113
+ # 1: head (to be broadcasted). L: query length. S: key length.
114
+ attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype)
115
+ if is_causal:
116
+ assert attn_mask is None
117
+ temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0)
118
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
119
+ attn_bias.to(query.dtype)
120
+
121
+ if attn_mask is not None:
122
+ if attn_mask.dtype == torch.bool:
123
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
124
+ else:
125
+ attn_bias += attn_mask
126
+
127
+ if enable_gqa:
128
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
129
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
130
+
131
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
132
+
133
+ if shrink_cross_attn:
134
+ cross_attn_scale = cross_attn_shrink_factor
135
+ else:
136
+ cross_attn_scale = 1
137
+
138
+ # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
139
+ attn_weight += attn_bias
140
+ attn_score = attn_weight
141
+ attn_weight = torch.softmax(attn_weight, dim=-1)
142
+ # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
143
+ # But this is intended, as we want to scale down the impact of the subject embeddings
144
+ # in the computed attention output tensors.
145
+ attn_weight = attn_weight * cross_attn_scale
146
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
147
+ output = attn_weight @ value
148
+ return output, attn_score, attn_weight
149
+
150
+ # All layers share the same attention processor instance.
151
+ class AttnProcessor_LoRA_Capture(nn.Module):
152
+ r"""
153
+ Revised from AttnProcessor2_0
154
+ """
155
+ # lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer.
156
+ def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
157
+ lora_uses_dora=True, lora_proj_layers=None,
158
+ lora_rank: int = 192, lora_alpha: float = 16,
159
+ cross_attn_shrink_factor: float = 0.5,
160
+ q_lora_updates_query=False, attn_proc_idx=-1):
161
+ super().__init__()
162
+
163
+ self.global_enable_lora = enable_lora
164
+ self.attn_proc_idx = attn_proc_idx
165
+ # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
166
+ # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
167
+ self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
168
+ self.lora_rank = lora_rank
169
+ self.lora_alpha = lora_alpha
170
+ self.lora_scale = self.lora_alpha / self.lora_rank
171
+ self.cross_attn_shrink_factor = cross_attn_shrink_factor
172
+ self.q_lora_updates_query = q_lora_updates_query
173
+
174
+ self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
175
+ if self.global_enable_lora:
176
+ for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
177
+ if lora_layer_name == 'q':
178
+ self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
179
+ use_dora=lora_uses_dora, lora_dropout=0.1)
180
+ elif lora_layer_name == 'k':
181
+ self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
182
+ use_dora=lora_uses_dora, lora_dropout=0.1)
183
+ elif lora_layer_name == 'v':
184
+ self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
185
+ use_dora=lora_uses_dora, lora_dropout=0.1)
186
+ elif lora_layer_name == 'out':
187
+ self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
188
+ use_dora=lora_uses_dora, lora_dropout=0.1)
189
+
190
+ # LoRA layers can be enabled/disabled dynamically.
191
+ def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
192
+ self.capture_ca_activations = capture_ca_activations
193
+ self.shrink_cross_attn = shrink_cross_attn
194
+ self.cached_activations = {}
195
+ # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
196
+ self.enable_lora = enable_lora and self.global_enable_lora
197
+
198
+ def __call__(
199
+ self,
200
+ attn: Attention,
201
+ hidden_states: torch.Tensor,
202
+ encoder_hidden_states: Optional[torch.Tensor] = None,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ temb: Optional[torch.Tensor] = None,
205
+ img_mask: Optional[torch.Tensor] = None,
206
+ subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None,
207
+ debug: bool = False,
208
+ *args,
209
+ **kwargs,
210
+ ) -> torch.Tensor:
211
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
212
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
213
+ deprecate("scale", "1.0.0", deprecation_message)
214
+
215
+ # hidden_states: [1, 4096, 320]
216
+ residual = hidden_states
217
+ # attn.spatial_norm is None.
218
+ if attn.spatial_norm is not None:
219
+ hidden_states = attn.spatial_norm(hidden_states, temb)
220
+
221
+ input_ndim = hidden_states.ndim
222
+
223
+ if input_ndim == 4:
224
+ batch_size, channel, height, width = hidden_states.shape
225
+ # Collapse the spatial dimensions to a single token dimension.
226
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
227
+
228
+ batch_size, sequence_length, _ = (
229
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
230
+ )
231
+
232
+ if attention_mask is not None:
233
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
234
+ # scaled_dot_product_attention expects attention_mask shape to be
235
+ # (batch, heads, source_length, target_length)
236
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
237
+
238
+ if attn.group_norm is not None:
239
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
240
+
241
+ query = attn.to_q(hidden_states)
242
+ # NOTE: there's a inconsistency between q lora and k, v loras.
243
+ # k, v loras are directly applied to key and value (currently k, v loras are never enabled),
244
+ # while q lora is applied to query2, and we keep the query unchanged.
245
+ if self.enable_lora and self.to_q_lora is not None:
246
+ # query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate
247
+ # cross attention scores between the latent images of the sc and mc instances.
248
+ query2 = self.to_q_lora(hidden_states)
249
+ # If not q_lora_updates_query, only query2 will be impacted by the LoRA layer.
250
+ # The query, and thus the attention score and attn_out, will be the same
251
+ # as the original ones.
252
+ if self.q_lora_updates_query:
253
+ query = query2
254
+ else:
255
+ query2 = query
256
+
257
+ scale = 1 / math.sqrt(query.size(-1))
258
+
259
+ is_cross_attn = (encoder_hidden_states is not None)
260
+ if (not is_cross_attn) and (img_mask is not None):
261
+ # NOTE: we assume the image is square. But this will fail if the image is not square.
262
+ # hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64]
263
+ # Scale the mask to the same size as hidden_states.
264
+ mask_size = int(math.sqrt(hidden_states.shape[-2]))
265
+ img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest')
266
+ if (img_mask.sum(dim=(2, 3)) == 0).any():
267
+ img_mask = None
268
+ else:
269
+ # img_mask: [2, 1, 64, 64] -> [2, 4096]
270
+ img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous()
271
+ # max_neg_value = -torch.finfo(hidden_states.dtype).max
272
+ # img_mask: [2, 4096] -> [2, 1, 1, 4096]
273
+ img_mask = rearrange(img_mask.bool(), 'b j -> b () () j')
274
+ # attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096].
275
+ # So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf).
276
+ # But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine.
277
+ # attn_score.masked_fill_(~img_mask, max_neg_value)
278
+ # NOTE: If there's an attention mask, it will be replaced by img_mask.
279
+ attention_mask = img_mask
280
+
281
+ if encoder_hidden_states is None:
282
+ encoder_hidden_states = hidden_states
283
+ elif attn.norm_cross:
284
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
285
+
286
+ if self.enable_lora and self.to_k_lora is not None:
287
+ key = self.to_k_lora(encoder_hidden_states)
288
+ else:
289
+ key = attn.to_k(encoder_hidden_states)
290
+
291
+ if self.enable_lora and self.to_v_lora is not None:
292
+ value = self.to_v_lora(encoder_hidden_states)
293
+ else:
294
+ value = attn.to_v(encoder_hidden_states)
295
+
296
+ if attn.norm_q is not None:
297
+ query = attn.norm_q(query)
298
+ query2 = attn.norm_q(query2)
299
+ if attn.norm_k is not None:
300
+ key = attn.norm_k(key)
301
+
302
+ inner_dim = key.shape[-1]
303
+ head_dim = inner_dim // attn.heads
304
+
305
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
306
+ query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
307
+
308
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
309
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
310
+
311
+ if debug and self.attn_proc_idx >= 0:
312
+ breakpoint()
313
+
314
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
315
+ if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
316
+ hidden_states, attn_score, attn_prob = \
317
+ scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
318
+ dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
319
+ cross_attn_shrink_factor=self.cross_attn_shrink_factor)
320
+ else:
321
+ # Use the faster implementation of scaled_dot_product_attention
322
+ # when not capturing the activations or suppressing the subject attention.
323
+ hidden_states = \
324
+ F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
325
+ attn_prob = attn_score = None
326
+
327
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
328
+ hidden_states = hidden_states.to(query.dtype)
329
+
330
+ # linear proj
331
+ if self.enable_lora and self.to_out_lora is not None:
332
+ hidden_states = self.to_out_lora(hidden_states)
333
+ else:
334
+ hidden_states = attn.to_out[0](hidden_states)
335
+
336
+ # dropout
337
+ hidden_states = attn.to_out[1](hidden_states)
338
+
339
+ if input_ndim == 4:
340
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
341
+
342
+ if attn.residual_connection:
343
+ hidden_states = hidden_states + residual
344
+
345
+ hidden_states = hidden_states / attn.rescale_output_factor
346
+
347
+ if is_cross_attn and self.capture_ca_activations:
348
+ # cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other.
349
+ # So sqrt(scale) will scale the product of two qs by scale.
350
+ # ANCHOR[id=attention_caching]
351
+ # query: [2, 8, 4096, 40] -> [2, 320, 4096]
352
+ self.cached_activations['q'] = \
353
+ rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
354
+ self.cached_activations['q2'] = \
355
+ rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
356
+ self.cached_activations['k'] = \
357
+ rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
358
+ self.cached_activations['v'] = \
359
+ rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
360
+ # attn_prob, attn_score: [2, 8, 4096, 77]
361
+ self.cached_activations['attn'] = attn_prob
362
+ self.cached_activations['attnscore'] = attn_score
363
+ # attn_out: [b, n, h * d] -> [b, h * d, n]
364
+ # [2, 4096, 320] -> [2, 320, 4096].
365
+ self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous()
366
+
367
+ return hidden_states
368
+
369
+ def CrossAttnUpBlock2D_forward_capture(
370
+ self,
371
+ hidden_states: torch.Tensor,
372
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
373
+ temb: Optional[torch.Tensor] = None,
374
+ encoder_hidden_states: Optional[torch.Tensor] = None,
375
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
376
+ upsample_size: Optional[int] = None,
377
+ attention_mask: Optional[torch.Tensor] = None,
378
+ encoder_attention_mask: Optional[torch.Tensor] = None,
379
+ ) -> torch.Tensor:
380
+ if cross_attention_kwargs is not None:
381
+ if cross_attention_kwargs.get("scale", None) is not None:
382
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
383
+
384
+ self.cached_outfeats = {}
385
+ res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1)
386
+ capture_outfeats = getattr(self, "capture_outfeats", False)
387
+ layer_idx = 0
388
+ res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale)
389
+
390
+ for resnet, attn in zip(self.resnets, self.attentions):
391
+ # pop res hidden states
392
+ res_hidden_states = res_hidden_states_tuple[-1]
393
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
394
+
395
+ # Scale down the magnitudes of gradients to res_hidden_states
396
+ # by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs.
397
+ res_hidden_states = res_grad_scaler(res_hidden_states)
398
+
399
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
400
+
401
+ if self.training and self.gradient_checkpointing:
402
+ def create_custom_forward(module, return_dict=None):
403
+ def custom_forward(*inputs):
404
+ if return_dict is not None:
405
+ return module(*inputs, return_dict=return_dict)
406
+ else:
407
+ return module(*inputs)
408
+
409
+ return custom_forward
410
+
411
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
412
+ hidden_states = torch.utils.checkpoint.checkpoint(
413
+ create_custom_forward(resnet),
414
+ hidden_states,
415
+ temb,
416
+ **ckpt_kwargs,
417
+ )
418
+ hidden_states = attn(
419
+ hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ cross_attention_kwargs=cross_attention_kwargs,
422
+ attention_mask=attention_mask,
423
+ encoder_attention_mask=encoder_attention_mask,
424
+ return_dict=False,
425
+ )[0]
426
+ else:
427
+ # resnet: ResnetBlock2D instance.
428
+ #LINK diffusers.models.resnet.ResnetBlock2D
429
+ # up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D,
430
+ # it's not transforming the UNet shortcut features.
431
+ hidden_states = resnet(hidden_states, temb)
432
+ hidden_states = attn(
433
+ hidden_states,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ cross_attention_kwargs=cross_attention_kwargs,
436
+ attention_mask=attention_mask,
437
+ encoder_attention_mask=encoder_attention_mask,
438
+ return_dict=False,
439
+ )[0]
440
+
441
+ if capture_outfeats:
442
+ self.cached_outfeats[layer_idx] = hidden_states
443
+ layer_idx += 1
444
+
445
+ if self.upsamplers is not None:
446
+ for upsampler in self.upsamplers:
447
+ hidden_states = upsampler(hidden_states, upsample_size)
448
+
449
+ return hidden_states
450
+
451
+
452
+ # Adapted from ConsistentIDPipeline:set_ip_adapter().
453
+ # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
454
+ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
455
+ lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
456
+ q_lora_updates_query=False):
457
+ attn_procs = {}
458
+ attn_capture_procs = {}
459
+ unet_modules = dict(unet.named_modules())
460
+ attn_opt_modules = {}
461
+
462
+ attn_proc_idx = 0
463
+
464
+ for name, attn_proc in unet.attn_processors.items():
465
+ # Only capture the activations of the last 3 CA layers.
466
+ if not name.startswith("up_blocks.3"):
467
+ # Not the last 3 CA layers. Don't enable LoRA or capture activations.
468
+ # Then the layer falls back to the original attention mechanism.
469
+ # We still use AttnProcessor_LoRA_Capture, as it can handle img_mask.
470
+ attn_procs[name] = AttnProcessor_LoRA_Capture(
471
+ capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
472
+ continue
473
+ # cross_attention_dim: 768.
474
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
475
+ if cross_attention_dim is None:
476
+ # Self attention. Don't enable LoRA or capture activations.
477
+ # We replace the default attn_proc with AttnProcessor_LoRA_Capture,
478
+ # so that it can incorporate img_mask into self-attention.
479
+ attn_procs[name] = AttnProcessor_LoRA_Capture(
480
+ capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
481
+ continue
482
+
483
+ # block_id = 3
484
+ # hidden_size: 320
485
+ # hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
486
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' ->
487
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q'
488
+ lora_layer_dict = {}
489
+ lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"]
490
+ lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"]
491
+ lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"]
492
+ # to_out is a ModuleList(Linear, Dropout).
493
+ lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0]
494
+
495
+ lora_proj_layers = {}
496
+ # Only apply LoRA to the specified layers.
497
+ for lora_layer_name in attn_lora_layer_names:
498
+ lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name]
499
+
500
+ attn_capture_proc = AttnProcessor_LoRA_Capture(
501
+ capture_ca_activations=True, enable_lora=use_attn_lora,
502
+ lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
503
+ # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
504
+ lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
505
+ cross_attn_shrink_factor=cross_attn_shrink_factor,
506
+ q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
507
+
508
+ attn_proc_idx += 1
509
+ # attn_procs has to use the original names.
510
+ attn_procs[name] = attn_capture_proc
511
+ # ModuleDict doesn't allow "." in the key.
512
+ name = name.replace(".", "_")
513
+ attn_capture_procs[name] = attn_capture_proc
514
+
515
+ if use_attn_lora:
516
+ for subname, module in attn_capture_proc.named_modules():
517
+ if isinstance(module, peft_lora.LoraLayer):
518
+ # ModuleDict doesn't allow "." in the key.
519
+ lora_path = name + "_" + subname.replace(".", "_")
520
+ attn_opt_modules[lora_path + "_lora_A"] = module.lora_A
521
+ attn_opt_modules[lora_path + "_lora_B"] = module.lora_B
522
+ # lora_uses_dora is always True, so we don't check it here.
523
+ attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector
524
+ # We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer,
525
+ # so according to the code logic in diffusers/loaders/peft.py,
526
+ # they will be managed by the diffusers PeftAdapterMixin instance, through the
527
+ # enable_adapters(), and set_adapter() methods.
528
+ # Therefore, we disable these calls on module.
529
+ # disable_adapters() is a property and changing it will cause exceptions.
530
+ module.enable_adapters = dummy_func
531
+ module.set_adapter = dummy_func
532
+
533
+ unet.set_attn_processor(attn_procs)
534
+
535
+ print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.")
536
+ print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.")
537
+ return attn_capture_procs, attn_opt_modules
538
+
539
+ # NOTE: cross-attn layers are included in the returned lora_modules.
540
+ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
541
+ # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
542
+ # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
543
+ # Cannot set to conv.+ as it will match added adapter module names, including
544
+ # up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout
545
+ if target_modules_pat is not None:
546
+ peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank,
547
+ lora_alpha=lora_alpha, lora_dropout=0.1,
548
+ target_modules=target_modules_pat)
549
+
550
+ # UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will
551
+ # cause weird errors. Instead, we directly use diffusers peft adapter methods.
552
+ unet.add_adapter(peft_config, "recon_loss")
553
+ unet.add_adapter(peft_config, "unet_distill")
554
+ unet.add_adapter(peft_config, "comp_distill")
555
+ unet.enable_adapters()
556
+
557
+ # lora_layers contain both the LoRA A and B matrices, as well as the original layers.
558
+ # lora_layers are used to set the flag, not used for optimization.
559
+ # lora_modules contain only the LoRA A and B matrices, so they are used for optimization.
560
+ # NOTE: lora_modules contain both ffn and cross-attn lora modules.
561
+ ffn_lora_layers = {}
562
+ ffn_opt_modules = {}
563
+ for name, module in unet.named_modules():
564
+ if isinstance(module, peft_lora.LoraLayer):
565
+ # We don't want to include cross-attn layers in ffn_lora_layers.
566
+ if target_modules_pat is not None and re.search(target_modules_pat, name):
567
+ ffn_lora_layers[name] = module
568
+ # ModuleDict doesn't allow "." in the key.
569
+ name = name.replace(".", "_")
570
+ # Since ModuleDict doesn't allow "." in the key, we manually collect
571
+ # the LoRA matrices in each module.
572
+ # NOTE: We cannot put every sub-module of module into lora_modules,
573
+ # as base_layer is also a sub-module of module, which we shouldn't optimize.
574
+ # Each value in ffn_opt_modules is a ModuleDict:
575
+ '''
576
+ (Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A']
577
+ ModuleDict(
578
+ (unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
579
+ (recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
580
+ )
581
+ '''
582
+ ffn_opt_modules[name + "_lora_A"] = module.lora_A
583
+ ffn_opt_modules[name + "_lora_B"] = module.lora_B
584
+ if lora_uses_dora:
585
+ ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector
586
+
587
+ print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.")
588
+ print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.")
589
+
590
+ return ffn_lora_layers, ffn_opt_modules
591
+
592
+ def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
593
+ outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
594
+ use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
595
+ shrink_cross_attn, res_hidden_states_gradscale):
596
+ # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
597
+ for attn_capture_proc in attn_capture_procs:
598
+ attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
599
+ # outfeat_capture_blocks only contains the last up block, up_blocks[3].
600
+ # It contains 3 FFN layers. We want to capture their output features.
601
+ for block in outfeat_capture_blocks:
602
+ block.capture_outfeats = capture_ca_activations
603
+
604
+ for block in res_hidden_states_gradscale_blocks:
605
+ block.res_hidden_states_gradscale = res_hidden_states_gradscale
606
+
607
+ if not use_ffn_lora:
608
+ unet.disable_adapters()
609
+ else:
610
+ # ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'.
611
+ if ffn_lora_adapter_name is not None:
612
+ unet.set_adapter(ffn_lora_adapter_name)
613
+ # NOTE: Don't forget to enable_adapters().
614
+ # The adapters are not enabled by default after set_adapter().
615
+ unet.enable_adapters()
616
+ else:
617
+ breakpoint()
618
+
619
+ # During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False,
620
+ # which might cause issues during DDP training.
621
+ # So we restore them to requires_grad=True.
622
+ # During test, unet_lora_modules will be passed as None, so this block will be skipped.
623
+ if unet_lora_modules is not None:
624
+ for param in unet_lora_modules.parameters():
625
+ param.requires_grad = True
626
+
627
+ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks,
628
+ captured_layer_indices=[22, 23, 24], out_dtype=torch.float32):
629
+ captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore',
630
+ 'q', 'q2', 'k', 'v', 'attn_out') }
631
+
632
+ if not capture_ca_activations:
633
+ return captured_activations
634
+
635
+ all_cached_outfeats = []
636
+ for block in outfeat_capture_blocks:
637
+ all_cached_outfeats.append(block.cached_outfeats)
638
+ # Clear the capture flag and cached outfeats.
639
+ block.cached_outfeats = {}
640
+ block.capture_outfeats = False
641
+
642
+ for layer_idx in captured_layer_indices:
643
+ # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
644
+ # 23, 24 -> 1, 2 (!! not 0, 1 !!)
645
+ internal_idx = layer_idx - 22
646
+ for k in captured_activations.keys():
647
+ if k == 'outfeat':
648
+ # Currently we only capture one block, up_blocks.3. So we hard-code the index 0.
649
+ captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype)
650
+ else:
651
+ # internal_idx is the index of layers in up_blocks.3.
652
+ # Layers 22, 23 and 24 map to 0, 1 and 2.
653
+ cached_activations = attn_capture_procs[internal_idx].cached_activations
654
+ captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype)
655
+
656
+ return captured_activations
adaface/face_id_to_ada_prompt.py CHANGED
@@ -53,6 +53,8 @@ class FaceID2AdaPrompt(nn.Module):
53
  self.text_to_image_prompt_encoder = None
54
  self.tokenizer = None
55
  self.dtype = kwargs.get('dtype', torch.float16)
 
 
56
 
57
  # Load Img2Ada SubjectBasisGenerator.
58
  self.subject_string = kwargs.get('subject_string', 'z')
@@ -73,12 +75,16 @@ class FaceID2AdaPrompt(nn.Module):
73
 
74
  self.use_clip_embs = False
75
  self.do_contrast_clip_embs_on_bg_features = False
 
 
 
 
76
  # num_id_vecs is the output embeddings of the ID2ImgPrompt module.
77
  # If there's no static image suffix embeddings, then num_id_vecs is also
78
  # the number of ada embeddings returned by the subject basis generator.
79
  # num_id_vecs will be set in each derived class.
80
  self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
81
- print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings and {self.num_static_img_suffix_embs} fixed image embeddings as input.')
82
 
83
  self.id_img_prompt_max_length = 77
84
  self.face_id_dim = 512
@@ -87,36 +93,35 @@ class FaceID2AdaPrompt(nn.Module):
87
  self.clip_embedding_dim = 1024
88
  self.output_dim = 768
89
 
90
- def get_id2img_learnable_modules(self):
91
- raise NotImplementedError
92
-
93
- def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list):
94
- id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules()
95
- for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list):
96
- module.load_state_dict(state_dict)
97
- print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.')
98
-
99
- # init_subj_basis_generator() can only be called after the derived class is initialized,
100
- # when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
101
- def init_subj_basis_generator(self):
102
  self.subj_basis_generator = \
103
- SubjBasisGenerator(num_id_vecs = self.num_id_vecs,
 
104
  num_static_img_suffix_embs = self.num_static_img_suffix_embs,
105
  bg_image_embedding_dim = self.clip_embedding_dim,
106
  output_dim = self.output_dim,
107
  placeholder_is_bg = False,
108
- prompt2token_proj_grad_scale = 1,
109
  bg_prompt_translator_has_to_out_proj=False)
110
 
111
  def load_adaface_ckpt(self, adaface_ckpt_path):
112
- ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
 
 
 
113
  string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
114
  if self.subject_string not in string_to_subj_basis_generator_dict:
115
  print(f"Subject '{self.subject_string}' not found in the embedding manager.")
116
  breakpoint()
117
 
118
  ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
119
- ckpt_subj_basis_generator.N_ID = self.num_id_vecs
 
 
 
 
 
120
  # Since we directly use the subject basis generator object from the ckpt,
121
  # fixing the number of static image suffix embeddings is much simpler.
122
  # Otherwise if we want to load the subject basis generator from its state_dict,
@@ -129,7 +134,7 @@ class FaceID2AdaPrompt(nn.Module):
129
  ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
130
  # Fix missing variables in old ckpt.
131
  ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
132
-
133
  self.subj_basis_generator.extend_prompt2token_proj_attention(\
134
  ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
135
  ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
@@ -155,6 +160,11 @@ class FaceID2AdaPrompt(nn.Module):
155
 
156
  self.subj_basis_generator.freeze_prompt2token_proj()
157
 
 
 
 
 
 
158
  @torch.no_grad()
159
  def get_clip_neg_features(self, BS):
160
  if self.clip_neg_features is None:
@@ -220,6 +230,7 @@ class FaceID2AdaPrompt(nn.Module):
220
  image_obj, _, _ = pad_image_obj_to_square(image_obj)
221
  image_np = np.array(image_obj.resize(size, Image.NEAREST))
222
  face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
 
223
  if len(face_info) > 0:
224
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
225
  # id_emb: [512,]
@@ -487,12 +498,20 @@ class FaceID2AdaPrompt(nn.Module):
487
  # avg_at_stage == ada_prompt_emb usually produces the worst results.
488
  # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
489
  # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
 
490
  def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
491
  p_dropout=0,
492
  return_zero_embs_for_dropped_encoders=True,
493
  avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
494
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
495
- perturb_std=0, enable_static_img_suffix_embs=False):
 
 
 
 
 
 
 
496
  if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
497
  img_prompt_avg_at_stage = None
498
  else:
@@ -509,7 +528,7 @@ class FaceID2AdaPrompt(nn.Module):
509
  id_batch_size = len(image_paths)
510
  else:
511
  id_batch_size = 1
512
-
513
  # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
514
  # NOTE: If face_id_embs, image_paths and image_objs are all None,
515
  # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
@@ -532,7 +551,7 @@ class FaceID2AdaPrompt(nn.Module):
532
  verbose=True)
533
 
534
  if face_image_count == 0:
535
- return None
536
 
537
  # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
538
  elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
@@ -545,19 +564,27 @@ class FaceID2AdaPrompt(nn.Module):
545
  out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
546
  is_face=True,
547
  enable_static_img_suffix_embs=enable_static_img_suffix_embs)
 
 
 
 
548
  # During training, img_prompt_avg_at_stage is None, and BS >= 1.
549
  # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
550
  if img_prompt_avg_at_stage is not None:
551
  # adaface_subj_embs: [1, 16, 768] -> [16, 768]
552
  adaface_subj_embs = adaface_subj_embs.squeeze(0)
553
 
554
- return adaface_subj_embs
555
 
556
  class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
557
- def __init__(self, *args, **kwargs):
558
- self.name = 'arc2face'
559
- self.num_id_vecs = 16
 
 
 
560
 
 
561
  super().__init__(*args, **kwargs)
562
 
563
  self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
@@ -576,14 +603,11 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
576
  '''
577
  # Use the same model as ID2AdaPrompt does.
578
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
579
- # Note there's a second "model" in the path.
580
- # Note DON'T use CUDAExecutionProvider, as it will hang DDP training.
581
- # Seems when loading insightface onto the GPU, it will only reside on the first GPU.
582
- # Then the process on the second GPU has issue to communicate with insightface on the first GPU, causing hanging.
583
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
584
  providers=['CPUExecutionProvider'])
585
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
586
- print(f'Face encoder loaded on CPU.')
587
 
588
  self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
589
  'models/arc2face', subfolder="encoder",
@@ -594,21 +618,58 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
594
  if self.out_id_embs_cfg_scale == -1:
595
  self.out_id_embs_cfg_scale = 1
596
  #### Arc2Face pipeline specific configs ####
597
- self.gen_neg_img_prompt = False
598
  # bg CLIP features are used by the bg subject basis generator.
599
- self.use_clip_embs = True
600
  self.do_contrast_clip_embs_on_bg_features = True
601
  # self.num_static_img_suffix_embs is initialized in the parent class.
602
- self.id_img_prompt_max_length = 22
603
- self.clip_embedding_dim = 1024
604
 
605
- self.init_subj_basis_generator()
606
  if self.adaface_ckpt_path is not None:
607
  self.load_adaface_ckpt(self.adaface_ckpt_path)
608
 
609
- print(f"{self.name} ada prompt encoder initialized, "
610
- f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.")
 
 
 
 
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
613
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
614
  clip_features=None,
@@ -656,16 +717,17 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
656
  # [N, 22, 768] -> [N, 16, 768]
657
  return prompt_embeds[:, 4:20]
658
 
659
- def get_id2img_learnable_modules(self):
660
- return [ self.text_to_image_prompt_encoder ]
661
-
662
  # ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
663
  class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
 
 
 
 
 
 
664
  def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
665
  *args, **kwargs):
666
- self.name = 'consistentID'
667
- self.num_id_vecs = 4
668
-
669
  super().__init__(*args, **kwargs)
670
  if pipe is None:
671
  # The base_model_path is kind of arbitrary, as the UNet and VAE in the model
@@ -712,13 +774,51 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
712
  self.clip_embedding_dim = 1280
713
  self.s_scale = 1.0
714
  self.shortcut = False
715
-
716
- self.init_subj_basis_generator()
717
  if self.adaface_ckpt_path is not None:
718
  self.load_adaface_ckpt(self.adaface_ckpt_path)
719
 
 
 
 
 
 
 
 
720
  print(f"{self.name} ada prompt encoder initialized, "
721
- f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
724
  clip_features=None,
@@ -757,25 +857,30 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
757
 
758
  return global_id_embeds
759
 
760
- def get_id2img_learnable_modules(self):
761
- return [ self.image_proj_model ]
762
-
763
  # A wrapper for combining multiple FaceID2AdaPrompt instances.
764
  class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
765
  def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
766
  out_id_embs_cfg_scales=None, enabled_encoders=None,
767
  *args, **kwargs):
768
  self.name = 'jointIDs'
 
769
  assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
770
- adaface_encoder_types2num_id_vecs = { 'arc2face': 16, 'consistentID': 4 }
771
- self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
 
 
772
  for encoder_type in adaface_encoder_types ]
773
- self.num_id_vecs = sum(self.encoders_num_id_vecs)
 
 
 
 
774
  super().__init__(*args, **kwargs)
775
 
776
  self.num_sub_encoders = len(adaface_encoder_types)
777
  self.id2ada_prompt_encoders = nn.ModuleList()
778
  self.encoders_num_static_img_suffix_embs = []
 
779
 
780
  # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
781
  # Now they are just placeholders.
@@ -785,10 +890,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
785
  self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
786
  else:
787
  # Do not normalize the weights, and just use them as is.
788
- self.out_id_embs_cfg_scales = out_id_embs_cfg_scales
789
 
790
  # Note we don't pass the adaface_ckpt_paths to the base class, but instead,
791
  # we load them once and for all in self.load_adaface_ckpt().
 
 
792
  for i, encoder_type in enumerate(adaface_encoder_types):
793
  kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
794
  if encoder_type == 'arc2face':
@@ -797,8 +904,10 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
797
  encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
798
  else:
799
  breakpoint()
 
800
  self.id2ada_prompt_encoders.append(encoder)
801
  self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
 
802
 
803
  self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
804
  # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
@@ -814,6 +923,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
814
  # Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders.
815
  self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders]
816
  self.clip_embedding_dim = sum(self.clip_embedding_dims)
 
817
  # The ctors of the derived classes have already initialized encoder.subj_basis_generator.
818
  # If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders.
819
  # This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead,
@@ -821,12 +931,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
821
  self.subj_basis_generator = \
822
  nn.ModuleList( [encoder.subj_basis_generator for encoder \
823
  in self.id2ada_prompt_encoders] )
824
-
 
825
  if adaface_ckpt_paths is not None:
826
  self.load_adaface_ckpt(adaface_ckpt_paths)
827
-
828
  print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
829
- f"ID vecs: {self.num_id_vecs}, static suffix embs: {self.num_static_img_suffix_embs}.")
830
 
831
  if enabled_encoders is not None:
832
  self.are_encoders_enabled = \
@@ -842,66 +953,79 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
842
  else:
843
  self.are_encoders_enabled = \
844
  torch.tensor([True] * self.num_sub_encoders)
845
-
846
  def load_adaface_ckpt(self, adaface_ckpt_paths):
847
- # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
848
- # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
849
  if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
850
- if len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
 
 
 
 
 
 
 
 
851
  adaface_ckpt_paths = adaface_ckpt_paths[0]
852
-
853
- if isinstance(adaface_ckpt_paths, str):
854
- # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
855
- # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
856
- # Therefore, no need to patch missing variables.
857
- ckpt = torch.load(adaface_ckpt_paths, map_location='cpu')
858
- string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
859
- if self.subject_string not in string_to_subj_basis_generator_dict:
860
- print(f"Subject '{self.subject_string}' not found in the embedding manager.")
861
  breakpoint()
862
 
863
- ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
864
- for i, subj_basis_generator in enumerate(self.subj_basis_generator):
865
- ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
866
- # Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
867
- ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
868
- img_prompt_dim=self.output_dim)
869
-
870
- if subj_basis_generator.prompt2token_proj_attention_multipliers \
871
- == [1] * 12:
872
- subj_basis_generator.extend_prompt2token_proj_attention(\
873
- ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
874
- elif subj_basis_generator.prompt2token_proj_attention_multipliers \
875
- != ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
876
- raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
877
-
878
- assert subj_basis_generator.prompt2token_proj_attention_multipliers \
879
- == ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
880
- "Inconsistent prompt2token_proj_attention_multipliers."
881
- subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
882
-
883
- # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
884
- # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
885
- # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
886
- # extend subj_basis_generator again.
887
- if self.extend_prompt2token_proj_attention_multiplier > 1:
888
- # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
889
- # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
890
- # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
891
- subj_basis_generator.extend_prompt2token_proj_attention(\
892
- None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
893
- perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
894
-
895
- subj_basis_generator.freeze_prompt2token_proj()
896
-
897
- print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
898
-
899
- elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
900
- for i, ckpt_path in enumerate(adaface_ckpt_paths):
901
- self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
902
- else:
903
  breakpoint()
904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905
  def extract_init_id_embeds_from_images(self, *args, **kwargs):
906
  total_faceless_img_count = 0
907
  all_id_embs = []
@@ -1039,7 +1163,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1039
 
1040
  N_ID = self.encoders_num_id_vecs[i]
1041
  if all_pos_prompt_embs[i] is None:
1042
- # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs embeddings.
1043
  all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
1044
  if all_neg_prompt_embs[i] is None:
1045
  all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
@@ -1061,6 +1185,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1061
  # So its .device is the device of its parameters.
1062
  device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
1063
  is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
 
 
 
 
 
 
 
1064
  BS = -1
1065
 
1066
  if face_id_embs is not None:
@@ -1068,13 +1199,17 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1068
  all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
1069
  else:
1070
  all_face_id_embs = [None] * self.num_sub_encoders
 
1071
  if img_prompt_embs is not None:
1072
  BS = img_prompt_embs.shape[0] if BS == -1 else BS
1073
- if img_prompt_embs.shape[1] != self.num_id_vecs:
1074
  breakpoint()
1075
- all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs, dim=1)
 
1076
  else:
1077
  all_img_prompt_embs = [None] * self.num_sub_encoders
 
 
1078
  if image_paths is not None:
1079
  BS = len(image_paths) if BS == -1 else BS
1080
  if BS == -1:
@@ -1097,25 +1232,32 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1097
  else:
1098
  are_encoders_enabled = self.are_encoders_enabled
1099
 
 
1100
  all_adaface_subj_embs = []
1101
  num_available_id_vecs = 0
 
1102
 
1103
  for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
1104
  if not are_encoders_enabled[i]:
1105
  adaface_subj_embs = None
1106
- print(f"Encoder {id2ada_prompt_encoder.name} is dropped.")
 
 
1107
  else:
 
1108
  # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
1109
  # -> each sub-enconder's subj_basis_generator.train().
1110
  # Therefore grad for the following call is enabled.
1111
- adaface_subj_embs = \
1112
  id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
1113
  all_face_id_embs[i],
1114
  all_img_prompt_embs[i],
1115
  *args, **kwargs)
1116
 
1117
- # adaface_subj_embs: [16, 768] or [4, 768].
1118
- N_ID = self.encoders_num_id_vecs[i]
 
 
1119
  if adaface_subj_embs is None:
1120
  if not return_zero_embs_for_dropped_encoders:
1121
  continue
@@ -1126,12 +1268,16 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1126
  all_adaface_subj_embs.append(adaface_subj_embs)
1127
  else:
1128
  all_adaface_subj_embs.append(adaface_subj_embs)
 
 
1129
  num_available_id_vecs += N_ID
1130
-
 
 
1131
  # No faces are found in the images, so return None embeddings.
1132
  # We don't want to return an all-zero embedding, which is useless.
1133
  if num_available_id_vecs == 0:
1134
- return None
1135
 
1136
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1137
  # during inference, we average across the batch dim.
@@ -1141,7 +1287,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1141
  # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
1142
  # all_adaface_subj_embs: [BS, 20, 768].
1143
  all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
1144
- return all_adaface_subj_embs
 
 
 
 
 
1145
 
1146
 
1147
  '''
 
53
  self.text_to_image_prompt_encoder = None
54
  self.tokenizer = None
55
  self.dtype = kwargs.get('dtype', torch.float16)
56
+ self.img2txt_dtype = kwargs.get('img2txt_dtype', torch.float16)
57
+ self.device = torch.device("cpu")
58
 
59
  # Load Img2Ada SubjectBasisGenerator.
60
  self.subject_string = kwargs.get('subject_string', 'z')
 
75
 
76
  self.use_clip_embs = False
77
  self.do_contrast_clip_embs_on_bg_features = False
78
+ # Override the default setting in derived classes.
79
+ if 'enable_static_img_suffix_embs' in kwargs:
80
+ self.default_enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
81
+
82
  # num_id_vecs is the output embeddings of the ID2ImgPrompt module.
83
  # If there's no static image suffix embeddings, then num_id_vecs is also
84
  # the number of ada embeddings returned by the subject basis generator.
85
  # num_id_vecs will be set in each derived class.
86
  self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
87
+ print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings + {self.num_static_img_suffix_embs} fixed image embeddings as input.')
88
 
89
  self.id_img_prompt_max_length = 77
90
  self.face_id_dim = 512
 
93
  self.clip_embedding_dim = 1024
94
  self.output_dim = 768
95
 
96
+ # init_img2txt_projection() can only be called after the derived class is initialized,
97
+ # when self.num_id_vecs0, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
98
+ def init_img2txt_projection(self):
 
 
 
 
 
 
 
 
 
99
  self.subj_basis_generator = \
100
+ SubjBasisGenerator(dtype=self.img2txt_dtype,
101
+ num_id_vecs = self.num_id_vecs0,
102
  num_static_img_suffix_embs = self.num_static_img_suffix_embs,
103
  bg_image_embedding_dim = self.clip_embedding_dim,
104
  output_dim = self.output_dim,
105
  placeholder_is_bg = False,
 
106
  bg_prompt_translator_has_to_out_proj=False)
107
 
108
  def load_adaface_ckpt(self, adaface_ckpt_path):
109
+ if isinstance(adaface_ckpt_path, (list, tuple, ListConfig)):
110
+ adaface_ckpt_path = adaface_ckpt_path[0]
111
+
112
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
113
  string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
114
  if self.subject_string not in string_to_subj_basis_generator_dict:
115
  print(f"Subject '{self.subject_string}' not found in the embedding manager.")
116
  breakpoint()
117
 
118
  ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
119
+ if isinstance(ckpt_subj_basis_generator, nn.ModuleList):
120
+ name2idx = { 'consistentID': 0, 'arc2face': 1 }
121
+ subj_basis_generator_idx = name2idx[self.name]
122
+ ckpt_subj_basis_generator = ckpt_subj_basis_generator[subj_basis_generator_idx]
123
+
124
+ ckpt_subj_basis_generator.N_ID = self.num_id_vecs0
125
  # Since we directly use the subject basis generator object from the ckpt,
126
  # fixing the number of static image suffix embeddings is much simpler.
127
  # Otherwise if we want to load the subject basis generator from its state_dict,
 
134
  ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
135
  # Fix missing variables in old ckpt.
136
  ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
137
+
138
  self.subj_basis_generator.extend_prompt2token_proj_attention(\
139
  ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
140
  ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
 
160
 
161
  self.subj_basis_generator.freeze_prompt2token_proj()
162
 
163
+ def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scale):
164
+ if isinstance(out_id_embs_cfg_scale, (list, tuple, ListConfig)):
165
+ out_id_embs_cfg_scale = out_id_embs_cfg_scale[0]
166
+ self.out_id_embs_cfg_scale = out_id_embs_cfg_scale
167
+
168
  @torch.no_grad()
169
  def get_clip_neg_features(self, BS):
170
  if self.clip_neg_features is None:
 
230
  image_obj, _, _ = pad_image_obj_to_square(image_obj)
231
  image_np = np.array(image_obj.resize(size, Image.NEAREST))
232
  face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
233
+
234
  if len(face_info) > 0:
235
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
236
  # id_emb: [512,]
 
498
  # avg_at_stage == ada_prompt_emb usually produces the worst results.
499
  # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
500
  # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
501
+ # enable_static_img_suffix_embs=None: use the default setting.
502
  def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
503
  p_dropout=0,
504
  return_zero_embs_for_dropped_encoders=True,
505
  avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
506
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
507
+ perturb_std=0, enable_static_img_suffix_embs=None):
508
+
509
+ if enable_static_img_suffix_embs is None:
510
+ enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
511
+
512
+ lens_subj_emb_segments = [ self.num_id_vecs + enable_static_img_suffix_embs \
513
+ * self.num_static_img_suffix_embs ]
514
+
515
  if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
516
  img_prompt_avg_at_stage = None
517
  else:
 
528
  id_batch_size = len(image_paths)
529
  else:
530
  id_batch_size = 1
531
+
532
  # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
533
  # NOTE: If face_id_embs, image_paths and image_objs are all None,
534
  # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
 
551
  verbose=True)
552
 
553
  if face_image_count == 0:
554
+ return None, None, lens_subj_emb_segments
555
 
556
  # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
557
  elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
 
564
  out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
565
  is_face=True,
566
  enable_static_img_suffix_embs=enable_static_img_suffix_embs)
567
+
568
+ if self.num_id_vecs < self.num_id_vecs0:
569
+ adaface_subj_embs = adaface_subj_embs[:, :self.num_id_vecs, :]
570
+
571
  # During training, img_prompt_avg_at_stage is None, and BS >= 1.
572
  # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
573
  if img_prompt_avg_at_stage is not None:
574
  # adaface_subj_embs: [1, 16, 768] -> [16, 768]
575
  adaface_subj_embs = adaface_subj_embs.squeeze(0)
576
 
577
+ return adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments
578
 
579
  class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
580
+ name = 'arc2face'
581
+ num_id_vecs0 = 16
582
+ # first 4 are kept, the rest 12 are averaged to another 4.
583
+ # Then concatenated to [8, 768].
584
+ num_id_vecs = 16
585
+ default_enable_static_img_suffix_embs = False
586
 
587
+ def __init__(self, *args, **kwargs):
588
  super().__init__(*args, **kwargs)
589
 
590
  self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
 
603
  '''
604
  # Use the same model as ID2AdaPrompt does.
605
  # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
606
+ # Note there are two "models" in the path.
 
 
 
607
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
608
  providers=['CPUExecutionProvider'])
609
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
610
+ print(f'Arc2Face Face encoder loaded on CPU.')
611
 
612
  self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
613
  'models/arc2face', subfolder="encoder",
 
618
  if self.out_id_embs_cfg_scale == -1:
619
  self.out_id_embs_cfg_scale = 1
620
  #### Arc2Face pipeline specific configs ####
621
+ self.gen_neg_img_prompt = False
622
  # bg CLIP features are used by the bg subject basis generator.
623
+ self.use_clip_embs = True
624
  self.do_contrast_clip_embs_on_bg_features = True
625
  # self.num_static_img_suffix_embs is initialized in the parent class.
626
+ self.id_img_prompt_max_length = 22
627
+ self.clip_embedding_dim = 1024
628
 
629
+ self.init_img2txt_projection()
630
  if self.adaface_ckpt_path is not None:
631
  self.load_adaface_ckpt(self.adaface_ckpt_path)
632
 
633
+ for param in self.clip_image_encoder.parameters():
634
+ param.requires_grad = False
635
+ for param in self.text_to_image_prompt_encoder.parameters():
636
+ param.requires_grad = False
637
+ for param in self.subj_basis_generator.parameters():
638
+ param.requires_grad = self.is_training
639
 
640
+ print(f"{self.name} ada prompt encoder initialized, "
641
+ f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
642
+
643
+ def _apply(self, fn):
644
+ super()._apply(fn) # Call the parent _apply to handle parameters and buffers
645
+ return
646
+ # A dirty hack to get the device of the model, passed from
647
+ # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
648
+ test_tensor = torch.zeros(1) # Create a test tensor
649
+ transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
650
+ device = transformed_tensor.device # Get the device of the transformed tensor
651
+ # No need to reload face_app on the same device.
652
+ if device == self.device:
653
+ return
654
+
655
+ if str(device) == 'cpu':
656
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
657
+ providers=['CPUExecutionProvider'])
658
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
659
+ else:
660
+ device_id = device.index
661
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
662
+ providers=['CUDAExecutionProvider'],
663
+ provider_options=[{"device_id": device_id,
664
+ "cudnn_conv_algo_search": "HEURISTIC",
665
+ "gpu_mem_limit": 2 * 1024**3
666
+ }])
667
+ self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
668
+
669
+ self.device = device
670
+ print(f'Arc2Face Face encoder reloaded on {device}.')
671
+ return
672
+
673
  # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
674
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
675
  clip_features=None,
 
717
  # [N, 22, 768] -> [N, 16, 768]
718
  return prompt_embeds[:, 4:20]
719
 
 
 
 
720
  # ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
721
  class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
722
+ name = 'consistentID'
723
+ num_id_vecs0 = 4
724
+ # No compression for ConsistentID.
725
+ num_id_vecs = 4
726
+ default_enable_static_img_suffix_embs = False
727
+
728
  def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
729
  *args, **kwargs):
730
+
 
 
731
  super().__init__(*args, **kwargs)
732
  if pipe is None:
733
  # The base_model_path is kind of arbitrary, as the UNet and VAE in the model
 
774
  self.clip_embedding_dim = 1280
775
  self.s_scale = 1.0
776
  self.shortcut = False
777
+
778
+ self.init_img2txt_projection()
779
  if self.adaface_ckpt_path is not None:
780
  self.load_adaface_ckpt(self.adaface_ckpt_path)
781
 
782
+ for param in self.clip_image_encoder.parameters():
783
+ param.requires_grad = False
784
+ for param in self.image_proj_model.parameters():
785
+ param.requires_grad = False
786
+ for param in self.subj_basis_generator.parameters():
787
+ param.requires_grad = self.is_training
788
+
789
  print(f"{self.name} ada prompt encoder initialized, "
790
+ f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
791
+
792
+ def _apply(self, fn):
793
+ super()._apply(fn) # Call the parent _apply to handle parameters and buffers
794
+ return
795
+ # A dirty hack to get the device of the model, passed from
796
+ # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
797
+ test_tensor = torch.zeros(1) # Create a test tensor
798
+ transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
799
+ device = transformed_tensor.device # Get the device of the transformed tensor
800
+ # No need to reload face_app on the same device.
801
+ if device == self.device:
802
+ return
803
+
804
+ if str(device) == 'cpu':
805
+ self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
806
+ providers=['CPUExecutionProvider'])
807
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
808
+ else:
809
+ device_id = device.index
810
+ self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
811
+ providers=['CUDAExecutionProvider'],
812
+ provider_options=[{"device_id": device_id,
813
+ "cudnn_conv_algo_search": "HEURISTIC",
814
+ "gpu_mem_limit": 2 * 1024**3
815
+ }])
816
+ self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
817
+
818
+ self.device = device
819
+ self.pipe.face_app = self.face_app
820
+ print(f'ConsistentID Face encoder reloaded on {device}.')
821
+
822
 
823
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
824
  clip_features=None,
 
857
 
858
  return global_id_embeds
859
 
 
 
 
860
  # A wrapper for combining multiple FaceID2AdaPrompt instances.
861
  class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
862
  def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
863
  out_id_embs_cfg_scales=None, enabled_encoders=None,
864
  *args, **kwargs):
865
  self.name = 'jointIDs'
866
+ name2class = { 'arc2face': Arc2Face_ID2AdaPrompt, 'consistentID': ConsistentID_ID2AdaPrompt }
867
  assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
868
+ adaface_encoder_types2num_id_vecs0 = { name: name2class[name].num_id_vecs0 for name in adaface_encoder_types }
869
+ adaface_encoder_types2num_id_vecs = { name: name2class[name].num_id_vecs for name in adaface_encoder_types }
870
+ # self.num_id_vecs0 is used in the parent class. So we need to initialize it here first.
871
+ self.encoders_num_id_vecs0 = [ adaface_encoder_types2num_id_vecs0[encoder_type] \
872
  for encoder_type in adaface_encoder_types ]
873
+ self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
874
+ for encoder_type in adaface_encoder_types ]
875
+ self.num_id_vecs0 = sum(self.encoders_num_id_vecs0)
876
+ self.num_id_vecs = sum(self.encoders_num_id_vecs)
877
+ # super() sets self.is_training.
878
  super().__init__(*args, **kwargs)
879
 
880
  self.num_sub_encoders = len(adaface_encoder_types)
881
  self.id2ada_prompt_encoders = nn.ModuleList()
882
  self.encoders_num_static_img_suffix_embs = []
883
+ self.default_enable_static_img_suffix_embs = []
884
 
885
  # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
886
  # Now they are just placeholders.
 
890
  self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
891
  else:
892
  # Do not normalize the weights, and just use them as is.
893
+ self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
894
 
895
  # Note we don't pass the adaface_ckpt_paths to the base class, but instead,
896
  # we load them once and for all in self.load_adaface_ckpt().
897
+ # NOTE: during inference, num_static_img_suffix_embs is fixed to be 4 for each encoder.
898
+ # But we can still disable static_img_suffix_embs by setting enable_static_img_suffix_embs to False.
899
  for i, encoder_type in enumerate(adaface_encoder_types):
900
  kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
901
  if encoder_type == 'arc2face':
 
904
  encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
905
  else:
906
  breakpoint()
907
+
908
  self.id2ada_prompt_encoders.append(encoder)
909
  self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
910
+ self.default_enable_static_img_suffix_embs.append(encoder.default_enable_static_img_suffix_embs)
911
 
912
  self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
913
  # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
 
923
  # Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders.
924
  self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders]
925
  self.clip_embedding_dim = sum(self.clip_embedding_dims)
926
+
927
  # The ctors of the derived classes have already initialized encoder.subj_basis_generator.
928
  # If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders.
929
  # This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead,
 
931
  self.subj_basis_generator = \
932
  nn.ModuleList( [encoder.subj_basis_generator for encoder \
933
  in self.id2ada_prompt_encoders] )
934
+
935
+ # load_adaface_ckpt() loads into self.subj_basis_generator. So we need to initialize self.subj_basis_generator first.
936
  if adaface_ckpt_paths is not None:
937
  self.load_adaface_ckpt(adaface_ckpt_paths)
938
+
939
  print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
940
+ f"ID vecs: {self.num_id_vecs0}, static suffix embs: {self.num_static_img_suffix_embs}.")
941
 
942
  if enabled_encoders is not None:
943
  self.are_encoders_enabled = \
 
953
  else:
954
  self.are_encoders_enabled = \
955
  torch.tensor([True] * self.num_sub_encoders)
956
+
957
  def load_adaface_ckpt(self, adaface_ckpt_paths):
 
 
958
  if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
959
+ # If multiple adaface ckpt paths are provided, then we assume they are the
960
+ # ckpts of the sub-encoders.
961
+ if len(adaface_ckpt_paths) == self.num_sub_encoders:
962
+ for i, ckpt_path in enumerate(adaface_ckpt_paths):
963
+ self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
964
+ return
965
+ # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
966
+ # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
967
+ elif len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
968
  adaface_ckpt_paths = adaface_ckpt_paths[0]
969
+ else:
 
 
 
 
 
 
 
 
970
  breakpoint()
971
 
972
+ adaface_ckpt_path = adaface_ckpt_paths
973
+ assert isinstance(adaface_ckpt_path, str), "adaface_ckpt_path should be a string."
974
+ # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
975
+ # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
976
+ # Therefore, no need to patch missing variables.
977
+ ckpt = torch.load(adaface_ckpt_paths, map_location='cpu', weights_only=False)
978
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
979
+ if self.subject_string not in string_to_subj_basis_generator_dict:
980
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
981
+ breakpoint()
982
+
983
+ ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
984
+ if len(ckpt_subj_basis_generators) != self.num_sub_encoders:
985
+ print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) "
986
+ f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
  breakpoint()
988
 
989
+ for i, subj_basis_generator in enumerate(self.subj_basis_generator):
990
+ ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
991
+ # Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
992
+ ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
993
+ img_prompt_dim=self.output_dim)
994
+
995
+ if subj_basis_generator.prompt2token_proj_attention_multipliers \
996
+ == [1] * 12:
997
+ subj_basis_generator.extend_prompt2token_proj_attention(\
998
+ ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
999
+ elif subj_basis_generator.prompt2token_proj_attention_multipliers \
1000
+ != ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
1001
+ raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
1002
+
1003
+ assert subj_basis_generator.prompt2token_proj_attention_multipliers \
1004
+ == ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
1005
+ "Inconsistent prompt2token_proj_attention_multipliers."
1006
+ subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
1007
+
1008
+ # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
1009
+ # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
1010
+ # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
1011
+ # extend subj_basis_generator again.
1012
+ if self.extend_prompt2token_proj_attention_multiplier > 1:
1013
+ # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
1014
+ # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
1015
+ # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
1016
+ subj_basis_generator.extend_prompt2token_proj_attention(\
1017
+ None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
1018
+ perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
1019
+
1020
+ subj_basis_generator.freeze_prompt2token_proj()
1021
+
1022
+ print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
1023
+
1024
+ def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scales):
1025
+ self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
1026
+ for i, out_id_embs_cfg_scale in enumerate(out_id_embs_cfg_scales):
1027
+ self.id2ada_prompt_encoders[i].set_out_id_embs_cfg_scale(out_id_embs_cfg_scale)
1028
+
1029
  def extract_init_id_embeds_from_images(self, *args, **kwargs):
1030
  total_faceless_img_count = 0
1031
  all_id_embs = []
 
1163
 
1164
  N_ID = self.encoders_num_id_vecs[i]
1165
  if all_pos_prompt_embs[i] is None:
1166
+ # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs0 embeddings.
1167
  all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
1168
  if all_neg_prompt_embs[i] is None:
1169
  all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
 
1185
  # So its .device is the device of its parameters.
1186
  device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
1187
  is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
1188
+ if kwargs.get('enable_static_img_suffix_embs', None) is None:
1189
+ enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
1190
+ else:
1191
+ enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
1192
+ if isinstance(enable_static_img_suffix_embs, bool):
1193
+ enable_static_img_suffix_embs = [enable_static_img_suffix_embs] * self.num_sub_encoders
1194
+
1195
  BS = -1
1196
 
1197
  if face_id_embs is not None:
 
1199
  all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
1200
  else:
1201
  all_face_id_embs = [None] * self.num_sub_encoders
1202
+
1203
  if img_prompt_embs is not None:
1204
  BS = img_prompt_embs.shape[0] if BS == -1 else BS
1205
+ if img_prompt_embs.shape[1] != self.num_id_vecs0:
1206
  breakpoint()
1207
+ all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs0, dim=1)
1208
+ img_prompt_embs_provided = True
1209
  else:
1210
  all_img_prompt_embs = [None] * self.num_sub_encoders
1211
+ img_prompt_embs_provided = False
1212
+
1213
  if image_paths is not None:
1214
  BS = len(image_paths) if BS == -1 else BS
1215
  if BS == -1:
 
1232
  else:
1233
  are_encoders_enabled = self.are_encoders_enabled
1234
 
1235
+ self.curr_are_encoders_enabled = are_encoders_enabled
1236
  all_adaface_subj_embs = []
1237
  num_available_id_vecs = 0
1238
+ lens_subj_emb_segments = []
1239
 
1240
  for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
1241
  if not are_encoders_enabled[i]:
1242
  adaface_subj_embs = None
1243
+ print(f"Encoder {id2ada_prompt_encoder.name} is disabled.")
1244
+ N_ID = id2ada_prompt_encoder.num_id_vecs + enable_static_img_suffix_embs[i] \
1245
+ * id2ada_prompt_encoder.num_static_img_suffix_embs
1246
  else:
1247
+ kwargs['enable_static_img_suffix_embs'] = enable_static_img_suffix_embs[i]
1248
  # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
1249
  # -> each sub-enconder's subj_basis_generator.train().
1250
  # Therefore grad for the following call is enabled.
1251
+ adaface_subj_embs, img_prompt_embs, encoder_lens_subj_emb_segments = \
1252
  id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
1253
  all_face_id_embs[i],
1254
  all_img_prompt_embs[i],
1255
  *args, **kwargs)
1256
 
1257
+ # adaface_subj_embs: arc2face [16, 768] or consistentID [4, 768],
1258
+ # or arc2face [20, 768] or consistentID [8, 768] if enable_static_img_suffix_embs=True.
1259
+ N_ID = encoder_lens_subj_emb_segments[0]
1260
+
1261
  if adaface_subj_embs is None:
1262
  if not return_zero_embs_for_dropped_encoders:
1263
  continue
 
1268
  all_adaface_subj_embs.append(adaface_subj_embs)
1269
  else:
1270
  all_adaface_subj_embs.append(adaface_subj_embs)
1271
+ if not img_prompt_embs_provided:
1272
+ all_img_prompt_embs[i] = img_prompt_embs
1273
  num_available_id_vecs += N_ID
1274
+
1275
+ lens_subj_emb_segments.append(N_ID)
1276
+
1277
  # No faces are found in the images, so return None embeddings.
1278
  # We don't want to return an all-zero embedding, which is useless.
1279
  if num_available_id_vecs == 0:
1280
+ return None, [0]
1281
 
1282
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1283
  # during inference, we average across the batch dim.
 
1287
  # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
1288
  # all_adaface_subj_embs: [BS, 20, 768].
1289
  all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
1290
+ # Check if some of the img_prompt_embs are None.
1291
+ if None in all_img_prompt_embs:
1292
+ all_img_prompt_embs = None
1293
+ else:
1294
+ all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=-2)
1295
+ return all_adaface_subj_embs, all_img_prompt_embs, lens_subj_emb_segments
1296
 
1297
 
1298
  '''
adaface/subj_basis_generator.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  from torch import nn
10
  from einops import rearrange
11
  from einops.layers.torch import Rearrange
12
- from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
13
 
14
  from torch import einsum
15
  from adaface.util import gen_gradient_scaler
@@ -57,7 +57,25 @@ class IP_MLPProjModel(nn.Module):
57
  x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
58
  x = self.norm(x)
59
  return x
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # group_dim: the tensor dimension that corresponds to the multiple groups.
62
  class LearnedSoftAggregate(nn.Module):
63
  def __init__(self, num_feat, group_dim, keepdim=False):
@@ -349,23 +367,26 @@ class CrossAttention(nn.Module):
349
  else:
350
  return out
351
 
 
352
  class ImgPrompt2TextPrompt(nn.Module):
353
- def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs):
 
354
  super().__init__()
355
  self.N_ID = num_id_vecs
356
  # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
357
  self.N_SFX = 0
 
358
 
359
  if not placeholder_is_bg:
360
- self.initialize_text_components(*args, **kwargs)
 
361
 
362
  # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
363
  # prompt2token_proj is with the same architecture as the original arc2face text encoder,
364
  # but retrained to do inverse mapping.
365
  # To be initialized in the subclass.
366
  self.prompt2token_proj = None
367
- self.dtype = dtype
368
-
369
  def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
370
  self.N_SFX = num_static_img_suffix_embs
371
  # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
@@ -376,11 +397,11 @@ class ImgPrompt2TextPrompt(nn.Module):
376
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
377
  elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
378
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
379
- self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
380
  elif self.N_SFX > 0:
381
  # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
382
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
383
- self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX])
384
  else:
385
  # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
386
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
@@ -391,7 +412,7 @@ class ImgPrompt2TextPrompt(nn.Module):
391
  # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
392
  # so we don't consider to reuse and extend a shorter static_img_suffix_embs).
393
  # So we reinitialize it.
394
- self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
395
  else:
396
  # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
397
  self.static_img_suffix_embs = None
@@ -399,9 +420,7 @@ class ImgPrompt2TextPrompt(nn.Module):
399
  # Implement a separate initialization function, so that it can be called from SubjBasisGenerator
400
  # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
401
  # ckpts which were not subclassed from ImgPrompt2TextPrompt.
402
- def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16,
403
- num_static_img_suffix_embs=0, img_prompt_dim=768):
404
- self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
405
  self.max_prompt_length = max_prompt_length
406
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
407
  # clip_text_embeddings: CLIPTextEmbeddings instance.
@@ -416,7 +435,7 @@ class ImgPrompt2TextPrompt(nn.Module):
416
  # pad_embeddings is still on CPU. But should be moved to GPU automatically.
417
  # Note: detach pad_embeddings from the computation graph, otherwise
418
  # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
419
- self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach()
420
 
421
  # image prompt space -> text prompt space.
422
  # return_emb_types: a list of strings, each string is among
@@ -439,7 +458,7 @@ class ImgPrompt2TextPrompt(nn.Module):
439
  else:
440
  breakpoint()
441
  else:
442
- # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation.
443
  # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
444
  list_extra_words = list_extra_words[:1]
445
 
@@ -466,7 +485,7 @@ class ImgPrompt2TextPrompt(nn.Module):
466
  face_prompt_embs_orig_dtype = face_prompt_embs.dtype
467
  face_prompt_embs = face_prompt_embs.to(self.dtype)
468
 
469
- ID_END = 4 + self.N_ID
470
  PAD_BEGIN = ID_END + self.N_SFX + 2
471
 
472
  # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
@@ -545,6 +564,7 @@ class ImgPrompt2TextPrompt(nn.Module):
545
  class SubjBasisGenerator(ImgPrompt2TextPrompt):
546
  def __init__(
547
  self,
 
548
  # number of cross-attention heads of the bg prompt translator.
549
  # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
550
  # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
@@ -553,22 +573,25 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
553
  # or number of background input identity vectors (no matter the subject is face or not).
554
  # 257: 257 CLIP tokens.
555
  num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
 
556
  num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
557
  num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
558
  bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
559
  obj_embedding_dim=384, # DINO object feature dimension for objects.
560
  output_dim=768, # CLIP text embedding input dimension.
 
561
  placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
562
- prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
563
  learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
564
- bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
565
  ):
566
 
567
  # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
568
- super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77,
569
- num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim)
 
570
 
571
  self.placeholder_is_bg = placeholder_is_bg
 
572
  self.num_out_embs = self.N_ID + self.N_SFX
573
  self.output_dim = output_dim
574
  # num_nonface_in_id_vecs should be the number of core ID embs, 16.
@@ -586,14 +609,18 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
586
  # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
587
  # If self.placeholder_is_bg: prompt2token_proj is set to None.
588
  # Use an attention dropout of 0.2 to increase robustness.
589
- clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05)
590
- self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14',
591
- config=clip_dropout_config)
592
- self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
593
- self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
594
- print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
595
- # If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj.
596
- # Otherwise, only freeze token and positional embeddings of the original CLIPTextModel.
 
 
 
 
597
  self.freeze_prompt2token_proj()
598
 
599
  # These multipliers are relative to the original CLIPTextModel.
@@ -631,6 +658,9 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
631
  identity_to_out=identity_to_out,
632
  out_has_skip=out_has_skip)
633
 
 
 
 
634
  self.output_scale = output_dim ** -0.5
635
 
636
  '''
@@ -686,21 +716,20 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
686
  hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
687
 
688
  # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
689
- with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
690
- # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
691
- # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
692
- # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
693
- # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
694
- # ada_id_embs: [BS, 16, 768].
695
- # return_emb_types: a list of strings, each string is among
696
- # ['full', 'core', 'full_pad', 'full_half_pad'].
697
- ada_id_embs, = \
698
- self.inverse_img_prompt_embs(faceid2img_prompt_embs,
699
- list_extra_words=None,
700
- return_emb_types=['core'],
701
- hidden_state_layer_weights=hidden_state_layer_weights,
702
- enable_static_img_suffix_embs=enable_static_img_suffix_embs)
703
- ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs)
704
  elif raw_id_embs is not None:
705
  # id_embs: [BS, 384] -> [BS, 18, 768].
706
  # obj_proj_in is expected to project the DINO object features to
@@ -726,14 +755,15 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
726
 
727
  adaface_out_embs = id_embs_out * self.output_scale # * 0.036
728
  else:
729
- adaface_out_embs = ada_id_embs
 
730
  # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
731
  if out_id_embs_cfg_scale != 1:
732
- # pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768].
733
  # NOTE: Never do cfg on static image suffix embeddings.
734
  # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
735
  # even if enable_static_img_suffix_embs=True.
736
- pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device)
737
  adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
738
  + pad_embeddings * (1 - out_id_embs_cfg_scale)
739
 
@@ -812,37 +842,37 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
812
  # Only applicable to fg basis generator.
813
  if self.placeholder_is_bg:
814
  return
815
- # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
816
- # Then we don't have to check whether it's for subj or bg.
817
- if self.prompt2token_proj_grad_scale == 0:
818
- frozen_components_name = 'all'
819
- frozen_param_set = self.prompt2token_proj.named_parameters()
820
- else:
821
- frozen_components_name = 'token_pos_embeddings'
822
- frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters()
823
-
824
  if self.prompt2token_proj is not None:
825
  frozen_param_names = []
826
- for param_name, param in frozen_param_set:
827
  if param.requires_grad:
828
  param.requires_grad = False
829
  frozen_param_names.append(param_name)
830
  # If param is already frozen, then no need to freeze it again.
831
- print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
832
  #print(f"Frozen parameters:\n{frozen_param_names}")
833
 
834
  def patch_old_subj_basis_generator_ckpt(self):
835
  # Fix compatability with the previous version.
836
  if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
837
  self.bg_prompt_translator_has_to_out_proj = False
838
- if not hasattr(self, 'num_out_embs'):
839
- self.num_out_embs = -1
840
  if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
841
  self.N_ID = self.num_id_vecs
 
 
 
842
  if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
843
  self.num_nonface_in_id_vecs = self.N_ID
844
  if not hasattr(self, 'dtype'):
845
- self.dtype = torch.float32
 
 
 
 
 
 
 
846
 
847
  if self.placeholder_is_bg:
848
  if not hasattr(self, 'pos_embs') or self.pos_embs is None:
@@ -860,6 +890,14 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
860
  num_static_img_suffix_embs=self.N_SFX,
861
  img_prompt_dim=self.output_dim)
862
 
 
 
 
 
 
 
 
 
863
  def __repr__(self):
864
  type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
865
 
 
9
  from torch import nn
10
  from einops import rearrange
11
  from einops.layers.torch import Rearrange
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
 
14
  from torch import einsum
15
  from adaface.util import gen_gradient_scaler
 
57
  x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
58
  x = self.norm(x)
59
  return x
60
+
61
+ class LayerwiseMLPProjWithSkip(nn.Module):
62
+ def __init__(self, id_embeddings_dim=768, num_layers=16, dim_mult=2):
63
+ super().__init__()
64
+
65
+ self.proj = nn.Sequential(
66
+ nn.Linear(id_embeddings_dim, id_embeddings_dim*dim_mult*num_layers),
67
+ Rearrange('b n (l d) -> b n l d', l=num_layers, d=id_embeddings_dim*dim_mult),
68
+ nn.GELU(),
69
+ nn.Linear(id_embeddings_dim*dim_mult, id_embeddings_dim),
70
+ )
71
+ self.norm = nn.LayerNorm(id_embeddings_dim)
72
+
73
+ def forward(self, id_embeds):
74
+ # B N D -> B N L D + B N L D -> B N L D
75
+ x = self.proj(id_embeds) + id_embeds.unsqueeze(1)
76
+ x = self.norm(x)
77
+ return x
78
+
79
  # group_dim: the tensor dimension that corresponds to the multiple groups.
80
  class LearnedSoftAggregate(nn.Module):
81
  def __init__(self, num_feat, group_dim, keepdim=False):
 
367
  else:
368
  return out
369
 
370
+
371
  class ImgPrompt2TextPrompt(nn.Module):
372
+ def __init__(self, placeholder_is_bg, num_id_vecs, num_static_img_suffix_embs,
373
+ max_prompt_length=77, img_prompt_dim=768, dtype=torch.float16):
374
  super().__init__()
375
  self.N_ID = num_id_vecs
376
  # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
377
  self.N_SFX = 0
378
+ self.dtype = dtype
379
 
380
  if not placeholder_is_bg:
381
+ self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
382
+ self.initialize_text_components(max_prompt_length)
383
 
384
  # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
385
  # prompt2token_proj is with the same architecture as the original arc2face text encoder,
386
  # but retrained to do inverse mapping.
387
  # To be initialized in the subclass.
388
  self.prompt2token_proj = None
389
+
 
390
  def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
391
  self.N_SFX = num_static_img_suffix_embs
392
  # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
 
397
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
398
  elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
399
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
400
+ self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
401
  elif self.N_SFX > 0:
402
  # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
403
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
404
+ self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX].to(dtype=self.dtype))
405
  else:
406
  # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
407
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
 
412
  # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
413
  # so we don't consider to reuse and extend a shorter static_img_suffix_embs).
414
  # So we reinitialize it.
415
+ self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
416
  else:
417
  # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
418
  self.static_img_suffix_embs = None
 
420
  # Implement a separate initialization function, so that it can be called from SubjBasisGenerator
421
  # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
422
  # ckpts which were not subclassed from ImgPrompt2TextPrompt.
423
+ def initialize_text_components(self, max_prompt_length=77):
 
 
424
  self.max_prompt_length = max_prompt_length
425
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
426
  # clip_text_embeddings: CLIPTextEmbeddings instance.
 
435
  # pad_embeddings is still on CPU. But should be moved to GPU automatically.
436
  # Note: detach pad_embeddings from the computation graph, otherwise
437
  # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
438
+ self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach().to(self.dtype)
439
 
440
  # image prompt space -> text prompt space.
441
  # return_emb_types: a list of strings, each string is among
 
458
  else:
459
  breakpoint()
460
  else:
461
+ # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_feat_distill_on_comp_prompt.
462
  # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
463
  list_extra_words = list_extra_words[:1]
464
 
 
485
  face_prompt_embs_orig_dtype = face_prompt_embs.dtype
486
  face_prompt_embs = face_prompt_embs.to(self.dtype)
487
 
488
+ ID_END = 4 + self.N_ID
489
  PAD_BEGIN = ID_END + self.N_SFX + 2
490
 
491
  # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
 
564
  class SubjBasisGenerator(ImgPrompt2TextPrompt):
565
  def __init__(
566
  self,
567
+ dtype=torch.float16,
568
  # number of cross-attention heads of the bg prompt translator.
569
  # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
570
  # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
 
573
  # or number of background input identity vectors (no matter the subject is face or not).
574
  # 257: 257 CLIP tokens.
575
  num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
576
+ num_ca_layers=16,
577
  num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
578
  num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
579
  bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
580
  obj_embedding_dim=384, # DINO object feature dimension for objects.
581
  output_dim=768, # CLIP text embedding input dimension.
582
+ use_layerwise_proj: bool = False, # Whether to use layerwise projection.
583
  placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
 
584
  learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
585
+ bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
586
  ):
587
 
588
  # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
589
+ super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
590
+ num_static_img_suffix_embs=num_static_img_suffix_embs,
591
+ max_prompt_length=77, img_prompt_dim=output_dim, dtype=dtype)
592
 
593
  self.placeholder_is_bg = placeholder_is_bg
594
+ self.num_ca_layers = num_ca_layers
595
  self.num_out_embs = self.N_ID + self.N_SFX
596
  self.output_dim = output_dim
597
  # num_nonface_in_id_vecs should be the number of core ID embs, 16.
 
609
  # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
610
  # If self.placeholder_is_bg: prompt2token_proj is set to None.
611
  # Use an attention dropout of 0.2 to increase robustness.
612
+ self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
613
+ self.prompt2token_proj.to(dtype=self.dtype)
614
+
615
+ if use_layerwise_proj:
616
+ # MLPProjWithSkip: MLP with skip connection.
617
+ # [BS, 4, 768] -> [BS, 16, 4, 768]. Extra 16: 16 layers.
618
+ self.layerwise_proj = LayerwiseMLPProjWithSkip(output_dim, dim_mult=2)
619
+ else:
620
+ self.layerwise_proj = nn.Identity() #Rearrange('b n d -> b l n d', l=16)
621
+
622
+ print(f"Subj prompt2token_proj initialized.")
623
+ # Only freeze token and positional embeddings of the original CLIPTextModel.
624
  self.freeze_prompt2token_proj()
625
 
626
  # These multipliers are relative to the original CLIPTextModel.
 
658
  identity_to_out=identity_to_out,
659
  out_has_skip=out_has_skip)
660
 
661
+ if self.dtype == torch.float16:
662
+ self.prompt_translator = self.prompt_translator.half()
663
+
664
  self.output_scale = output_dim ** -0.5
665
 
666
  '''
 
716
  hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
717
 
718
  # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
719
+ # inverse_img_prompt_embs() applies self.prompt2token_proj to faceid2img_prompt_embs.
720
+ # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
721
+ # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
722
+ # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
723
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
724
+ # ada_id_embs: [BS, 16, 768].
725
+ # return_emb_types: a list of strings, each string is among
726
+ # ['full', 'core', 'full_pad', 'full_half_pad'].
727
+ ada_id_embs, = \
728
+ self.inverse_img_prompt_embs(faceid2img_prompt_embs,
729
+ list_extra_words=None,
730
+ return_emb_types=['core'],
731
+ hidden_state_layer_weights=hidden_state_layer_weights,
732
+ enable_static_img_suffix_embs=enable_static_img_suffix_embs)
 
733
  elif raw_id_embs is not None:
734
  # id_embs: [BS, 384] -> [BS, 18, 768].
735
  # obj_proj_in is expected to project the DINO object features to
 
755
 
756
  adaface_out_embs = id_embs_out * self.output_scale # * 0.036
757
  else:
758
+ # [BS, 16, 768] -> [BS, layers=16, tokens=16, 768]
759
+ adaface_out_embs = self.layerwise_proj(ada_id_embs)
760
  # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
761
  if out_id_embs_cfg_scale != 1:
762
+ # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
763
  # NOTE: Never do cfg on static image suffix embeddings.
764
  # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
765
  # even if enable_static_img_suffix_embs=True.
766
+ pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).unsqueeze(1).to(ada_id_embs.device)
767
  adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
768
  + pad_embeddings * (1 - out_id_embs_cfg_scale)
769
 
 
842
  # Only applicable to fg basis generator.
843
  if self.placeholder_is_bg:
844
  return
845
+
 
 
 
 
 
 
 
 
846
  if self.prompt2token_proj is not None:
847
  frozen_param_names = []
848
+ for param_name, param in self.prompt2token_proj.text_model.embeddings.named_parameters():
849
  if param.requires_grad:
850
  param.requires_grad = False
851
  frozen_param_names.append(param_name)
852
  # If param is already frozen, then no need to freeze it again.
853
+ print(f"{len(frozen_param_names)} params of token_pos_embeddings in Subj prompt2token_proj is frozen.")
854
  #print(f"Frozen parameters:\n{frozen_param_names}")
855
 
856
  def patch_old_subj_basis_generator_ckpt(self):
857
  # Fix compatability with the previous version.
858
  if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
859
  self.bg_prompt_translator_has_to_out_proj = False
 
 
860
  if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
861
  self.N_ID = self.num_id_vecs
862
+ # Update the number of output embeddings.
863
+ self.num_out_embs = self.N_ID + self.N_SFX
864
+
865
  if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
866
  self.num_nonface_in_id_vecs = self.N_ID
867
  if not hasattr(self, 'dtype'):
868
+ self.dtype = torch.float16
869
+ if not self.placeholder_is_bg:
870
+ self.prompt2token_proj.to(dtype=self.dtype)
871
+ else:
872
+ self.prompt_translator.half()
873
+
874
+ if not hasattr(self, 'num_ca_layers'):
875
+ self.num_ca_layers = 16
876
 
877
  if self.placeholder_is_bg:
878
  if not hasattr(self, 'pos_embs') or self.pos_embs is None:
 
890
  num_static_img_suffix_embs=self.N_SFX,
891
  img_prompt_dim=self.output_dim)
892
 
893
+ if not hasattr(self, 'use_layerwise_proj'):
894
+ self.use_layerwise_proj = False
895
+ if not hasattr(self, 'layerwise_proj'):
896
+ if self.use_layerwise_proj:
897
+ self.layerwise_proj = LayerwiseMLPProjWithSkip(self.output_dim, dim_mult=2)
898
+ else:
899
+ self.layerwise_proj = nn.Identity()
900
+
901
  def __repr__(self):
902
  type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
903
 
adaface/unet_teachers.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
 
2
  import numpy as np
3
- import pytorch_lightning as pl
4
  from diffusers import UNet2DConditionModel
5
  from adaface.util import UNetEnsemble, create_consistentid_pipeline
6
  from diffusers import UNet2DConditionModel
@@ -12,9 +12,9 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
12
  teacher_type = teacher_type[0]
13
 
14
  if teacher_type == "arc2face":
15
- return Arc2FaceTeacher(**kwargs)
16
  elif teacher_type == "unet_ensemble":
17
- # unet, extra_unet_dirpaths and unet_weights are passed in kwargs.
18
  # Even if we distill from unet_ensemble, we still need to load arc2face for generating
19
  # arc2face embeddings.
20
  # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
@@ -22,20 +22,24 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
22
  # However, since the __call__ method of the ddpm unet takes different formats of params,
23
  # for simplicity, we still use the diffusers unet.
24
  # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
25
- return UNetEnsembleTeacher(device=device, **kwargs)
26
  elif teacher_type == "consistentID":
27
- return ConsistentIDTeacher(**kwargs)
28
  elif teacher_type == "simple_unet":
29
- return SimpleUNetTeacher(**kwargs)
30
  # Since we've dereferenced the list if it has only one element,
31
  # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
32
  elif isinstance(teacher_type, (tuple, list, ListConfig)):
33
  # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
34
- return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
35
  else:
36
  raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
37
 
38
- class UNetTeacher(pl.LightningModule):
 
 
 
 
39
  def __init__(self, **kwargs):
40
  super().__init__()
41
  self.name = None
@@ -56,9 +60,10 @@ class UNetTeacher(pl.LightningModule):
56
  # to be initialized, which will unnecessarily complicate the code.
57
  # noise: the initial noise for the first iteration.
58
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
59
- # uses_same_t: when sampling t, use the same t for all instances.
60
- def forward(self, ddpm_model, x_start, noise, t, teacher_context,
61
- num_denoising_steps=1, uses_same_t=False):
 
62
  assert num_denoising_steps <= 10
63
 
64
  if self.p_uses_cfg > 0:
@@ -71,27 +76,22 @@ class UNetTeacher(pl.LightningModule):
71
 
72
  if self.uses_cfg:
73
  print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
 
 
 
 
 
 
74
  else:
75
  self.cfg_scale = 1
76
  print("Teacher does not use CFG.")
77
 
78
- # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
79
- # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
80
- # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
81
- if self.name == 'unet_ensemble':
82
- teacher_pos_contexts = []
83
- # teacher_context is a list of teacher contexts.
84
- for teacher_context_i in teacher_context:
85
- pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
86
- if pos_context.shape[0] != x_start.shape[0]:
87
- breakpoint()
88
- teacher_pos_contexts.append(pos_context)
89
- teacher_context = teacher_pos_contexts
90
- else:
91
- pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
92
- if pos_context.shape[0] != x_start.shape[0]:
93
- breakpoint()
94
- teacher_context = pos_context
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
@@ -102,15 +102,21 @@ class UNetTeacher(pl.LightningModule):
102
  # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
 
105
  if self.name == 'unet_ensemble':
106
  # teacher_context is a list of teacher contexts.
107
  for teacher_context_i in teacher_context:
108
- if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
109
  breakpoint()
110
  else:
111
- if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
112
  breakpoint()
113
-
 
 
 
 
 
114
  # Initially, x_starts only contains the original x_start.
115
  x_starts = [ x_start ]
116
  noises = [ noise ]
@@ -125,24 +131,35 @@ class UNetTeacher(pl.LightningModule):
125
  # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
126
  x_noisy = ddpm_model.q_sample(x_start, t, noise)
127
 
128
- if self.uses_cfg:
129
  x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
130
  t2 = t.repeat(2)
131
  else:
132
  x_noisy2 = x_noisy
133
- t2 = t
134
 
135
  # If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
136
  noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
137
  return_dict=False)[0]
138
  if self.uses_cfg and self.cfg_scale > 1:
139
- pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
140
  noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
141
 
142
- # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
143
- pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
144
  noise_preds.append(noise_pred)
145
-
 
146
  # The predicted x0 is used as the x_start for the next denoising step.
147
  x_starts.append(pred_x0)
148
 
@@ -157,20 +174,43 @@ class UNetTeacher(pl.LightningModule):
157
  # of the current timestep.
158
  t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
159
  t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
 
 
160
  earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
161
  earlier_timesteps = earlier_timesteps.long()
 
162
 
163
- if uses_same_t:
164
- # If uses_same_t, we use the same earlier_timesteps for all instances.
165
  earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
 
166
 
167
  # earlier_timesteps = ts[i+1] < ts[i].
168
  ts.append(earlier_timesteps)
169
-
170
- noise = torch.randn_like(pred_x0)
171
  noises.append(noise)
172
 
173
  return noise_preds, x_starts, noises, ts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  class Arc2FaceTeacher(UNetTeacher):
176
  def __init__(self, **kwargs):
@@ -185,11 +225,11 @@ class Arc2FaceTeacher(UNetTeacher):
185
  self.cfg_scale_range = [1, 1]
186
 
187
  class UNetEnsembleTeacher(UNetTeacher):
188
- # unet_weights are not model weights, but scalar weights for individual unets.
189
- def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs):
190
  super().__init__(**kwargs)
191
  self.name = "unet_ensemble"
192
- self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device)
193
 
194
  class ConsistentIDTeacher(UNetTeacher):
195
  def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
@@ -199,12 +239,9 @@ class ConsistentIDTeacher(UNetTeacher):
199
  # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
200
  # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
201
  # Instead, we have to initialize it to GPU directly.
202
- pipe = create_consistentid_pipeline(base_model_path)
203
- # Compatible with the UNetTeacher interface.
204
- self.unet = pipe.unet
205
- # Release VAE and text_encoder to save memory. UNet is still needed for denoising
206
  # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
207
- pipe.release_components(["vae", "text_encoder"])
208
 
209
  # We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
210
  # Note p_uses_cfg=0.5 will also be passed in in kwargs.
 
1
  import torch
2
+ from torch import nn
3
  import numpy as np
 
4
  from diffusers import UNet2DConditionModel
5
  from adaface.util import UNetEnsemble, create_consistentid_pipeline
6
  from diffusers import UNet2DConditionModel
 
12
  teacher_type = teacher_type[0]
13
 
14
  if teacher_type == "arc2face":
15
+ teacher = Arc2FaceTeacher(**kwargs)
16
  elif teacher_type == "unet_ensemble":
17
+ # unet, extra_unet_dirpaths and unet_weights_in_ensemble are passed in kwargs.
18
  # Even if we distill from unet_ensemble, we still need to load arc2face for generating
19
  # arc2face embeddings.
20
  # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
 
22
  # However, since the __call__ method of the ddpm unet takes different formats of params,
23
  # for simplicity, we still use the diffusers unet.
24
  # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
25
+ teacher = UNetEnsembleTeacher(device=device, **kwargs)
26
  elif teacher_type == "consistentID":
27
+ teacher = ConsistentIDTeacher(**kwargs)
28
  elif teacher_type == "simple_unet":
29
+ teacher = SimpleUNetTeacher(**kwargs)
30
  # Since we've dereferenced the list if it has only one element,
31
  # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
32
  elif isinstance(teacher_type, (tuple, list, ListConfig)):
33
  # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
34
+ teacher = UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
35
  else:
36
  raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
37
 
38
+ for param in teacher.parameters():
39
+ param.requires_grad = False
40
+ return teacher
41
+
42
+ class UNetTeacher(nn.Module):
43
  def __init__(self, **kwargs):
44
  super().__init__()
45
  self.name = None
 
60
  # to be initialized, which will unnecessarily complicate the code.
61
  # noise: the initial noise for the first iteration.
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
+ # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
+ def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
+ num_denoising_steps=1, same_t_noise_across_instances=False,
66
+ global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
  if self.p_uses_cfg > 0:
 
76
 
77
  if self.uses_cfg:
78
  print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
79
+ if negative_context is not None:
80
+ negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
81
+
82
+ # if negative_context is None, then teacher_context is a combination of
83
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
84
+ # If negative_context is not None, then teacher_context is only pos_context.
85
  else:
86
  self.cfg_scale = 1
87
  print("Teacher does not use CFG.")
88
 
89
+ # If negative_context is None, then teacher_context is a combination of
90
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
91
+ # Since not uses_cfg, we only need pos_context.
92
+ # If negative_context is not None, then teacher_context is only pos_context.
93
+ if negative_context is None:
94
+ teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
 
 
 
 
 
 
 
 
 
 
 
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
 
102
  # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
105
+ is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
106
  if self.name == 'unet_ensemble':
107
  # teacher_context is a list of teacher contexts.
108
  for teacher_context_i in teacher_context:
109
+ if teacher_context_i.shape[0] != x_start.shape[0] * is_context_doubled:
110
  breakpoint()
111
  else:
112
+ if teacher_context.shape[0] != x_start.shape[0] * is_context_doubled:
113
  breakpoint()
114
+
115
+ if same_t_noise_across_instances:
116
+ # If same_t_noise_across_instances, we use the same t and noise for all instances.
117
+ t = t[0].repeat(x_start.shape[0])
118
+ noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
119
+
120
  # Initially, x_starts only contains the original x_start.
121
  x_starts = [ x_start ]
122
  noises = [ noise ]
 
131
  # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
132
  x_noisy = ddpm_model.q_sample(x_start, t, noise)
133
 
134
+ if self.uses_cfg and self.cfg_scale > 1 and negative_context is None:
135
  x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
136
  t2 = t.repeat(2)
137
  else:
138
  x_noisy2 = x_noisy
139
+ t2 = t
140
 
141
  # If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
142
  noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
143
  return_dict=False)[0]
144
  if self.uses_cfg and self.cfg_scale > 1:
145
+ if negative_context is None:
146
+ pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
147
+ else:
148
+ # If negative_context is not None, then teacher_context is only pos_context.
149
+ pos_noise_pred = noise_pred
150
+ with torch.no_grad():
151
+ if self.name == 'unet_ensemble':
152
+ neg_noise_pred = self.unet.unets[0](sample=x_noisy, timestep=t,
153
+ encoder_hidden_states=negative_context, return_dict=False)[0]
154
+ else:
155
+ neg_noise_pred = self.unet(sample=x_noisy, timestep=t,
156
+ encoder_hidden_states=negative_context, return_dict=False)[0]
157
+
158
  noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
159
 
 
 
160
  noise_preds.append(noise_pred)
161
+ # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
162
+ pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
163
  # The predicted x0 is used as the x_start for the next denoising step.
164
  x_starts.append(pred_x0)
165
 
 
174
  # of the current timestep.
175
  t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
176
  t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
177
+ t_lb = torch.clamp(t_lb, min=global_t_lb)
178
+ t_ub = torch.clamp(t_ub, max=global_t_ub)
179
  earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
180
  earlier_timesteps = earlier_timesteps.long()
181
+ noise = torch.randn_like(pred_x0)
182
 
183
+ if same_t_noise_across_instances:
184
+ # If same_t_noise_across_instances, we use the same earlier_timesteps and noise for all instances.
185
  earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
186
+ noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
187
 
188
  # earlier_timesteps = ts[i+1] < ts[i].
189
  ts.append(earlier_timesteps)
 
 
190
  noises.append(noise)
191
 
192
  return noise_preds, x_starts, noises, ts
193
+
194
+ def extract_pos_context(self, teacher_context, BS):
195
+ # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
196
+ # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
197
+ # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
198
+ if self.name == 'unet_ensemble':
199
+ teacher_pos_contexts = []
200
+ # teacher_context is a list of teacher contexts.
201
+ for teacher_context_i in teacher_context:
202
+ pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
203
+ if pos_context.shape[0] != BS:
204
+ breakpoint()
205
+ teacher_pos_contexts.append(pos_context)
206
+ teacher_context = teacher_pos_contexts
207
+ else:
208
+ pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
209
+ if pos_context.shape[0] != BS:
210
+ breakpoint()
211
+ teacher_context = pos_context
212
+
213
+ return teacher_context
214
 
215
  class Arc2FaceTeacher(UNetTeacher):
216
  def __init__(self, **kwargs):
 
225
  self.cfg_scale_range = [1, 1]
226
 
227
  class UNetEnsembleTeacher(UNetTeacher):
228
+ # unet_weights_in_ensemble are not model weights, but scalar weights for individual unets.
229
+ def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', **kwargs):
230
  super().__init__(**kwargs)
231
  self.name = "unet_ensemble"
232
+ self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble, device)
233
 
234
  class ConsistentIDTeacher(UNetTeacher):
235
  def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
 
239
  # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
240
  # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
241
  # Instead, we have to initialize it to GPU directly.
242
+ # Release VAE and text_encoder to save memory. UNet is needed for denoising
 
 
 
243
  # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
244
+ self.unet = create_consistentid_pipeline(base_model_path, unet_only=True)
245
 
246
  # We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
247
  # Note p_uses_cfg=0.5 will also be passed in in kwargs.
adaface/util.py CHANGED
@@ -57,7 +57,7 @@ def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_di
57
  ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
58
  return ts.numpy().astype(np_array.dtype)
59
 
60
- def calc_stats(emb_name, embeddings, mean_dim=0):
61
  print("%s:" %emb_name)
62
  repeat_count = [1] * embeddings.ndim
63
  repeat_count[mean_dim] = embeddings.shape[mean_dim]
@@ -153,13 +153,14 @@ def pad_image_obj_to_square(image_obj, new_size=-1):
153
 
154
  class UNetEnsemble(nn.Module):
155
  # The first unet is the unet already loaded in a pipeline.
156
- def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16):
157
  super().__init__()
158
 
159
- self.unets = nn.ModuleList()
160
  if unets is not None:
161
- self.unets += [ unet.to(device) for unet in unets ]
162
-
 
 
163
  if unet_types is not None:
164
  for unet_type in unet_types:
165
  if unet_type == "arc2face":
@@ -169,25 +170,27 @@ class UNetEnsemble(nn.Module):
169
  unet = create_consistentid_pipeline(unet_only=True)
170
  else:
171
  breakpoint()
172
- self.unets.append(unet.to(device=device))
173
 
174
  if extra_unet_dirpaths is not None:
175
  for unet_path in extra_unet_dirpaths:
176
  unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
177
- self.unets.append(unet.to(device=device))
178
 
179
- if unet_weights is None:
180
- unet_weights = [1.] * len(self.unets)
181
- elif len(self.unets) < len(unet_weights):
182
- unet_weights = unet_weights[:len(self.unets)]
183
- elif len(self.unets) > len(unet_weights):
184
  breakpoint()
185
 
186
- unet_weights = torch.tensor(unet_weights, dtype=torch_dtype)
187
- unet_weights = unet_weights / unet_weights.sum()
188
- self.unet_weights = nn.Parameter(unet_weights, requires_grad=False)
189
 
190
- print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}")
 
 
 
191
  # Set these fields to be compatible with diffusers.
192
  self.dtype = self.unets[0].dtype
193
  self.device = self.unets[0].device
@@ -215,8 +218,8 @@ class UNetEnsemble(nn.Module):
215
  samples.append(sample)
216
 
217
  samples = torch.stack(samples, dim=0)
218
- unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1)))
219
- sample = (samples * unet_weights).sum(dim=0)
220
 
221
  if not return_dict:
222
  return (sample,)
 
57
  ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
58
  return ts.numpy().astype(np_array.dtype)
59
 
60
+ def calc_stats(emb_name, embeddings, mean_dim=-1):
61
  print("%s:" %emb_name)
62
  repeat_count = [1] * embeddings.ndim
63
  repeat_count[mean_dim] = embeddings.shape[mean_dim]
 
153
 
154
  class UNetEnsemble(nn.Module):
155
  # The first unet is the unet already loaded in a pipeline.
156
+ def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', torch_dtype=torch.float16):
157
  super().__init__()
158
 
 
159
  if unets is not None:
160
+ unets = [ unet.to(device) for unet in unets ]
161
+ else:
162
+ unets = []
163
+
164
  if unet_types is not None:
165
  for unet_type in unet_types:
166
  if unet_type == "arc2face":
 
170
  unet = create_consistentid_pipeline(unet_only=True)
171
  else:
172
  breakpoint()
173
+ unets.append(unet.to(device=device))
174
 
175
  if extra_unet_dirpaths is not None:
176
  for unet_path in extra_unet_dirpaths:
177
  unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
178
+ unets.append(unet.to(device=device))
179
 
180
+ if unet_weights_in_ensemble is None:
181
+ unet_weights_in_ensemble = [1.] * len(unets)
182
+ elif len(unets) < len(unet_weights_in_ensemble):
183
+ unet_weights_in_ensemble = unet_weights_in_ensemble[:len(unets)]
184
+ elif len(unets) > len(unet_weights_in_ensemble):
185
  breakpoint()
186
 
187
+ unet_weights_in_ensemble = torch.tensor(unet_weights_in_ensemble, dtype=torch_dtype)
188
+ unet_weights_in_ensemble = unet_weights_in_ensemble / unet_weights_in_ensemble.sum()
 
189
 
190
+ self.unets = nn.ModuleList(unets)
191
+ # Put the weights in a Parameter so that they will be moved to the same device as the model.
192
+ self.unet_weights_in_ensemble = nn.Parameter(unet_weights_in_ensemble, requires_grad=False)
193
+ print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights_in_ensemble.data.cpu().numpy()}")
194
  # Set these fields to be compatible with diffusers.
195
  self.dtype = self.unets[0].dtype
196
  self.device = self.unets[0].device
 
218
  samples.append(sample)
219
 
220
  samples = torch.stack(samples, dim=0)
221
+ unet_weights_in_ensemble = self.unet_weights_in_ensemble.reshape(-1, *([1] * (samples.ndim - 1)))
222
+ sample = (samples * unet_weights_in_ensemble).sum(dim=0)
223
 
224
  if not return_dict:
225
  return (sample,)
animatediff/sd/.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
- *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
- *.model filter=lfs diff=lfs merge=lfs -text
12
- *.msgpack filter=lfs diff=lfs merge=lfs -text
13
- *.npy filter=lfs diff=lfs merge=lfs -text
14
- *.npz filter=lfs diff=lfs merge=lfs -text
15
- *.onnx filter=lfs diff=lfs merge=lfs -text
16
- *.ot filter=lfs diff=lfs merge=lfs -text
17
- *.parquet filter=lfs diff=lfs merge=lfs -text
18
- *.pb filter=lfs diff=lfs merge=lfs -text
19
- *.pickle filter=lfs diff=lfs merge=lfs -text
20
- *.pkl filter=lfs diff=lfs merge=lfs -text
21
- *.pt filter=lfs diff=lfs merge=lfs -text
22
- *.pth filter=lfs diff=lfs merge=lfs -text
23
- *.rar filter=lfs diff=lfs merge=lfs -text
24
- *.safetensors filter=lfs diff=lfs merge=lfs -text
25
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
- *.tar.* filter=lfs diff=lfs merge=lfs -text
27
- *.tflite filter=lfs diff=lfs merge=lfs -text
28
- *.tgz filter=lfs diff=lfs merge=lfs -text
29
- *.wasm filter=lfs diff=lfs merge=lfs -text
30
- *.xz filter=lfs diff=lfs merge=lfs -text
31
- *.zip filter=lfs diff=lfs merge=lfs -text
32
- *.zst filter=lfs diff=lfs merge=lfs -text
33
- *tfevents* filter=lfs diff=lfs merge=lfs -text
34
- v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
35
- v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/feature_extractor/preprocessor_config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "crop_size": 224,
3
- "do_center_crop": true,
4
- "do_convert_rgb": true,
5
- "do_normalize": true,
6
- "do_resize": true,
7
- "feature_extractor_type": "CLIPFeatureExtractor",
8
- "image_mean": [
9
- 0.48145466,
10
- 0.4578275,
11
- 0.40821073
12
- ],
13
- "image_std": [
14
- 0.26862954,
15
- 0.26130258,
16
- 0.27577711
17
- ],
18
- "resample": 3,
19
- "size": 224
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/model_index.json DELETED
@@ -1,32 +0,0 @@
1
- {
2
- "_class_name": "StableDiffusionPipeline",
3
- "_diffusers_version": "0.6.0",
4
- "feature_extractor": [
5
- "transformers",
6
- "CLIPImageProcessor"
7
- ],
8
- "safety_checker": [
9
- "stable_diffusion",
10
- "StableDiffusionSafetyChecker"
11
- ],
12
- "scheduler": [
13
- "diffusers",
14
- "PNDMScheduler"
15
- ],
16
- "text_encoder": [
17
- "transformers",
18
- "CLIPTextModel"
19
- ],
20
- "tokenizer": [
21
- "transformers",
22
- "CLIPTokenizer"
23
- ],
24
- "unet": [
25
- "diffusers",
26
- "UNet2DConditionModel"
27
- ],
28
- "vae": [
29
- "diffusers",
30
- "AutoencoderKL"
31
- ]
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/safety_checker/config.json DELETED
@@ -1,175 +0,0 @@
1
- {
2
- "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
3
- "_name_or_path": "CompVis/stable-diffusion-safety-checker",
4
- "architectures": [
5
- "StableDiffusionSafetyChecker"
6
- ],
7
- "initializer_factor": 1.0,
8
- "logit_scale_init_value": 2.6592,
9
- "model_type": "clip",
10
- "projection_dim": 768,
11
- "text_config": {
12
- "_name_or_path": "",
13
- "add_cross_attention": false,
14
- "architectures": null,
15
- "attention_dropout": 0.0,
16
- "bad_words_ids": null,
17
- "bos_token_id": 0,
18
- "chunk_size_feed_forward": 0,
19
- "cross_attention_hidden_size": null,
20
- "decoder_start_token_id": null,
21
- "diversity_penalty": 0.0,
22
- "do_sample": false,
23
- "dropout": 0.0,
24
- "early_stopping": false,
25
- "encoder_no_repeat_ngram_size": 0,
26
- "eos_token_id": 2,
27
- "exponential_decay_length_penalty": null,
28
- "finetuning_task": null,
29
- "forced_bos_token_id": null,
30
- "forced_eos_token_id": null,
31
- "hidden_act": "quick_gelu",
32
- "hidden_size": 768,
33
- "id2label": {
34
- "0": "LABEL_0",
35
- "1": "LABEL_1"
36
- },
37
- "initializer_factor": 1.0,
38
- "initializer_range": 0.02,
39
- "intermediate_size": 3072,
40
- "is_decoder": false,
41
- "is_encoder_decoder": false,
42
- "label2id": {
43
- "LABEL_0": 0,
44
- "LABEL_1": 1
45
- },
46
- "layer_norm_eps": 1e-05,
47
- "length_penalty": 1.0,
48
- "max_length": 20,
49
- "max_position_embeddings": 77,
50
- "min_length": 0,
51
- "model_type": "clip_text_model",
52
- "no_repeat_ngram_size": 0,
53
- "num_attention_heads": 12,
54
- "num_beam_groups": 1,
55
- "num_beams": 1,
56
- "num_hidden_layers": 12,
57
- "num_return_sequences": 1,
58
- "output_attentions": false,
59
- "output_hidden_states": false,
60
- "output_scores": false,
61
- "pad_token_id": 1,
62
- "prefix": null,
63
- "problem_type": null,
64
- "pruned_heads": {},
65
- "remove_invalid_values": false,
66
- "repetition_penalty": 1.0,
67
- "return_dict": true,
68
- "return_dict_in_generate": false,
69
- "sep_token_id": null,
70
- "task_specific_params": null,
71
- "temperature": 1.0,
72
- "tf_legacy_loss": false,
73
- "tie_encoder_decoder": false,
74
- "tie_word_embeddings": true,
75
- "tokenizer_class": null,
76
- "top_k": 50,
77
- "top_p": 1.0,
78
- "torch_dtype": null,
79
- "torchscript": false,
80
- "transformers_version": "4.22.0.dev0",
81
- "typical_p": 1.0,
82
- "use_bfloat16": false,
83
- "vocab_size": 49408
84
- },
85
- "text_config_dict": {
86
- "hidden_size": 768,
87
- "intermediate_size": 3072,
88
- "num_attention_heads": 12,
89
- "num_hidden_layers": 12
90
- },
91
- "torch_dtype": "float32",
92
- "transformers_version": null,
93
- "vision_config": {
94
- "_name_or_path": "",
95
- "add_cross_attention": false,
96
- "architectures": null,
97
- "attention_dropout": 0.0,
98
- "bad_words_ids": null,
99
- "bos_token_id": null,
100
- "chunk_size_feed_forward": 0,
101
- "cross_attention_hidden_size": null,
102
- "decoder_start_token_id": null,
103
- "diversity_penalty": 0.0,
104
- "do_sample": false,
105
- "dropout": 0.0,
106
- "early_stopping": false,
107
- "encoder_no_repeat_ngram_size": 0,
108
- "eos_token_id": null,
109
- "exponential_decay_length_penalty": null,
110
- "finetuning_task": null,
111
- "forced_bos_token_id": null,
112
- "forced_eos_token_id": null,
113
- "hidden_act": "quick_gelu",
114
- "hidden_size": 1024,
115
- "id2label": {
116
- "0": "LABEL_0",
117
- "1": "LABEL_1"
118
- },
119
- "image_size": 224,
120
- "initializer_factor": 1.0,
121
- "initializer_range": 0.02,
122
- "intermediate_size": 4096,
123
- "is_decoder": false,
124
- "is_encoder_decoder": false,
125
- "label2id": {
126
- "LABEL_0": 0,
127
- "LABEL_1": 1
128
- },
129
- "layer_norm_eps": 1e-05,
130
- "length_penalty": 1.0,
131
- "max_length": 20,
132
- "min_length": 0,
133
- "model_type": "clip_vision_model",
134
- "no_repeat_ngram_size": 0,
135
- "num_attention_heads": 16,
136
- "num_beam_groups": 1,
137
- "num_beams": 1,
138
- "num_channels": 3,
139
- "num_hidden_layers": 24,
140
- "num_return_sequences": 1,
141
- "output_attentions": false,
142
- "output_hidden_states": false,
143
- "output_scores": false,
144
- "pad_token_id": null,
145
- "patch_size": 14,
146
- "prefix": null,
147
- "problem_type": null,
148
- "pruned_heads": {},
149
- "remove_invalid_values": false,
150
- "repetition_penalty": 1.0,
151
- "return_dict": true,
152
- "return_dict_in_generate": false,
153
- "sep_token_id": null,
154
- "task_specific_params": null,
155
- "temperature": 1.0,
156
- "tf_legacy_loss": false,
157
- "tie_encoder_decoder": false,
158
- "tie_word_embeddings": true,
159
- "tokenizer_class": null,
160
- "top_k": 50,
161
- "top_p": 1.0,
162
- "torch_dtype": null,
163
- "torchscript": false,
164
- "transformers_version": "4.22.0.dev0",
165
- "typical_p": 1.0,
166
- "use_bfloat16": false
167
- },
168
- "vision_config_dict": {
169
- "hidden_size": 1024,
170
- "intermediate_size": 4096,
171
- "num_attention_heads": 16,
172
- "num_hidden_layers": 24,
173
- "patch_size": 14
174
- }
175
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/scheduler/scheduler_config.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_class_name": "PNDMScheduler",
3
- "_diffusers_version": "0.6.0",
4
- "beta_end": 0.012,
5
- "beta_schedule": "scaled_linear",
6
- "beta_start": 0.00085,
7
- "num_train_timesteps": 1000,
8
- "set_alpha_to_one": false,
9
- "skip_prk_steps": true,
10
- "steps_offset": 1,
11
- "trained_betas": null,
12
- "clip_sample": false
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/text_encoder/config.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "_name_or_path": "openai/clip-vit-large-patch14",
3
- "architectures": [
4
- "CLIPTextModel"
5
- ],
6
- "attention_dropout": 0.0,
7
- "bos_token_id": 0,
8
- "dropout": 0.0,
9
- "eos_token_id": 2,
10
- "hidden_act": "quick_gelu",
11
- "hidden_size": 768,
12
- "initializer_factor": 1.0,
13
- "initializer_range": 0.02,
14
- "intermediate_size": 3072,
15
- "layer_norm_eps": 1e-05,
16
- "max_position_embeddings": 77,
17
- "model_type": "clip_text_model",
18
- "num_attention_heads": 12,
19
- "num_hidden_layers": 12,
20
- "pad_token_id": 1,
21
- "projection_dim": 768,
22
- "torch_dtype": "float32",
23
- "transformers_version": "4.22.0.dev0",
24
- "vocab_size": 49408
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/tokenizer/merges.txt DELETED
The diff for this file is too large to render. See raw diff
 
animatediff/sd/tokenizer/special_tokens_map.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "bos_token": {
3
- "content": "<|startoftext|>",
4
- "lstrip": false,
5
- "normalized": true,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "eos_token": {
10
- "content": "<|endoftext|>",
11
- "lstrip": false,
12
- "normalized": true,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "pad_token": "<|endoftext|>",
17
- "unk_token": {
18
- "content": "<|endoftext|>",
19
- "lstrip": false,
20
- "normalized": true,
21
- "rstrip": false,
22
- "single_word": false
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/tokenizer/tokenizer_config.json DELETED
@@ -1,34 +0,0 @@
1
- {
2
- "add_prefix_space": false,
3
- "bos_token": {
4
- "__type": "AddedToken",
5
- "content": "<|startoftext|>",
6
- "lstrip": false,
7
- "normalized": true,
8
- "rstrip": false,
9
- "single_word": false
10
- },
11
- "do_lower_case": true,
12
- "eos_token": {
13
- "__type": "AddedToken",
14
- "content": "<|endoftext|>",
15
- "lstrip": false,
16
- "normalized": true,
17
- "rstrip": false,
18
- "single_word": false
19
- },
20
- "errors": "replace",
21
- "model_max_length": 77,
22
- "name_or_path": "openai/clip-vit-large-patch14",
23
- "pad_token": "<|endoftext|>",
24
- "special_tokens_map_file": "./special_tokens_map.json",
25
- "tokenizer_class": "CLIPTokenizer",
26
- "unk_token": {
27
- "__type": "AddedToken",
28
- "content": "<|endoftext|>",
29
- "lstrip": false,
30
- "normalized": true,
31
- "rstrip": false,
32
- "single_word": false
33
- }
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/tokenizer/vocab.json DELETED
The diff for this file is too large to render. See raw diff
 
animatediff/sd/unet/config.json DELETED
@@ -1,36 +0,0 @@
1
- {
2
- "_class_name": "UNet2DConditionModel",
3
- "_diffusers_version": "0.6.0",
4
- "act_fn": "silu",
5
- "attention_head_dim": 8,
6
- "block_out_channels": [
7
- 320,
8
- 640,
9
- 1280,
10
- 1280
11
- ],
12
- "center_input_sample": false,
13
- "cross_attention_dim": 768,
14
- "down_block_types": [
15
- "CrossAttnDownBlock2D",
16
- "CrossAttnDownBlock2D",
17
- "CrossAttnDownBlock2D",
18
- "DownBlock2D"
19
- ],
20
- "downsample_padding": 1,
21
- "flip_sin_to_cos": true,
22
- "freq_shift": 0,
23
- "in_channels": 4,
24
- "layers_per_block": 2,
25
- "mid_block_scale_factor": 1,
26
- "norm_eps": 1e-05,
27
- "norm_num_groups": 32,
28
- "out_channels": 4,
29
- "sample_size": 64,
30
- "up_block_types": [
31
- "UpBlock2D",
32
- "CrossAttnUpBlock2D",
33
- "CrossAttnUpBlock2D",
34
- "CrossAttnUpBlock2D"
35
- ]
36
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/v1-inference.yaml DELETED
@@ -1,70 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-04
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- image_size: 64
13
- channels: 4
14
- cond_stage_trainable: false # Note: different from the one we trained before
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
-
20
- scheduler_config: # 10000 warmup steps
21
- target: ldm.lr_scheduler.LambdaLinearScheduler
22
- params:
23
- warm_up_steps: [ 10000 ]
24
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
- f_start: [ 1.e-6 ]
26
- f_max: [ 1. ]
27
- f_min: [ 1. ]
28
-
29
- unet_config:
30
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
- params:
32
- image_size: 32 # unused
33
- in_channels: 4
34
- out_channels: 4
35
- model_channels: 320
36
- attention_resolutions: [ 4, 2, 1 ]
37
- num_res_blocks: 2
38
- channel_mult: [ 1, 2, 4, 4 ]
39
- num_heads: 8
40
- use_spatial_transformer: True
41
- transformer_depth: 1
42
- context_dim: 768
43
- use_checkpoint: True
44
- legacy: False
45
-
46
- first_stage_config:
47
- target: ldm.models.autoencoder.AutoencoderKL
48
- params:
49
- embed_dim: 4
50
- monitor: val/rec_loss
51
- ddconfig:
52
- double_z: true
53
- z_channels: 4
54
- resolution: 256
55
- in_channels: 3
56
- out_ch: 3
57
- ch: 128
58
- ch_mult:
59
- - 1
60
- - 2
61
- - 4
62
- - 4
63
- num_res_blocks: 2
64
- attn_resolutions: []
65
- dropout: 0.0
66
- lossconfig:
67
- target: torch.nn.Identity
68
-
69
- cond_stage_config:
70
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/sd/vae/config.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "_class_name": "AutoencoderKL",
3
- "_diffusers_version": "0.6.0",
4
- "act_fn": "silu",
5
- "block_out_channels": [
6
- 128,
7
- 256,
8
- 512,
9
- 512
10
- ],
11
- "down_block_types": [
12
- "DownEncoderBlock2D",
13
- "DownEncoderBlock2D",
14
- "DownEncoderBlock2D",
15
- "DownEncoderBlock2D"
16
- ],
17
- "in_channels": 3,
18
- "latent_channels": 4,
19
- "layers_per_block": 2,
20
- "norm_num_groups": 32,
21
- "out_channels": 3,
22
- "sample_size": 512,
23
- "up_block_types": [
24
- "UpDecoderBlock2D",
25
- "UpDecoderBlock2D",
26
- "UpDecoderBlock2D",
27
- "UpDecoderBlock2D"
28
- ]
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
animatediff/utils/convert_from_ckpt.py CHANGED
@@ -714,7 +714,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
714
 
715
 
716
  def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
717
- text_model = CLIPTextModel.from_pretrained("animatediff/sd/text_encoder", torch_dtype=dtype)
718
  keys = list(checkpoint.keys())
719
 
720
  text_model_dict = {}
 
714
 
715
 
716
  def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
717
+ text_model = CLIPTextModel.from_pretrained("models/animatediff/sd/text_encoder", torch_dtype=dtype)
718
  keys = list(checkpoint.keys())
719
 
720
  text_model_dict = {}
app.py CHANGED
@@ -24,13 +24,11 @@ parser = argparse.ArgumentParser()
24
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
25
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
26
  parser.add_argument('--adaface_ckpt_path', type=str,
27
- default='models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt')
28
- parser.add_argument('--model_style_type', type=str, default='realistic',
29
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
30
- parser.add_argument("--guidance_scale", type=float, default=6.0,
31
  help="The guidance scale for the diffusion model. Default: 8.0")
32
- parser.add_argument("--do_neg_id_prompt_weight", type=float, default=0,
33
- help="The weight of added ID prompt embeddings into the negative prompt. Default: 0, disabled.")
34
 
35
  parser.add_argument('--gpu', type=int, default=None)
36
  parser.add_argument('--ip', type=str, default="0.0.0.0")
@@ -41,20 +39,38 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
41
  seed = random.randint(0, MAX_SEED)
42
  return seed
43
 
 
 
 
 
 
 
 
 
 
44
  # model = load_model()
45
  # This FaceAnalysis is just to crop the face areas from the uploaded images,
46
  # and is independent of the adaface FaceAnalysis apps.
47
- app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
48
  app.prepare(ctx_id=0, det_size=(320, 320))
49
- device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
 
 
 
 
 
 
 
 
 
50
 
51
  global adaface, id_animator
52
 
53
- base_model_path = model_style_type2base_model_path[args.model_style_type]
54
  id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
55
- adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
56
  adaface_encoder_types=args.adaface_encoder_types,
57
- adaface_ckpt_paths=[args.adaface_ckpt_path], device='cpu')
58
 
59
  basedir = os.getcwd()
60
  savedir = os.path.join(basedir,'samples')
@@ -79,7 +95,7 @@ def get_clicked_image(data: gr.SelectData):
79
  return data.index
80
 
81
  @spaces.GPU
82
- def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prompt_weight, out_image_count=4):
83
  if uploaded_image_paths is None:
84
  print("No image uploaded")
85
  return None, None, None
@@ -92,9 +108,11 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
92
  # [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
93
  # Extract the file paths.
94
  uploaded_image_paths = [path[0] for path in uploaded_image_paths]
95
- adaface_subj_embs = \
96
- adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
97
- update_text_encoder=True)
 
 
98
 
99
  if adaface_subj_embs is None:
100
  raise gr.Error(f"Failed to detect any faces! Please try with other images")
@@ -102,20 +120,22 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
102
  # Generate two images each time for the user to select from.
103
  noise = torch.randn(out_image_count, 3, 512, 512)
104
 
105
- enhance_face = True
106
- if enhance_face and "face portrait" not in prompt:
107
  if "portrait" in prompt:
108
  # Enhance the face features by replacing "portrait" with "face portrait".
109
  prompt = prompt.replace("portrait", "face portrait")
110
  else:
111
  prompt = "face portrait, " + prompt
112
 
 
 
113
  # samples: A list of PIL Image instances.
114
  with torch.no_grad():
115
  samples = adaface(noise, prompt, placeholder_tokens_pos='append',
116
  guidance_scale=guidance_scale,
117
- do_neg_id_prompt_weight=do_neg_id_prompt_weight,
118
- out_image_count=out_image_count, verbose=True)
 
119
 
120
  face_paths = []
121
  for sample in samples:
@@ -131,9 +151,9 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
131
  @spaces.GPU(duration=90)
132
  def generate_video(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
133
  init_image_strength, init_image_final_weight,
134
- prompt, negative_prompt, num_steps, video_length, guidance_scale, do_neg_id_prompt_weight,
135
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
136
- is_adaface_enabled, adaface_ckpt_path, adaface_power_scale,
137
  id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
138
 
139
  global adaface, id_animator
@@ -143,10 +163,17 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
143
  if prompt is None:
144
  prompt = ""
145
 
146
- prompt = prompt + " 8k uhd, high quality"
147
- if " shot" not in prompt:
148
- prompt = prompt + ", medium shot"
149
-
 
 
 
 
 
 
 
150
  prompt_img_lists=[]
151
  for path in uploaded_image_paths:
152
  img = cv2.imread(path)
@@ -158,16 +185,11 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
158
  # prompt_img_lists is a list of PIL images.
159
  prompt_img_lists.append(load_image(face_path).resize((224,224)))
160
 
161
- if adaface is None or not is_adaface_enabled:
162
  adaface_prompt_embeds, negative_prompt_embeds = None, None
 
163
  image_embed_cfg_scales = (1, 1)
164
  else:
165
- if (adaface_ckpt_path is not None and adaface_ckpt_path.strip() != '') \
166
- and (adaface_ckpt_path != args.adaface_ckpt_path):
167
- args.adaface_ckpt_path = adaface_ckpt_path
168
- # Reload the adaface model weights.
169
- adaface.id2ada_prompt_encoder.load_adaface_ckpt(adaface_ckpt_path)
170
-
171
  with torch.no_grad():
172
  adaface_subj_embs = \
173
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
@@ -176,9 +198,10 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
176
  # adaface_prompt_embeds: [1, 77, 768].
177
  adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
178
  adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
179
- do_neg_id_prompt_weight=do_neg_id_prompt_weight,
180
  verbose=True)
181
 
 
182
  image_embed_cfg_scales = (image_embed_cfg_begin_scale, image_embed_cfg_end_scale)
183
 
184
  # init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
@@ -198,8 +221,8 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
198
  prompt = prompt,
199
  negative_prompt = negative_prompt,
200
  adaface_prompt_embeds = (adaface_prompt_embeds, negative_prompt_embeds),
201
- # adaface_power_scale is not so useful, and when it's set >= 2, weird artifacts appear.
202
- # Here it's limited to 0.7~1.3.
203
  adaface_power_scale = adaface_power_scale,
204
  num_inference_steps = num_steps,
205
  id_animator_anneal_steps = id_animator_anneal_steps,
@@ -216,7 +239,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
216
  save_videos_grid(sample, save_sample_path)
217
  return save_sample_path
218
 
219
- def check_prompt_and_model_type(prompt, model_style_type):
220
  global adaface, id_animator
221
 
222
  model_style_type = model_style_type.lower()
@@ -236,21 +259,20 @@ def check_prompt_and_model_type(prompt, model_style_type):
236
  with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
237
  gr.Markdown(
238
  """
239
- # AdaFace-Animate: Zero-Shot Subject-Driven Video Generation for Humans
240
  """
241
  )
242
  gr.Markdown(
243
  """
244
- <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
245
 
246
- ❗️**What's New**❗️
247
- - Support switching between two model styles: **Realistic** and **Anime**.
248
- - If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight.
249
 
250
  ❗️**Tips**❗️
251
  - You can upload one or more subject images for generating ID-specific video.
252
- - If the face dominates the video frames, try increasing the 'Weight of ID prompt in the negative prompt'.
253
- - If the face loses focus, try increasing the guidance scale.
254
  - If the motion is weird, e.g., the prompt is "... running", try increasing the number of sampling steps.
255
  - Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
256
  - AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
@@ -258,8 +280,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
258
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
259
  </a>
260
 
261
- **TODO:**
262
- - ControlNet integration.
263
  """
264
  )
265
 
@@ -270,6 +290,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
270
  file_types=["image"],
271
  file_count="multiple"
272
  )
 
273
  image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
274
  uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
275
  with gr.Column(visible=False) as clear_button_column:
@@ -280,6 +301,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
280
  file_types=["image"],
281
  file_count="multiple"
282
  )
 
283
  init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
284
  # Although there's only one image, we still use columns=3, to scale down the image size.
285
  # Otherwise it will occupy the full width, and the gallery won't show the whole image.
@@ -288,41 +310,47 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
288
  init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
289
 
290
  with gr.Column(visible=True) as init_gen_button_column:
291
- gen_init = gr.Button(value="Generate 3 new init images")
292
  with gr.Column(visible=False) as init_clear_button_column:
293
  remove_init_and_reupload = gr.ClearButton(value="Upload an old init image", components=init_img_files, size="sm")
294
 
295
  prompt = gr.Dropdown(label="Prompt",
296
- info="Try something like 'man/woman walking on the beach'.",
297
- value="((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
298
- allow_custom_value=True,
299
- filterable=False,
300
- choices=[
301
- "((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
302
- "walking on the beach, sunset, orange sky, eye level shot",
303
- "in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot",
304
- "dancing pose among folks in a park, waving hands",
305
- "in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot",
306
- "jedi wielding a lightsaber, star wars, full body view, eye level shot",
307
- "playing guitar on a boat, ocean waves",
308
- "with a passion for reading, curled up with a book in a cozy nook near a window",
309
- #"running pose in a park, full body view, eye level shot",
310
- "in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot"
311
- ])
312
-
 
 
 
 
 
 
313
  init_image_strength = gr.Slider(
314
  label="Init Image Strength",
315
  info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
316
  minimum=0,
317
- maximum=1.5,
318
- step=0.25,
319
  value=1,
320
  )
321
  init_image_final_weight = gr.Slider(
322
- label="Final Weight of the Init Image",
323
  info="How much the init image should influence the end of the video",
324
  minimum=0,
325
- maximum=0.25,
326
  step=0.025,
327
  value=0.1,
328
  )
@@ -331,7 +359,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
331
  label="Base Model Style Type",
332
  info="Switching the base model type will take 10~20 seconds to reload the model",
333
  value=args.model_style_type.capitalize(),
334
- choices=["Realistic", "Anime"], #"Photorealistic"],
335
  allow_custom_value=False,
336
  filterable=False,
337
  )
@@ -344,15 +372,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
344
  value=args.guidance_scale,
345
  )
346
 
347
- do_neg_id_prompt_weight = gr.Slider(
348
- label="Weight of ID prompt in the negative prompt",
349
- minimum=0.0,
350
- maximum=0.9,
351
- step=0.1,
352
- value=args.do_neg_id_prompt_weight,
353
- visible=True
354
- )
355
-
356
  seed = gr.Slider(
357
  label="Seed",
358
  minimum=0,
@@ -365,11 +384,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
365
  value=True,
366
  info="Uncheck for reproducible results")
367
 
368
- negative_prompt = gr.Textbox(
369
- label="Negative Prompt",
370
- placeholder="low quality",
371
- value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream",
372
- )
373
  num_steps = gr.Slider(
374
  label="Number of sampling steps. More steps for better composition, but longer time.",
375
  minimum=30,
@@ -394,36 +408,42 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
394
  is_adaface_enabled = gr.Checkbox(label="Enable AdaFace",
395
  info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).",
396
  value=True)
397
- adaface_ckpt_path = gr.Textbox(
398
- label="AdaFace checkpoint path",
399
- placeholder=args.adaface_ckpt_path,
400
- value=args.adaface_ckpt_path,
401
- )
402
 
403
  adaface_power_scale = gr.Slider(
404
  label="AdaFace Embedding Power Scale",
405
  info="Increase this scale slightly only if the face is defocused or the face details are not clear",
406
  minimum=0.8,
407
  maximum=1.2,
 
 
 
 
 
 
 
 
 
 
408
  step=0.1,
409
  value=1,
 
410
  )
411
 
412
  image_embed_cfg_begin_scale = gr.Slider(
413
  label="ID-Animator Image Embedding Initial Scale",
414
  info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
415
- minimum=0.6,
416
- maximum=1.5,
417
  step=0.1,
418
- value=1.0,
419
  )
420
  image_embed_cfg_end_scale = gr.Slider(
421
  label="ID-Animator Image Embedding Final Scale",
422
  info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
423
- minimum=0.3,
424
- maximum=1.5,
425
  step=0.1,
426
- value=0.5,
427
  )
428
 
429
  id_animator_anneal_steps = gr.Slider(
@@ -431,18 +451,15 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
431
  minimum=0,
432
  maximum=40,
433
  step=1,
434
- value=20,
435
  visible=True,
436
  )
437
 
438
- attn_scale = gr.Slider(
439
- label="ID-Animator Attention Processor Scale",
440
- info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" ,
441
- minimum=0,
442
- maximum=2,
443
- step=0.1,
444
- value=1,
445
- )
446
 
447
  with gr.Column():
448
  result_video = gr.Video(label="Generated Animation", interactive=False)
@@ -462,8 +479,8 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
462
  outputs=seed,
463
  queue=False,
464
  api_name=False,
465
- ).then(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt,
466
- guidance_scale, do_neg_id_prompt_weight],
467
  outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
468
  uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
469
 
@@ -478,9 +495,10 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
478
  fn=generate_video,
479
  inputs=[image_container, files,
480
  init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
481
- prompt, negative_prompt, num_steps, video_length, guidance_scale, do_neg_id_prompt_weight,
482
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
483
- is_adaface_enabled, adaface_ckpt_path, adaface_power_scale, id_animator_anneal_steps],
 
484
  outputs=[result_video]
485
  )
486
 
 
24
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
25
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
26
  parser.add_argument('--adaface_ckpt_path', type=str,
27
+ default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt')
28
+ parser.add_argument('--model_style_type', type=str, default='photorealistic',
29
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
30
+ parser.add_argument("--guidance_scale", type=float, default=8.0,
31
  help="The guidance scale for the diffusion model. Default: 8.0")
 
 
32
 
33
  parser.add_argument('--gpu', type=int, default=None)
34
  parser.add_argument('--ip', type=str, default="0.0.0.0")
 
39
  seed = random.randint(0, MAX_SEED)
40
  return seed
41
 
42
+ def is_running_on_spaces():
43
+ return os.getenv("SPACE_ID") is not None
44
+
45
+ from huggingface_hub import snapshot_download
46
+ large_files = ["models/*", "models/**/*"]
47
+ snapshot_download(repo_id="adaface-neurips/adaface-animate-models",
48
+ repo_type="model", allow_patterns=large_files, local_dir=".")
49
+ os.makedirs("/tmp/gradio", exist_ok=True)
50
+
51
  # model = load_model()
52
  # This FaceAnalysis is just to crop the face areas from the uploaded images,
53
  # and is independent of the adaface FaceAnalysis apps.
54
+ app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider'])
55
  app.prepare(ctx_id=0, det_size=(320, 320))
56
+
57
+ if is_running_on_spaces():
58
+ device = 'cuda:0'
59
+ else:
60
+ if args.gpu is None:
61
+ device = "cuda"
62
+ else:
63
+ device = f"cuda:{args.gpu}"
64
+
65
+ print(f"Device: {device}")
66
 
67
  global adaface, id_animator
68
 
69
+ adaface_base_model_path = model_style_type2base_model_path["photorealistic"]
70
  id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
71
+ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
72
  adaface_encoder_types=args.adaface_encoder_types,
73
+ adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu')
74
 
75
  basedir = os.getcwd()
76
  savedir = os.path.join(basedir,'samples')
 
95
  return data.index
96
 
97
  @spaces.GPU
98
+ def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale, out_image_count=4):
99
  if uploaded_image_paths is None:
100
  print("No image uploaded")
101
  return None, None, None
 
108
  # [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
109
  # Extract the file paths.
110
  uploaded_image_paths = [path[0] for path in uploaded_image_paths]
111
+
112
+ with torch.no_grad():
113
+ adaface_subj_embs = \
114
+ adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
115
+ update_text_encoder=True)
116
 
117
  if adaface_subj_embs is None:
118
  raise gr.Error(f"Failed to detect any faces! Please try with other images")
 
120
  # Generate two images each time for the user to select from.
121
  noise = torch.randn(out_image_count, 3, 512, 512)
122
 
123
+ if highlight_face and "face portrait" not in prompt:
 
124
  if "portrait" in prompt:
125
  # Enhance the face features by replacing "portrait" with "face portrait".
126
  prompt = prompt.replace("portrait", "face portrait")
127
  else:
128
  prompt = "face portrait, " + prompt
129
 
130
+ guidance_scale = min(guidance_scale, 5)
131
+
132
  # samples: A list of PIL Image instances.
133
  with torch.no_grad():
134
  samples = adaface(noise, prompt, placeholder_tokens_pos='append',
135
  guidance_scale=guidance_scale,
136
+ out_image_count=out_image_count,
137
+ repeat_prompt_for_each_encoder=True,
138
+ verbose=True)
139
 
140
  face_paths = []
141
  for sample in samples:
 
151
  @spaces.GPU(duration=90)
152
  def generate_video(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
153
  init_image_strength, init_image_final_weight,
154
+ prompt, negative_prompt, num_steps, video_length, guidance_scale,
155
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
156
+ highlight_face, is_adaface_enabled, adaface_power_scale,
157
  id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
158
 
159
  global adaface, id_animator
 
163
  if prompt is None:
164
  prompt = ""
165
 
166
+ #prompt = prompt + " 8k uhd, high quality"
167
+ #if " shot" not in prompt:
168
+ # prompt = prompt + ", medium shot"
169
+
170
+ if highlight_face and "face portrait" not in prompt:
171
+ if "portrait" in prompt:
172
+ # Enhance the face features by replacing "portrait" with "face portrait".
173
+ prompt = prompt.replace("portrait", "face portrait")
174
+ else:
175
+ prompt = "face portrait, " + prompt
176
+
177
  prompt_img_lists=[]
178
  for path in uploaded_image_paths:
179
  img = cv2.imread(path)
 
185
  # prompt_img_lists is a list of PIL images.
186
  prompt_img_lists.append(load_image(face_path).resize((224,224)))
187
 
188
+ if adaface is None or (not is_adaface_enabled):
189
  adaface_prompt_embeds, negative_prompt_embeds = None, None
190
+ # ID-Animator Image Embedding Initial and End Scales
191
  image_embed_cfg_scales = (1, 1)
192
  else:
 
 
 
 
 
 
193
  with torch.no_grad():
194
  adaface_subj_embs = \
195
  adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
 
198
  # adaface_prompt_embeds: [1, 77, 768].
199
  adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
200
  adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
201
+ repeat_prompt_for_each_encoder=True,
202
  verbose=True)
203
 
204
+ # ID-Animator Image Embedding Initial and End Scales
205
  image_embed_cfg_scales = (image_embed_cfg_begin_scale, image_embed_cfg_end_scale)
206
 
207
  # init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
 
221
  prompt = prompt,
222
  negative_prompt = negative_prompt,
223
  adaface_prompt_embeds = (adaface_prompt_embeds, negative_prompt_embeds),
224
+ # adaface_power_scale is not so useful, and when it's set >= 1.2, weird artifacts appear.
225
+ # Here it's limited to 1~1.1.
226
  adaface_power_scale = adaface_power_scale,
227
  num_inference_steps = num_steps,
228
  id_animator_anneal_steps = id_animator_anneal_steps,
 
239
  save_videos_grid(sample, save_sample_path)
240
  return save_sample_path
241
 
242
+ def check_prompt_and_model_type(prompt, model_style_type, progress=gr.Progress()):
243
  global adaface, id_animator
244
 
245
  model_style_type = model_style_type.lower()
 
259
  with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
260
  gr.Markdown(
261
  """
262
+ # AdaFace-Animate: Zero-Shot Human Subject-Driven Video Generation
263
  """
264
  )
265
  gr.Markdown(
266
  """
267
+ <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
268
 
269
+ ❗️**NOTE**❗️
270
+ - Support switching between three model styles: **Realistic**, **Photorealistic** and **Anime**. **Realistic** is less realistic than **Photorealistic** but has better motions.
271
+ - If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
272
 
273
  ❗️**Tips**❗️
274
  - You can upload one or more subject images for generating ID-specific video.
275
+ - If the face loses focus, try enabling "Highlight face".
 
276
  - If the motion is weird, e.g., the prompt is "... running", try increasing the number of sampling steps.
277
  - Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
278
  - AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
 
280
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
281
  </a>
282
 
 
 
283
  """
284
  )
285
 
 
290
  file_types=["image"],
291
  file_count="multiple"
292
  )
293
+ files.GRADIO_CACHE = "/tmp/gradio"
294
  image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
295
  uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
296
  with gr.Column(visible=False) as clear_button_column:
 
301
  file_types=["image"],
302
  file_count="multiple"
303
  )
304
+ init_img_files.GRADIO_CACHE = "/tmp/gradio"
305
  init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
306
  # Although there's only one image, we still use columns=3, to scale down the image size.
307
  # Otherwise it will occupy the full width, and the gallery won't show the whole image.
 
310
  init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
311
 
312
  with gr.Column(visible=True) as init_gen_button_column:
313
+ gen_init = gr.Button(value="Generate 4 new init images")
314
  with gr.Column(visible=False) as init_clear_button_column:
315
  remove_init_and_reupload = gr.ClearButton(value="Upload an old init image", components=init_img_files, size="sm")
316
 
317
  prompt = gr.Dropdown(label="Prompt",
318
+ info="Try something like 'walking on the beach'.",
319
+ value="highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
320
+ allow_custom_value=True,
321
+ choices=[
322
+ "portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
323
+ "portrait, walking on the beach, sunset",
324
+ "portrait, in a white apron and chef hat, garnishing a gourmet dish",
325
+ "portrait, dancing pose among folks in a park, waving hands",
326
+ "portrait, in iron man costume, the sky ablaze with hues of orange and purple",
327
+ "portrait, jedi wielding a lightsaber, star wars",
328
+ "portrait, night view of tokyo street, neon light",
329
+ "portrait, playing guitar on a boat, ocean waves",
330
+ "portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
331
+ "portrait, celebrating new year, fireworks",
332
+ "portrait, running pose in a park",
333
+ "portrait, in space suit, space helmet, walking on mars",
334
+ "portrait, in superman costume, the sky ablaze with hues of orange and purple"
335
+ ])
336
+
337
+ highlight_face = gr.Checkbox(label="Highlight face", value=False,
338
+ info="Enhance the facial features by prepending 'face portrait' to the prompt",
339
+ visible=True)
340
+
341
  init_image_strength = gr.Slider(
342
  label="Init Image Strength",
343
  info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
344
  minimum=0,
345
+ maximum=3,
346
+ step=0.1,
347
  value=1,
348
  )
349
  init_image_final_weight = gr.Slider(
350
+ label="Final Strength of the Init Image",
351
  info="How much the init image should influence the end of the video",
352
  minimum=0,
353
+ maximum=2,
354
  step=0.025,
355
  value=0.1,
356
  )
 
359
  label="Base Model Style Type",
360
  info="Switching the base model type will take 10~20 seconds to reload the model",
361
  value=args.model_style_type.capitalize(),
362
+ choices=["Realistic", "Anime", "Photorealistic"],
363
  allow_custom_value=False,
364
  filterable=False,
365
  )
 
372
  value=args.guidance_scale,
373
  )
374
 
 
 
 
 
 
 
 
 
 
375
  seed = gr.Slider(
376
  label="Seed",
377
  minimum=0,
 
384
  value=True,
385
  info="Uncheck for reproducible results")
386
 
 
 
 
 
 
387
  num_steps = gr.Slider(
388
  label="Number of sampling steps. More steps for better composition, but longer time.",
389
  minimum=30,
 
408
  is_adaface_enabled = gr.Checkbox(label="Enable AdaFace",
409
  info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).",
410
  value=True)
 
 
 
 
 
411
 
412
  adaface_power_scale = gr.Slider(
413
  label="AdaFace Embedding Power Scale",
414
  info="Increase this scale slightly only if the face is defocused or the face details are not clear",
415
  minimum=0.8,
416
  maximum=1.2,
417
+ step=0.05,
418
+ value=1.1,
419
+ visible=True,
420
+ )
421
+
422
+ attn_scale = gr.Slider(
423
+ label="Attention Processor Scale",
424
+ info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" ,
425
+ minimum=0.5,
426
+ maximum=2,
427
  step=0.1,
428
  value=1,
429
+ visible=True
430
  )
431
 
432
  image_embed_cfg_begin_scale = gr.Slider(
433
  label="ID-Animator Image Embedding Initial Scale",
434
  info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
435
+ minimum=0,
436
+ maximum=1,
437
  step=0.1,
438
+ value=0.5,
439
  )
440
  image_embed_cfg_end_scale = gr.Slider(
441
  label="ID-Animator Image Embedding Final Scale",
442
  info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
443
+ minimum=0,
444
+ maximum=1,
445
  step=0.1,
446
+ value=0.1,
447
  )
448
 
449
  id_animator_anneal_steps = gr.Slider(
 
451
  minimum=0,
452
  maximum=40,
453
  step=1,
454
+ value=40,
455
  visible=True,
456
  )
457
 
458
+ negative_prompt = gr.Textbox(
459
+ label="Negative Prompt",
460
+ placeholder="low quality",
461
+ value="deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream, nude, naked, nsfw, topless, bare breasts",
462
+ )
 
 
 
463
 
464
  with gr.Column():
465
  result_video = gr.Video(label="Generated Animation", interactive=False)
 
479
  outputs=seed,
480
  queue=False,
481
  api_name=False,
482
+ ).then(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, highlight_face,
483
+ guidance_scale],
484
  outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
485
  uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
486
 
 
495
  fn=generate_video,
496
  inputs=[image_container, files,
497
  init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
498
+ prompt, negative_prompt, num_steps, video_length, guidance_scale,
499
  seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
500
+ highlight_face, is_adaface_enabled,
501
+ adaface_power_scale, id_animator_anneal_steps],
502
  outputs=[result_video]
503
  )
504
 
faceadapter/face_adapter.py CHANGED
@@ -315,10 +315,10 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora):
315
  negative_prompt_embeds0 = negative_prompt_embeds_
316
  adaface_prompt_embeds, negative_prompt_embeds_ = adaface_prompt_embeds
317
  # self.torch_type == torch.float16. adaface_prompt_embeds is torch.float32.
318
- prompt_embeds_ = adaface_prompt_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
319
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
320
  if adaface_power_scale != 1.0:
321
- prompt_embeds_ = prompt_embeds_ * adaface_power_scale - negative_prompt_embeds0 * (1 - adaface_power_scale)
322
 
323
  # Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling.
324
  image_prompt_embeds_begin = image_prompt_embeds * image_embed_cfg_scales[0] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[0])
 
315
  negative_prompt_embeds0 = negative_prompt_embeds_
316
  adaface_prompt_embeds, negative_prompt_embeds_ = adaface_prompt_embeds
317
  # self.torch_type == torch.float16. adaface_prompt_embeds is torch.float32.
318
+ prompt_embeds_ = adaface_prompt_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
319
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
320
  if adaface_power_scale != 1.0:
321
+ prompt_embeds_ = prompt_embeds_ * adaface_power_scale + negative_prompt_embeds0 * (1 - adaface_power_scale)
322
 
323
  # Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling.
324
  image_prompt_embeds_begin = image_prompt_embeds * image_embed_cfg_scales[0] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[0])
infer.py CHANGED
@@ -17,7 +17,7 @@ model_style_type2base_model_path = {
17
 
18
  def load_model(model_style_type="realistic", device="cuda"):
19
  inference_config = "inference-v2.yaml"
20
- sd_version = "animatediff/sd"
21
  id_ckpt = "models/animator.ckpt"
22
  image_encoder_path = "models/image_encoder"
23
 
@@ -73,7 +73,7 @@ def load_model(model_style_type="realistic", device="cuda"):
73
 
74
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
75
  # print(vae)
76
- #vae ->to_q,to_k,to_v
77
  # print(converted_vae_checkpoint)
78
  convert_vae_keys = list(converted_vae_checkpoint.keys())
79
  for key in convert_vae_keys:
@@ -93,7 +93,8 @@ def load_model(model_style_type="realistic", device="cuda"):
93
  pipeline.vae.load_state_dict(converted_vae_checkpoint)
94
 
95
  converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
96
- pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
 
97
 
98
  pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
99
 
 
17
 
18
  def load_model(model_style_type="realistic", device="cuda"):
19
  inference_config = "inference-v2.yaml"
20
+ sd_version = "models/animatediff/sd"
21
  id_ckpt = "models/animator.ckpt"
22
  image_encoder_path = "models/image_encoder"
23
 
 
73
 
74
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
75
  # print(vae)
76
+ # vae ->to_q, to_k, to_v
77
  # print(converted_vae_checkpoint)
78
  convert_vae_keys = list(converted_vae_checkpoint.keys())
79
  for key in convert_vae_keys:
 
93
  pipeline.vae.load_state_dict(converted_vae_checkpoint)
94
 
95
  converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
96
+ m, u = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
97
+ print(f"### custom unet missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
98
 
99
  pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
100
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  diffusers==0.29.2
2
- torch
3
  torchvision
4
  imageio
5
  imageio-ffmpeg
 
1
  diffusers==0.29.2
2
+ torch==2.4.1
3
  torchvision
4
  imageio
5
  imageio-ffmpeg