Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8ee7393
1
Parent(s):
76ccb95
update code
Browse files- .gitignore +2 -0
- ConsistentID/app.py +2 -2
- ConsistentID/requirements.txt +1 -1
- README2.md +2 -2
- adaface/__init__.py +0 -0
- adaface/adaface_infer.py +17 -22
- adaface/adaface_translate.py +59 -36
- adaface/adaface_wrapper.py +366 -72
- adaface/diffusers_attn_lora_capture.py +656 -0
- adaface/face_id_to_ada_prompt.py +270 -119
- adaface/subj_basis_generator.py +97 -59
- adaface/unet_teachers.py +86 -49
- adaface/util.py +21 -18
- animatediff/sd/.gitattributes +0 -35
- animatediff/sd/feature_extractor/preprocessor_config.json +0 -20
- animatediff/sd/model_index.json +0 -32
- animatediff/sd/safety_checker/config.json +0 -175
- animatediff/sd/scheduler/scheduler_config.json +0 -13
- animatediff/sd/text_encoder/config.json +0 -25
- animatediff/sd/tokenizer/merges.txt +0 -0
- animatediff/sd/tokenizer/special_tokens_map.json +0 -24
- animatediff/sd/tokenizer/tokenizer_config.json +0 -34
- animatediff/sd/tokenizer/vocab.json +0 -0
- animatediff/sd/unet/config.json +0 -36
- animatediff/sd/v1-inference.yaml +0 -70
- animatediff/sd/vae/config.json +0 -29
- animatediff/utils/convert_from_ckpt.py +1 -1
- app.py +123 -105
- faceadapter/face_adapter.py +2 -2
- infer.py +4 -3
- requirements.txt +1 -1
.gitignore
CHANGED
@@ -6,3 +6,5 @@ gradio_cached_examples/
|
|
6 |
samples/*
|
7 |
samples/
|
8 |
.gradio/certificate.pem
|
|
|
|
|
|
6 |
samples/*
|
7 |
samples/
|
8 |
.gradio/certificate.pem
|
9 |
+
models/*
|
10 |
+
models
|
ConsistentID/app.py
CHANGED
@@ -26,8 +26,8 @@ pipe = ConsistentIDPipeline.from_pretrained(
|
|
26 |
|
27 |
### Load consistentID_model checkpoint
|
28 |
pipe.load_ConsistentID_model(
|
29 |
-
consistentID_weight_path="./models/ConsistentID-v1.bin",
|
30 |
-
bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
|
31 |
)
|
32 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
33 |
pipe = pipe.to(device, torch.float16)
|
|
|
26 |
|
27 |
### Load consistentID_model checkpoint
|
28 |
pipe.load_ConsistentID_model(
|
29 |
+
consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
|
30 |
+
bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
|
31 |
)
|
32 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
33 |
pipe = pipe.to(device, torch.float16)
|
ConsistentID/requirements.txt
CHANGED
@@ -7,7 +7,7 @@ peft
|
|
7 |
opencv-python
|
8 |
insightface
|
9 |
diffusers
|
10 |
-
torch
|
11 |
torchvision
|
12 |
transformers
|
13 |
spaces
|
|
|
7 |
opencv-python
|
8 |
insightface
|
9 |
diffusers
|
10 |
+
torch==2.4.1
|
11 |
torchvision
|
12 |
transformers
|
13 |
spaces
|
README2.md
CHANGED
@@ -187,9 +187,9 @@ To exclude the effects of AdaFace, we generate a subset of videos with AdaFace-A
|
|
187 |
## Installation
|
188 |
|
189 |
### Manually Download Model Checkpoints
|
190 |
-
- Download Stable Diffusion V1.5 into ``animatediff/sd``:
|
191 |
|
192 |
-
``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 animatediff/sd``
|
193 |
- Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
|
194 |
- Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
|
195 |
- Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
|
|
|
187 |
## Installation
|
188 |
|
189 |
### Manually Download Model Checkpoints
|
190 |
+
- Download Stable Diffusion V1.5 into ``models/animatediff/sd``:
|
191 |
|
192 |
+
``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/animatediff/sd``
|
193 |
- Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
|
194 |
- Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
|
195 |
- Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
|
adaface/__init__.py
ADDED
File without changes
|
adaface/adaface_infer.py
CHANGED
@@ -41,42 +41,36 @@ def seed_everything(seed):
|
|
41 |
def parse_args():
|
42 |
parser = argparse.ArgumentParser()
|
43 |
parser.add_argument("--pipeline", type=str, default="text2img",
|
44 |
-
choices=["text2img", "img2img", "text2img3", "flux"],
|
45 |
help="Type of pipeline to use (default: txt2img)")
|
46 |
parser.add_argument("--base_model_path", type=str, default=None,
|
47 |
help="Type of checkpoints to use (default: None, using the official model)")
|
48 |
-
parser.add_argument('--
|
49 |
-
|
50 |
-
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
|
51 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
|
|
|
|
|
|
52 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
53 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
|
54 |
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
|
55 |
parser.add_argument("--main_unet_filepath", type=str, default=None,
|
56 |
help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
|
57 |
parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
|
58 |
-
default=[
|
59 |
help="Extra paths to the checkpoints of the UNet models")
|
60 |
-
parser.add_argument('--
|
61 |
help="Weights for the UNet models")
|
62 |
parser.add_argument("--subject", type=str)
|
63 |
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
64 |
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
65 |
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
66 |
-
parser.add_argument("--
|
67 |
parser.add_argument("--randface", action="store_true")
|
68 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
69 |
help="Guidance scale for the diffusion model")
|
70 |
-
parser.add_argument("--id_cfg_scale", type=float, default=6,
|
71 |
-
help="CFG scale when generating the identity embeddings")
|
72 |
-
|
73 |
-
parser.add_argument("--subject_string",
|
74 |
-
type=str, default="z",
|
75 |
-
help="Subject placeholder string used in prompts to denote the concept.")
|
76 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
77 |
help="Number of images to display in a row in the output grid image.")
|
78 |
-
parser.add_argument("--num_inference_steps", type=int, default=50,
|
79 |
-
help="Number of DDIM inference steps")
|
80 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
81 |
parser.add_argument("--seed", type=int, default=42,
|
82 |
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
@@ -95,16 +89,15 @@ if __name__ == "__main__":
|
|
95 |
|
96 |
if args.pipeline not in ["text2img", "img2img"]:
|
97 |
args.extra_unet_dirpaths = None
|
98 |
-
args.
|
99 |
|
100 |
adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
|
101 |
-
args.adaface_encoder_types, args.
|
102 |
-
args.adaface_encoder_cfg_scales,
|
103 |
-
args.subject_string, args.num_inference_steps,
|
104 |
unet_types=None,
|
105 |
main_unet_filepath=args.main_unet_filepath,
|
106 |
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
107 |
-
|
108 |
|
109 |
if not args.randface:
|
110 |
image_folder = args.subject
|
@@ -143,7 +136,7 @@ if __name__ == "__main__":
|
|
143 |
rand_init_id_embs = torch.randn(1, 512)
|
144 |
|
145 |
init_id_embs = rand_init_id_embs if args.randface else None
|
146 |
-
|
147 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
148 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
149 |
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
@@ -151,5 +144,7 @@ if __name__ == "__main__":
|
|
151 |
adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
|
152 |
perturb_at_stage='img_prompt_emb',
|
153 |
perturb_std=args.perturb_std, update_text_encoder=True)
|
154 |
-
images = adaface(
|
|
|
|
|
155 |
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
|
|
|
41 |
def parse_args():
|
42 |
parser = argparse.ArgumentParser()
|
43 |
parser.add_argument("--pipeline", type=str, default="text2img",
|
44 |
+
choices=["text2img", "text2imgxl", "img2img", "text2img3", "flux"],
|
45 |
help="Type of pipeline to use (default: txt2img)")
|
46 |
parser.add_argument("--base_model_path", type=str, default=None,
|
47 |
help="Type of checkpoints to use (default: None, using the official model)")
|
48 |
+
parser.add_argument('--adaface_ckpt_path', type=str, required=True)
|
49 |
+
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
|
|
50 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
51 |
+
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
52 |
+
choices=["arc2face", "consistentID"],
|
53 |
+
help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
|
54 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
55 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
|
56 |
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
|
57 |
parser.add_argument("--main_unet_filepath", type=str, default=None,
|
58 |
help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path")
|
59 |
parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
|
60 |
+
default=[],
|
61 |
help="Extra paths to the checkpoints of the UNet models")
|
62 |
+
parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
|
63 |
help="Weights for the UNet models")
|
64 |
parser.add_argument("--subject", type=str)
|
65 |
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
66 |
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
67 |
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
68 |
+
parser.add_argument("--perturb_std", type=float, default=0)
|
69 |
parser.add_argument("--randface", action="store_true")
|
70 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
71 |
help="Guidance scale for the diffusion model")
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
73 |
help="Number of images to display in a row in the output grid image.")
|
|
|
|
|
74 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
75 |
parser.add_argument("--seed", type=int, default=42,
|
76 |
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
|
|
89 |
|
90 |
if args.pipeline not in ["text2img", "img2img"]:
|
91 |
args.extra_unet_dirpaths = None
|
92 |
+
args.unet_weights_in_ensemble = None
|
93 |
|
94 |
adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
|
95 |
+
args.adaface_encoder_types, args.adaface_ckpt_path,
|
96 |
+
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
|
|
97 |
unet_types=None,
|
98 |
main_unet_filepath=args.main_unet_filepath,
|
99 |
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
100 |
+
unet_weights_in_ensemble=args.unet_weights_in_ensemble, device=args.device)
|
101 |
|
102 |
if not args.randface:
|
103 |
image_folder = args.subject
|
|
|
136 |
rand_init_id_embs = torch.randn(1, 512)
|
137 |
|
138 |
init_id_embs = rand_init_id_embs if args.randface else None
|
139 |
+
init_noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
|
140 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
141 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
142 |
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
|
|
144 |
adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
|
145 |
perturb_at_stage='img_prompt_emb',
|
146 |
perturb_std=args.perturb_std, update_text_encoder=True)
|
147 |
+
images = adaface(init_noise, args.prompt, None, None,
|
148 |
+
'append', args.guidance_scale,
|
149 |
+
args.out_image_count, verbose=True)
|
150 |
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
|
adaface/adaface_translate.py
CHANGED
@@ -25,21 +25,25 @@ def seed_everything(seed):
|
|
25 |
|
26 |
def parse_args():
|
27 |
parser = argparse.ArgumentParser()
|
28 |
-
parser.add_argument("--base_model_path", type=str, default='models/
|
29 |
-
help="Path to the UNet checkpoint (
|
30 |
-
parser.add_argument('--
|
31 |
-
|
32 |
-
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"],
|
33 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
|
|
|
|
|
|
34 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
35 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
|
36 |
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
|
37 |
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
|
38 |
-
default=[
|
39 |
help="Extra paths to the checkpoints of the UNet models")
|
40 |
-
parser.add_argument('--
|
41 |
help="Weights for the UNet models")
|
42 |
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
|
|
|
|
43 |
# If True, the input folder contains images of mixed subjects.
|
44 |
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
45 |
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
@@ -49,19 +53,14 @@ def parse_args():
|
|
49 |
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
50 |
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
51 |
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
52 |
-
parser.add_argument("--
|
53 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
54 |
help="Guidance scale for the diffusion model")
|
55 |
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
56 |
help="Strength of the reference image in the output image.")
|
57 |
-
parser.add_argument("--subject_string",
|
58 |
-
type=str, default="z",
|
59 |
-
help="Subject placeholder string used in prompts to denote the concept.")
|
60 |
parser.add_argument("--prompt", type=str, default="a person z")
|
61 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
62 |
help="Number of images to display in a row in the output grid image.")
|
63 |
-
parser.add_argument("--num_inference_steps", type=int, default=50,
|
64 |
-
help="Number of DDIM inference steps")
|
65 |
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
66 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
67 |
parser.add_argument("--seed", type=int, default=42,
|
@@ -90,15 +89,16 @@ if __name__ == "__main__":
|
|
90 |
process_index = 0
|
91 |
|
92 |
adaface = AdaFaceWrapper("img2img", args.base_model_path,
|
93 |
-
args.adaface_encoder_types, args.
|
94 |
-
args.adaface_encoder_cfg_scales,
|
95 |
-
args.subject_string, args.num_inference_steps,
|
96 |
unet_types=None,
|
97 |
-
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
|
|
98 |
device=args.device)
|
99 |
|
100 |
in_folder = args.in_folder
|
101 |
if os.path.isfile(in_folder):
|
|
|
102 |
subject_folders = [ os.path.dirname(in_folder) ]
|
103 |
images_by_subject = [[in_folder]]
|
104 |
else:
|
@@ -154,6 +154,24 @@ if __name__ == "__main__":
|
|
154 |
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
155 |
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
158 |
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
159 |
# Otherwise, we use the folder name as the signature of the images.
|
@@ -173,29 +191,32 @@ if __name__ == "__main__":
|
|
173 |
os.makedirs(subject_out_folder)
|
174 |
print(f"Output images will be saved to {subject_out_folder}")
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
192 |
|
193 |
with torch.no_grad():
|
194 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
195 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
196 |
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
197 |
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
198 |
-
out_images = adaface(in_images, args.prompt, None,
|
|
|
|
|
199 |
|
200 |
for img_i, img in enumerate(out_images):
|
201 |
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
@@ -203,9 +224,11 @@ if __name__ == "__main__":
|
|
203 |
copy_i = img_i // len(in_images)
|
204 |
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
205 |
if copy_i == 0:
|
206 |
-
|
207 |
else:
|
208 |
-
|
|
|
|
|
209 |
|
210 |
if args.copy_masks:
|
211 |
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
|
|
25 |
|
26 |
def parse_args():
|
27 |
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--base_model_path", type=str, default='models/sar/sar.safetensors',
|
29 |
+
help="Path to the UNet checkpoint (Default: SAR)")
|
30 |
+
parser.add_argument('--adaface_ckpt_path', type=str, required=True)
|
31 |
+
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
|
|
32 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
33 |
+
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
34 |
+
choices=["arc2face", "consistentID"],
|
35 |
+
help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
|
36 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
37 |
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
|
38 |
help="CFG scales of output embeddings of the ID2Ada prompt encoders")
|
39 |
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
|
40 |
+
default=[],
|
41 |
help="Extra paths to the checkpoints of the UNet models")
|
42 |
+
parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
|
43 |
help="Weights for the UNet models")
|
44 |
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
45 |
+
parser.add_argument("--restore_image", type=str, default=None,
|
46 |
+
help="Path to the image to be restored")
|
47 |
# If True, the input folder contains images of mixed subjects.
|
48 |
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
49 |
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
|
|
53 |
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
54 |
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
55 |
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
56 |
+
parser.add_argument("--perturb_std", type=float, default=0)
|
57 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
58 |
help="Guidance scale for the diffusion model")
|
59 |
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
60 |
help="Strength of the reference image in the output image.")
|
|
|
|
|
|
|
61 |
parser.add_argument("--prompt", type=str, default="a person z")
|
62 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
63 |
help="Number of images to display in a row in the output grid image.")
|
|
|
|
|
64 |
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
65 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
66 |
parser.add_argument("--seed", type=int, default=42,
|
|
|
89 |
process_index = 0
|
90 |
|
91 |
adaface = AdaFaceWrapper("img2img", args.base_model_path,
|
92 |
+
args.adaface_encoder_types, args.adaface_ckpt_path,
|
93 |
+
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
|
|
94 |
unet_types=None,
|
95 |
+
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
96 |
+
unet_weights_in_ensemble=args.unet_weights_in_ensemble,
|
97 |
device=args.device)
|
98 |
|
99 |
in_folder = args.in_folder
|
100 |
if os.path.isfile(in_folder):
|
101 |
+
args.in_folder = os.path.dirname(args.in_folder)
|
102 |
subject_folders = [ os.path.dirname(in_folder) ]
|
103 |
images_by_subject = [[in_folder]]
|
104 |
else:
|
|
|
154 |
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
155 |
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
156 |
|
157 |
+
if args.restore_image is not None:
|
158 |
+
in_images = []
|
159 |
+
for image_path in [args.restore_image]:
|
160 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
161 |
+
# [512, 512, 3] -> [3, 512, 512].
|
162 |
+
image = np.array(image).transpose(2, 0, 1)
|
163 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
164 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
165 |
+
in_images.append(image)
|
166 |
+
|
167 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
168 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
169 |
+
in_images = torch.cat(in_images, dim=0)
|
170 |
+
# in_images: [5, 3, 512, 512].
|
171 |
+
# Normalize the pixel values to [0, 1].
|
172 |
+
in_images = in_images / 255.0
|
173 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
174 |
+
|
175 |
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
176 |
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
177 |
# Otherwise, we use the folder name as the signature of the images.
|
|
|
191 |
os.makedirs(subject_out_folder)
|
192 |
print(f"Output images will be saved to {subject_out_folder}")
|
193 |
|
194 |
+
if args.restore_image is None:
|
195 |
+
in_images = []
|
196 |
+
for image_path in image_paths:
|
197 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
198 |
+
# [512, 512, 3] -> [3, 512, 512].
|
199 |
+
image = np.array(image).transpose(2, 0, 1)
|
200 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
201 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
202 |
+
in_images.append(image)
|
203 |
+
|
204 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
205 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
206 |
+
in_images = torch.cat(in_images, dim=0)
|
207 |
+
# in_images: [5, 3, 512, 512].
|
208 |
+
# Normalize the pixel values to [0, 1].
|
209 |
+
in_images = in_images / 255.0
|
210 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
211 |
|
212 |
with torch.no_grad():
|
213 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
214 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
215 |
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
216 |
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
217 |
+
out_images = adaface(in_images, args.prompt, None, None,
|
218 |
+
'append', args.guidance_scale, num_out_images,
|
219 |
+
ref_img_strength=args.ref_img_strength)
|
220 |
|
221 |
for img_i, img in enumerate(out_images):
|
222 |
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
|
|
224 |
copy_i = img_i // len(in_images)
|
225 |
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
226 |
if copy_i == 0:
|
227 |
+
save_path = os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")
|
228 |
else:
|
229 |
+
save_path = os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")
|
230 |
+
img.save(save_path)
|
231 |
+
print(f"Saved {save_path}")
|
232 |
|
233 |
if args.copy_masks:
|
234 |
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
adaface/adaface_wrapper.py
CHANGED
@@ -8,22 +8,29 @@ from diffusers import (
|
|
8 |
StableDiffusion3Pipeline,
|
9 |
#FluxPipeline,
|
10 |
DDIMScheduler,
|
|
|
|
|
11 |
AutoencoderKL,
|
|
|
12 |
)
|
13 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
14 |
from adaface.util import UNetEnsemble
|
15 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
|
|
16 |
from safetensors.torch import load_file as safetensors_load_file
|
17 |
import re, os
|
18 |
import numpy as np
|
|
|
19 |
|
20 |
class AdaFaceWrapper(nn.Module):
|
21 |
def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
|
22 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
23 |
-
enabled_encoders=None,
|
24 |
-
subject_string='z',
|
25 |
use_840k_vae=False, use_ds_text_encoder=False,
|
26 |
-
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None,
|
|
|
|
|
27 |
device='cuda', is_training=False):
|
28 |
'''
|
29 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
@@ -38,15 +45,23 @@ class AdaFaceWrapper(nn.Module):
|
|
38 |
self.adaface_ckpt_paths = adaface_ckpt_paths
|
39 |
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
40 |
self.enabled_encoders = enabled_encoders
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.subject_string = subject_string
|
|
|
42 |
|
43 |
-
self.
|
|
|
44 |
self.use_840k_vae = use_840k_vae
|
45 |
self.use_ds_text_encoder = use_ds_text_encoder
|
46 |
self.main_unet_filepath = main_unet_filepath
|
47 |
self.unet_types = unet_types
|
48 |
self.extra_unet_dirpaths = extra_unet_dirpaths
|
49 |
-
self.
|
50 |
self.device = device
|
51 |
self.is_training = is_training
|
52 |
|
@@ -62,7 +77,14 @@ class AdaFaceWrapper(nn.Module):
|
|
62 |
self.initialize_pipeline()
|
63 |
# During inference, we never use static image suffix embeddings.
|
64 |
# So num_id_vecs is the length of the returned adaface embeddings for each encoder.
|
65 |
-
self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
self.extend_tokenizer_and_text_encoder()
|
67 |
|
68 |
def to(self, device):
|
@@ -76,7 +98,8 @@ class AdaFaceWrapper(nn.Module):
|
|
76 |
self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
|
77 |
self.adaface_ckpt_paths,
|
78 |
self.adaface_encoder_cfg_scales,
|
79 |
-
self.enabled_encoders
|
|
|
80 |
|
81 |
self.id2ada_prompt_encoder.to(self.device)
|
82 |
print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
|
@@ -118,10 +141,10 @@ class AdaFaceWrapper(nn.Module):
|
|
118 |
|
119 |
if self.base_model_path is None:
|
120 |
base_model_path_dict = {
|
121 |
-
'text2img':
|
122 |
-
'text2imgxl':
|
123 |
-
'text2img3':
|
124 |
-
'flux':
|
125 |
}
|
126 |
self.base_model_path = base_model_path_dict[self.pipeline_name]
|
127 |
|
@@ -137,6 +160,20 @@ class AdaFaceWrapper(nn.Module):
|
|
137 |
safety_checker=None
|
138 |
)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
if self.main_unet_filepath is not None:
|
141 |
print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
|
142 |
ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
|
@@ -147,12 +184,19 @@ class AdaFaceWrapper(nn.Module):
|
|
147 |
|
148 |
if (self.unet_types is not None and len(self.unet_types) > 0) \
|
149 |
or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
|
150 |
-
unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.
|
151 |
device=self.device, torch_dtype=torch.float16)
|
152 |
pipeline.unet = unet_ensemble
|
153 |
|
154 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
if self.use_840k_vae:
|
157 |
pipeline.vae = vae
|
158 |
print("Replaced the VAE with the 840k-step VAE.")
|
@@ -167,19 +211,56 @@ class AdaFaceWrapper(nn.Module):
|
|
167 |
pipeline.vae = None
|
168 |
print("Removed UNet and VAE from the pipeline.")
|
169 |
|
170 |
-
if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"]:
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
pipeline.scheduler = noise_scheduler
|
180 |
-
# Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
|
|
|
181 |
self.pipeline = pipeline.to(self.device)
|
182 |
|
|
|
|
|
|
|
|
|
183 |
def load_unet_from_file(self, unet_path, device=None):
|
184 |
if os.path.isfile(unet_path):
|
185 |
if unet_path.endswith(".safetensors"):
|
@@ -208,7 +289,109 @@ class AdaFaceWrapper(nn.Module):
|
|
208 |
else:
|
209 |
raise ValueError(f"UNet path {unet_path} is not a file.")
|
210 |
return unet_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
def extend_tokenizer_and_text_encoder(self):
|
213 |
if np.sum(self.encoders_num_id_vecs) < 1:
|
214 |
raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
|
@@ -218,6 +401,7 @@ class AdaFaceWrapper(nn.Module):
|
|
218 |
# We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
|
219 |
self.all_placeholder_tokens = []
|
220 |
self.placeholder_tokens_strs = []
|
|
|
221 |
for i in range(len(self.adaface_encoder_types)):
|
222 |
placeholder_tokens = []
|
223 |
for j in range(self.encoders_num_id_vecs[i]):
|
@@ -225,9 +409,11 @@ class AdaFaceWrapper(nn.Module):
|
|
225 |
placeholder_tokens_str = " ".join(placeholder_tokens)
|
226 |
|
227 |
self.all_placeholder_tokens.extend(placeholder_tokens)
|
|
|
228 |
self.placeholder_tokens_strs.append(placeholder_tokens_str)
|
229 |
|
230 |
self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
|
|
|
231 |
# all_null_placeholder_tokens_str: ", , , , ..." (20 times).
|
232 |
# It just contains the commas and spaces with the same length, but no actual tokens.
|
233 |
self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
|
@@ -241,7 +427,7 @@ class AdaFaceWrapper(nn.Module):
|
|
241 |
|
242 |
print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
|
243 |
|
244 |
-
# placeholder_token_ids: [49408, ...,
|
245 |
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
|
246 |
#print("New tokens:", self.placeholder_token_ids)
|
247 |
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
@@ -252,24 +438,49 @@ class AdaFaceWrapper(nn.Module):
|
|
252 |
|
253 |
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
254 |
# subj_embs: [16, 768].
|
255 |
-
def update_text_encoder_subj_embeddings(self, subj_embs):
|
256 |
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
257 |
# token_embeds: [49412, 768]
|
258 |
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
|
|
|
|
|
|
|
|
259 |
with torch.no_grad():
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
def update_prompt(self, prompt, placeholder_tokens_pos='append',
|
|
|
265 |
use_null_placeholders=False):
|
266 |
if prompt is None:
|
267 |
prompt = ""
|
268 |
|
269 |
if use_null_placeholders:
|
270 |
all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
|
|
|
|
|
|
|
271 |
else:
|
272 |
-
all_placeholder_tokens_str = self.
|
273 |
|
274 |
# Delete the subject_string from the prompt.
|
275 |
prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
|
@@ -279,15 +490,29 @@ class AdaFaceWrapper(nn.Module):
|
|
279 |
# When we do joint training, seems both work better if they are appended to the prompt.
|
280 |
# Therefore we simply appended all placeholder_tokens_str's to the prompt.
|
281 |
# NOTE: Prepending them hurts compositional prompts.
|
282 |
-
if
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
else:
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
return prompt
|
290 |
|
|
|
|
|
291 |
# If face_id_embs is None, then it extracts face_id_embs from the images,
|
292 |
# then map them to ada prompt embeddings.
|
293 |
# avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
|
@@ -298,27 +523,29 @@ class AdaFaceWrapper(nn.Module):
|
|
298 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
299 |
perturb_std=0, update_text_encoder=True):
|
300 |
|
301 |
-
all_adaface_subj_embs = \
|
302 |
self.id2ada_prompt_encoder.generate_adaface_embeddings(\
|
303 |
image_paths, face_id_embs=face_id_embs,
|
304 |
img_prompt_embs=None,
|
305 |
avg_at_stage=avg_at_stage,
|
306 |
perturb_at_stage=perturb_at_stage,
|
307 |
perturb_std=perturb_std,
|
308 |
-
enable_static_img_suffix_embs=
|
309 |
|
310 |
if all_adaface_subj_embs is None:
|
311 |
return None
|
312 |
|
|
|
|
|
313 |
if all_adaface_subj_embs.ndim == 4:
|
314 |
-
# [1, 1,
|
315 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
|
316 |
elif all_adaface_subj_embs.ndim == 3:
|
317 |
-
# [1,
|
318 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
|
319 |
|
320 |
if update_text_encoder:
|
321 |
-
self.update_text_encoder_subj_embeddings(all_adaface_subj_embs)
|
322 |
return all_adaface_subj_embs
|
323 |
|
324 |
def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
|
@@ -368,6 +595,7 @@ class AdaFaceWrapper(nn.Module):
|
|
368 |
else:
|
369 |
breakpoint()
|
370 |
else:
|
|
|
371 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
372 |
prompt_embeds_, negative_prompt_embeds_ = \
|
373 |
self.pipeline.encode_prompt(prompt, device=device,
|
@@ -378,9 +606,53 @@ class AdaFaceWrapper(nn.Module):
|
|
378 |
return prompt_embeds_, negative_prompt_embeds_, \
|
379 |
pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
def encode_prompt(self, prompt, negative_prompt=None,
|
382 |
placeholder_tokens_pos='append',
|
383 |
-
|
|
|
|
|
|
|
|
|
384 |
device=None, verbose=False):
|
385 |
if negative_prompt is None:
|
386 |
negative_prompt = self.negative_prompt
|
@@ -389,59 +661,81 @@ class AdaFaceWrapper(nn.Module):
|
|
389 |
device = self.device
|
390 |
|
391 |
plain_prompt = prompt
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
if verbose:
|
394 |
print(f"Subject prompt:\n{prompt}")
|
395 |
|
396 |
-
if do_neg_id_prompt_weight > 0:
|
397 |
-
# Use 'prepend' for the negative prompt, since it's long and we want to make sure
|
398 |
-
# the placeholder tokens are not cut off.
|
399 |
-
negative_prompt0 = negative_prompt
|
400 |
-
negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend')
|
401 |
-
null_negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend',
|
402 |
-
use_null_placeholders=True)
|
403 |
-
''' if verbose:
|
404 |
-
print(f"Negative prompt:\n{negative_prompt}")
|
405 |
-
print(f"Null negative prompt:\n{null_negative_prompt}")
|
406 |
-
|
407 |
-
'''
|
408 |
-
else:
|
409 |
-
null_negative_prompt = None
|
410 |
-
|
411 |
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
412 |
# So we manually move it to GPU here.
|
413 |
self.pipeline.text_encoder.to(device)
|
414 |
|
415 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
416 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
417 |
-
|
418 |
-
if
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
425 |
|
426 |
# ref_img_strength is used only in the img2img pipeline.
|
427 |
-
def forward(self, noise, prompt, negative_prompt=None,
|
428 |
placeholder_tokens_pos='append',
|
429 |
-
do_neg_id_prompt_weight=0,
|
430 |
guidance_scale=6.0, out_image_count=4,
|
431 |
-
ref_img_strength=0.8, generator=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
noise = noise.to(device=self.device, dtype=torch.float16)
|
|
|
|
|
433 |
|
434 |
if negative_prompt is None:
|
435 |
negative_prompt = self.negative_prompt
|
436 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
# Repeat the prompt embeddings for all images in the batch.
|
444 |
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
|
|
445 |
if negative_prompt_embeds_ is not None:
|
446 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
447 |
|
|
|
8 |
StableDiffusion3Pipeline,
|
9 |
#FluxPipeline,
|
10 |
DDIMScheduler,
|
11 |
+
PNDMScheduler,
|
12 |
+
DPMSolverSinglestepScheduler,
|
13 |
AutoencoderKL,
|
14 |
+
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
from adaface.util import UNetEnsemble
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
+
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
21 |
import re, os
|
22 |
import numpy as np
|
23 |
+
from peft.utils.constants import DUMMY_TARGET_MODULES
|
24 |
|
25 |
class AdaFaceWrapper(nn.Module):
|
26 |
def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
+
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
+
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
+
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
+
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
|
34 |
device='cuda', is_training=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
|
|
45 |
self.adaface_ckpt_paths = adaface_ckpt_paths
|
46 |
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
47 |
self.enabled_encoders = enabled_encoders
|
48 |
+
# None, or a list of two bools for two encoders. If None, both are disabled.
|
49 |
+
self.enable_static_img_suffix_embs = enable_static_img_suffix_embs
|
50 |
+
self.unet_uses_attn_lora = unet_uses_attn_lora
|
51 |
+
self.attn_lora_layer_names = attn_lora_layer_names
|
52 |
+
self.q_lora_updates_query = q_lora_updates_query
|
53 |
+
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
+
self.shrink_cross_attn = shrink_cross_attn
|
56 |
|
57 |
+
self.default_scheduler_name = default_scheduler_name
|
58 |
+
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
59 |
self.use_840k_vae = use_840k_vae
|
60 |
self.use_ds_text_encoder = use_ds_text_encoder
|
61 |
self.main_unet_filepath = main_unet_filepath
|
62 |
self.unet_types = unet_types
|
63 |
self.extra_unet_dirpaths = extra_unet_dirpaths
|
64 |
+
self.unet_weights_in_ensemble = unet_weights_in_ensemble
|
65 |
self.device = device
|
66 |
self.is_training = is_training
|
67 |
|
|
|
77 |
self.initialize_pipeline()
|
78 |
# During inference, we never use static image suffix embeddings.
|
79 |
# So num_id_vecs is the length of the returned adaface embeddings for each encoder.
|
80 |
+
self.encoders_num_id_vecs = np.array(self.id2ada_prompt_encoder.encoders_num_id_vecs)
|
81 |
+
self.encoders_num_static_img_suffix_embs = np.array(self.id2ada_prompt_encoder.encoders_num_static_img_suffix_embs)
|
82 |
+
if self.enable_static_img_suffix_embs is not None:
|
83 |
+
assert len(self.enable_static_img_suffix_embs) == len(self.encoders_num_id_vecs)
|
84 |
+
self.encoders_num_static_img_suffix_embs *= np.array(self.enable_static_img_suffix_embs)
|
85 |
+
self.encoders_num_id_vecs += self.encoders_num_static_img_suffix_embs
|
86 |
+
|
87 |
+
self.img_prompt_embs = None
|
88 |
self.extend_tokenizer_and_text_encoder()
|
89 |
|
90 |
def to(self, device):
|
|
|
98 |
self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
|
99 |
self.adaface_ckpt_paths,
|
100 |
self.adaface_encoder_cfg_scales,
|
101 |
+
self.enabled_encoders,
|
102 |
+
num_static_img_suffix_embs=4)
|
103 |
|
104 |
self.id2ada_prompt_encoder.to(self.device)
|
105 |
print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
|
|
|
141 |
|
142 |
if self.base_model_path is None:
|
143 |
base_model_path_dict = {
|
144 |
+
'text2img': 'models/sd15-dste8-vae.safetensors',
|
145 |
+
'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
|
146 |
+
'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
|
147 |
+
'flux': 'black-forest-labs/FLUX.1-schnell',
|
148 |
}
|
149 |
self.base_model_path = base_model_path_dict[self.pipeline_name]
|
150 |
|
|
|
160 |
safety_checker=None
|
161 |
)
|
162 |
|
163 |
+
if self.use_lcm:
|
164 |
+
lcm_path_dict = {
|
165 |
+
'text2img': 'latent-consistency/lcm-lora-sdv1-5',
|
166 |
+
'text2imgxl': 'latent-consistency/lcm-lora-sdxl',
|
167 |
+
}
|
168 |
+
if self.pipeline_name not in lcm_path_dict:
|
169 |
+
raise ValueError(f"Pipeline {self.pipeline_name} does not support LCM.")
|
170 |
+
|
171 |
+
lcm_path = lcm_path_dict[self.pipeline_name]
|
172 |
+
pipeline.load_lora_weights(lcm_path)
|
173 |
+
pipeline.fuse_lora()
|
174 |
+
print(f"Loaded LCM weights from {lcm_path}.")
|
175 |
+
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
176 |
+
|
177 |
if self.main_unet_filepath is not None:
|
178 |
print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
|
179 |
ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
|
|
|
184 |
|
185 |
if (self.unet_types is not None and len(self.unet_types) > 0) \
|
186 |
or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
|
187 |
+
unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights_in_ensemble,
|
188 |
device=self.device, torch_dtype=torch.float16)
|
189 |
pipeline.unet = unet_ensemble
|
190 |
|
191 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
192 |
+
if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
|
193 |
+
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
194 |
+
attn_lora_layer_names=self.attn_lora_layer_names,
|
195 |
+
shrink_cross_attn=self.shrink_cross_attn,
|
196 |
+
q_lora_updates_query=self.q_lora_updates_query)
|
197 |
+
|
198 |
+
pipeline.unet = unet2
|
199 |
+
|
200 |
if self.use_840k_vae:
|
201 |
pipeline.vae = vae
|
202 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
211 |
pipeline.vae = None
|
212 |
print("Removed UNet and VAE from the pipeline.")
|
213 |
|
214 |
+
if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"] and not self.use_lcm:
|
215 |
+
if self.default_scheduler_name == 'ddim':
|
216 |
+
noise_scheduler = DDIMScheduler(
|
217 |
+
num_train_timesteps=1000,
|
218 |
+
beta_start=0.00085,
|
219 |
+
beta_end=0.012,
|
220 |
+
beta_schedule="scaled_linear",
|
221 |
+
clip_sample=False,
|
222 |
+
set_alpha_to_one=False,
|
223 |
+
steps_offset=1,
|
224 |
+
timestep_spacing="leading",
|
225 |
+
rescale_betas_zero_snr=False,
|
226 |
+
)
|
227 |
+
elif self.default_scheduler_name == 'pndm':
|
228 |
+
noise_scheduler = PNDMScheduler(
|
229 |
+
num_train_timesteps=1000,
|
230 |
+
beta_start=0.00085,
|
231 |
+
beta_end=0.012,
|
232 |
+
beta_schedule="scaled_linear",
|
233 |
+
set_alpha_to_one=False,
|
234 |
+
steps_offset=1,
|
235 |
+
timestep_spacing="leading",
|
236 |
+
skip_prk_steps=True,
|
237 |
+
)
|
238 |
+
elif self.default_scheduler_name == 'dpm++':
|
239 |
+
noise_scheduler = DPMSolverSinglestepScheduler(
|
240 |
+
beta_start=0.00085,
|
241 |
+
beta_end=0.012,
|
242 |
+
beta_schedule="scaled_linear",
|
243 |
+
prediction_type="epsilon",
|
244 |
+
num_train_timesteps=1000,
|
245 |
+
trained_betas=None,
|
246 |
+
thresholding=False,
|
247 |
+
algorithm_type="dpmsolver++",
|
248 |
+
solver_type="midpoint",
|
249 |
+
lower_order_final=True,
|
250 |
+
use_karras_sigmas=True,
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
breakpoint()
|
254 |
+
|
255 |
pipeline.scheduler = noise_scheduler
|
256 |
+
# Otherwise, if not use_lcm, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
|
257 |
+
# if use_lcm, pipeline.scheduler == LCMScheduler
|
258 |
self.pipeline = pipeline.to(self.device)
|
259 |
|
260 |
+
def set_adaface_encoder_cfg_scales(self, adaface_encoder_cfg_scales):
|
261 |
+
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
262 |
+
self.id2ada_prompt_encoder.set_out_id_embs_cfg_scale(adaface_encoder_cfg_scales)
|
263 |
+
|
264 |
def load_unet_from_file(self, unet_path, device=None):
|
265 |
if os.path.isfile(unet_path):
|
266 |
if unet_path.endswith(".safetensors"):
|
|
|
289 |
else:
|
290 |
raise ValueError(f"UNet path {unet_path} is not a file.")
|
291 |
return unet_state_dict
|
292 |
+
|
293 |
+
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
294 |
+
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
295 |
+
use_attn_lora=True, use_ffn_lora=False,
|
296 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
297 |
+
shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
|
298 |
+
q_lora_updates_query=False):
|
299 |
+
attn_capture_procs, attn_opt_modules = \
|
300 |
+
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
301 |
+
lora_rank=192, lora_scale_down=8,
|
302 |
+
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
303 |
+
q_lora_updates_query=q_lora_updates_query)
|
304 |
+
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
305 |
+
if use_ffn_lora:
|
306 |
+
target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
307 |
+
else:
|
308 |
+
# A special pattern, "dummy-target-modules" tells PEFT to add loras on NONE of the layers.
|
309 |
+
# We couldn't simply skip PEFT initialization (converting unet to a PEFT model),
|
310 |
+
# otherwise the attn lora layers will cause nan quickly during a fp16 training.
|
311 |
+
target_modules_pat = DUMMY_TARGET_MODULES
|
312 |
+
|
313 |
+
unet, ffn_lora_layers, ffn_opt_modules = \
|
314 |
+
set_up_ffn_loras(unet, target_modules_pat=target_modules_pat, lora_uses_dora=True)
|
315 |
+
|
316 |
+
# self.attn_capture_procs and ffn_lora_layers will be used in set_lora_and_capture_flags().
|
317 |
+
self.attn_capture_procs = list(attn_capture_procs.values())
|
318 |
+
self.ffn_lora_layers = list(ffn_lora_layers.values())
|
319 |
+
# Combine attn_opt_modules and ffn_opt_modules into unet_lora_modules.
|
320 |
+
# unet_lora_modules is for optimization and loading/saving.
|
321 |
+
unet_lora_modules = {}
|
322 |
+
# attn_opt_modules and ffn_opt_modules have different depths of keys.
|
323 |
+
# attn_opt_modules:
|
324 |
+
# up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_std_shrink_factor,
|
325 |
+
# up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_to_q_lora_lora_A, ...
|
326 |
+
# ffn_opt_modules:
|
327 |
+
# base_model_model_up_blocks_3_resnets_1_conv1_lora_A, ...
|
328 |
+
# with the prefix 'base_model_model_'. Because ffn_opt_modules are extracted from the peft-wrapped model,
|
329 |
+
# and attn_opt_modules are extracted from the original unet model.
|
330 |
+
# To be compatible with old param keys, we append 'base_model_model_' to the keys of attn_opt_modules.
|
331 |
+
unet_lora_modules.update({ f'base_model_model_{k}': v for k, v in attn_opt_modules.items() })
|
332 |
+
unet_lora_modules.update(ffn_opt_modules)
|
333 |
+
# ParameterDict can contain both Parameter and nn.Module.
|
334 |
+
# TODO: maybe in the future, we couldn't put nn.Module in nn.ParameterDict.
|
335 |
+
self.unet_lora_modules = torch.nn.ParameterDict(unet_lora_modules)
|
336 |
+
|
337 |
+
missing, unexpected = self.unet_lora_modules.load_state_dict(unet_lora_modules_state_dict, strict=False)
|
338 |
+
if len(missing) > 0:
|
339 |
+
print(f"Missing Keys: {missing}")
|
340 |
+
if len(unexpected) > 0:
|
341 |
+
print(f"Unexpected Keys: {unexpected}")
|
342 |
+
|
343 |
+
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
344 |
+
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
345 |
+
|
346 |
+
# If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
|
347 |
+
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
348 |
+
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
349 |
+
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
350 |
+
shrink_cross_attn=shrink_cross_attn)
|
351 |
+
|
352 |
+
return unet
|
353 |
+
|
354 |
+
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
355 |
+
shrink_cross_attn=False, q_lora_updates_query=False):
|
356 |
+
unet_lora_weight_found = False
|
357 |
+
if isinstance(self.adaface_ckpt_paths, str):
|
358 |
+
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
359 |
+
else:
|
360 |
+
adaface_ckpt_paths = self.adaface_ckpt_paths
|
361 |
+
|
362 |
+
for adaface_ckpt_path in adaface_ckpt_paths:
|
363 |
+
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
|
364 |
+
if 'unet_lora_modules' in ckpt_dict:
|
365 |
+
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
366 |
+
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
367 |
+
unet_lora_weight_found = True
|
368 |
+
break
|
369 |
+
|
370 |
+
# Since unet lora weights are not found in the adaface ckpt, we give up on loading unet attn processors.
|
371 |
+
if not unet_lora_weight_found:
|
372 |
+
print(f"LoRA weights not found in {self.adaface_ckpt_paths}.")
|
373 |
+
return unet
|
374 |
|
375 |
+
self.outfeat_capture_blocks = []
|
376 |
+
|
377 |
+
if isinstance(unet, UNetEnsemble):
|
378 |
+
for i, unet_ in enumerate(unet.unets):
|
379 |
+
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
380 |
+
use_attn_lora=use_attn_lora,
|
381 |
+
attn_lora_layer_names=attn_lora_layer_names,
|
382 |
+
shrink_cross_attn=shrink_cross_attn,
|
383 |
+
q_lora_updates_query=q_lora_updates_query)
|
384 |
+
unet.unets[i] = unet_
|
385 |
+
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
386 |
+
else:
|
387 |
+
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
388 |
+
use_attn_lora=use_attn_lora,
|
389 |
+
attn_lora_layer_names=attn_lora_layer_names,
|
390 |
+
shrink_cross_attn=shrink_cross_attn,
|
391 |
+
q_lora_updates_query=q_lora_updates_query)
|
392 |
+
|
393 |
+
return unet
|
394 |
+
|
395 |
def extend_tokenizer_and_text_encoder(self):
|
396 |
if np.sum(self.encoders_num_id_vecs) < 1:
|
397 |
raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
|
|
|
401 |
# We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
|
402 |
self.all_placeholder_tokens = []
|
403 |
self.placeholder_tokens_strs = []
|
404 |
+
self.encoder_placeholder_tokens = []
|
405 |
for i in range(len(self.adaface_encoder_types)):
|
406 |
placeholder_tokens = []
|
407 |
for j in range(self.encoders_num_id_vecs[i]):
|
|
|
409 |
placeholder_tokens_str = " ".join(placeholder_tokens)
|
410 |
|
411 |
self.all_placeholder_tokens.extend(placeholder_tokens)
|
412 |
+
self.encoder_placeholder_tokens.append(placeholder_tokens)
|
413 |
self.placeholder_tokens_strs.append(placeholder_tokens_str)
|
414 |
|
415 |
self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
|
416 |
+
self.updated_tokens_str = self.all_placeholder_tokens_str
|
417 |
# all_null_placeholder_tokens_str: ", , , , ..." (20 times).
|
418 |
# It just contains the commas and spaces with the same length, but no actual tokens.
|
419 |
self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
|
|
|
427 |
|
428 |
print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
|
429 |
|
430 |
+
# placeholder_token_ids: [49408, ..., 49427].
|
431 |
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
|
432 |
#print("New tokens:", self.placeholder_token_ids)
|
433 |
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
|
|
438 |
|
439 |
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
440 |
# subj_embs: [16, 768].
|
441 |
+
def update_text_encoder_subj_embeddings(self, subj_embs, lens_subj_emb_segments):
|
442 |
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
443 |
# token_embeds: [49412, 768]
|
444 |
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
445 |
+
all_encoders_updated_tokens = []
|
446 |
+
all_encoders_updated_token_strs = []
|
447 |
+
idx = 0
|
448 |
+
|
449 |
with torch.no_grad():
|
450 |
+
# sum of lens_subj_emb_segments are probably shorter than self.placeholder_token_ids,
|
451 |
+
# when some static_img_suffix_embs are disabled.
|
452 |
+
for i, encoder_type in enumerate(self.adaface_encoder_types):
|
453 |
+
encoder_updated_tokens = []
|
454 |
+
if (self.enabled_encoders is not None) and (encoder_type not in self.enabled_encoders):
|
455 |
+
idx += lens_subj_emb_segments[i]
|
456 |
+
continue
|
457 |
+
for j in range(lens_subj_emb_segments[i]):
|
458 |
+
placeholder_token = f"{self.subject_string}_{i}_{j}"
|
459 |
+
token_id = self.pipeline.tokenizer.convert_tokens_to_ids(placeholder_token)
|
460 |
+
token_embeds[token_id] = subj_embs[idx]
|
461 |
+
encoder_updated_tokens.append(placeholder_token)
|
462 |
+
idx += 1
|
463 |
+
|
464 |
+
all_encoders_updated_tokens.extend(encoder_updated_tokens)
|
465 |
+
all_encoders_updated_token_strs.append(" ".join(encoder_updated_tokens))
|
466 |
+
|
467 |
+
self.updated_tokens_str = " ".join(all_encoders_updated_token_strs)
|
468 |
+
self.all_encoders_updated_token_strs = all_encoders_updated_token_strs
|
469 |
+
print(f"Updated {len(all_encoders_updated_tokens)} tokens ({self.updated_tokens_str}) in the text encoder.")
|
470 |
|
471 |
def update_prompt(self, prompt, placeholder_tokens_pos='append',
|
472 |
+
repeat_prompt_for_each_encoder=True,
|
473 |
use_null_placeholders=False):
|
474 |
if prompt is None:
|
475 |
prompt = ""
|
476 |
|
477 |
if use_null_placeholders:
|
478 |
all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
|
479 |
+
if not re.search(r"\b(man|woman|person|child|girl|boy)\b", prompt.lower()):
|
480 |
+
all_placeholder_tokens_str = "person " + all_placeholder_tokens_str
|
481 |
+
repeat_prompt_for_each_encoder = False
|
482 |
else:
|
483 |
+
all_placeholder_tokens_str = self.updated_tokens_str
|
484 |
|
485 |
# Delete the subject_string from the prompt.
|
486 |
prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
|
|
|
490 |
# When we do joint training, seems both work better if they are appended to the prompt.
|
491 |
# Therefore we simply appended all placeholder_tokens_str's to the prompt.
|
492 |
# NOTE: Prepending them hurts compositional prompts.
|
493 |
+
if repeat_prompt_for_each_encoder:
|
494 |
+
encoder_prompts = []
|
495 |
+
for encoder_updated_token_strs in self.all_encoders_updated_token_strs:
|
496 |
+
if placeholder_tokens_pos == 'prepend':
|
497 |
+
encoder_prompt = encoder_updated_token_strs + " " + prompt
|
498 |
+
elif placeholder_tokens_pos == 'append':
|
499 |
+
encoder_prompt = prompt + " " + encoder_updated_token_strs
|
500 |
+
else:
|
501 |
+
breakpoint()
|
502 |
+
encoder_prompts.append(encoder_prompt)
|
503 |
+
prompt = ", ".join(encoder_prompts)
|
504 |
else:
|
505 |
+
if placeholder_tokens_pos == 'prepend':
|
506 |
+
prompt = all_placeholder_tokens_str + " " + prompt
|
507 |
+
elif placeholder_tokens_pos == 'append':
|
508 |
+
prompt = prompt + " " + all_placeholder_tokens_str
|
509 |
+
else:
|
510 |
+
breakpoint()
|
511 |
|
512 |
return prompt
|
513 |
|
514 |
+
# NOTE: all_adaface_subj_embs is the input to the CLIP text encoder.
|
515 |
+
# ** DO NOT use it as prompt_embeds in the forward() method.
|
516 |
# If face_id_embs is None, then it extracts face_id_embs from the images,
|
517 |
# then map them to ada prompt embeddings.
|
518 |
# avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
|
|
|
523 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
524 |
perturb_std=0, update_text_encoder=True):
|
525 |
|
526 |
+
all_adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments = \
|
527 |
self.id2ada_prompt_encoder.generate_adaface_embeddings(\
|
528 |
image_paths, face_id_embs=face_id_embs,
|
529 |
img_prompt_embs=None,
|
530 |
avg_at_stage=avg_at_stage,
|
531 |
perturb_at_stage=perturb_at_stage,
|
532 |
perturb_std=perturb_std,
|
533 |
+
enable_static_img_suffix_embs=self.enable_static_img_suffix_embs)
|
534 |
|
535 |
if all_adaface_subj_embs is None:
|
536 |
return None
|
537 |
|
538 |
+
self.img_prompt_embs = img_prompt_embs
|
539 |
+
|
540 |
if all_adaface_subj_embs.ndim == 4:
|
541 |
+
# [1, 1, 20, 768] -> [20, 768]
|
542 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
|
543 |
elif all_adaface_subj_embs.ndim == 3:
|
544 |
+
# [1, 20, 768] -> [20, 768]
|
545 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
|
546 |
|
547 |
if update_text_encoder:
|
548 |
+
self.update_text_encoder_subj_embeddings(all_adaface_subj_embs, lens_subj_emb_segments)
|
549 |
return all_adaface_subj_embs
|
550 |
|
551 |
def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
|
|
|
595 |
else:
|
596 |
breakpoint()
|
597 |
else:
|
598 |
+
# "text2img" and "img2img" pipelines.
|
599 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
600 |
prompt_embeds_, negative_prompt_embeds_ = \
|
601 |
self.pipeline.encode_prompt(prompt, device=device,
|
|
|
606 |
return prompt_embeds_, negative_prompt_embeds_, \
|
607 |
pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
608 |
|
609 |
+
# alt_prompt_embed_type: 'ada-nonmix', 'img'
|
610 |
+
def mix_ada_embs_with_other_embs(self, prompt, prompt_embeds,
|
611 |
+
alt_prompt_embed_type, alt_prompt_emb_weights):
|
612 |
+
# Scan prompt and replace tokens in self.placeholder_token_ids
|
613 |
+
# with the corresponding image embeddings.
|
614 |
+
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
615 |
+
prompt_embeds2 = prompt_embeds.clone()
|
616 |
+
if alt_prompt_embed_type == 'img':
|
617 |
+
if self.img_prompt_embs is None:
|
618 |
+
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
619 |
+
return prompt_embeds
|
620 |
+
# self.img_prompt_embs: [1, 20, 768]
|
621 |
+
repl_embeddings = self.img_prompt_embs
|
622 |
+
elif alt_prompt_embed_type == 'ada-nonmix':
|
623 |
+
repl_embeddings_, _, _, _ = self.encode_prompt(prompt, ablate_prompt_only_placeholders=True,
|
624 |
+
verbose=True)
|
625 |
+
# repl_embeddings_: [1, 77, 768] -> [1, 20, 768]
|
626 |
+
repl_embeddings = repl_embeddings_[:, 1:len(self.all_placeholder_tokens)+1]
|
627 |
+
else:
|
628 |
+
breakpoint()
|
629 |
+
|
630 |
+
repl_tokens = {}
|
631 |
+
for i in range(len(prompt_tokens)):
|
632 |
+
if prompt_tokens[i] in self.all_placeholder_tokens:
|
633 |
+
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
634 |
+
if prompt_tokens[i] in sublist), 0)
|
635 |
+
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
|
636 |
+
prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
|
637 |
+
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
638 |
+
repl_tokens[prompt_tokens[i]] = 1
|
639 |
+
|
640 |
+
repl_token_count = len(repl_tokens)
|
641 |
+
if np.all(np.array(alt_prompt_emb_weights) == 1):
|
642 |
+
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
643 |
+
else:
|
644 |
+
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
645 |
+
|
646 |
+
return prompt_embeds2
|
647 |
+
|
648 |
+
|
649 |
def encode_prompt(self, prompt, negative_prompt=None,
|
650 |
placeholder_tokens_pos='append',
|
651 |
+
ablate_prompt_only_placeholders=False,
|
652 |
+
ablate_prompt_no_placeholders=False,
|
653 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
654 |
+
nonmix_prompt_emb_weight=0,
|
655 |
+
repeat_prompt_for_each_encoder=True,
|
656 |
device=None, verbose=False):
|
657 |
if negative_prompt is None:
|
658 |
negative_prompt = self.negative_prompt
|
|
|
661 |
device = self.device
|
662 |
|
663 |
plain_prompt = prompt
|
664 |
+
if ablate_prompt_only_placeholders:
|
665 |
+
prompt = self.updated_tokens_str
|
666 |
+
else:
|
667 |
+
prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos,
|
668 |
+
repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
|
669 |
+
use_null_placeholders=ablate_prompt_no_placeholders)
|
670 |
+
|
671 |
if verbose:
|
672 |
print(f"Subject prompt:\n{prompt}")
|
673 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
675 |
# So we manually move it to GPU here.
|
676 |
self.pipeline.text_encoder.to(device)
|
677 |
|
678 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
679 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
680 |
+
|
681 |
+
if ablate_prompt_embed_type != 'ada':
|
682 |
+
alt_prompt_embed_type = ablate_prompt_embed_type
|
683 |
+
alt_prompt_emb_weights = (1, 1)
|
684 |
+
elif nonmix_prompt_emb_weight > 0:
|
685 |
+
alt_prompt_embed_type = 'ada-nonmix'
|
686 |
+
alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
|
687 |
+
else:
|
688 |
+
alt_prompt_emb_weights = (0, 0)
|
689 |
+
|
690 |
+
if sum(alt_prompt_emb_weights) > 0:
|
691 |
+
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
692 |
+
alt_prompt_embed_type, alt_prompt_emb_weights)
|
693 |
+
|
694 |
return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
695 |
|
696 |
# ref_img_strength is used only in the img2img pipeline.
|
697 |
+
def forward(self, noise, prompt, prompt_embeds=None, negative_prompt=None,
|
698 |
placeholder_tokens_pos='append',
|
|
|
699 |
guidance_scale=6.0, out_image_count=4,
|
700 |
+
ref_img_strength=0.8, generator=None,
|
701 |
+
ablate_prompt_only_placeholders=False,
|
702 |
+
ablate_prompt_no_placeholders=False,
|
703 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
704 |
+
nonmix_prompt_emb_weight=0,
|
705 |
+
repeat_prompt_for_each_encoder=True,
|
706 |
+
verbose=False):
|
707 |
noise = noise.to(device=self.device, dtype=torch.float16)
|
708 |
+
if self.use_lcm:
|
709 |
+
guidance_scale = 0
|
710 |
|
711 |
if negative_prompt is None:
|
712 |
negative_prompt = self.negative_prompt
|
713 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
714 |
+
if prompt_embeds is None:
|
715 |
+
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
|
716 |
+
negative_pooled_prompt_embeds_ = \
|
717 |
+
self.encode_prompt(prompt, negative_prompt,
|
718 |
+
placeholder_tokens_pos=placeholder_tokens_pos,
|
719 |
+
ablate_prompt_only_placeholders=ablate_prompt_only_placeholders,
|
720 |
+
ablate_prompt_no_placeholders=ablate_prompt_no_placeholders,
|
721 |
+
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
722 |
+
nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
|
723 |
+
repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
|
724 |
+
device=self.device,
|
725 |
+
verbose=verbose)
|
726 |
+
else:
|
727 |
+
if len(prompt_embeds) == 2:
|
728 |
+
prompt_embeds_, negative_prompt_embeds_ = prompt_embeds
|
729 |
+
pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None
|
730 |
+
elif len(prompt_embeds) == 4:
|
731 |
+
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
|
732 |
+
negative_pooled_prompt_embeds_ = prompt_embeds
|
733 |
+
else:
|
734 |
+
breakpoint()
|
735 |
+
|
736 |
# Repeat the prompt embeddings for all images in the batch.
|
737 |
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
738 |
+
|
739 |
if negative_prompt_embeds_ is not None:
|
740 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
741 |
|
adaface/diffusers_attn_lora_capture.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Optional, Tuple, Dict, Any
|
5 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
+
from diffusers.utils import logging, is_torch_version, deprecate
|
7 |
+
from diffusers.utils.torch_utils import fourier_filter
|
8 |
+
# UNet is a diffusers PeftAdapterMixin instance.
|
9 |
+
from diffusers.loaders.peft import PeftAdapterMixin
|
10 |
+
from peft import LoraConfig, get_peft_model
|
11 |
+
import peft.tuners.lora as peft_lora
|
12 |
+
from peft.tuners.lora.dora import DoraLinearLayer
|
13 |
+
from einops import rearrange
|
14 |
+
import math, re
|
15 |
+
import numpy as np
|
16 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
+
|
21 |
+
def dummy_func(*args, **kwargs):
|
22 |
+
pass
|
23 |
+
|
24 |
+
# Revised from RevGrad, by removing the grad negation.
|
25 |
+
class ScaleGrad(torch.autograd.Function):
|
26 |
+
@staticmethod
|
27 |
+
def forward(ctx, input_, alpha_, debug=False):
|
28 |
+
ctx.save_for_backward(alpha_, debug)
|
29 |
+
output = input_
|
30 |
+
if debug:
|
31 |
+
print(f"input: {input_.abs().mean().item()}")
|
32 |
+
return output
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def backward(ctx, grad_output): # pragma: no cover
|
36 |
+
# saved_tensors returns a tuple of tensors.
|
37 |
+
alpha_, debug = ctx.saved_tensors
|
38 |
+
if ctx.needs_input_grad[0]:
|
39 |
+
grad_output2 = grad_output * alpha_
|
40 |
+
if debug:
|
41 |
+
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
42 |
+
else:
|
43 |
+
grad_output2 = None
|
44 |
+
return grad_output2, None, None
|
45 |
+
|
46 |
+
class GradientScaler(nn.Module):
|
47 |
+
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
48 |
+
"""
|
49 |
+
A gradient scaling layer.
|
50 |
+
This layer has no parameters, and simply scales the gradient in the backward pass.
|
51 |
+
"""
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
self._alpha = torch.tensor(alpha, requires_grad=False)
|
55 |
+
self._debug = torch.tensor(debug, requires_grad=False)
|
56 |
+
|
57 |
+
def forward(self, input_):
|
58 |
+
_debug = self._debug if hasattr(self, '_debug') else False
|
59 |
+
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
60 |
+
|
61 |
+
def gen_gradient_scaler(alpha, debug=False):
|
62 |
+
if alpha == 1:
|
63 |
+
return nn.Identity()
|
64 |
+
if alpha > 0:
|
65 |
+
return GradientScaler(alpha, debug=debug)
|
66 |
+
else:
|
67 |
+
assert alpha == 0
|
68 |
+
# Don't use lambda function here, otherwise the object can't be pickled.
|
69 |
+
return torch.detach
|
70 |
+
|
71 |
+
def split_indices_by_instance(indices, as_dict=False):
|
72 |
+
indices_B, indices_N = indices
|
73 |
+
unique_indices_B = torch.unique(indices_B)
|
74 |
+
if not as_dict:
|
75 |
+
indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ]
|
76 |
+
else:
|
77 |
+
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
78 |
+
return indices_by_instance
|
79 |
+
|
80 |
+
# If do_sum, returned emb_attns is 3D. Otherwise 4D.
|
81 |
+
# indices are applied on the first 2 dims of attn_mat.
|
82 |
+
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
|
83 |
+
indices_by_instance = split_indices_by_instance(indices)
|
84 |
+
|
85 |
+
# emb_attns[0]: [1, 9, 8, 64]
|
86 |
+
# 8: 8 attention heads. Last dim 64: number of image tokens.
|
87 |
+
emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
|
88 |
+
if all_token_weights is not None:
|
89 |
+
# all_token_weights: [4, 77].
|
90 |
+
# token_weights_by_instance[0]: [1, 9, 1, 1].
|
91 |
+
token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
|
92 |
+
else:
|
93 |
+
token_weights = [ 1 ] * len(indices_by_instance)
|
94 |
+
|
95 |
+
# Apply token weights.
|
96 |
+
emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
|
97 |
+
|
98 |
+
# sum among K_subj_i subj embeddings -> [1, 8, 64]
|
99 |
+
if do_sum:
|
100 |
+
emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
|
101 |
+
elif do_mean:
|
102 |
+
emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
|
103 |
+
|
104 |
+
emb_attns = torch.cat(emb_attns, dim=0)
|
105 |
+
return emb_attns
|
106 |
+
|
107 |
+
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
108 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
|
109 |
+
shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
|
110 |
+
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
111 |
+
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
112 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
113 |
+
# 1: head (to be broadcasted). L: query length. S: key length.
|
114 |
+
attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype)
|
115 |
+
if is_causal:
|
116 |
+
assert attn_mask is None
|
117 |
+
temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0)
|
118 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
119 |
+
attn_bias.to(query.dtype)
|
120 |
+
|
121 |
+
if attn_mask is not None:
|
122 |
+
if attn_mask.dtype == torch.bool:
|
123 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
124 |
+
else:
|
125 |
+
attn_bias += attn_mask
|
126 |
+
|
127 |
+
if enable_gqa:
|
128 |
+
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
129 |
+
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
130 |
+
|
131 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
132 |
+
|
133 |
+
if shrink_cross_attn:
|
134 |
+
cross_attn_scale = cross_attn_shrink_factor
|
135 |
+
else:
|
136 |
+
cross_attn_scale = 1
|
137 |
+
|
138 |
+
# attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
|
139 |
+
attn_weight += attn_bias
|
140 |
+
attn_score = attn_weight
|
141 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
142 |
+
# NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
|
143 |
+
# But this is intended, as we want to scale down the impact of the subject embeddings
|
144 |
+
# in the computed attention output tensors.
|
145 |
+
attn_weight = attn_weight * cross_attn_scale
|
146 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
147 |
+
output = attn_weight @ value
|
148 |
+
return output, attn_score, attn_weight
|
149 |
+
|
150 |
+
# All layers share the same attention processor instance.
|
151 |
+
class AttnProcessor_LoRA_Capture(nn.Module):
|
152 |
+
r"""
|
153 |
+
Revised from AttnProcessor2_0
|
154 |
+
"""
|
155 |
+
# lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer.
|
156 |
+
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
157 |
+
lora_uses_dora=True, lora_proj_layers=None,
|
158 |
+
lora_rank: int = 192, lora_alpha: float = 16,
|
159 |
+
cross_attn_shrink_factor: float = 0.5,
|
160 |
+
q_lora_updates_query=False, attn_proc_idx=-1):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
self.global_enable_lora = enable_lora
|
164 |
+
self.attn_proc_idx = attn_proc_idx
|
165 |
+
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
166 |
+
# By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
|
167 |
+
self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
|
168 |
+
self.lora_rank = lora_rank
|
169 |
+
self.lora_alpha = lora_alpha
|
170 |
+
self.lora_scale = self.lora_alpha / self.lora_rank
|
171 |
+
self.cross_attn_shrink_factor = cross_attn_shrink_factor
|
172 |
+
self.q_lora_updates_query = q_lora_updates_query
|
173 |
+
|
174 |
+
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
175 |
+
if self.global_enable_lora:
|
176 |
+
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
177 |
+
if lora_layer_name == 'q':
|
178 |
+
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
179 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
180 |
+
elif lora_layer_name == 'k':
|
181 |
+
self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
182 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
183 |
+
elif lora_layer_name == 'v':
|
184 |
+
self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
185 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
186 |
+
elif lora_layer_name == 'out':
|
187 |
+
self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
188 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
189 |
+
|
190 |
+
# LoRA layers can be enabled/disabled dynamically.
|
191 |
+
def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
|
192 |
+
self.capture_ca_activations = capture_ca_activations
|
193 |
+
self.shrink_cross_attn = shrink_cross_attn
|
194 |
+
self.cached_activations = {}
|
195 |
+
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
196 |
+
self.enable_lora = enable_lora and self.global_enable_lora
|
197 |
+
|
198 |
+
def __call__(
|
199 |
+
self,
|
200 |
+
attn: Attention,
|
201 |
+
hidden_states: torch.Tensor,
|
202 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
203 |
+
attention_mask: Optional[torch.Tensor] = None,
|
204 |
+
temb: Optional[torch.Tensor] = None,
|
205 |
+
img_mask: Optional[torch.Tensor] = None,
|
206 |
+
subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None,
|
207 |
+
debug: bool = False,
|
208 |
+
*args,
|
209 |
+
**kwargs,
|
210 |
+
) -> torch.Tensor:
|
211 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
212 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
213 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
214 |
+
|
215 |
+
# hidden_states: [1, 4096, 320]
|
216 |
+
residual = hidden_states
|
217 |
+
# attn.spatial_norm is None.
|
218 |
+
if attn.spatial_norm is not None:
|
219 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
220 |
+
|
221 |
+
input_ndim = hidden_states.ndim
|
222 |
+
|
223 |
+
if input_ndim == 4:
|
224 |
+
batch_size, channel, height, width = hidden_states.shape
|
225 |
+
# Collapse the spatial dimensions to a single token dimension.
|
226 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
227 |
+
|
228 |
+
batch_size, sequence_length, _ = (
|
229 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
230 |
+
)
|
231 |
+
|
232 |
+
if attention_mask is not None:
|
233 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
234 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
235 |
+
# (batch, heads, source_length, target_length)
|
236 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
237 |
+
|
238 |
+
if attn.group_norm is not None:
|
239 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
240 |
+
|
241 |
+
query = attn.to_q(hidden_states)
|
242 |
+
# NOTE: there's a inconsistency between q lora and k, v loras.
|
243 |
+
# k, v loras are directly applied to key and value (currently k, v loras are never enabled),
|
244 |
+
# while q lora is applied to query2, and we keep the query unchanged.
|
245 |
+
if self.enable_lora and self.to_q_lora is not None:
|
246 |
+
# query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate
|
247 |
+
# cross attention scores between the latent images of the sc and mc instances.
|
248 |
+
query2 = self.to_q_lora(hidden_states)
|
249 |
+
# If not q_lora_updates_query, only query2 will be impacted by the LoRA layer.
|
250 |
+
# The query, and thus the attention score and attn_out, will be the same
|
251 |
+
# as the original ones.
|
252 |
+
if self.q_lora_updates_query:
|
253 |
+
query = query2
|
254 |
+
else:
|
255 |
+
query2 = query
|
256 |
+
|
257 |
+
scale = 1 / math.sqrt(query.size(-1))
|
258 |
+
|
259 |
+
is_cross_attn = (encoder_hidden_states is not None)
|
260 |
+
if (not is_cross_attn) and (img_mask is not None):
|
261 |
+
# NOTE: we assume the image is square. But this will fail if the image is not square.
|
262 |
+
# hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64]
|
263 |
+
# Scale the mask to the same size as hidden_states.
|
264 |
+
mask_size = int(math.sqrt(hidden_states.shape[-2]))
|
265 |
+
img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest')
|
266 |
+
if (img_mask.sum(dim=(2, 3)) == 0).any():
|
267 |
+
img_mask = None
|
268 |
+
else:
|
269 |
+
# img_mask: [2, 1, 64, 64] -> [2, 4096]
|
270 |
+
img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous()
|
271 |
+
# max_neg_value = -torch.finfo(hidden_states.dtype).max
|
272 |
+
# img_mask: [2, 4096] -> [2, 1, 1, 4096]
|
273 |
+
img_mask = rearrange(img_mask.bool(), 'b j -> b () () j')
|
274 |
+
# attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096].
|
275 |
+
# So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf).
|
276 |
+
# But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine.
|
277 |
+
# attn_score.masked_fill_(~img_mask, max_neg_value)
|
278 |
+
# NOTE: If there's an attention mask, it will be replaced by img_mask.
|
279 |
+
attention_mask = img_mask
|
280 |
+
|
281 |
+
if encoder_hidden_states is None:
|
282 |
+
encoder_hidden_states = hidden_states
|
283 |
+
elif attn.norm_cross:
|
284 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
285 |
+
|
286 |
+
if self.enable_lora and self.to_k_lora is not None:
|
287 |
+
key = self.to_k_lora(encoder_hidden_states)
|
288 |
+
else:
|
289 |
+
key = attn.to_k(encoder_hidden_states)
|
290 |
+
|
291 |
+
if self.enable_lora and self.to_v_lora is not None:
|
292 |
+
value = self.to_v_lora(encoder_hidden_states)
|
293 |
+
else:
|
294 |
+
value = attn.to_v(encoder_hidden_states)
|
295 |
+
|
296 |
+
if attn.norm_q is not None:
|
297 |
+
query = attn.norm_q(query)
|
298 |
+
query2 = attn.norm_q(query2)
|
299 |
+
if attn.norm_k is not None:
|
300 |
+
key = attn.norm_k(key)
|
301 |
+
|
302 |
+
inner_dim = key.shape[-1]
|
303 |
+
head_dim = inner_dim // attn.heads
|
304 |
+
|
305 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
306 |
+
query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
307 |
+
|
308 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
309 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
310 |
+
|
311 |
+
if debug and self.attn_proc_idx >= 0:
|
312 |
+
breakpoint()
|
313 |
+
|
314 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
315 |
+
if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
|
316 |
+
hidden_states, attn_score, attn_prob = \
|
317 |
+
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
318 |
+
dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
|
319 |
+
cross_attn_shrink_factor=self.cross_attn_shrink_factor)
|
320 |
+
else:
|
321 |
+
# Use the faster implementation of scaled_dot_product_attention
|
322 |
+
# when not capturing the activations or suppressing the subject attention.
|
323 |
+
hidden_states = \
|
324 |
+
F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
325 |
+
attn_prob = attn_score = None
|
326 |
+
|
327 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
328 |
+
hidden_states = hidden_states.to(query.dtype)
|
329 |
+
|
330 |
+
# linear proj
|
331 |
+
if self.enable_lora and self.to_out_lora is not None:
|
332 |
+
hidden_states = self.to_out_lora(hidden_states)
|
333 |
+
else:
|
334 |
+
hidden_states = attn.to_out[0](hidden_states)
|
335 |
+
|
336 |
+
# dropout
|
337 |
+
hidden_states = attn.to_out[1](hidden_states)
|
338 |
+
|
339 |
+
if input_ndim == 4:
|
340 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
341 |
+
|
342 |
+
if attn.residual_connection:
|
343 |
+
hidden_states = hidden_states + residual
|
344 |
+
|
345 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
346 |
+
|
347 |
+
if is_cross_attn and self.capture_ca_activations:
|
348 |
+
# cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other.
|
349 |
+
# So sqrt(scale) will scale the product of two qs by scale.
|
350 |
+
# ANCHOR[id=attention_caching]
|
351 |
+
# query: [2, 8, 4096, 40] -> [2, 320, 4096]
|
352 |
+
self.cached_activations['q'] = \
|
353 |
+
rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
354 |
+
self.cached_activations['q2'] = \
|
355 |
+
rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
356 |
+
self.cached_activations['k'] = \
|
357 |
+
rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
358 |
+
self.cached_activations['v'] = \
|
359 |
+
rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
360 |
+
# attn_prob, attn_score: [2, 8, 4096, 77]
|
361 |
+
self.cached_activations['attn'] = attn_prob
|
362 |
+
self.cached_activations['attnscore'] = attn_score
|
363 |
+
# attn_out: [b, n, h * d] -> [b, h * d, n]
|
364 |
+
# [2, 4096, 320] -> [2, 320, 4096].
|
365 |
+
self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous()
|
366 |
+
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
def CrossAttnUpBlock2D_forward_capture(
|
370 |
+
self,
|
371 |
+
hidden_states: torch.Tensor,
|
372 |
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
373 |
+
temb: Optional[torch.Tensor] = None,
|
374 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
375 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
376 |
+
upsample_size: Optional[int] = None,
|
377 |
+
attention_mask: Optional[torch.Tensor] = None,
|
378 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
379 |
+
) -> torch.Tensor:
|
380 |
+
if cross_attention_kwargs is not None:
|
381 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
382 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
383 |
+
|
384 |
+
self.cached_outfeats = {}
|
385 |
+
res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1)
|
386 |
+
capture_outfeats = getattr(self, "capture_outfeats", False)
|
387 |
+
layer_idx = 0
|
388 |
+
res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale)
|
389 |
+
|
390 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
391 |
+
# pop res hidden states
|
392 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
393 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
394 |
+
|
395 |
+
# Scale down the magnitudes of gradients to res_hidden_states
|
396 |
+
# by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs.
|
397 |
+
res_hidden_states = res_grad_scaler(res_hidden_states)
|
398 |
+
|
399 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
400 |
+
|
401 |
+
if self.training and self.gradient_checkpointing:
|
402 |
+
def create_custom_forward(module, return_dict=None):
|
403 |
+
def custom_forward(*inputs):
|
404 |
+
if return_dict is not None:
|
405 |
+
return module(*inputs, return_dict=return_dict)
|
406 |
+
else:
|
407 |
+
return module(*inputs)
|
408 |
+
|
409 |
+
return custom_forward
|
410 |
+
|
411 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
412 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
413 |
+
create_custom_forward(resnet),
|
414 |
+
hidden_states,
|
415 |
+
temb,
|
416 |
+
**ckpt_kwargs,
|
417 |
+
)
|
418 |
+
hidden_states = attn(
|
419 |
+
hidden_states,
|
420 |
+
encoder_hidden_states=encoder_hidden_states,
|
421 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
encoder_attention_mask=encoder_attention_mask,
|
424 |
+
return_dict=False,
|
425 |
+
)[0]
|
426 |
+
else:
|
427 |
+
# resnet: ResnetBlock2D instance.
|
428 |
+
#LINK diffusers.models.resnet.ResnetBlock2D
|
429 |
+
# up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D,
|
430 |
+
# it's not transforming the UNet shortcut features.
|
431 |
+
hidden_states = resnet(hidden_states, temb)
|
432 |
+
hidden_states = attn(
|
433 |
+
hidden_states,
|
434 |
+
encoder_hidden_states=encoder_hidden_states,
|
435 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
436 |
+
attention_mask=attention_mask,
|
437 |
+
encoder_attention_mask=encoder_attention_mask,
|
438 |
+
return_dict=False,
|
439 |
+
)[0]
|
440 |
+
|
441 |
+
if capture_outfeats:
|
442 |
+
self.cached_outfeats[layer_idx] = hidden_states
|
443 |
+
layer_idx += 1
|
444 |
+
|
445 |
+
if self.upsamplers is not None:
|
446 |
+
for upsampler in self.upsamplers:
|
447 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
448 |
+
|
449 |
+
return hidden_states
|
450 |
+
|
451 |
+
|
452 |
+
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
453 |
+
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
454 |
+
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
455 |
+
lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
|
456 |
+
q_lora_updates_query=False):
|
457 |
+
attn_procs = {}
|
458 |
+
attn_capture_procs = {}
|
459 |
+
unet_modules = dict(unet.named_modules())
|
460 |
+
attn_opt_modules = {}
|
461 |
+
|
462 |
+
attn_proc_idx = 0
|
463 |
+
|
464 |
+
for name, attn_proc in unet.attn_processors.items():
|
465 |
+
# Only capture the activations of the last 3 CA layers.
|
466 |
+
if not name.startswith("up_blocks.3"):
|
467 |
+
# Not the last 3 CA layers. Don't enable LoRA or capture activations.
|
468 |
+
# Then the layer falls back to the original attention mechanism.
|
469 |
+
# We still use AttnProcessor_LoRA_Capture, as it can handle img_mask.
|
470 |
+
attn_procs[name] = AttnProcessor_LoRA_Capture(
|
471 |
+
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
|
472 |
+
continue
|
473 |
+
# cross_attention_dim: 768.
|
474 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
475 |
+
if cross_attention_dim is None:
|
476 |
+
# Self attention. Don't enable LoRA or capture activations.
|
477 |
+
# We replace the default attn_proc with AttnProcessor_LoRA_Capture,
|
478 |
+
# so that it can incorporate img_mask into self-attention.
|
479 |
+
attn_procs[name] = AttnProcessor_LoRA_Capture(
|
480 |
+
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
|
481 |
+
continue
|
482 |
+
|
483 |
+
# block_id = 3
|
484 |
+
# hidden_size: 320
|
485 |
+
# hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
486 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' ->
|
487 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q'
|
488 |
+
lora_layer_dict = {}
|
489 |
+
lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"]
|
490 |
+
lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"]
|
491 |
+
lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"]
|
492 |
+
# to_out is a ModuleList(Linear, Dropout).
|
493 |
+
lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0]
|
494 |
+
|
495 |
+
lora_proj_layers = {}
|
496 |
+
# Only apply LoRA to the specified layers.
|
497 |
+
for lora_layer_name in attn_lora_layer_names:
|
498 |
+
lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name]
|
499 |
+
|
500 |
+
attn_capture_proc = AttnProcessor_LoRA_Capture(
|
501 |
+
capture_ca_activations=True, enable_lora=use_attn_lora,
|
502 |
+
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
503 |
+
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
504 |
+
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
505 |
+
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
506 |
+
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
507 |
+
|
508 |
+
attn_proc_idx += 1
|
509 |
+
# attn_procs has to use the original names.
|
510 |
+
attn_procs[name] = attn_capture_proc
|
511 |
+
# ModuleDict doesn't allow "." in the key.
|
512 |
+
name = name.replace(".", "_")
|
513 |
+
attn_capture_procs[name] = attn_capture_proc
|
514 |
+
|
515 |
+
if use_attn_lora:
|
516 |
+
for subname, module in attn_capture_proc.named_modules():
|
517 |
+
if isinstance(module, peft_lora.LoraLayer):
|
518 |
+
# ModuleDict doesn't allow "." in the key.
|
519 |
+
lora_path = name + "_" + subname.replace(".", "_")
|
520 |
+
attn_opt_modules[lora_path + "_lora_A"] = module.lora_A
|
521 |
+
attn_opt_modules[lora_path + "_lora_B"] = module.lora_B
|
522 |
+
# lora_uses_dora is always True, so we don't check it here.
|
523 |
+
attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector
|
524 |
+
# We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer,
|
525 |
+
# so according to the code logic in diffusers/loaders/peft.py,
|
526 |
+
# they will be managed by the diffusers PeftAdapterMixin instance, through the
|
527 |
+
# enable_adapters(), and set_adapter() methods.
|
528 |
+
# Therefore, we disable these calls on module.
|
529 |
+
# disable_adapters() is a property and changing it will cause exceptions.
|
530 |
+
module.enable_adapters = dummy_func
|
531 |
+
module.set_adapter = dummy_func
|
532 |
+
|
533 |
+
unet.set_attn_processor(attn_procs)
|
534 |
+
|
535 |
+
print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.")
|
536 |
+
print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.")
|
537 |
+
return attn_capture_procs, attn_opt_modules
|
538 |
+
|
539 |
+
# NOTE: cross-attn layers are included in the returned lora_modules.
|
540 |
+
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
|
541 |
+
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
542 |
+
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
543 |
+
# Cannot set to conv.+ as it will match added adapter module names, including
|
544 |
+
# up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout
|
545 |
+
if target_modules_pat is not None:
|
546 |
+
peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank,
|
547 |
+
lora_alpha=lora_alpha, lora_dropout=0.1,
|
548 |
+
target_modules=target_modules_pat)
|
549 |
+
|
550 |
+
# UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will
|
551 |
+
# cause weird errors. Instead, we directly use diffusers peft adapter methods.
|
552 |
+
unet.add_adapter(peft_config, "recon_loss")
|
553 |
+
unet.add_adapter(peft_config, "unet_distill")
|
554 |
+
unet.add_adapter(peft_config, "comp_distill")
|
555 |
+
unet.enable_adapters()
|
556 |
+
|
557 |
+
# lora_layers contain both the LoRA A and B matrices, as well as the original layers.
|
558 |
+
# lora_layers are used to set the flag, not used for optimization.
|
559 |
+
# lora_modules contain only the LoRA A and B matrices, so they are used for optimization.
|
560 |
+
# NOTE: lora_modules contain both ffn and cross-attn lora modules.
|
561 |
+
ffn_lora_layers = {}
|
562 |
+
ffn_opt_modules = {}
|
563 |
+
for name, module in unet.named_modules():
|
564 |
+
if isinstance(module, peft_lora.LoraLayer):
|
565 |
+
# We don't want to include cross-attn layers in ffn_lora_layers.
|
566 |
+
if target_modules_pat is not None and re.search(target_modules_pat, name):
|
567 |
+
ffn_lora_layers[name] = module
|
568 |
+
# ModuleDict doesn't allow "." in the key.
|
569 |
+
name = name.replace(".", "_")
|
570 |
+
# Since ModuleDict doesn't allow "." in the key, we manually collect
|
571 |
+
# the LoRA matrices in each module.
|
572 |
+
# NOTE: We cannot put every sub-module of module into lora_modules,
|
573 |
+
# as base_layer is also a sub-module of module, which we shouldn't optimize.
|
574 |
+
# Each value in ffn_opt_modules is a ModuleDict:
|
575 |
+
'''
|
576 |
+
(Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A']
|
577 |
+
ModuleDict(
|
578 |
+
(unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
579 |
+
(recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
580 |
+
)
|
581 |
+
'''
|
582 |
+
ffn_opt_modules[name + "_lora_A"] = module.lora_A
|
583 |
+
ffn_opt_modules[name + "_lora_B"] = module.lora_B
|
584 |
+
if lora_uses_dora:
|
585 |
+
ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector
|
586 |
+
|
587 |
+
print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.")
|
588 |
+
print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.")
|
589 |
+
|
590 |
+
return ffn_lora_layers, ffn_opt_modules
|
591 |
+
|
592 |
+
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
593 |
+
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
594 |
+
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
595 |
+
shrink_cross_attn, res_hidden_states_gradscale):
|
596 |
+
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
597 |
+
for attn_capture_proc in attn_capture_procs:
|
598 |
+
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
|
599 |
+
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
600 |
+
# It contains 3 FFN layers. We want to capture their output features.
|
601 |
+
for block in outfeat_capture_blocks:
|
602 |
+
block.capture_outfeats = capture_ca_activations
|
603 |
+
|
604 |
+
for block in res_hidden_states_gradscale_blocks:
|
605 |
+
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
606 |
+
|
607 |
+
if not use_ffn_lora:
|
608 |
+
unet.disable_adapters()
|
609 |
+
else:
|
610 |
+
# ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'.
|
611 |
+
if ffn_lora_adapter_name is not None:
|
612 |
+
unet.set_adapter(ffn_lora_adapter_name)
|
613 |
+
# NOTE: Don't forget to enable_adapters().
|
614 |
+
# The adapters are not enabled by default after set_adapter().
|
615 |
+
unet.enable_adapters()
|
616 |
+
else:
|
617 |
+
breakpoint()
|
618 |
+
|
619 |
+
# During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False,
|
620 |
+
# which might cause issues during DDP training.
|
621 |
+
# So we restore them to requires_grad=True.
|
622 |
+
# During test, unet_lora_modules will be passed as None, so this block will be skipped.
|
623 |
+
if unet_lora_modules is not None:
|
624 |
+
for param in unet_lora_modules.parameters():
|
625 |
+
param.requires_grad = True
|
626 |
+
|
627 |
+
def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks,
|
628 |
+
captured_layer_indices=[22, 23, 24], out_dtype=torch.float32):
|
629 |
+
captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore',
|
630 |
+
'q', 'q2', 'k', 'v', 'attn_out') }
|
631 |
+
|
632 |
+
if not capture_ca_activations:
|
633 |
+
return captured_activations
|
634 |
+
|
635 |
+
all_cached_outfeats = []
|
636 |
+
for block in outfeat_capture_blocks:
|
637 |
+
all_cached_outfeats.append(block.cached_outfeats)
|
638 |
+
# Clear the capture flag and cached outfeats.
|
639 |
+
block.cached_outfeats = {}
|
640 |
+
block.capture_outfeats = False
|
641 |
+
|
642 |
+
for layer_idx in captured_layer_indices:
|
643 |
+
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
644 |
+
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
645 |
+
internal_idx = layer_idx - 22
|
646 |
+
for k in captured_activations.keys():
|
647 |
+
if k == 'outfeat':
|
648 |
+
# Currently we only capture one block, up_blocks.3. So we hard-code the index 0.
|
649 |
+
captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype)
|
650 |
+
else:
|
651 |
+
# internal_idx is the index of layers in up_blocks.3.
|
652 |
+
# Layers 22, 23 and 24 map to 0, 1 and 2.
|
653 |
+
cached_activations = attn_capture_procs[internal_idx].cached_activations
|
654 |
+
captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype)
|
655 |
+
|
656 |
+
return captured_activations
|
adaface/face_id_to_ada_prompt.py
CHANGED
@@ -53,6 +53,8 @@ class FaceID2AdaPrompt(nn.Module):
|
|
53 |
self.text_to_image_prompt_encoder = None
|
54 |
self.tokenizer = None
|
55 |
self.dtype = kwargs.get('dtype', torch.float16)
|
|
|
|
|
56 |
|
57 |
# Load Img2Ada SubjectBasisGenerator.
|
58 |
self.subject_string = kwargs.get('subject_string', 'z')
|
@@ -73,12 +75,16 @@ class FaceID2AdaPrompt(nn.Module):
|
|
73 |
|
74 |
self.use_clip_embs = False
|
75 |
self.do_contrast_clip_embs_on_bg_features = False
|
|
|
|
|
|
|
|
|
76 |
# num_id_vecs is the output embeddings of the ID2ImgPrompt module.
|
77 |
# If there's no static image suffix embeddings, then num_id_vecs is also
|
78 |
# the number of ada embeddings returned by the subject basis generator.
|
79 |
# num_id_vecs will be set in each derived class.
|
80 |
self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
|
81 |
-
print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings
|
82 |
|
83 |
self.id_img_prompt_max_length = 77
|
84 |
self.face_id_dim = 512
|
@@ -87,36 +93,35 @@ class FaceID2AdaPrompt(nn.Module):
|
|
87 |
self.clip_embedding_dim = 1024
|
88 |
self.output_dim = 768
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list):
|
94 |
-
id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules()
|
95 |
-
for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list):
|
96 |
-
module.load_state_dict(state_dict)
|
97 |
-
print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.')
|
98 |
-
|
99 |
-
# init_subj_basis_generator() can only be called after the derived class is initialized,
|
100 |
-
# when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
|
101 |
-
def init_subj_basis_generator(self):
|
102 |
self.subj_basis_generator = \
|
103 |
-
SubjBasisGenerator(
|
|
|
104 |
num_static_img_suffix_embs = self.num_static_img_suffix_embs,
|
105 |
bg_image_embedding_dim = self.clip_embedding_dim,
|
106 |
output_dim = self.output_dim,
|
107 |
placeholder_is_bg = False,
|
108 |
-
prompt2token_proj_grad_scale = 1,
|
109 |
bg_prompt_translator_has_to_out_proj=False)
|
110 |
|
111 |
def load_adaface_ckpt(self, adaface_ckpt_path):
|
112 |
-
|
|
|
|
|
|
|
113 |
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
114 |
if self.subject_string not in string_to_subj_basis_generator_dict:
|
115 |
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
116 |
breakpoint()
|
117 |
|
118 |
ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
119 |
-
ckpt_subj_basis_generator
|
|
|
|
|
|
|
|
|
|
|
120 |
# Since we directly use the subject basis generator object from the ckpt,
|
121 |
# fixing the number of static image suffix embeddings is much simpler.
|
122 |
# Otherwise if we want to load the subject basis generator from its state_dict,
|
@@ -129,7 +134,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
129 |
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
|
130 |
# Fix missing variables in old ckpt.
|
131 |
ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
|
132 |
-
|
133 |
self.subj_basis_generator.extend_prompt2token_proj_attention(\
|
134 |
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
135 |
ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
|
@@ -155,6 +160,11 @@ class FaceID2AdaPrompt(nn.Module):
|
|
155 |
|
156 |
self.subj_basis_generator.freeze_prompt2token_proj()
|
157 |
|
|
|
|
|
|
|
|
|
|
|
158 |
@torch.no_grad()
|
159 |
def get_clip_neg_features(self, BS):
|
160 |
if self.clip_neg_features is None:
|
@@ -220,6 +230,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
220 |
image_obj, _, _ = pad_image_obj_to_square(image_obj)
|
221 |
image_np = np.array(image_obj.resize(size, Image.NEAREST))
|
222 |
face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
|
|
|
223 |
if len(face_info) > 0:
|
224 |
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
225 |
# id_emb: [512,]
|
@@ -487,12 +498,20 @@ class FaceID2AdaPrompt(nn.Module):
|
|
487 |
# avg_at_stage == ada_prompt_emb usually produces the worst results.
|
488 |
# avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
|
489 |
# p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
|
|
|
490 |
def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
|
491 |
p_dropout=0,
|
492 |
return_zero_embs_for_dropped_encoders=True,
|
493 |
avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
|
494 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
495 |
-
perturb_std=0, enable_static_img_suffix_embs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
|
497 |
img_prompt_avg_at_stage = None
|
498 |
else:
|
@@ -509,7 +528,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
509 |
id_batch_size = len(image_paths)
|
510 |
else:
|
511 |
id_batch_size = 1
|
512 |
-
|
513 |
# faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
|
514 |
# NOTE: If face_id_embs, image_paths and image_objs are all None,
|
515 |
# then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
|
@@ -532,7 +551,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
532 |
verbose=True)
|
533 |
|
534 |
if face_image_count == 0:
|
535 |
-
return None
|
536 |
|
537 |
# No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
|
538 |
elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
|
@@ -545,19 +564,27 @@ class FaceID2AdaPrompt(nn.Module):
|
|
545 |
out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
|
546 |
is_face=True,
|
547 |
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
|
|
|
|
|
|
|
|
548 |
# During training, img_prompt_avg_at_stage is None, and BS >= 1.
|
549 |
# During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
|
550 |
if img_prompt_avg_at_stage is not None:
|
551 |
# adaface_subj_embs: [1, 16, 768] -> [16, 768]
|
552 |
adaface_subj_embs = adaface_subj_embs.squeeze(0)
|
553 |
|
554 |
-
return adaface_subj_embs
|
555 |
|
556 |
class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
|
|
|
|
560 |
|
|
|
561 |
super().__init__(*args, **kwargs)
|
562 |
|
563 |
self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
|
@@ -576,14 +603,11 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
576 |
'''
|
577 |
# Use the same model as ID2AdaPrompt does.
|
578 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
579 |
-
# Note there
|
580 |
-
# Note DON'T use CUDAExecutionProvider, as it will hang DDP training.
|
581 |
-
# Seems when loading insightface onto the GPU, it will only reside on the first GPU.
|
582 |
-
# Then the process on the second GPU has issue to communicate with insightface on the first GPU, causing hanging.
|
583 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
584 |
providers=['CPUExecutionProvider'])
|
585 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
586 |
-
print(f'Face encoder loaded on CPU.')
|
587 |
|
588 |
self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
|
589 |
'models/arc2face', subfolder="encoder",
|
@@ -594,21 +618,58 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
594 |
if self.out_id_embs_cfg_scale == -1:
|
595 |
self.out_id_embs_cfg_scale = 1
|
596 |
#### Arc2Face pipeline specific configs ####
|
597 |
-
self.gen_neg_img_prompt
|
598 |
# bg CLIP features are used by the bg subject basis generator.
|
599 |
-
self.use_clip_embs
|
600 |
self.do_contrast_clip_embs_on_bg_features = True
|
601 |
# self.num_static_img_suffix_embs is initialized in the parent class.
|
602 |
-
self.id_img_prompt_max_length
|
603 |
-
self.clip_embedding_dim
|
604 |
|
605 |
-
self.
|
606 |
if self.adaface_ckpt_path is not None:
|
607 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
608 |
|
609 |
-
|
610 |
-
|
|
|
|
|
|
|
|
|
611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
# Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
|
613 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
614 |
clip_features=None,
|
@@ -656,16 +717,17 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
656 |
# [N, 22, 768] -> [N, 16, 768]
|
657 |
return prompt_embeds[:, 4:20]
|
658 |
|
659 |
-
def get_id2img_learnable_modules(self):
|
660 |
-
return [ self.text_to_image_prompt_encoder ]
|
661 |
-
|
662 |
# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
|
663 |
class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
|
665 |
*args, **kwargs):
|
666 |
-
|
667 |
-
self.num_id_vecs = 4
|
668 |
-
|
669 |
super().__init__(*args, **kwargs)
|
670 |
if pipe is None:
|
671 |
# The base_model_path is kind of arbitrary, as the UNet and VAE in the model
|
@@ -712,13 +774,51 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
712 |
self.clip_embedding_dim = 1280
|
713 |
self.s_scale = 1.0
|
714 |
self.shortcut = False
|
715 |
-
|
716 |
-
self.
|
717 |
if self.adaface_ckpt_path is not None:
|
718 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
720 |
print(f"{self.name} ada prompt encoder initialized, "
|
721 |
-
f"ID vecs: {self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
|
723 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
724 |
clip_features=None,
|
@@ -757,25 +857,30 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
757 |
|
758 |
return global_id_embeds
|
759 |
|
760 |
-
def get_id2img_learnable_modules(self):
|
761 |
-
return [ self.image_proj_model ]
|
762 |
-
|
763 |
# A wrapper for combining multiple FaceID2AdaPrompt instances.
|
764 |
class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
765 |
def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
|
766 |
out_id_embs_cfg_scales=None, enabled_encoders=None,
|
767 |
*args, **kwargs):
|
768 |
self.name = 'jointIDs'
|
|
|
769 |
assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
|
770 |
-
|
771 |
-
|
|
|
|
|
772 |
for encoder_type in adaface_encoder_types ]
|
773 |
-
self.
|
|
|
|
|
|
|
|
|
774 |
super().__init__(*args, **kwargs)
|
775 |
|
776 |
self.num_sub_encoders = len(adaface_encoder_types)
|
777 |
self.id2ada_prompt_encoders = nn.ModuleList()
|
778 |
self.encoders_num_static_img_suffix_embs = []
|
|
|
779 |
|
780 |
# TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
|
781 |
# Now they are just placeholders.
|
@@ -785,10 +890,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
785 |
self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
|
786 |
else:
|
787 |
# Do not normalize the weights, and just use them as is.
|
788 |
-
self.out_id_embs_cfg_scales = out_id_embs_cfg_scales
|
789 |
|
790 |
# Note we don't pass the adaface_ckpt_paths to the base class, but instead,
|
791 |
# we load them once and for all in self.load_adaface_ckpt().
|
|
|
|
|
792 |
for i, encoder_type in enumerate(adaface_encoder_types):
|
793 |
kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
|
794 |
if encoder_type == 'arc2face':
|
@@ -797,8 +904,10 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
797 |
encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
|
798 |
else:
|
799 |
breakpoint()
|
|
|
800 |
self.id2ada_prompt_encoders.append(encoder)
|
801 |
self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
|
|
|
802 |
|
803 |
self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
|
804 |
# No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
|
@@ -814,6 +923,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
814 |
# Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders.
|
815 |
self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders]
|
816 |
self.clip_embedding_dim = sum(self.clip_embedding_dims)
|
|
|
817 |
# The ctors of the derived classes have already initialized encoder.subj_basis_generator.
|
818 |
# If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders.
|
819 |
# This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead,
|
@@ -821,12 +931,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
821 |
self.subj_basis_generator = \
|
822 |
nn.ModuleList( [encoder.subj_basis_generator for encoder \
|
823 |
in self.id2ada_prompt_encoders] )
|
824 |
-
|
|
|
825 |
if adaface_ckpt_paths is not None:
|
826 |
self.load_adaface_ckpt(adaface_ckpt_paths)
|
827 |
-
|
828 |
print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
|
829 |
-
f"ID vecs: {self.
|
830 |
|
831 |
if enabled_encoders is not None:
|
832 |
self.are_encoders_enabled = \
|
@@ -842,66 +953,79 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
842 |
else:
|
843 |
self.are_encoders_enabled = \
|
844 |
torch.tensor([True] * self.num_sub_encoders)
|
845 |
-
|
846 |
def load_adaface_ckpt(self, adaface_ckpt_paths):
|
847 |
-
# If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
|
848 |
-
# so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
|
849 |
if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
850 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
851 |
adaface_ckpt_paths = adaface_ckpt_paths[0]
|
852 |
-
|
853 |
-
if isinstance(adaface_ckpt_paths, str):
|
854 |
-
# This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
|
855 |
-
# the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
|
856 |
-
# Therefore, no need to patch missing variables.
|
857 |
-
ckpt = torch.load(adaface_ckpt_paths, map_location='cpu')
|
858 |
-
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
859 |
-
if self.subject_string not in string_to_subj_basis_generator_dict:
|
860 |
-
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
861 |
breakpoint()
|
862 |
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
assert subj_basis_generator.prompt2token_proj_attention_multipliers \
|
879 |
-
== ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
|
880 |
-
"Inconsistent prompt2token_proj_attention_multipliers."
|
881 |
-
subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
|
882 |
-
|
883 |
-
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
884 |
-
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
885 |
-
# If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
|
886 |
-
# extend subj_basis_generator again.
|
887 |
-
if self.extend_prompt2token_proj_attention_multiplier > 1:
|
888 |
-
# During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
|
889 |
-
# During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
|
890 |
-
# During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
|
891 |
-
subj_basis_generator.extend_prompt2token_proj_attention(\
|
892 |
-
None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
|
893 |
-
perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
|
894 |
-
|
895 |
-
subj_basis_generator.freeze_prompt2token_proj()
|
896 |
-
|
897 |
-
print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
|
898 |
-
|
899 |
-
elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
900 |
-
for i, ckpt_path in enumerate(adaface_ckpt_paths):
|
901 |
-
self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
|
902 |
-
else:
|
903 |
breakpoint()
|
904 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
905 |
def extract_init_id_embeds_from_images(self, *args, **kwargs):
|
906 |
total_faceless_img_count = 0
|
907 |
all_id_embs = []
|
@@ -1039,7 +1163,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1039 |
|
1040 |
N_ID = self.encoders_num_id_vecs[i]
|
1041 |
if all_pos_prompt_embs[i] is None:
|
1042 |
-
# Both pos_prompt_embs and neg_prompt_embs have N_ID ==
|
1043 |
all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
1044 |
if all_neg_prompt_embs[i] is None:
|
1045 |
all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
@@ -1061,6 +1185,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1061 |
# So its .device is the device of its parameters.
|
1062 |
device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
|
1063 |
is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
BS = -1
|
1065 |
|
1066 |
if face_id_embs is not None:
|
@@ -1068,13 +1199,17 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1068 |
all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
|
1069 |
else:
|
1070 |
all_face_id_embs = [None] * self.num_sub_encoders
|
|
|
1071 |
if img_prompt_embs is not None:
|
1072 |
BS = img_prompt_embs.shape[0] if BS == -1 else BS
|
1073 |
-
if img_prompt_embs.shape[1] != self.
|
1074 |
breakpoint()
|
1075 |
-
all_img_prompt_embs = img_prompt_embs.split(self.
|
|
|
1076 |
else:
|
1077 |
all_img_prompt_embs = [None] * self.num_sub_encoders
|
|
|
|
|
1078 |
if image_paths is not None:
|
1079 |
BS = len(image_paths) if BS == -1 else BS
|
1080 |
if BS == -1:
|
@@ -1097,25 +1232,32 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1097 |
else:
|
1098 |
are_encoders_enabled = self.are_encoders_enabled
|
1099 |
|
|
|
1100 |
all_adaface_subj_embs = []
|
1101 |
num_available_id_vecs = 0
|
|
|
1102 |
|
1103 |
for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
|
1104 |
if not are_encoders_enabled[i]:
|
1105 |
adaface_subj_embs = None
|
1106 |
-
print(f"Encoder {id2ada_prompt_encoder.name} is
|
|
|
|
|
1107 |
else:
|
|
|
1108 |
# ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
|
1109 |
# -> each sub-enconder's subj_basis_generator.train().
|
1110 |
# Therefore grad for the following call is enabled.
|
1111 |
-
adaface_subj_embs = \
|
1112 |
id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
|
1113 |
all_face_id_embs[i],
|
1114 |
all_img_prompt_embs[i],
|
1115 |
*args, **kwargs)
|
1116 |
|
1117 |
-
|
1118 |
-
|
|
|
|
|
1119 |
if adaface_subj_embs is None:
|
1120 |
if not return_zero_embs_for_dropped_encoders:
|
1121 |
continue
|
@@ -1126,12 +1268,16 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1126 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1127 |
else:
|
1128 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
|
|
|
|
1129 |
num_available_id_vecs += N_ID
|
1130 |
-
|
|
|
|
|
1131 |
# No faces are found in the images, so return None embeddings.
|
1132 |
# We don't want to return an all-zero embedding, which is useless.
|
1133 |
if num_available_id_vecs == 0:
|
1134 |
-
return None
|
1135 |
|
1136 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1137 |
# during inference, we average across the batch dim.
|
@@ -1141,7 +1287,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1141 |
# all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
|
1142 |
# all_adaface_subj_embs: [BS, 20, 768].
|
1143 |
all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
|
1144 |
-
|
|
|
|
|
|
|
|
|
|
|
1145 |
|
1146 |
|
1147 |
'''
|
|
|
53 |
self.text_to_image_prompt_encoder = None
|
54 |
self.tokenizer = None
|
55 |
self.dtype = kwargs.get('dtype', torch.float16)
|
56 |
+
self.img2txt_dtype = kwargs.get('img2txt_dtype', torch.float16)
|
57 |
+
self.device = torch.device("cpu")
|
58 |
|
59 |
# Load Img2Ada SubjectBasisGenerator.
|
60 |
self.subject_string = kwargs.get('subject_string', 'z')
|
|
|
75 |
|
76 |
self.use_clip_embs = False
|
77 |
self.do_contrast_clip_embs_on_bg_features = False
|
78 |
+
# Override the default setting in derived classes.
|
79 |
+
if 'enable_static_img_suffix_embs' in kwargs:
|
80 |
+
self.default_enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
|
81 |
+
|
82 |
# num_id_vecs is the output embeddings of the ID2ImgPrompt module.
|
83 |
# If there's no static image suffix embeddings, then num_id_vecs is also
|
84 |
# the number of ada embeddings returned by the subject basis generator.
|
85 |
# num_id_vecs will be set in each derived class.
|
86 |
self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
|
87 |
+
print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings + {self.num_static_img_suffix_embs} fixed image embeddings as input.')
|
88 |
|
89 |
self.id_img_prompt_max_length = 77
|
90 |
self.face_id_dim = 512
|
|
|
93 |
self.clip_embedding_dim = 1024
|
94 |
self.output_dim = 768
|
95 |
|
96 |
+
# init_img2txt_projection() can only be called after the derived class is initialized,
|
97 |
+
# when self.num_id_vecs0, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
|
98 |
+
def init_img2txt_projection(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
self.subj_basis_generator = \
|
100 |
+
SubjBasisGenerator(dtype=self.img2txt_dtype,
|
101 |
+
num_id_vecs = self.num_id_vecs0,
|
102 |
num_static_img_suffix_embs = self.num_static_img_suffix_embs,
|
103 |
bg_image_embedding_dim = self.clip_embedding_dim,
|
104 |
output_dim = self.output_dim,
|
105 |
placeholder_is_bg = False,
|
|
|
106 |
bg_prompt_translator_has_to_out_proj=False)
|
107 |
|
108 |
def load_adaface_ckpt(self, adaface_ckpt_path):
|
109 |
+
if isinstance(adaface_ckpt_path, (list, tuple, ListConfig)):
|
110 |
+
adaface_ckpt_path = adaface_ckpt_path[0]
|
111 |
+
|
112 |
+
ckpt = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
|
113 |
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
114 |
if self.subject_string not in string_to_subj_basis_generator_dict:
|
115 |
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
116 |
breakpoint()
|
117 |
|
118 |
ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
119 |
+
if isinstance(ckpt_subj_basis_generator, nn.ModuleList):
|
120 |
+
name2idx = { 'consistentID': 0, 'arc2face': 1 }
|
121 |
+
subj_basis_generator_idx = name2idx[self.name]
|
122 |
+
ckpt_subj_basis_generator = ckpt_subj_basis_generator[subj_basis_generator_idx]
|
123 |
+
|
124 |
+
ckpt_subj_basis_generator.N_ID = self.num_id_vecs0
|
125 |
# Since we directly use the subject basis generator object from the ckpt,
|
126 |
# fixing the number of static image suffix embeddings is much simpler.
|
127 |
# Otherwise if we want to load the subject basis generator from its state_dict,
|
|
|
134 |
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
|
135 |
# Fix missing variables in old ckpt.
|
136 |
ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
|
137 |
+
|
138 |
self.subj_basis_generator.extend_prompt2token_proj_attention(\
|
139 |
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
140 |
ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
|
|
|
160 |
|
161 |
self.subj_basis_generator.freeze_prompt2token_proj()
|
162 |
|
163 |
+
def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scale):
|
164 |
+
if isinstance(out_id_embs_cfg_scale, (list, tuple, ListConfig)):
|
165 |
+
out_id_embs_cfg_scale = out_id_embs_cfg_scale[0]
|
166 |
+
self.out_id_embs_cfg_scale = out_id_embs_cfg_scale
|
167 |
+
|
168 |
@torch.no_grad()
|
169 |
def get_clip_neg_features(self, BS):
|
170 |
if self.clip_neg_features is None:
|
|
|
230 |
image_obj, _, _ = pad_image_obj_to_square(image_obj)
|
231 |
image_np = np.array(image_obj.resize(size, Image.NEAREST))
|
232 |
face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
|
233 |
+
|
234 |
if len(face_info) > 0:
|
235 |
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
236 |
# id_emb: [512,]
|
|
|
498 |
# avg_at_stage == ada_prompt_emb usually produces the worst results.
|
499 |
# avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
|
500 |
# p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
|
501 |
+
# enable_static_img_suffix_embs=None: use the default setting.
|
502 |
def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
|
503 |
p_dropout=0,
|
504 |
return_zero_embs_for_dropped_encoders=True,
|
505 |
avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
|
506 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
507 |
+
perturb_std=0, enable_static_img_suffix_embs=None):
|
508 |
+
|
509 |
+
if enable_static_img_suffix_embs is None:
|
510 |
+
enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
|
511 |
+
|
512 |
+
lens_subj_emb_segments = [ self.num_id_vecs + enable_static_img_suffix_embs \
|
513 |
+
* self.num_static_img_suffix_embs ]
|
514 |
+
|
515 |
if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
|
516 |
img_prompt_avg_at_stage = None
|
517 |
else:
|
|
|
528 |
id_batch_size = len(image_paths)
|
529 |
else:
|
530 |
id_batch_size = 1
|
531 |
+
|
532 |
# faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
|
533 |
# NOTE: If face_id_embs, image_paths and image_objs are all None,
|
534 |
# then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
|
|
|
551 |
verbose=True)
|
552 |
|
553 |
if face_image_count == 0:
|
554 |
+
return None, None, lens_subj_emb_segments
|
555 |
|
556 |
# No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
|
557 |
elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
|
|
|
564 |
out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
|
565 |
is_face=True,
|
566 |
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
567 |
+
|
568 |
+
if self.num_id_vecs < self.num_id_vecs0:
|
569 |
+
adaface_subj_embs = adaface_subj_embs[:, :self.num_id_vecs, :]
|
570 |
+
|
571 |
# During training, img_prompt_avg_at_stage is None, and BS >= 1.
|
572 |
# During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
|
573 |
if img_prompt_avg_at_stage is not None:
|
574 |
# adaface_subj_embs: [1, 16, 768] -> [16, 768]
|
575 |
adaface_subj_embs = adaface_subj_embs.squeeze(0)
|
576 |
|
577 |
+
return adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments
|
578 |
|
579 |
class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
580 |
+
name = 'arc2face'
|
581 |
+
num_id_vecs0 = 16
|
582 |
+
# first 4 are kept, the rest 12 are averaged to another 4.
|
583 |
+
# Then concatenated to [8, 768].
|
584 |
+
num_id_vecs = 16
|
585 |
+
default_enable_static_img_suffix_embs = False
|
586 |
|
587 |
+
def __init__(self, *args, **kwargs):
|
588 |
super().__init__(*args, **kwargs)
|
589 |
|
590 |
self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
|
|
|
603 |
'''
|
604 |
# Use the same model as ID2AdaPrompt does.
|
605 |
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
606 |
+
# Note there are two "models" in the path.
|
|
|
|
|
|
|
607 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
608 |
providers=['CPUExecutionProvider'])
|
609 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
610 |
+
print(f'Arc2Face Face encoder loaded on CPU.')
|
611 |
|
612 |
self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
|
613 |
'models/arc2face', subfolder="encoder",
|
|
|
618 |
if self.out_id_embs_cfg_scale == -1:
|
619 |
self.out_id_embs_cfg_scale = 1
|
620 |
#### Arc2Face pipeline specific configs ####
|
621 |
+
self.gen_neg_img_prompt = False
|
622 |
# bg CLIP features are used by the bg subject basis generator.
|
623 |
+
self.use_clip_embs = True
|
624 |
self.do_contrast_clip_embs_on_bg_features = True
|
625 |
# self.num_static_img_suffix_embs is initialized in the parent class.
|
626 |
+
self.id_img_prompt_max_length = 22
|
627 |
+
self.clip_embedding_dim = 1024
|
628 |
|
629 |
+
self.init_img2txt_projection()
|
630 |
if self.adaface_ckpt_path is not None:
|
631 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
632 |
|
633 |
+
for param in self.clip_image_encoder.parameters():
|
634 |
+
param.requires_grad = False
|
635 |
+
for param in self.text_to_image_prompt_encoder.parameters():
|
636 |
+
param.requires_grad = False
|
637 |
+
for param in self.subj_basis_generator.parameters():
|
638 |
+
param.requires_grad = self.is_training
|
639 |
|
640 |
+
print(f"{self.name} ada prompt encoder initialized, "
|
641 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
|
642 |
+
|
643 |
+
def _apply(self, fn):
|
644 |
+
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
645 |
+
return
|
646 |
+
# A dirty hack to get the device of the model, passed from
|
647 |
+
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
648 |
+
test_tensor = torch.zeros(1) # Create a test tensor
|
649 |
+
transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
|
650 |
+
device = transformed_tensor.device # Get the device of the transformed tensor
|
651 |
+
# No need to reload face_app on the same device.
|
652 |
+
if device == self.device:
|
653 |
+
return
|
654 |
+
|
655 |
+
if str(device) == 'cpu':
|
656 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
657 |
+
providers=['CPUExecutionProvider'])
|
658 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
659 |
+
else:
|
660 |
+
device_id = device.index
|
661 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
662 |
+
providers=['CUDAExecutionProvider'],
|
663 |
+
provider_options=[{"device_id": device_id,
|
664 |
+
"cudnn_conv_algo_search": "HEURISTIC",
|
665 |
+
"gpu_mem_limit": 2 * 1024**3
|
666 |
+
}])
|
667 |
+
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
668 |
+
|
669 |
+
self.device = device
|
670 |
+
print(f'Arc2Face Face encoder reloaded on {device}.')
|
671 |
+
return
|
672 |
+
|
673 |
# Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
|
674 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
675 |
clip_features=None,
|
|
|
717 |
# [N, 22, 768] -> [N, 16, 768]
|
718 |
return prompt_embeds[:, 4:20]
|
719 |
|
|
|
|
|
|
|
720 |
# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
|
721 |
class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
722 |
+
name = 'consistentID'
|
723 |
+
num_id_vecs0 = 4
|
724 |
+
# No compression for ConsistentID.
|
725 |
+
num_id_vecs = 4
|
726 |
+
default_enable_static_img_suffix_embs = False
|
727 |
+
|
728 |
def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
|
729 |
*args, **kwargs):
|
730 |
+
|
|
|
|
|
731 |
super().__init__(*args, **kwargs)
|
732 |
if pipe is None:
|
733 |
# The base_model_path is kind of arbitrary, as the UNet and VAE in the model
|
|
|
774 |
self.clip_embedding_dim = 1280
|
775 |
self.s_scale = 1.0
|
776 |
self.shortcut = False
|
777 |
+
|
778 |
+
self.init_img2txt_projection()
|
779 |
if self.adaface_ckpt_path is not None:
|
780 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
781 |
|
782 |
+
for param in self.clip_image_encoder.parameters():
|
783 |
+
param.requires_grad = False
|
784 |
+
for param in self.image_proj_model.parameters():
|
785 |
+
param.requires_grad = False
|
786 |
+
for param in self.subj_basis_generator.parameters():
|
787 |
+
param.requires_grad = self.is_training
|
788 |
+
|
789 |
print(f"{self.name} ada prompt encoder initialized, "
|
790 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
|
791 |
+
|
792 |
+
def _apply(self, fn):
|
793 |
+
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
794 |
+
return
|
795 |
+
# A dirty hack to get the device of the model, passed from
|
796 |
+
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
797 |
+
test_tensor = torch.zeros(1) # Create a test tensor
|
798 |
+
transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
|
799 |
+
device = transformed_tensor.device # Get the device of the transformed tensor
|
800 |
+
# No need to reload face_app on the same device.
|
801 |
+
if device == self.device:
|
802 |
+
return
|
803 |
+
|
804 |
+
if str(device) == 'cpu':
|
805 |
+
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
806 |
+
providers=['CPUExecutionProvider'])
|
807 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
808 |
+
else:
|
809 |
+
device_id = device.index
|
810 |
+
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
811 |
+
providers=['CUDAExecutionProvider'],
|
812 |
+
provider_options=[{"device_id": device_id,
|
813 |
+
"cudnn_conv_algo_search": "HEURISTIC",
|
814 |
+
"gpu_mem_limit": 2 * 1024**3
|
815 |
+
}])
|
816 |
+
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
817 |
+
|
818 |
+
self.device = device
|
819 |
+
self.pipe.face_app = self.face_app
|
820 |
+
print(f'ConsistentID Face encoder reloaded on {device}.')
|
821 |
+
|
822 |
|
823 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
824 |
clip_features=None,
|
|
|
857 |
|
858 |
return global_id_embeds
|
859 |
|
|
|
|
|
|
|
860 |
# A wrapper for combining multiple FaceID2AdaPrompt instances.
|
861 |
class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
862 |
def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
|
863 |
out_id_embs_cfg_scales=None, enabled_encoders=None,
|
864 |
*args, **kwargs):
|
865 |
self.name = 'jointIDs'
|
866 |
+
name2class = { 'arc2face': Arc2Face_ID2AdaPrompt, 'consistentID': ConsistentID_ID2AdaPrompt }
|
867 |
assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
|
868 |
+
adaface_encoder_types2num_id_vecs0 = { name: name2class[name].num_id_vecs0 for name in adaface_encoder_types }
|
869 |
+
adaface_encoder_types2num_id_vecs = { name: name2class[name].num_id_vecs for name in adaface_encoder_types }
|
870 |
+
# self.num_id_vecs0 is used in the parent class. So we need to initialize it here first.
|
871 |
+
self.encoders_num_id_vecs0 = [ adaface_encoder_types2num_id_vecs0[encoder_type] \
|
872 |
for encoder_type in adaface_encoder_types ]
|
873 |
+
self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
|
874 |
+
for encoder_type in adaface_encoder_types ]
|
875 |
+
self.num_id_vecs0 = sum(self.encoders_num_id_vecs0)
|
876 |
+
self.num_id_vecs = sum(self.encoders_num_id_vecs)
|
877 |
+
# super() sets self.is_training.
|
878 |
super().__init__(*args, **kwargs)
|
879 |
|
880 |
self.num_sub_encoders = len(adaface_encoder_types)
|
881 |
self.id2ada_prompt_encoders = nn.ModuleList()
|
882 |
self.encoders_num_static_img_suffix_embs = []
|
883 |
+
self.default_enable_static_img_suffix_embs = []
|
884 |
|
885 |
# TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
|
886 |
# Now they are just placeholders.
|
|
|
890 |
self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
|
891 |
else:
|
892 |
# Do not normalize the weights, and just use them as is.
|
893 |
+
self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
|
894 |
|
895 |
# Note we don't pass the adaface_ckpt_paths to the base class, but instead,
|
896 |
# we load them once and for all in self.load_adaface_ckpt().
|
897 |
+
# NOTE: during inference, num_static_img_suffix_embs is fixed to be 4 for each encoder.
|
898 |
+
# But we can still disable static_img_suffix_embs by setting enable_static_img_suffix_embs to False.
|
899 |
for i, encoder_type in enumerate(adaface_encoder_types):
|
900 |
kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
|
901 |
if encoder_type == 'arc2face':
|
|
|
904 |
encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
|
905 |
else:
|
906 |
breakpoint()
|
907 |
+
|
908 |
self.id2ada_prompt_encoders.append(encoder)
|
909 |
self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
|
910 |
+
self.default_enable_static_img_suffix_embs.append(encoder.default_enable_static_img_suffix_embs)
|
911 |
|
912 |
self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
|
913 |
# No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
|
|
|
923 |
# Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders.
|
924 |
self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders]
|
925 |
self.clip_embedding_dim = sum(self.clip_embedding_dims)
|
926 |
+
|
927 |
# The ctors of the derived classes have already initialized encoder.subj_basis_generator.
|
928 |
# If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders.
|
929 |
# This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead,
|
|
|
931 |
self.subj_basis_generator = \
|
932 |
nn.ModuleList( [encoder.subj_basis_generator for encoder \
|
933 |
in self.id2ada_prompt_encoders] )
|
934 |
+
|
935 |
+
# load_adaface_ckpt() loads into self.subj_basis_generator. So we need to initialize self.subj_basis_generator first.
|
936 |
if adaface_ckpt_paths is not None:
|
937 |
self.load_adaface_ckpt(adaface_ckpt_paths)
|
938 |
+
|
939 |
print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
|
940 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix embs: {self.num_static_img_suffix_embs}.")
|
941 |
|
942 |
if enabled_encoders is not None:
|
943 |
self.are_encoders_enabled = \
|
|
|
953 |
else:
|
954 |
self.are_encoders_enabled = \
|
955 |
torch.tensor([True] * self.num_sub_encoders)
|
956 |
+
|
957 |
def load_adaface_ckpt(self, adaface_ckpt_paths):
|
|
|
|
|
958 |
if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
959 |
+
# If multiple adaface ckpt paths are provided, then we assume they are the
|
960 |
+
# ckpts of the sub-encoders.
|
961 |
+
if len(adaface_ckpt_paths) == self.num_sub_encoders:
|
962 |
+
for i, ckpt_path in enumerate(adaface_ckpt_paths):
|
963 |
+
self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
|
964 |
+
return
|
965 |
+
# If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
|
966 |
+
# so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
|
967 |
+
elif len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
|
968 |
adaface_ckpt_paths = adaface_ckpt_paths[0]
|
969 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
970 |
breakpoint()
|
971 |
|
972 |
+
adaface_ckpt_path = adaface_ckpt_paths
|
973 |
+
assert isinstance(adaface_ckpt_path, str), "adaface_ckpt_path should be a string."
|
974 |
+
# This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
|
975 |
+
# the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
|
976 |
+
# Therefore, no need to patch missing variables.
|
977 |
+
ckpt = torch.load(adaface_ckpt_paths, map_location='cpu', weights_only=False)
|
978 |
+
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
979 |
+
if self.subject_string not in string_to_subj_basis_generator_dict:
|
980 |
+
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
981 |
+
breakpoint()
|
982 |
+
|
983 |
+
ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
|
984 |
+
if len(ckpt_subj_basis_generators) != self.num_sub_encoders:
|
985 |
+
print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) "
|
986 |
+
f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
987 |
breakpoint()
|
988 |
|
989 |
+
for i, subj_basis_generator in enumerate(self.subj_basis_generator):
|
990 |
+
ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
|
991 |
+
# Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
|
992 |
+
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
|
993 |
+
img_prompt_dim=self.output_dim)
|
994 |
+
|
995 |
+
if subj_basis_generator.prompt2token_proj_attention_multipliers \
|
996 |
+
== [1] * 12:
|
997 |
+
subj_basis_generator.extend_prompt2token_proj_attention(\
|
998 |
+
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
999 |
+
elif subj_basis_generator.prompt2token_proj_attention_multipliers \
|
1000 |
+
!= ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
|
1001 |
+
raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
|
1002 |
+
|
1003 |
+
assert subj_basis_generator.prompt2token_proj_attention_multipliers \
|
1004 |
+
== ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
|
1005 |
+
"Inconsistent prompt2token_proj_attention_multipliers."
|
1006 |
+
subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
|
1007 |
+
|
1008 |
+
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
1009 |
+
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
1010 |
+
# If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
|
1011 |
+
# extend subj_basis_generator again.
|
1012 |
+
if self.extend_prompt2token_proj_attention_multiplier > 1:
|
1013 |
+
# During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
|
1014 |
+
# During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
|
1015 |
+
# During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
|
1016 |
+
subj_basis_generator.extend_prompt2token_proj_attention(\
|
1017 |
+
None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
|
1018 |
+
perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
|
1019 |
+
|
1020 |
+
subj_basis_generator.freeze_prompt2token_proj()
|
1021 |
+
|
1022 |
+
print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
|
1023 |
+
|
1024 |
+
def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scales):
|
1025 |
+
self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
|
1026 |
+
for i, out_id_embs_cfg_scale in enumerate(out_id_embs_cfg_scales):
|
1027 |
+
self.id2ada_prompt_encoders[i].set_out_id_embs_cfg_scale(out_id_embs_cfg_scale)
|
1028 |
+
|
1029 |
def extract_init_id_embeds_from_images(self, *args, **kwargs):
|
1030 |
total_faceless_img_count = 0
|
1031 |
all_id_embs = []
|
|
|
1163 |
|
1164 |
N_ID = self.encoders_num_id_vecs[i]
|
1165 |
if all_pos_prompt_embs[i] is None:
|
1166 |
+
# Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs0 embeddings.
|
1167 |
all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
1168 |
if all_neg_prompt_embs[i] is None:
|
1169 |
all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
|
|
1185 |
# So its .device is the device of its parameters.
|
1186 |
device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
|
1187 |
is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
|
1188 |
+
if kwargs.get('enable_static_img_suffix_embs', None) is None:
|
1189 |
+
enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
|
1190 |
+
else:
|
1191 |
+
enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
|
1192 |
+
if isinstance(enable_static_img_suffix_embs, bool):
|
1193 |
+
enable_static_img_suffix_embs = [enable_static_img_suffix_embs] * self.num_sub_encoders
|
1194 |
+
|
1195 |
BS = -1
|
1196 |
|
1197 |
if face_id_embs is not None:
|
|
|
1199 |
all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
|
1200 |
else:
|
1201 |
all_face_id_embs = [None] * self.num_sub_encoders
|
1202 |
+
|
1203 |
if img_prompt_embs is not None:
|
1204 |
BS = img_prompt_embs.shape[0] if BS == -1 else BS
|
1205 |
+
if img_prompt_embs.shape[1] != self.num_id_vecs0:
|
1206 |
breakpoint()
|
1207 |
+
all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs0, dim=1)
|
1208 |
+
img_prompt_embs_provided = True
|
1209 |
else:
|
1210 |
all_img_prompt_embs = [None] * self.num_sub_encoders
|
1211 |
+
img_prompt_embs_provided = False
|
1212 |
+
|
1213 |
if image_paths is not None:
|
1214 |
BS = len(image_paths) if BS == -1 else BS
|
1215 |
if BS == -1:
|
|
|
1232 |
else:
|
1233 |
are_encoders_enabled = self.are_encoders_enabled
|
1234 |
|
1235 |
+
self.curr_are_encoders_enabled = are_encoders_enabled
|
1236 |
all_adaface_subj_embs = []
|
1237 |
num_available_id_vecs = 0
|
1238 |
+
lens_subj_emb_segments = []
|
1239 |
|
1240 |
for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
|
1241 |
if not are_encoders_enabled[i]:
|
1242 |
adaface_subj_embs = None
|
1243 |
+
print(f"Encoder {id2ada_prompt_encoder.name} is disabled.")
|
1244 |
+
N_ID = id2ada_prompt_encoder.num_id_vecs + enable_static_img_suffix_embs[i] \
|
1245 |
+
* id2ada_prompt_encoder.num_static_img_suffix_embs
|
1246 |
else:
|
1247 |
+
kwargs['enable_static_img_suffix_embs'] = enable_static_img_suffix_embs[i]
|
1248 |
# ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
|
1249 |
# -> each sub-enconder's subj_basis_generator.train().
|
1250 |
# Therefore grad for the following call is enabled.
|
1251 |
+
adaface_subj_embs, img_prompt_embs, encoder_lens_subj_emb_segments = \
|
1252 |
id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
|
1253 |
all_face_id_embs[i],
|
1254 |
all_img_prompt_embs[i],
|
1255 |
*args, **kwargs)
|
1256 |
|
1257 |
+
# adaface_subj_embs: arc2face [16, 768] or consistentID [4, 768],
|
1258 |
+
# or arc2face [20, 768] or consistentID [8, 768] if enable_static_img_suffix_embs=True.
|
1259 |
+
N_ID = encoder_lens_subj_emb_segments[0]
|
1260 |
+
|
1261 |
if adaface_subj_embs is None:
|
1262 |
if not return_zero_embs_for_dropped_encoders:
|
1263 |
continue
|
|
|
1268 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1269 |
else:
|
1270 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1271 |
+
if not img_prompt_embs_provided:
|
1272 |
+
all_img_prompt_embs[i] = img_prompt_embs
|
1273 |
num_available_id_vecs += N_ID
|
1274 |
+
|
1275 |
+
lens_subj_emb_segments.append(N_ID)
|
1276 |
+
|
1277 |
# No faces are found in the images, so return None embeddings.
|
1278 |
# We don't want to return an all-zero embedding, which is useless.
|
1279 |
if num_available_id_vecs == 0:
|
1280 |
+
return None, [0]
|
1281 |
|
1282 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1283 |
# during inference, we average across the batch dim.
|
|
|
1287 |
# all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
|
1288 |
# all_adaface_subj_embs: [BS, 20, 768].
|
1289 |
all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
|
1290 |
+
# Check if some of the img_prompt_embs are None.
|
1291 |
+
if None in all_img_prompt_embs:
|
1292 |
+
all_img_prompt_embs = None
|
1293 |
+
else:
|
1294 |
+
all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=-2)
|
1295 |
+
return all_adaface_subj_embs, all_img_prompt_embs, lens_subj_emb_segments
|
1296 |
|
1297 |
|
1298 |
'''
|
adaface/subj_basis_generator.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9 |
from torch import nn
|
10 |
from einops import rearrange
|
11 |
from einops.layers.torch import Rearrange
|
12 |
-
from transformers import CLIPTokenizer, CLIPTextModel
|
13 |
|
14 |
from torch import einsum
|
15 |
from adaface.util import gen_gradient_scaler
|
@@ -57,7 +57,25 @@ class IP_MLPProjModel(nn.Module):
|
|
57 |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
58 |
x = self.norm(x)
|
59 |
return x
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
62 |
class LearnedSoftAggregate(nn.Module):
|
63 |
def __init__(self, num_feat, group_dim, keepdim=False):
|
@@ -349,23 +367,26 @@ class CrossAttention(nn.Module):
|
|
349 |
else:
|
350 |
return out
|
351 |
|
|
|
352 |
class ImgPrompt2TextPrompt(nn.Module):
|
353 |
-
def __init__(self, placeholder_is_bg, num_id_vecs,
|
|
|
354 |
super().__init__()
|
355 |
self.N_ID = num_id_vecs
|
356 |
# If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
|
357 |
self.N_SFX = 0
|
|
|
358 |
|
359 |
if not placeholder_is_bg:
|
360 |
-
self.
|
|
|
361 |
|
362 |
# prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
363 |
# prompt2token_proj is with the same architecture as the original arc2face text encoder,
|
364 |
# but retrained to do inverse mapping.
|
365 |
# To be initialized in the subclass.
|
366 |
self.prompt2token_proj = None
|
367 |
-
|
368 |
-
|
369 |
def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
|
370 |
self.N_SFX = num_static_img_suffix_embs
|
371 |
# We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
|
@@ -376,11 +397,11 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
376 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
|
377 |
elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
|
378 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
|
379 |
-
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
|
380 |
elif self.N_SFX > 0:
|
381 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
|
382 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
|
383 |
-
self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX])
|
384 |
else:
|
385 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
|
386 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
|
@@ -391,7 +412,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
391 |
# or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
|
392 |
# so we don't consider to reuse and extend a shorter static_img_suffix_embs).
|
393 |
# So we reinitialize it.
|
394 |
-
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
|
395 |
else:
|
396 |
# If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
|
397 |
self.static_img_suffix_embs = None
|
@@ -399,9 +420,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
399 |
# Implement a separate initialization function, so that it can be called from SubjBasisGenerator
|
400 |
# after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
|
401 |
# ckpts which were not subclassed from ImgPrompt2TextPrompt.
|
402 |
-
def initialize_text_components(self, max_prompt_length=77
|
403 |
-
num_static_img_suffix_embs=0, img_prompt_dim=768):
|
404 |
-
self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
|
405 |
self.max_prompt_length = max_prompt_length
|
406 |
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
407 |
# clip_text_embeddings: CLIPTextEmbeddings instance.
|
@@ -416,7 +435,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
416 |
# pad_embeddings is still on CPU. But should be moved to GPU automatically.
|
417 |
# Note: detach pad_embeddings from the computation graph, otherwise
|
418 |
# deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
|
419 |
-
self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach()
|
420 |
|
421 |
# image prompt space -> text prompt space.
|
422 |
# return_emb_types: a list of strings, each string is among
|
@@ -439,7 +458,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
439 |
else:
|
440 |
breakpoint()
|
441 |
else:
|
442 |
-
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in
|
443 |
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
444 |
list_extra_words = list_extra_words[:1]
|
445 |
|
@@ -466,7 +485,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
466 |
face_prompt_embs_orig_dtype = face_prompt_embs.dtype
|
467 |
face_prompt_embs = face_prompt_embs.to(self.dtype)
|
468 |
|
469 |
-
ID_END = 4
|
470 |
PAD_BEGIN = ID_END + self.N_SFX + 2
|
471 |
|
472 |
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
@@ -545,6 +564,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
545 |
class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
546 |
def __init__(
|
547 |
self,
|
|
|
548 |
# number of cross-attention heads of the bg prompt translator.
|
549 |
# Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
550 |
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
@@ -553,22 +573,25 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
553 |
# or number of background input identity vectors (no matter the subject is face or not).
|
554 |
# 257: 257 CLIP tokens.
|
555 |
num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
|
|
|
556 |
num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
|
557 |
num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
|
558 |
bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
|
559 |
obj_embedding_dim=384, # DINO object feature dimension for objects.
|
560 |
output_dim=768, # CLIP text embedding input dimension.
|
|
|
561 |
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
|
562 |
-
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
|
563 |
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
564 |
-
bg_prompt_translator_has_to_out_proj:
|
565 |
):
|
566 |
|
567 |
# If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
|
568 |
-
super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
|
569 |
-
num_static_img_suffix_embs=num_static_img_suffix_embs,
|
|
|
570 |
|
571 |
self.placeholder_is_bg = placeholder_is_bg
|
|
|
572 |
self.num_out_embs = self.N_ID + self.N_SFX
|
573 |
self.output_dim = output_dim
|
574 |
# num_nonface_in_id_vecs should be the number of core ID embs, 16.
|
@@ -586,14 +609,18 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
586 |
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
|
587 |
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
588 |
# Use an attention dropout of 0.2 to increase robustness.
|
589 |
-
|
590 |
-
self.prompt2token_proj
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
597 |
self.freeze_prompt2token_proj()
|
598 |
|
599 |
# These multipliers are relative to the original CLIPTextModel.
|
@@ -631,6 +658,9 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
631 |
identity_to_out=identity_to_out,
|
632 |
out_has_skip=out_has_skip)
|
633 |
|
|
|
|
|
|
|
634 |
self.output_scale = output_dim ** -0.5
|
635 |
|
636 |
'''
|
@@ -686,21 +716,20 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
686 |
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
687 |
|
688 |
# faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs)
|
704 |
elif raw_id_embs is not None:
|
705 |
# id_embs: [BS, 384] -> [BS, 18, 768].
|
706 |
# obj_proj_in is expected to project the DINO object features to
|
@@ -726,14 +755,15 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
726 |
|
727 |
adaface_out_embs = id_embs_out * self.output_scale # * 0.036
|
728 |
else:
|
729 |
-
|
|
|
730 |
# If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
|
731 |
if out_id_embs_cfg_scale != 1:
|
732 |
-
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768].
|
733 |
# NOTE: Never do cfg on static image suffix embeddings.
|
734 |
# So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
|
735 |
# even if enable_static_img_suffix_embs=True.
|
736 |
-
pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device)
|
737 |
adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
|
738 |
+ pad_embeddings * (1 - out_id_embs_cfg_scale)
|
739 |
|
@@ -812,37 +842,37 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
812 |
# Only applicable to fg basis generator.
|
813 |
if self.placeholder_is_bg:
|
814 |
return
|
815 |
-
|
816 |
-
# Then we don't have to check whether it's for subj or bg.
|
817 |
-
if self.prompt2token_proj_grad_scale == 0:
|
818 |
-
frozen_components_name = 'all'
|
819 |
-
frozen_param_set = self.prompt2token_proj.named_parameters()
|
820 |
-
else:
|
821 |
-
frozen_components_name = 'token_pos_embeddings'
|
822 |
-
frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters()
|
823 |
-
|
824 |
if self.prompt2token_proj is not None:
|
825 |
frozen_param_names = []
|
826 |
-
for param_name, param in
|
827 |
if param.requires_grad:
|
828 |
param.requires_grad = False
|
829 |
frozen_param_names.append(param_name)
|
830 |
# If param is already frozen, then no need to freeze it again.
|
831 |
-
print(f"{
|
832 |
#print(f"Frozen parameters:\n{frozen_param_names}")
|
833 |
|
834 |
def patch_old_subj_basis_generator_ckpt(self):
|
835 |
# Fix compatability with the previous version.
|
836 |
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
837 |
self.bg_prompt_translator_has_to_out_proj = False
|
838 |
-
if not hasattr(self, 'num_out_embs'):
|
839 |
-
self.num_out_embs = -1
|
840 |
if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
|
841 |
self.N_ID = self.num_id_vecs
|
|
|
|
|
|
|
842 |
if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
|
843 |
self.num_nonface_in_id_vecs = self.N_ID
|
844 |
if not hasattr(self, 'dtype'):
|
845 |
-
self.dtype = torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
847 |
if self.placeholder_is_bg:
|
848 |
if not hasattr(self, 'pos_embs') or self.pos_embs is None:
|
@@ -860,6 +890,14 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
860 |
num_static_img_suffix_embs=self.N_SFX,
|
861 |
img_prompt_dim=self.output_dim)
|
862 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
863 |
def __repr__(self):
|
864 |
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
865 |
|
|
|
9 |
from torch import nn
|
10 |
from einops import rearrange
|
11 |
from einops.layers.torch import Rearrange
|
12 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
13 |
|
14 |
from torch import einsum
|
15 |
from adaface.util import gen_gradient_scaler
|
|
|
57 |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
58 |
x = self.norm(x)
|
59 |
return x
|
60 |
+
|
61 |
+
class LayerwiseMLPProjWithSkip(nn.Module):
|
62 |
+
def __init__(self, id_embeddings_dim=768, num_layers=16, dim_mult=2):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.proj = nn.Sequential(
|
66 |
+
nn.Linear(id_embeddings_dim, id_embeddings_dim*dim_mult*num_layers),
|
67 |
+
Rearrange('b n (l d) -> b n l d', l=num_layers, d=id_embeddings_dim*dim_mult),
|
68 |
+
nn.GELU(),
|
69 |
+
nn.Linear(id_embeddings_dim*dim_mult, id_embeddings_dim),
|
70 |
+
)
|
71 |
+
self.norm = nn.LayerNorm(id_embeddings_dim)
|
72 |
+
|
73 |
+
def forward(self, id_embeds):
|
74 |
+
# B N D -> B N L D + B N L D -> B N L D
|
75 |
+
x = self.proj(id_embeds) + id_embeds.unsqueeze(1)
|
76 |
+
x = self.norm(x)
|
77 |
+
return x
|
78 |
+
|
79 |
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
80 |
class LearnedSoftAggregate(nn.Module):
|
81 |
def __init__(self, num_feat, group_dim, keepdim=False):
|
|
|
367 |
else:
|
368 |
return out
|
369 |
|
370 |
+
|
371 |
class ImgPrompt2TextPrompt(nn.Module):
|
372 |
+
def __init__(self, placeholder_is_bg, num_id_vecs, num_static_img_suffix_embs,
|
373 |
+
max_prompt_length=77, img_prompt_dim=768, dtype=torch.float16):
|
374 |
super().__init__()
|
375 |
self.N_ID = num_id_vecs
|
376 |
# If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
|
377 |
self.N_SFX = 0
|
378 |
+
self.dtype = dtype
|
379 |
|
380 |
if not placeholder_is_bg:
|
381 |
+
self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
|
382 |
+
self.initialize_text_components(max_prompt_length)
|
383 |
|
384 |
# prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
385 |
# prompt2token_proj is with the same architecture as the original arc2face text encoder,
|
386 |
# but retrained to do inverse mapping.
|
387 |
# To be initialized in the subclass.
|
388 |
self.prompt2token_proj = None
|
389 |
+
|
|
|
390 |
def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
|
391 |
self.N_SFX = num_static_img_suffix_embs
|
392 |
# We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
|
|
|
397 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
|
398 |
elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
|
399 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
|
400 |
+
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
|
401 |
elif self.N_SFX > 0:
|
402 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
|
403 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
|
404 |
+
self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX].to(dtype=self.dtype))
|
405 |
else:
|
406 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
|
407 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
|
|
|
412 |
# or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
|
413 |
# so we don't consider to reuse and extend a shorter static_img_suffix_embs).
|
414 |
# So we reinitialize it.
|
415 |
+
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
|
416 |
else:
|
417 |
# If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
|
418 |
self.static_img_suffix_embs = None
|
|
|
420 |
# Implement a separate initialization function, so that it can be called from SubjBasisGenerator
|
421 |
# after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
|
422 |
# ckpts which were not subclassed from ImgPrompt2TextPrompt.
|
423 |
+
def initialize_text_components(self, max_prompt_length=77):
|
|
|
|
|
424 |
self.max_prompt_length = max_prompt_length
|
425 |
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
426 |
# clip_text_embeddings: CLIPTextEmbeddings instance.
|
|
|
435 |
# pad_embeddings is still on CPU. But should be moved to GPU automatically.
|
436 |
# Note: detach pad_embeddings from the computation graph, otherwise
|
437 |
# deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
|
438 |
+
self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach().to(self.dtype)
|
439 |
|
440 |
# image prompt space -> text prompt space.
|
441 |
# return_emb_types: a list of strings, each string is among
|
|
|
458 |
else:
|
459 |
breakpoint()
|
460 |
else:
|
461 |
+
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_feat_distill_on_comp_prompt.
|
462 |
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
463 |
list_extra_words = list_extra_words[:1]
|
464 |
|
|
|
485 |
face_prompt_embs_orig_dtype = face_prompt_embs.dtype
|
486 |
face_prompt_embs = face_prompt_embs.to(self.dtype)
|
487 |
|
488 |
+
ID_END = 4 + self.N_ID
|
489 |
PAD_BEGIN = ID_END + self.N_SFX + 2
|
490 |
|
491 |
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
|
|
564 |
class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
565 |
def __init__(
|
566 |
self,
|
567 |
+
dtype=torch.float16,
|
568 |
# number of cross-attention heads of the bg prompt translator.
|
569 |
# Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
570 |
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
|
|
573 |
# or number of background input identity vectors (no matter the subject is face or not).
|
574 |
# 257: 257 CLIP tokens.
|
575 |
num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
|
576 |
+
num_ca_layers=16,
|
577 |
num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
|
578 |
num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
|
579 |
bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
|
580 |
obj_embedding_dim=384, # DINO object feature dimension for objects.
|
581 |
output_dim=768, # CLIP text embedding input dimension.
|
582 |
+
use_layerwise_proj: bool = False, # Whether to use layerwise projection.
|
583 |
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
|
|
|
584 |
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
585 |
+
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
|
586 |
):
|
587 |
|
588 |
# If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
|
589 |
+
super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
|
590 |
+
num_static_img_suffix_embs=num_static_img_suffix_embs,
|
591 |
+
max_prompt_length=77, img_prompt_dim=output_dim, dtype=dtype)
|
592 |
|
593 |
self.placeholder_is_bg = placeholder_is_bg
|
594 |
+
self.num_ca_layers = num_ca_layers
|
595 |
self.num_out_embs = self.N_ID + self.N_SFX
|
596 |
self.output_dim = output_dim
|
597 |
# num_nonface_in_id_vecs should be the number of core ID embs, 16.
|
|
|
609 |
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
|
610 |
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
611 |
# Use an attention dropout of 0.2 to increase robustness.
|
612 |
+
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
613 |
+
self.prompt2token_proj.to(dtype=self.dtype)
|
614 |
+
|
615 |
+
if use_layerwise_proj:
|
616 |
+
# MLPProjWithSkip: MLP with skip connection.
|
617 |
+
# [BS, 4, 768] -> [BS, 16, 4, 768]. Extra 16: 16 layers.
|
618 |
+
self.layerwise_proj = LayerwiseMLPProjWithSkip(output_dim, dim_mult=2)
|
619 |
+
else:
|
620 |
+
self.layerwise_proj = nn.Identity() #Rearrange('b n d -> b l n d', l=16)
|
621 |
+
|
622 |
+
print(f"Subj prompt2token_proj initialized.")
|
623 |
+
# Only freeze token and positional embeddings of the original CLIPTextModel.
|
624 |
self.freeze_prompt2token_proj()
|
625 |
|
626 |
# These multipliers are relative to the original CLIPTextModel.
|
|
|
658 |
identity_to_out=identity_to_out,
|
659 |
out_has_skip=out_has_skip)
|
660 |
|
661 |
+
if self.dtype == torch.float16:
|
662 |
+
self.prompt_translator = self.prompt_translator.half()
|
663 |
+
|
664 |
self.output_scale = output_dim ** -0.5
|
665 |
|
666 |
'''
|
|
|
716 |
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
717 |
|
718 |
# faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
|
719 |
+
# inverse_img_prompt_embs() applies self.prompt2token_proj to faceid2img_prompt_embs.
|
720 |
+
# If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
|
721 |
+
# and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
|
722 |
+
# If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
|
723 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
724 |
+
# ada_id_embs: [BS, 16, 768].
|
725 |
+
# return_emb_types: a list of strings, each string is among
|
726 |
+
# ['full', 'core', 'full_pad', 'full_half_pad'].
|
727 |
+
ada_id_embs, = \
|
728 |
+
self.inverse_img_prompt_embs(faceid2img_prompt_embs,
|
729 |
+
list_extra_words=None,
|
730 |
+
return_emb_types=['core'],
|
731 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
732 |
+
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
|
|
733 |
elif raw_id_embs is not None:
|
734 |
# id_embs: [BS, 384] -> [BS, 18, 768].
|
735 |
# obj_proj_in is expected to project the DINO object features to
|
|
|
755 |
|
756 |
adaface_out_embs = id_embs_out * self.output_scale # * 0.036
|
757 |
else:
|
758 |
+
# [BS, 16, 768] -> [BS, layers=16, tokens=16, 768]
|
759 |
+
adaface_out_embs = self.layerwise_proj(ada_id_embs)
|
760 |
# If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
|
761 |
if out_id_embs_cfg_scale != 1:
|
762 |
+
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
|
763 |
# NOTE: Never do cfg on static image suffix embeddings.
|
764 |
# So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
|
765 |
# even if enable_static_img_suffix_embs=True.
|
766 |
+
pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).unsqueeze(1).to(ada_id_embs.device)
|
767 |
adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
|
768 |
+ pad_embeddings * (1 - out_id_embs_cfg_scale)
|
769 |
|
|
|
842 |
# Only applicable to fg basis generator.
|
843 |
if self.placeholder_is_bg:
|
844 |
return
|
845 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
if self.prompt2token_proj is not None:
|
847 |
frozen_param_names = []
|
848 |
+
for param_name, param in self.prompt2token_proj.text_model.embeddings.named_parameters():
|
849 |
if param.requires_grad:
|
850 |
param.requires_grad = False
|
851 |
frozen_param_names.append(param_name)
|
852 |
# If param is already frozen, then no need to freeze it again.
|
853 |
+
print(f"{len(frozen_param_names)} params of token_pos_embeddings in Subj prompt2token_proj is frozen.")
|
854 |
#print(f"Frozen parameters:\n{frozen_param_names}")
|
855 |
|
856 |
def patch_old_subj_basis_generator_ckpt(self):
|
857 |
# Fix compatability with the previous version.
|
858 |
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
859 |
self.bg_prompt_translator_has_to_out_proj = False
|
|
|
|
|
860 |
if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
|
861 |
self.N_ID = self.num_id_vecs
|
862 |
+
# Update the number of output embeddings.
|
863 |
+
self.num_out_embs = self.N_ID + self.N_SFX
|
864 |
+
|
865 |
if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
|
866 |
self.num_nonface_in_id_vecs = self.N_ID
|
867 |
if not hasattr(self, 'dtype'):
|
868 |
+
self.dtype = torch.float16
|
869 |
+
if not self.placeholder_is_bg:
|
870 |
+
self.prompt2token_proj.to(dtype=self.dtype)
|
871 |
+
else:
|
872 |
+
self.prompt_translator.half()
|
873 |
+
|
874 |
+
if not hasattr(self, 'num_ca_layers'):
|
875 |
+
self.num_ca_layers = 16
|
876 |
|
877 |
if self.placeholder_is_bg:
|
878 |
if not hasattr(self, 'pos_embs') or self.pos_embs is None:
|
|
|
890 |
num_static_img_suffix_embs=self.N_SFX,
|
891 |
img_prompt_dim=self.output_dim)
|
892 |
|
893 |
+
if not hasattr(self, 'use_layerwise_proj'):
|
894 |
+
self.use_layerwise_proj = False
|
895 |
+
if not hasattr(self, 'layerwise_proj'):
|
896 |
+
if self.use_layerwise_proj:
|
897 |
+
self.layerwise_proj = LayerwiseMLPProjWithSkip(self.output_dim, dim_mult=2)
|
898 |
+
else:
|
899 |
+
self.layerwise_proj = nn.Identity()
|
900 |
+
|
901 |
def __repr__(self):
|
902 |
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
903 |
|
adaface/unet_teachers.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
-
import pytorch_lightning as pl
|
4 |
from diffusers import UNet2DConditionModel
|
5 |
from adaface.util import UNetEnsemble, create_consistentid_pipeline
|
6 |
from diffusers import UNet2DConditionModel
|
@@ -12,9 +12,9 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
|
|
12 |
teacher_type = teacher_type[0]
|
13 |
|
14 |
if teacher_type == "arc2face":
|
15 |
-
|
16 |
elif teacher_type == "unet_ensemble":
|
17 |
-
# unet, extra_unet_dirpaths and
|
18 |
# Even if we distill from unet_ensemble, we still need to load arc2face for generating
|
19 |
# arc2face embeddings.
|
20 |
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
|
@@ -22,20 +22,24 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
|
|
22 |
# However, since the __call__ method of the ddpm unet takes different formats of params,
|
23 |
# for simplicity, we still use the diffusers unet.
|
24 |
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
|
25 |
-
|
26 |
elif teacher_type == "consistentID":
|
27 |
-
|
28 |
elif teacher_type == "simple_unet":
|
29 |
-
|
30 |
# Since we've dereferenced the list if it has only one element,
|
31 |
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
|
32 |
elif isinstance(teacher_type, (tuple, list, ListConfig)):
|
33 |
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
|
34 |
-
|
35 |
else:
|
36 |
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
def __init__(self, **kwargs):
|
40 |
super().__init__()
|
41 |
self.name = None
|
@@ -56,9 +60,10 @@ class UNetTeacher(pl.LightningModule):
|
|
56 |
# to be initialized, which will unnecessarily complicate the code.
|
57 |
# noise: the initial noise for the first iteration.
|
58 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
59 |
-
#
|
60 |
-
def forward(self, ddpm_model, x_start, noise, t, teacher_context,
|
61 |
-
num_denoising_steps=1,
|
|
|
62 |
assert num_denoising_steps <= 10
|
63 |
|
64 |
if self.p_uses_cfg > 0:
|
@@ -71,27 +76,22 @@ class UNetTeacher(pl.LightningModule):
|
|
71 |
|
72 |
if self.uses_cfg:
|
73 |
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
self.cfg_scale = 1
|
76 |
print("Teacher does not use CFG.")
|
77 |
|
78 |
-
# If
|
79 |
-
#
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
for teacher_context_i in teacher_context:
|
85 |
-
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
86 |
-
if pos_context.shape[0] != x_start.shape[0]:
|
87 |
-
breakpoint()
|
88 |
-
teacher_pos_contexts.append(pos_context)
|
89 |
-
teacher_context = teacher_pos_contexts
|
90 |
-
else:
|
91 |
-
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
92 |
-
if pos_context.shape[0] != x_start.shape[0]:
|
93 |
-
breakpoint()
|
94 |
-
teacher_context = pos_context
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
@@ -102,15 +102,21 @@ class UNetTeacher(pl.LightningModule):
|
|
102 |
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
|
|
105 |
if self.name == 'unet_ensemble':
|
106 |
# teacher_context is a list of teacher contexts.
|
107 |
for teacher_context_i in teacher_context:
|
108 |
-
if teacher_context_i.shape[0] != x_start.shape[0] *
|
109 |
breakpoint()
|
110 |
else:
|
111 |
-
if teacher_context.shape[0] != x_start.shape[0] *
|
112 |
breakpoint()
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
# Initially, x_starts only contains the original x_start.
|
115 |
x_starts = [ x_start ]
|
116 |
noises = [ noise ]
|
@@ -125,24 +131,35 @@ class UNetTeacher(pl.LightningModule):
|
|
125 |
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
|
126 |
x_noisy = ddpm_model.q_sample(x_start, t, noise)
|
127 |
|
128 |
-
if self.uses_cfg:
|
129 |
x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
|
130 |
t2 = t.repeat(2)
|
131 |
else:
|
132 |
x_noisy2 = x_noisy
|
133 |
-
t2
|
134 |
|
135 |
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
|
136 |
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
|
137 |
return_dict=False)[0]
|
138 |
if self.uses_cfg and self.cfg_scale > 1:
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
|
141 |
|
142 |
-
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
|
143 |
-
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
|
144 |
noise_preds.append(noise_pred)
|
145 |
-
|
|
|
146 |
# The predicted x0 is used as the x_start for the next denoising step.
|
147 |
x_starts.append(pred_x0)
|
148 |
|
@@ -157,20 +174,43 @@ class UNetTeacher(pl.LightningModule):
|
|
157 |
# of the current timestep.
|
158 |
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
|
159 |
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
|
|
|
|
|
160 |
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
|
161 |
earlier_timesteps = earlier_timesteps.long()
|
|
|
162 |
|
163 |
-
if
|
164 |
-
# If
|
165 |
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
|
|
|
166 |
|
167 |
# earlier_timesteps = ts[i+1] < ts[i].
|
168 |
ts.append(earlier_timesteps)
|
169 |
-
|
170 |
-
noise = torch.randn_like(pred_x0)
|
171 |
noises.append(noise)
|
172 |
|
173 |
return noise_preds, x_starts, noises, ts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
class Arc2FaceTeacher(UNetTeacher):
|
176 |
def __init__(self, **kwargs):
|
@@ -185,11 +225,11 @@ class Arc2FaceTeacher(UNetTeacher):
|
|
185 |
self.cfg_scale_range = [1, 1]
|
186 |
|
187 |
class UNetEnsembleTeacher(UNetTeacher):
|
188 |
-
#
|
189 |
-
def __init__(self, unets, unet_types, extra_unet_dirpaths,
|
190 |
super().__init__(**kwargs)
|
191 |
self.name = "unet_ensemble"
|
192 |
-
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths,
|
193 |
|
194 |
class ConsistentIDTeacher(UNetTeacher):
|
195 |
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
|
@@ -199,12 +239,9 @@ class ConsistentIDTeacher(UNetTeacher):
|
|
199 |
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
|
200 |
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
|
201 |
# Instead, we have to initialize it to GPU directly.
|
202 |
-
|
203 |
-
# Compatible with the UNetTeacher interface.
|
204 |
-
self.unet = pipe.unet
|
205 |
-
# Release VAE and text_encoder to save memory. UNet is still needed for denoising
|
206 |
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
|
207 |
-
|
208 |
|
209 |
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
|
210 |
# Note p_uses_cfg=0.5 will also be passed in in kwargs.
|
|
|
1 |
import torch
|
2 |
+
from torch import nn
|
3 |
import numpy as np
|
|
|
4 |
from diffusers import UNet2DConditionModel
|
5 |
from adaface.util import UNetEnsemble, create_consistentid_pipeline
|
6 |
from diffusers import UNet2DConditionModel
|
|
|
12 |
teacher_type = teacher_type[0]
|
13 |
|
14 |
if teacher_type == "arc2face":
|
15 |
+
teacher = Arc2FaceTeacher(**kwargs)
|
16 |
elif teacher_type == "unet_ensemble":
|
17 |
+
# unet, extra_unet_dirpaths and unet_weights_in_ensemble are passed in kwargs.
|
18 |
# Even if we distill from unet_ensemble, we still need to load arc2face for generating
|
19 |
# arc2face embeddings.
|
20 |
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
|
|
|
22 |
# However, since the __call__ method of the ddpm unet takes different formats of params,
|
23 |
# for simplicity, we still use the diffusers unet.
|
24 |
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
|
25 |
+
teacher = UNetEnsembleTeacher(device=device, **kwargs)
|
26 |
elif teacher_type == "consistentID":
|
27 |
+
teacher = ConsistentIDTeacher(**kwargs)
|
28 |
elif teacher_type == "simple_unet":
|
29 |
+
teacher = SimpleUNetTeacher(**kwargs)
|
30 |
# Since we've dereferenced the list if it has only one element,
|
31 |
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
|
32 |
elif isinstance(teacher_type, (tuple, list, ListConfig)):
|
33 |
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
|
34 |
+
teacher = UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
|
35 |
else:
|
36 |
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
|
37 |
|
38 |
+
for param in teacher.parameters():
|
39 |
+
param.requires_grad = False
|
40 |
+
return teacher
|
41 |
+
|
42 |
+
class UNetTeacher(nn.Module):
|
43 |
def __init__(self, **kwargs):
|
44 |
super().__init__()
|
45 |
self.name = None
|
|
|
60 |
# to be initialized, which will unnecessarily complicate the code.
|
61 |
# noise: the initial noise for the first iteration.
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
+
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
+
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
+
num_denoising_steps=1, same_t_noise_across_instances=False,
|
66 |
+
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
if self.p_uses_cfg > 0:
|
|
|
76 |
|
77 |
if self.uses_cfg:
|
78 |
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
79 |
+
if negative_context is not None:
|
80 |
+
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
81 |
+
|
82 |
+
# if negative_context is None, then teacher_context is a combination of
|
83 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
84 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
85 |
else:
|
86 |
self.cfg_scale = 1
|
87 |
print("Teacher does not use CFG.")
|
88 |
|
89 |
+
# If negative_context is None, then teacher_context is a combination of
|
90 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
91 |
+
# Since not uses_cfg, we only need pos_context.
|
92 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
93 |
+
if negative_context is None:
|
94 |
+
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
|
|
102 |
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
105 |
+
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
106 |
if self.name == 'unet_ensemble':
|
107 |
# teacher_context is a list of teacher contexts.
|
108 |
for teacher_context_i in teacher_context:
|
109 |
+
if teacher_context_i.shape[0] != x_start.shape[0] * is_context_doubled:
|
110 |
breakpoint()
|
111 |
else:
|
112 |
+
if teacher_context.shape[0] != x_start.shape[0] * is_context_doubled:
|
113 |
breakpoint()
|
114 |
+
|
115 |
+
if same_t_noise_across_instances:
|
116 |
+
# If same_t_noise_across_instances, we use the same t and noise for all instances.
|
117 |
+
t = t[0].repeat(x_start.shape[0])
|
118 |
+
noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
|
119 |
+
|
120 |
# Initially, x_starts only contains the original x_start.
|
121 |
x_starts = [ x_start ]
|
122 |
noises = [ noise ]
|
|
|
131 |
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
|
132 |
x_noisy = ddpm_model.q_sample(x_start, t, noise)
|
133 |
|
134 |
+
if self.uses_cfg and self.cfg_scale > 1 and negative_context is None:
|
135 |
x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
|
136 |
t2 = t.repeat(2)
|
137 |
else:
|
138 |
x_noisy2 = x_noisy
|
139 |
+
t2 = t
|
140 |
|
141 |
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
|
142 |
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
|
143 |
return_dict=False)[0]
|
144 |
if self.uses_cfg and self.cfg_scale > 1:
|
145 |
+
if negative_context is None:
|
146 |
+
pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
|
147 |
+
else:
|
148 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
149 |
+
pos_noise_pred = noise_pred
|
150 |
+
with torch.no_grad():
|
151 |
+
if self.name == 'unet_ensemble':
|
152 |
+
neg_noise_pred = self.unet.unets[0](sample=x_noisy, timestep=t,
|
153 |
+
encoder_hidden_states=negative_context, return_dict=False)[0]
|
154 |
+
else:
|
155 |
+
neg_noise_pred = self.unet(sample=x_noisy, timestep=t,
|
156 |
+
encoder_hidden_states=negative_context, return_dict=False)[0]
|
157 |
+
|
158 |
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
|
159 |
|
|
|
|
|
160 |
noise_preds.append(noise_pred)
|
161 |
+
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
|
162 |
+
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
|
163 |
# The predicted x0 is used as the x_start for the next denoising step.
|
164 |
x_starts.append(pred_x0)
|
165 |
|
|
|
174 |
# of the current timestep.
|
175 |
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
|
176 |
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
|
177 |
+
t_lb = torch.clamp(t_lb, min=global_t_lb)
|
178 |
+
t_ub = torch.clamp(t_ub, max=global_t_ub)
|
179 |
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
|
180 |
earlier_timesteps = earlier_timesteps.long()
|
181 |
+
noise = torch.randn_like(pred_x0)
|
182 |
|
183 |
+
if same_t_noise_across_instances:
|
184 |
+
# If same_t_noise_across_instances, we use the same earlier_timesteps and noise for all instances.
|
185 |
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
|
186 |
+
noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
|
187 |
|
188 |
# earlier_timesteps = ts[i+1] < ts[i].
|
189 |
ts.append(earlier_timesteps)
|
|
|
|
|
190 |
noises.append(noise)
|
191 |
|
192 |
return noise_preds, x_starts, noises, ts
|
193 |
+
|
194 |
+
def extract_pos_context(self, teacher_context, BS):
|
195 |
+
# If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
|
196 |
+
# But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
|
197 |
+
# So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
|
198 |
+
if self.name == 'unet_ensemble':
|
199 |
+
teacher_pos_contexts = []
|
200 |
+
# teacher_context is a list of teacher contexts.
|
201 |
+
for teacher_context_i in teacher_context:
|
202 |
+
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
203 |
+
if pos_context.shape[0] != BS:
|
204 |
+
breakpoint()
|
205 |
+
teacher_pos_contexts.append(pos_context)
|
206 |
+
teacher_context = teacher_pos_contexts
|
207 |
+
else:
|
208 |
+
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
209 |
+
if pos_context.shape[0] != BS:
|
210 |
+
breakpoint()
|
211 |
+
teacher_context = pos_context
|
212 |
+
|
213 |
+
return teacher_context
|
214 |
|
215 |
class Arc2FaceTeacher(UNetTeacher):
|
216 |
def __init__(self, **kwargs):
|
|
|
225 |
self.cfg_scale_range = [1, 1]
|
226 |
|
227 |
class UNetEnsembleTeacher(UNetTeacher):
|
228 |
+
# unet_weights_in_ensemble are not model weights, but scalar weights for individual unets.
|
229 |
+
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', **kwargs):
|
230 |
super().__init__(**kwargs)
|
231 |
self.name = "unet_ensemble"
|
232 |
+
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble, device)
|
233 |
|
234 |
class ConsistentIDTeacher(UNetTeacher):
|
235 |
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
|
|
|
239 |
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
|
240 |
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
|
241 |
# Instead, we have to initialize it to GPU directly.
|
242 |
+
# Release VAE and text_encoder to save memory. UNet is needed for denoising
|
|
|
|
|
|
|
243 |
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
|
244 |
+
self.unet = create_consistentid_pipeline(base_model_path, unet_only=True)
|
245 |
|
246 |
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
|
247 |
# Note p_uses_cfg=0.5 will also be passed in in kwargs.
|
adaface/util.py
CHANGED
@@ -57,7 +57,7 @@ def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_di
|
|
57 |
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
|
58 |
return ts.numpy().astype(np_array.dtype)
|
59 |
|
60 |
-
def calc_stats(emb_name, embeddings, mean_dim
|
61 |
print("%s:" %emb_name)
|
62 |
repeat_count = [1] * embeddings.ndim
|
63 |
repeat_count[mean_dim] = embeddings.shape[mean_dim]
|
@@ -153,13 +153,14 @@ def pad_image_obj_to_square(image_obj, new_size=-1):
|
|
153 |
|
154 |
class UNetEnsemble(nn.Module):
|
155 |
# The first unet is the unet already loaded in a pipeline.
|
156 |
-
def __init__(self, unets, unet_types, extra_unet_dirpaths,
|
157 |
super().__init__()
|
158 |
|
159 |
-
self.unets = nn.ModuleList()
|
160 |
if unets is not None:
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
if unet_types is not None:
|
164 |
for unet_type in unet_types:
|
165 |
if unet_type == "arc2face":
|
@@ -169,25 +170,27 @@ class UNetEnsemble(nn.Module):
|
|
169 |
unet = create_consistentid_pipeline(unet_only=True)
|
170 |
else:
|
171 |
breakpoint()
|
172 |
-
|
173 |
|
174 |
if extra_unet_dirpaths is not None:
|
175 |
for unet_path in extra_unet_dirpaths:
|
176 |
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
|
177 |
-
|
178 |
|
179 |
-
if
|
180 |
-
|
181 |
-
elif len(
|
182 |
-
|
183 |
-
elif len(
|
184 |
breakpoint()
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
self.unet_weights = nn.Parameter(unet_weights, requires_grad=False)
|
189 |
|
190 |
-
|
|
|
|
|
|
|
191 |
# Set these fields to be compatible with diffusers.
|
192 |
self.dtype = self.unets[0].dtype
|
193 |
self.device = self.unets[0].device
|
@@ -215,8 +218,8 @@ class UNetEnsemble(nn.Module):
|
|
215 |
samples.append(sample)
|
216 |
|
217 |
samples = torch.stack(samples, dim=0)
|
218 |
-
|
219 |
-
sample = (samples *
|
220 |
|
221 |
if not return_dict:
|
222 |
return (sample,)
|
|
|
57 |
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
|
58 |
return ts.numpy().astype(np_array.dtype)
|
59 |
|
60 |
+
def calc_stats(emb_name, embeddings, mean_dim=-1):
|
61 |
print("%s:" %emb_name)
|
62 |
repeat_count = [1] * embeddings.ndim
|
63 |
repeat_count[mean_dim] = embeddings.shape[mean_dim]
|
|
|
153 |
|
154 |
class UNetEnsemble(nn.Module):
|
155 |
# The first unet is the unet already loaded in a pipeline.
|
156 |
+
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', torch_dtype=torch.float16):
|
157 |
super().__init__()
|
158 |
|
|
|
159 |
if unets is not None:
|
160 |
+
unets = [ unet.to(device) for unet in unets ]
|
161 |
+
else:
|
162 |
+
unets = []
|
163 |
+
|
164 |
if unet_types is not None:
|
165 |
for unet_type in unet_types:
|
166 |
if unet_type == "arc2face":
|
|
|
170 |
unet = create_consistentid_pipeline(unet_only=True)
|
171 |
else:
|
172 |
breakpoint()
|
173 |
+
unets.append(unet.to(device=device))
|
174 |
|
175 |
if extra_unet_dirpaths is not None:
|
176 |
for unet_path in extra_unet_dirpaths:
|
177 |
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
|
178 |
+
unets.append(unet.to(device=device))
|
179 |
|
180 |
+
if unet_weights_in_ensemble is None:
|
181 |
+
unet_weights_in_ensemble = [1.] * len(unets)
|
182 |
+
elif len(unets) < len(unet_weights_in_ensemble):
|
183 |
+
unet_weights_in_ensemble = unet_weights_in_ensemble[:len(unets)]
|
184 |
+
elif len(unets) > len(unet_weights_in_ensemble):
|
185 |
breakpoint()
|
186 |
|
187 |
+
unet_weights_in_ensemble = torch.tensor(unet_weights_in_ensemble, dtype=torch_dtype)
|
188 |
+
unet_weights_in_ensemble = unet_weights_in_ensemble / unet_weights_in_ensemble.sum()
|
|
|
189 |
|
190 |
+
self.unets = nn.ModuleList(unets)
|
191 |
+
# Put the weights in a Parameter so that they will be moved to the same device as the model.
|
192 |
+
self.unet_weights_in_ensemble = nn.Parameter(unet_weights_in_ensemble, requires_grad=False)
|
193 |
+
print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights_in_ensemble.data.cpu().numpy()}")
|
194 |
# Set these fields to be compatible with diffusers.
|
195 |
self.dtype = self.unets[0].dtype
|
196 |
self.device = self.unets[0].device
|
|
|
218 |
samples.append(sample)
|
219 |
|
220 |
samples = torch.stack(samples, dim=0)
|
221 |
+
unet_weights_in_ensemble = self.unet_weights_in_ensemble.reshape(-1, *([1] * (samples.ndim - 1)))
|
222 |
+
sample = (samples * unet_weights_in_ensemble).sum(dim=0)
|
223 |
|
224 |
if not return_dict:
|
225 |
return (sample,)
|
animatediff/sd/.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
-
v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
|
35 |
-
v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/feature_extractor/preprocessor_config.json
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"crop_size": 224,
|
3 |
-
"do_center_crop": true,
|
4 |
-
"do_convert_rgb": true,
|
5 |
-
"do_normalize": true,
|
6 |
-
"do_resize": true,
|
7 |
-
"feature_extractor_type": "CLIPFeatureExtractor",
|
8 |
-
"image_mean": [
|
9 |
-
0.48145466,
|
10 |
-
0.4578275,
|
11 |
-
0.40821073
|
12 |
-
],
|
13 |
-
"image_std": [
|
14 |
-
0.26862954,
|
15 |
-
0.26130258,
|
16 |
-
0.27577711
|
17 |
-
],
|
18 |
-
"resample": 3,
|
19 |
-
"size": 224
|
20 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/model_index.json
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "StableDiffusionPipeline",
|
3 |
-
"_diffusers_version": "0.6.0",
|
4 |
-
"feature_extractor": [
|
5 |
-
"transformers",
|
6 |
-
"CLIPImageProcessor"
|
7 |
-
],
|
8 |
-
"safety_checker": [
|
9 |
-
"stable_diffusion",
|
10 |
-
"StableDiffusionSafetyChecker"
|
11 |
-
],
|
12 |
-
"scheduler": [
|
13 |
-
"diffusers",
|
14 |
-
"PNDMScheduler"
|
15 |
-
],
|
16 |
-
"text_encoder": [
|
17 |
-
"transformers",
|
18 |
-
"CLIPTextModel"
|
19 |
-
],
|
20 |
-
"tokenizer": [
|
21 |
-
"transformers",
|
22 |
-
"CLIPTokenizer"
|
23 |
-
],
|
24 |
-
"unet": [
|
25 |
-
"diffusers",
|
26 |
-
"UNet2DConditionModel"
|
27 |
-
],
|
28 |
-
"vae": [
|
29 |
-
"diffusers",
|
30 |
-
"AutoencoderKL"
|
31 |
-
]
|
32 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/safety_checker/config.json
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
|
3 |
-
"_name_or_path": "CompVis/stable-diffusion-safety-checker",
|
4 |
-
"architectures": [
|
5 |
-
"StableDiffusionSafetyChecker"
|
6 |
-
],
|
7 |
-
"initializer_factor": 1.0,
|
8 |
-
"logit_scale_init_value": 2.6592,
|
9 |
-
"model_type": "clip",
|
10 |
-
"projection_dim": 768,
|
11 |
-
"text_config": {
|
12 |
-
"_name_or_path": "",
|
13 |
-
"add_cross_attention": false,
|
14 |
-
"architectures": null,
|
15 |
-
"attention_dropout": 0.0,
|
16 |
-
"bad_words_ids": null,
|
17 |
-
"bos_token_id": 0,
|
18 |
-
"chunk_size_feed_forward": 0,
|
19 |
-
"cross_attention_hidden_size": null,
|
20 |
-
"decoder_start_token_id": null,
|
21 |
-
"diversity_penalty": 0.0,
|
22 |
-
"do_sample": false,
|
23 |
-
"dropout": 0.0,
|
24 |
-
"early_stopping": false,
|
25 |
-
"encoder_no_repeat_ngram_size": 0,
|
26 |
-
"eos_token_id": 2,
|
27 |
-
"exponential_decay_length_penalty": null,
|
28 |
-
"finetuning_task": null,
|
29 |
-
"forced_bos_token_id": null,
|
30 |
-
"forced_eos_token_id": null,
|
31 |
-
"hidden_act": "quick_gelu",
|
32 |
-
"hidden_size": 768,
|
33 |
-
"id2label": {
|
34 |
-
"0": "LABEL_0",
|
35 |
-
"1": "LABEL_1"
|
36 |
-
},
|
37 |
-
"initializer_factor": 1.0,
|
38 |
-
"initializer_range": 0.02,
|
39 |
-
"intermediate_size": 3072,
|
40 |
-
"is_decoder": false,
|
41 |
-
"is_encoder_decoder": false,
|
42 |
-
"label2id": {
|
43 |
-
"LABEL_0": 0,
|
44 |
-
"LABEL_1": 1
|
45 |
-
},
|
46 |
-
"layer_norm_eps": 1e-05,
|
47 |
-
"length_penalty": 1.0,
|
48 |
-
"max_length": 20,
|
49 |
-
"max_position_embeddings": 77,
|
50 |
-
"min_length": 0,
|
51 |
-
"model_type": "clip_text_model",
|
52 |
-
"no_repeat_ngram_size": 0,
|
53 |
-
"num_attention_heads": 12,
|
54 |
-
"num_beam_groups": 1,
|
55 |
-
"num_beams": 1,
|
56 |
-
"num_hidden_layers": 12,
|
57 |
-
"num_return_sequences": 1,
|
58 |
-
"output_attentions": false,
|
59 |
-
"output_hidden_states": false,
|
60 |
-
"output_scores": false,
|
61 |
-
"pad_token_id": 1,
|
62 |
-
"prefix": null,
|
63 |
-
"problem_type": null,
|
64 |
-
"pruned_heads": {},
|
65 |
-
"remove_invalid_values": false,
|
66 |
-
"repetition_penalty": 1.0,
|
67 |
-
"return_dict": true,
|
68 |
-
"return_dict_in_generate": false,
|
69 |
-
"sep_token_id": null,
|
70 |
-
"task_specific_params": null,
|
71 |
-
"temperature": 1.0,
|
72 |
-
"tf_legacy_loss": false,
|
73 |
-
"tie_encoder_decoder": false,
|
74 |
-
"tie_word_embeddings": true,
|
75 |
-
"tokenizer_class": null,
|
76 |
-
"top_k": 50,
|
77 |
-
"top_p": 1.0,
|
78 |
-
"torch_dtype": null,
|
79 |
-
"torchscript": false,
|
80 |
-
"transformers_version": "4.22.0.dev0",
|
81 |
-
"typical_p": 1.0,
|
82 |
-
"use_bfloat16": false,
|
83 |
-
"vocab_size": 49408
|
84 |
-
},
|
85 |
-
"text_config_dict": {
|
86 |
-
"hidden_size": 768,
|
87 |
-
"intermediate_size": 3072,
|
88 |
-
"num_attention_heads": 12,
|
89 |
-
"num_hidden_layers": 12
|
90 |
-
},
|
91 |
-
"torch_dtype": "float32",
|
92 |
-
"transformers_version": null,
|
93 |
-
"vision_config": {
|
94 |
-
"_name_or_path": "",
|
95 |
-
"add_cross_attention": false,
|
96 |
-
"architectures": null,
|
97 |
-
"attention_dropout": 0.0,
|
98 |
-
"bad_words_ids": null,
|
99 |
-
"bos_token_id": null,
|
100 |
-
"chunk_size_feed_forward": 0,
|
101 |
-
"cross_attention_hidden_size": null,
|
102 |
-
"decoder_start_token_id": null,
|
103 |
-
"diversity_penalty": 0.0,
|
104 |
-
"do_sample": false,
|
105 |
-
"dropout": 0.0,
|
106 |
-
"early_stopping": false,
|
107 |
-
"encoder_no_repeat_ngram_size": 0,
|
108 |
-
"eos_token_id": null,
|
109 |
-
"exponential_decay_length_penalty": null,
|
110 |
-
"finetuning_task": null,
|
111 |
-
"forced_bos_token_id": null,
|
112 |
-
"forced_eos_token_id": null,
|
113 |
-
"hidden_act": "quick_gelu",
|
114 |
-
"hidden_size": 1024,
|
115 |
-
"id2label": {
|
116 |
-
"0": "LABEL_0",
|
117 |
-
"1": "LABEL_1"
|
118 |
-
},
|
119 |
-
"image_size": 224,
|
120 |
-
"initializer_factor": 1.0,
|
121 |
-
"initializer_range": 0.02,
|
122 |
-
"intermediate_size": 4096,
|
123 |
-
"is_decoder": false,
|
124 |
-
"is_encoder_decoder": false,
|
125 |
-
"label2id": {
|
126 |
-
"LABEL_0": 0,
|
127 |
-
"LABEL_1": 1
|
128 |
-
},
|
129 |
-
"layer_norm_eps": 1e-05,
|
130 |
-
"length_penalty": 1.0,
|
131 |
-
"max_length": 20,
|
132 |
-
"min_length": 0,
|
133 |
-
"model_type": "clip_vision_model",
|
134 |
-
"no_repeat_ngram_size": 0,
|
135 |
-
"num_attention_heads": 16,
|
136 |
-
"num_beam_groups": 1,
|
137 |
-
"num_beams": 1,
|
138 |
-
"num_channels": 3,
|
139 |
-
"num_hidden_layers": 24,
|
140 |
-
"num_return_sequences": 1,
|
141 |
-
"output_attentions": false,
|
142 |
-
"output_hidden_states": false,
|
143 |
-
"output_scores": false,
|
144 |
-
"pad_token_id": null,
|
145 |
-
"patch_size": 14,
|
146 |
-
"prefix": null,
|
147 |
-
"problem_type": null,
|
148 |
-
"pruned_heads": {},
|
149 |
-
"remove_invalid_values": false,
|
150 |
-
"repetition_penalty": 1.0,
|
151 |
-
"return_dict": true,
|
152 |
-
"return_dict_in_generate": false,
|
153 |
-
"sep_token_id": null,
|
154 |
-
"task_specific_params": null,
|
155 |
-
"temperature": 1.0,
|
156 |
-
"tf_legacy_loss": false,
|
157 |
-
"tie_encoder_decoder": false,
|
158 |
-
"tie_word_embeddings": true,
|
159 |
-
"tokenizer_class": null,
|
160 |
-
"top_k": 50,
|
161 |
-
"top_p": 1.0,
|
162 |
-
"torch_dtype": null,
|
163 |
-
"torchscript": false,
|
164 |
-
"transformers_version": "4.22.0.dev0",
|
165 |
-
"typical_p": 1.0,
|
166 |
-
"use_bfloat16": false
|
167 |
-
},
|
168 |
-
"vision_config_dict": {
|
169 |
-
"hidden_size": 1024,
|
170 |
-
"intermediate_size": 4096,
|
171 |
-
"num_attention_heads": 16,
|
172 |
-
"num_hidden_layers": 24,
|
173 |
-
"patch_size": 14
|
174 |
-
}
|
175 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/scheduler/scheduler_config.json
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "PNDMScheduler",
|
3 |
-
"_diffusers_version": "0.6.0",
|
4 |
-
"beta_end": 0.012,
|
5 |
-
"beta_schedule": "scaled_linear",
|
6 |
-
"beta_start": 0.00085,
|
7 |
-
"num_train_timesteps": 1000,
|
8 |
-
"set_alpha_to_one": false,
|
9 |
-
"skip_prk_steps": true,
|
10 |
-
"steps_offset": 1,
|
11 |
-
"trained_betas": null,
|
12 |
-
"clip_sample": false
|
13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/text_encoder/config.json
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
3 |
-
"architectures": [
|
4 |
-
"CLIPTextModel"
|
5 |
-
],
|
6 |
-
"attention_dropout": 0.0,
|
7 |
-
"bos_token_id": 0,
|
8 |
-
"dropout": 0.0,
|
9 |
-
"eos_token_id": 2,
|
10 |
-
"hidden_act": "quick_gelu",
|
11 |
-
"hidden_size": 768,
|
12 |
-
"initializer_factor": 1.0,
|
13 |
-
"initializer_range": 0.02,
|
14 |
-
"intermediate_size": 3072,
|
15 |
-
"layer_norm_eps": 1e-05,
|
16 |
-
"max_position_embeddings": 77,
|
17 |
-
"model_type": "clip_text_model",
|
18 |
-
"num_attention_heads": 12,
|
19 |
-
"num_hidden_layers": 12,
|
20 |
-
"pad_token_id": 1,
|
21 |
-
"projection_dim": 768,
|
22 |
-
"torch_dtype": "float32",
|
23 |
-
"transformers_version": "4.22.0.dev0",
|
24 |
-
"vocab_size": 49408
|
25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/tokenizer/merges.txt
DELETED
The diff for this file is too large to render.
See raw diff
|
|
animatediff/sd/tokenizer/special_tokens_map.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"bos_token": {
|
3 |
-
"content": "<|startoftext|>",
|
4 |
-
"lstrip": false,
|
5 |
-
"normalized": true,
|
6 |
-
"rstrip": false,
|
7 |
-
"single_word": false
|
8 |
-
},
|
9 |
-
"eos_token": {
|
10 |
-
"content": "<|endoftext|>",
|
11 |
-
"lstrip": false,
|
12 |
-
"normalized": true,
|
13 |
-
"rstrip": false,
|
14 |
-
"single_word": false
|
15 |
-
},
|
16 |
-
"pad_token": "<|endoftext|>",
|
17 |
-
"unk_token": {
|
18 |
-
"content": "<|endoftext|>",
|
19 |
-
"lstrip": false,
|
20 |
-
"normalized": true,
|
21 |
-
"rstrip": false,
|
22 |
-
"single_word": false
|
23 |
-
}
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/tokenizer/tokenizer_config.json
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"add_prefix_space": false,
|
3 |
-
"bos_token": {
|
4 |
-
"__type": "AddedToken",
|
5 |
-
"content": "<|startoftext|>",
|
6 |
-
"lstrip": false,
|
7 |
-
"normalized": true,
|
8 |
-
"rstrip": false,
|
9 |
-
"single_word": false
|
10 |
-
},
|
11 |
-
"do_lower_case": true,
|
12 |
-
"eos_token": {
|
13 |
-
"__type": "AddedToken",
|
14 |
-
"content": "<|endoftext|>",
|
15 |
-
"lstrip": false,
|
16 |
-
"normalized": true,
|
17 |
-
"rstrip": false,
|
18 |
-
"single_word": false
|
19 |
-
},
|
20 |
-
"errors": "replace",
|
21 |
-
"model_max_length": 77,
|
22 |
-
"name_or_path": "openai/clip-vit-large-patch14",
|
23 |
-
"pad_token": "<|endoftext|>",
|
24 |
-
"special_tokens_map_file": "./special_tokens_map.json",
|
25 |
-
"tokenizer_class": "CLIPTokenizer",
|
26 |
-
"unk_token": {
|
27 |
-
"__type": "AddedToken",
|
28 |
-
"content": "<|endoftext|>",
|
29 |
-
"lstrip": false,
|
30 |
-
"normalized": true,
|
31 |
-
"rstrip": false,
|
32 |
-
"single_word": false
|
33 |
-
}
|
34 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/tokenizer/vocab.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
animatediff/sd/unet/config.json
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "UNet2DConditionModel",
|
3 |
-
"_diffusers_version": "0.6.0",
|
4 |
-
"act_fn": "silu",
|
5 |
-
"attention_head_dim": 8,
|
6 |
-
"block_out_channels": [
|
7 |
-
320,
|
8 |
-
640,
|
9 |
-
1280,
|
10 |
-
1280
|
11 |
-
],
|
12 |
-
"center_input_sample": false,
|
13 |
-
"cross_attention_dim": 768,
|
14 |
-
"down_block_types": [
|
15 |
-
"CrossAttnDownBlock2D",
|
16 |
-
"CrossAttnDownBlock2D",
|
17 |
-
"CrossAttnDownBlock2D",
|
18 |
-
"DownBlock2D"
|
19 |
-
],
|
20 |
-
"downsample_padding": 1,
|
21 |
-
"flip_sin_to_cos": true,
|
22 |
-
"freq_shift": 0,
|
23 |
-
"in_channels": 4,
|
24 |
-
"layers_per_block": 2,
|
25 |
-
"mid_block_scale_factor": 1,
|
26 |
-
"norm_eps": 1e-05,
|
27 |
-
"norm_num_groups": 32,
|
28 |
-
"out_channels": 4,
|
29 |
-
"sample_size": 64,
|
30 |
-
"up_block_types": [
|
31 |
-
"UpBlock2D",
|
32 |
-
"CrossAttnUpBlock2D",
|
33 |
-
"CrossAttnUpBlock2D",
|
34 |
-
"CrossAttnUpBlock2D"
|
35 |
-
]
|
36 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/v1-inference.yaml
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 1.0e-04
|
3 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
-
params:
|
5 |
-
linear_start: 0.00085
|
6 |
-
linear_end: 0.0120
|
7 |
-
num_timesteps_cond: 1
|
8 |
-
log_every_t: 200
|
9 |
-
timesteps: 1000
|
10 |
-
first_stage_key: "jpg"
|
11 |
-
cond_stage_key: "txt"
|
12 |
-
image_size: 64
|
13 |
-
channels: 4
|
14 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
-
conditioning_key: crossattn
|
16 |
-
monitor: val/loss_simple_ema
|
17 |
-
scale_factor: 0.18215
|
18 |
-
use_ema: False
|
19 |
-
|
20 |
-
scheduler_config: # 10000 warmup steps
|
21 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
-
params:
|
23 |
-
warm_up_steps: [ 10000 ]
|
24 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
-
f_start: [ 1.e-6 ]
|
26 |
-
f_max: [ 1. ]
|
27 |
-
f_min: [ 1. ]
|
28 |
-
|
29 |
-
unet_config:
|
30 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
-
params:
|
32 |
-
image_size: 32 # unused
|
33 |
-
in_channels: 4
|
34 |
-
out_channels: 4
|
35 |
-
model_channels: 320
|
36 |
-
attention_resolutions: [ 4, 2, 1 ]
|
37 |
-
num_res_blocks: 2
|
38 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
-
num_heads: 8
|
40 |
-
use_spatial_transformer: True
|
41 |
-
transformer_depth: 1
|
42 |
-
context_dim: 768
|
43 |
-
use_checkpoint: True
|
44 |
-
legacy: False
|
45 |
-
|
46 |
-
first_stage_config:
|
47 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
-
params:
|
49 |
-
embed_dim: 4
|
50 |
-
monitor: val/rec_loss
|
51 |
-
ddconfig:
|
52 |
-
double_z: true
|
53 |
-
z_channels: 4
|
54 |
-
resolution: 256
|
55 |
-
in_channels: 3
|
56 |
-
out_ch: 3
|
57 |
-
ch: 128
|
58 |
-
ch_mult:
|
59 |
-
- 1
|
60 |
-
- 2
|
61 |
-
- 4
|
62 |
-
- 4
|
63 |
-
num_res_blocks: 2
|
64 |
-
attn_resolutions: []
|
65 |
-
dropout: 0.0
|
66 |
-
lossconfig:
|
67 |
-
target: torch.nn.Identity
|
68 |
-
|
69 |
-
cond_stage_config:
|
70 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/sd/vae/config.json
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_class_name": "AutoencoderKL",
|
3 |
-
"_diffusers_version": "0.6.0",
|
4 |
-
"act_fn": "silu",
|
5 |
-
"block_out_channels": [
|
6 |
-
128,
|
7 |
-
256,
|
8 |
-
512,
|
9 |
-
512
|
10 |
-
],
|
11 |
-
"down_block_types": [
|
12 |
-
"DownEncoderBlock2D",
|
13 |
-
"DownEncoderBlock2D",
|
14 |
-
"DownEncoderBlock2D",
|
15 |
-
"DownEncoderBlock2D"
|
16 |
-
],
|
17 |
-
"in_channels": 3,
|
18 |
-
"latent_channels": 4,
|
19 |
-
"layers_per_block": 2,
|
20 |
-
"norm_num_groups": 32,
|
21 |
-
"out_channels": 3,
|
22 |
-
"sample_size": 512,
|
23 |
-
"up_block_types": [
|
24 |
-
"UpDecoderBlock2D",
|
25 |
-
"UpDecoderBlock2D",
|
26 |
-
"UpDecoderBlock2D",
|
27 |
-
"UpDecoderBlock2D"
|
28 |
-
]
|
29 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animatediff/utils/convert_from_ckpt.py
CHANGED
@@ -714,7 +714,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|
714 |
|
715 |
|
716 |
def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
|
717 |
-
text_model = CLIPTextModel.from_pretrained("animatediff/sd/text_encoder", torch_dtype=dtype)
|
718 |
keys = list(checkpoint.keys())
|
719 |
|
720 |
text_model_dict = {}
|
|
|
714 |
|
715 |
|
716 |
def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
|
717 |
+
text_model = CLIPTextModel.from_pretrained("models/animatediff/sd/text_encoder", torch_dtype=dtype)
|
718 |
keys = list(checkpoint.keys())
|
719 |
|
720 |
text_model_dict = {}
|
app.py
CHANGED
@@ -24,13 +24,11 @@ parser = argparse.ArgumentParser()
|
|
24 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
25 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
26 |
parser.add_argument('--adaface_ckpt_path', type=str,
|
27 |
-
default='models/adaface/
|
28 |
-
parser.add_argument('--model_style_type', type=str, default='
|
29 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
30 |
-
parser.add_argument("--guidance_scale", type=float, default=
|
31 |
help="The guidance scale for the diffusion model. Default: 8.0")
|
32 |
-
parser.add_argument("--do_neg_id_prompt_weight", type=float, default=0,
|
33 |
-
help="The weight of added ID prompt embeddings into the negative prompt. Default: 0, disabled.")
|
34 |
|
35 |
parser.add_argument('--gpu', type=int, default=None)
|
36 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
@@ -41,20 +39,38 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
41 |
seed = random.randint(0, MAX_SEED)
|
42 |
return seed
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# model = load_model()
|
45 |
# This FaceAnalysis is just to crop the face areas from the uploaded images,
|
46 |
# and is independent of the adaface FaceAnalysis apps.
|
47 |
-
app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['
|
48 |
app.prepare(ctx_id=0, det_size=(320, 320))
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
global adaface, id_animator
|
52 |
|
53 |
-
|
54 |
id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
|
55 |
-
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=
|
56 |
adaface_encoder_types=args.adaface_encoder_types,
|
57 |
-
adaface_ckpt_paths=
|
58 |
|
59 |
basedir = os.getcwd()
|
60 |
savedir = os.path.join(basedir,'samples')
|
@@ -79,7 +95,7 @@ def get_clicked_image(data: gr.SelectData):
|
|
79 |
return data.index
|
80 |
|
81 |
@spaces.GPU
|
82 |
-
def gen_init_images(uploaded_image_paths, prompt,
|
83 |
if uploaded_image_paths is None:
|
84 |
print("No image uploaded")
|
85 |
return None, None, None
|
@@ -92,9 +108,11 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
|
|
92 |
# [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
|
93 |
# Extract the file paths.
|
94 |
uploaded_image_paths = [path[0] for path in uploaded_image_paths]
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
98 |
|
99 |
if adaface_subj_embs is None:
|
100 |
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
@@ -102,20 +120,22 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
|
|
102 |
# Generate two images each time for the user to select from.
|
103 |
noise = torch.randn(out_image_count, 3, 512, 512)
|
104 |
|
105 |
-
|
106 |
-
if enhance_face and "face portrait" not in prompt:
|
107 |
if "portrait" in prompt:
|
108 |
# Enhance the face features by replacing "portrait" with "face portrait".
|
109 |
prompt = prompt.replace("portrait", "face portrait")
|
110 |
else:
|
111 |
prompt = "face portrait, " + prompt
|
112 |
|
|
|
|
|
113 |
# samples: A list of PIL Image instances.
|
114 |
with torch.no_grad():
|
115 |
samples = adaface(noise, prompt, placeholder_tokens_pos='append',
|
116 |
guidance_scale=guidance_scale,
|
117 |
-
|
118 |
-
|
|
|
119 |
|
120 |
face_paths = []
|
121 |
for sample in samples:
|
@@ -131,9 +151,9 @@ def gen_init_images(uploaded_image_paths, prompt, guidance_scale, do_neg_id_prom
|
|
131 |
@spaces.GPU(duration=90)
|
132 |
def generate_video(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
|
133 |
init_image_strength, init_image_final_weight,
|
134 |
-
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
135 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
136 |
-
|
137 |
id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
|
138 |
|
139 |
global adaface, id_animator
|
@@ -143,10 +163,17 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
143 |
if prompt is None:
|
144 |
prompt = ""
|
145 |
|
146 |
-
prompt = prompt + " 8k uhd, high quality"
|
147 |
-
if " shot" not in prompt:
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
prompt_img_lists=[]
|
151 |
for path in uploaded_image_paths:
|
152 |
img = cv2.imread(path)
|
@@ -158,16 +185,11 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
158 |
# prompt_img_lists is a list of PIL images.
|
159 |
prompt_img_lists.append(load_image(face_path).resize((224,224)))
|
160 |
|
161 |
-
if adaface is None or not is_adaface_enabled:
|
162 |
adaface_prompt_embeds, negative_prompt_embeds = None, None
|
|
|
163 |
image_embed_cfg_scales = (1, 1)
|
164 |
else:
|
165 |
-
if (adaface_ckpt_path is not None and adaface_ckpt_path.strip() != '') \
|
166 |
-
and (adaface_ckpt_path != args.adaface_ckpt_path):
|
167 |
-
args.adaface_ckpt_path = adaface_ckpt_path
|
168 |
-
# Reload the adaface model weights.
|
169 |
-
adaface.id2ada_prompt_encoder.load_adaface_ckpt(adaface_ckpt_path)
|
170 |
-
|
171 |
with torch.no_grad():
|
172 |
adaface_subj_embs = \
|
173 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
@@ -176,9 +198,10 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
176 |
# adaface_prompt_embeds: [1, 77, 768].
|
177 |
adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
|
178 |
adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
|
179 |
-
|
180 |
verbose=True)
|
181 |
|
|
|
182 |
image_embed_cfg_scales = (image_embed_cfg_begin_scale, image_embed_cfg_end_scale)
|
183 |
|
184 |
# init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
|
@@ -198,8 +221,8 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
198 |
prompt = prompt,
|
199 |
negative_prompt = negative_prompt,
|
200 |
adaface_prompt_embeds = (adaface_prompt_embeds, negative_prompt_embeds),
|
201 |
-
# adaface_power_scale is not so useful, and when it's set >= 2, weird artifacts appear.
|
202 |
-
# Here it's limited to
|
203 |
adaface_power_scale = adaface_power_scale,
|
204 |
num_inference_steps = num_steps,
|
205 |
id_animator_anneal_steps = id_animator_anneal_steps,
|
@@ -216,7 +239,7 @@ def generate_video(image_container, uploaded_image_paths, init_img_file_paths, i
|
|
216 |
save_videos_grid(sample, save_sample_path)
|
217 |
return save_sample_path
|
218 |
|
219 |
-
def check_prompt_and_model_type(prompt, model_style_type):
|
220 |
global adaface, id_animator
|
221 |
|
222 |
model_style_type = model_style_type.lower()
|
@@ -236,21 +259,20 @@ def check_prompt_and_model_type(prompt, model_style_type):
|
|
236 |
with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
237 |
gr.Markdown(
|
238 |
"""
|
239 |
-
# AdaFace-Animate: Zero-Shot Subject-Driven Video Generation
|
240 |
"""
|
241 |
)
|
242 |
gr.Markdown(
|
243 |
"""
|
244 |
-
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for
|
245 |
|
246 |
-
❗️**
|
247 |
-
- Support switching between
|
248 |
-
- If you
|
249 |
|
250 |
❗️**Tips**❗️
|
251 |
- You can upload one or more subject images for generating ID-specific video.
|
252 |
-
- If the face
|
253 |
-
- If the face loses focus, try increasing the guidance scale.
|
254 |
- If the motion is weird, e.g., the prompt is "... running", try increasing the number of sampling steps.
|
255 |
- Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
|
256 |
- AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
|
@@ -258,8 +280,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
258 |
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
259 |
</a>
|
260 |
|
261 |
-
**TODO:**
|
262 |
-
- ControlNet integration.
|
263 |
"""
|
264 |
)
|
265 |
|
@@ -270,6 +290,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
270 |
file_types=["image"],
|
271 |
file_count="multiple"
|
272 |
)
|
|
|
273 |
image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
|
274 |
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
|
275 |
with gr.Column(visible=False) as clear_button_column:
|
@@ -280,6 +301,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
280 |
file_types=["image"],
|
281 |
file_count="multiple"
|
282 |
)
|
|
|
283 |
init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
|
284 |
# Although there's only one image, we still use columns=3, to scale down the image size.
|
285 |
# Otherwise it will occupy the full width, and the gallery won't show the whole image.
|
@@ -288,41 +310,47 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
288 |
init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
|
289 |
|
290 |
with gr.Column(visible=True) as init_gen_button_column:
|
291 |
-
gen_init = gr.Button(value="Generate
|
292 |
with gr.Column(visible=False) as init_clear_button_column:
|
293 |
remove_init_and_reupload = gr.ClearButton(value="Upload an old init image", components=init_img_files, size="sm")
|
294 |
|
295 |
prompt = gr.Dropdown(label="Prompt",
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
"
|
302 |
-
"
|
303 |
-
"
|
304 |
-
"
|
305 |
-
"
|
306 |
-
"
|
307 |
-
"playing guitar on a boat, ocean waves",
|
308 |
-
"with a passion for reading, curled up with a book in a cozy nook near a window",
|
309 |
-
|
310 |
-
"
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
init_image_strength = gr.Slider(
|
314 |
label="Init Image Strength",
|
315 |
info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
|
316 |
minimum=0,
|
317 |
-
maximum=
|
318 |
-
step=0.
|
319 |
value=1,
|
320 |
)
|
321 |
init_image_final_weight = gr.Slider(
|
322 |
-
label="Final
|
323 |
info="How much the init image should influence the end of the video",
|
324 |
minimum=0,
|
325 |
-
maximum=
|
326 |
step=0.025,
|
327 |
value=0.1,
|
328 |
)
|
@@ -331,7 +359,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
331 |
label="Base Model Style Type",
|
332 |
info="Switching the base model type will take 10~20 seconds to reload the model",
|
333 |
value=args.model_style_type.capitalize(),
|
334 |
-
choices=["Realistic", "Anime"
|
335 |
allow_custom_value=False,
|
336 |
filterable=False,
|
337 |
)
|
@@ -344,15 +372,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
344 |
value=args.guidance_scale,
|
345 |
)
|
346 |
|
347 |
-
do_neg_id_prompt_weight = gr.Slider(
|
348 |
-
label="Weight of ID prompt in the negative prompt",
|
349 |
-
minimum=0.0,
|
350 |
-
maximum=0.9,
|
351 |
-
step=0.1,
|
352 |
-
value=args.do_neg_id_prompt_weight,
|
353 |
-
visible=True
|
354 |
-
)
|
355 |
-
|
356 |
seed = gr.Slider(
|
357 |
label="Seed",
|
358 |
minimum=0,
|
@@ -365,11 +384,6 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
365 |
value=True,
|
366 |
info="Uncheck for reproducible results")
|
367 |
|
368 |
-
negative_prompt = gr.Textbox(
|
369 |
-
label="Negative Prompt",
|
370 |
-
placeholder="low quality",
|
371 |
-
value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream",
|
372 |
-
)
|
373 |
num_steps = gr.Slider(
|
374 |
label="Number of sampling steps. More steps for better composition, but longer time.",
|
375 |
minimum=30,
|
@@ -394,36 +408,42 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
394 |
is_adaface_enabled = gr.Checkbox(label="Enable AdaFace",
|
395 |
info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).",
|
396 |
value=True)
|
397 |
-
adaface_ckpt_path = gr.Textbox(
|
398 |
-
label="AdaFace checkpoint path",
|
399 |
-
placeholder=args.adaface_ckpt_path,
|
400 |
-
value=args.adaface_ckpt_path,
|
401 |
-
)
|
402 |
|
403 |
adaface_power_scale = gr.Slider(
|
404 |
label="AdaFace Embedding Power Scale",
|
405 |
info="Increase this scale slightly only if the face is defocused or the face details are not clear",
|
406 |
minimum=0.8,
|
407 |
maximum=1.2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
step=0.1,
|
409 |
value=1,
|
|
|
410 |
)
|
411 |
|
412 |
image_embed_cfg_begin_scale = gr.Slider(
|
413 |
label="ID-Animator Image Embedding Initial Scale",
|
414 |
info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
|
415 |
-
minimum=0
|
416 |
-
maximum=1
|
417 |
step=0.1,
|
418 |
-
value=
|
419 |
)
|
420 |
image_embed_cfg_end_scale = gr.Slider(
|
421 |
label="ID-Animator Image Embedding Final Scale",
|
422 |
info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
|
423 |
-
minimum=0
|
424 |
-
maximum=1
|
425 |
step=0.1,
|
426 |
-
value=0.
|
427 |
)
|
428 |
|
429 |
id_animator_anneal_steps = gr.Slider(
|
@@ -431,18 +451,15 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
431 |
minimum=0,
|
432 |
maximum=40,
|
433 |
step=1,
|
434 |
-
value=
|
435 |
visible=True,
|
436 |
)
|
437 |
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
step=0.1,
|
444 |
-
value=1,
|
445 |
-
)
|
446 |
|
447 |
with gr.Column():
|
448 |
result_video = gr.Video(label="Generated Animation", interactive=False)
|
@@ -462,8 +479,8 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
462 |
outputs=seed,
|
463 |
queue=False,
|
464 |
api_name=False,
|
465 |
-
).then(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt,
|
466 |
-
guidance_scale
|
467 |
outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
|
468 |
uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
|
469 |
|
@@ -478,9 +495,10 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
478 |
fn=generate_video,
|
479 |
inputs=[image_container, files,
|
480 |
init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
481 |
-
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
482 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
483 |
-
|
|
|
484 |
outputs=[result_video]
|
485 |
)
|
486 |
|
|
|
24 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
25 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
26 |
parser.add_argument('--adaface_ckpt_path', type=str,
|
27 |
+
default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt')
|
28 |
+
parser.add_argument('--model_style_type', type=str, default='photorealistic',
|
29 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
30 |
+
parser.add_argument("--guidance_scale", type=float, default=8.0,
|
31 |
help="The guidance scale for the diffusion model. Default: 8.0")
|
|
|
|
|
32 |
|
33 |
parser.add_argument('--gpu', type=int, default=None)
|
34 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
|
39 |
seed = random.randint(0, MAX_SEED)
|
40 |
return seed
|
41 |
|
42 |
+
def is_running_on_spaces():
|
43 |
+
return os.getenv("SPACE_ID") is not None
|
44 |
+
|
45 |
+
from huggingface_hub import snapshot_download
|
46 |
+
large_files = ["models/*", "models/**/*"]
|
47 |
+
snapshot_download(repo_id="adaface-neurips/adaface-animate-models",
|
48 |
+
repo_type="model", allow_patterns=large_files, local_dir=".")
|
49 |
+
os.makedirs("/tmp/gradio", exist_ok=True)
|
50 |
+
|
51 |
# model = load_model()
|
52 |
# This FaceAnalysis is just to crop the face areas from the uploaded images,
|
53 |
# and is independent of the adaface FaceAnalysis apps.
|
54 |
+
app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider'])
|
55 |
app.prepare(ctx_id=0, det_size=(320, 320))
|
56 |
+
|
57 |
+
if is_running_on_spaces():
|
58 |
+
device = 'cuda:0'
|
59 |
+
else:
|
60 |
+
if args.gpu is None:
|
61 |
+
device = "cuda"
|
62 |
+
else:
|
63 |
+
device = f"cuda:{args.gpu}"
|
64 |
+
|
65 |
+
print(f"Device: {device}")
|
66 |
|
67 |
global adaface, id_animator
|
68 |
|
69 |
+
adaface_base_model_path = model_style_type2base_model_path["photorealistic"]
|
70 |
id_animator = load_model(model_style_type=args.model_style_type, device='cpu')
|
71 |
+
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=adaface_base_model_path,
|
72 |
adaface_encoder_types=args.adaface_encoder_types,
|
73 |
+
adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu')
|
74 |
|
75 |
basedir = os.getcwd()
|
76 |
savedir = os.path.join(basedir,'samples')
|
|
|
95 |
return data.index
|
96 |
|
97 |
@spaces.GPU
|
98 |
+
def gen_init_images(uploaded_image_paths, prompt, highlight_face, guidance_scale, out_image_count=4):
|
99 |
if uploaded_image_paths is None:
|
100 |
print("No image uploaded")
|
101 |
return None, None, None
|
|
|
108 |
# [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
|
109 |
# Extract the file paths.
|
110 |
uploaded_image_paths = [path[0] for path in uploaded_image_paths]
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
adaface_subj_embs = \
|
114 |
+
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
115 |
+
update_text_encoder=True)
|
116 |
|
117 |
if adaface_subj_embs is None:
|
118 |
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
|
|
120 |
# Generate two images each time for the user to select from.
|
121 |
noise = torch.randn(out_image_count, 3, 512, 512)
|
122 |
|
123 |
+
if highlight_face and "face portrait" not in prompt:
|
|
|
124 |
if "portrait" in prompt:
|
125 |
# Enhance the face features by replacing "portrait" with "face portrait".
|
126 |
prompt = prompt.replace("portrait", "face portrait")
|
127 |
else:
|
128 |
prompt = "face portrait, " + prompt
|
129 |
|
130 |
+
guidance_scale = min(guidance_scale, 5)
|
131 |
+
|
132 |
# samples: A list of PIL Image instances.
|
133 |
with torch.no_grad():
|
134 |
samples = adaface(noise, prompt, placeholder_tokens_pos='append',
|
135 |
guidance_scale=guidance_scale,
|
136 |
+
out_image_count=out_image_count,
|
137 |
+
repeat_prompt_for_each_encoder=True,
|
138 |
+
verbose=True)
|
139 |
|
140 |
face_paths = []
|
141 |
for sample in samples:
|
|
|
151 |
@spaces.GPU(duration=90)
|
152 |
def generate_video(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
|
153 |
init_image_strength, init_image_final_weight,
|
154 |
+
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
155 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
156 |
+
highlight_face, is_adaface_enabled, adaface_power_scale,
|
157 |
id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)):
|
158 |
|
159 |
global adaface, id_animator
|
|
|
163 |
if prompt is None:
|
164 |
prompt = ""
|
165 |
|
166 |
+
#prompt = prompt + " 8k uhd, high quality"
|
167 |
+
#if " shot" not in prompt:
|
168 |
+
# prompt = prompt + ", medium shot"
|
169 |
+
|
170 |
+
if highlight_face and "face portrait" not in prompt:
|
171 |
+
if "portrait" in prompt:
|
172 |
+
# Enhance the face features by replacing "portrait" with "face portrait".
|
173 |
+
prompt = prompt.replace("portrait", "face portrait")
|
174 |
+
else:
|
175 |
+
prompt = "face portrait, " + prompt
|
176 |
+
|
177 |
prompt_img_lists=[]
|
178 |
for path in uploaded_image_paths:
|
179 |
img = cv2.imread(path)
|
|
|
185 |
# prompt_img_lists is a list of PIL images.
|
186 |
prompt_img_lists.append(load_image(face_path).resize((224,224)))
|
187 |
|
188 |
+
if adaface is None or (not is_adaface_enabled):
|
189 |
adaface_prompt_embeds, negative_prompt_embeds = None, None
|
190 |
+
# ID-Animator Image Embedding Initial and End Scales
|
191 |
image_embed_cfg_scales = (1, 1)
|
192 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
with torch.no_grad():
|
194 |
adaface_subj_embs = \
|
195 |
adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None,
|
|
|
198 |
# adaface_prompt_embeds: [1, 77, 768].
|
199 |
adaface_prompt_embeds, negative_prompt_embeds, _, _ = \
|
200 |
adaface.encode_prompt(prompt, placeholder_tokens_pos='append',
|
201 |
+
repeat_prompt_for_each_encoder=True,
|
202 |
verbose=True)
|
203 |
|
204 |
+
# ID-Animator Image Embedding Initial and End Scales
|
205 |
image_embed_cfg_scales = (image_embed_cfg_begin_scale, image_embed_cfg_end_scale)
|
206 |
|
207 |
# init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
|
|
|
221 |
prompt = prompt,
|
222 |
negative_prompt = negative_prompt,
|
223 |
adaface_prompt_embeds = (adaface_prompt_embeds, negative_prompt_embeds),
|
224 |
+
# adaface_power_scale is not so useful, and when it's set >= 1.2, weird artifacts appear.
|
225 |
+
# Here it's limited to 1~1.1.
|
226 |
adaface_power_scale = adaface_power_scale,
|
227 |
num_inference_steps = num_steps,
|
228 |
id_animator_anneal_steps = id_animator_anneal_steps,
|
|
|
239 |
save_videos_grid(sample, save_sample_path)
|
240 |
return save_sample_path
|
241 |
|
242 |
+
def check_prompt_and_model_type(prompt, model_style_type, progress=gr.Progress()):
|
243 |
global adaface, id_animator
|
244 |
|
245 |
model_style_type = model_style_type.lower()
|
|
|
259 |
with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
260 |
gr.Markdown(
|
261 |
"""
|
262 |
+
# AdaFace-Animate: Zero-Shot Human Subject-Driven Video Generation
|
263 |
"""
|
264 |
)
|
265 |
gr.Markdown(
|
266 |
"""
|
267 |
+
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Text-space Face Encoder for Face Synthesis and Processing</b>.<br>
|
268 |
|
269 |
+
❗️**NOTE**❗️
|
270 |
+
- Support switching between three model styles: **Realistic**, **Photorealistic** and **Anime**. **Realistic** is less realistic than **Photorealistic** but has better motions.
|
271 |
+
- If you change the model style, please wait for 20~30 seconds for loading new model weight before the model begins to generate images/videos.
|
272 |
|
273 |
❗️**Tips**❗️
|
274 |
- You can upload one or more subject images for generating ID-specific video.
|
275 |
+
- If the face loses focus, try enabling "Highlight face".
|
|
|
276 |
- If the motion is weird, e.g., the prompt is "... running", try increasing the number of sampling steps.
|
277 |
- Usage explanations and demos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
|
278 |
- AdaFace Text-to-Image: <a href="https://huggingface.co/spaces/adaface-neurips/adaface" style="display: inline-flex; align-items: center;">
|
|
|
280 |
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
281 |
</a>
|
282 |
|
|
|
|
|
283 |
"""
|
284 |
)
|
285 |
|
|
|
290 |
file_types=["image"],
|
291 |
file_count="multiple"
|
292 |
)
|
293 |
+
files.GRADIO_CACHE = "/tmp/gradio"
|
294 |
image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
|
295 |
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=2, height=300)
|
296 |
with gr.Column(visible=False) as clear_button_column:
|
|
|
301 |
file_types=["image"],
|
302 |
file_count="multiple"
|
303 |
)
|
304 |
+
init_img_files.GRADIO_CACHE = "/tmp/gradio"
|
305 |
init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
|
306 |
# Although there's only one image, we still use columns=3, to scale down the image size.
|
307 |
# Otherwise it will occupy the full width, and the gallery won't show the whole image.
|
|
|
310 |
init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
|
311 |
|
312 |
with gr.Column(visible=True) as init_gen_button_column:
|
313 |
+
gen_init = gr.Button(value="Generate 4 new init images")
|
314 |
with gr.Column(visible=False) as init_clear_button_column:
|
315 |
remove_init_and_reupload = gr.ClearButton(value="Upload an old init image", components=init_img_files, size="sm")
|
316 |
|
317 |
prompt = gr.Dropdown(label="Prompt",
|
318 |
+
info="Try something like 'walking on the beach'.",
|
319 |
+
value="highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
320 |
+
allow_custom_value=True,
|
321 |
+
choices=[
|
322 |
+
"portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
323 |
+
"portrait, walking on the beach, sunset",
|
324 |
+
"portrait, in a white apron and chef hat, garnishing a gourmet dish",
|
325 |
+
"portrait, dancing pose among folks in a park, waving hands",
|
326 |
+
"portrait, in iron man costume, the sky ablaze with hues of orange and purple",
|
327 |
+
"portrait, jedi wielding a lightsaber, star wars",
|
328 |
+
"portrait, night view of tokyo street, neon light",
|
329 |
+
"portrait, playing guitar on a boat, ocean waves",
|
330 |
+
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
|
331 |
+
"portrait, celebrating new year, fireworks",
|
332 |
+
"portrait, running pose in a park",
|
333 |
+
"portrait, in space suit, space helmet, walking on mars",
|
334 |
+
"portrait, in superman costume, the sky ablaze with hues of orange and purple"
|
335 |
+
])
|
336 |
+
|
337 |
+
highlight_face = gr.Checkbox(label="Highlight face", value=False,
|
338 |
+
info="Enhance the facial features by prepending 'face portrait' to the prompt",
|
339 |
+
visible=True)
|
340 |
+
|
341 |
init_image_strength = gr.Slider(
|
342 |
label="Init Image Strength",
|
343 |
info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
|
344 |
minimum=0,
|
345 |
+
maximum=3,
|
346 |
+
step=0.1,
|
347 |
value=1,
|
348 |
)
|
349 |
init_image_final_weight = gr.Slider(
|
350 |
+
label="Final Strength of the Init Image",
|
351 |
info="How much the init image should influence the end of the video",
|
352 |
minimum=0,
|
353 |
+
maximum=2,
|
354 |
step=0.025,
|
355 |
value=0.1,
|
356 |
)
|
|
|
359 |
label="Base Model Style Type",
|
360 |
info="Switching the base model type will take 10~20 seconds to reload the model",
|
361 |
value=args.model_style_type.capitalize(),
|
362 |
+
choices=["Realistic", "Anime", "Photorealistic"],
|
363 |
allow_custom_value=False,
|
364 |
filterable=False,
|
365 |
)
|
|
|
372 |
value=args.guidance_scale,
|
373 |
)
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
seed = gr.Slider(
|
376 |
label="Seed",
|
377 |
minimum=0,
|
|
|
384 |
value=True,
|
385 |
info="Uncheck for reproducible results")
|
386 |
|
|
|
|
|
|
|
|
|
|
|
387 |
num_steps = gr.Slider(
|
388 |
label="Number of sampling steps. More steps for better composition, but longer time.",
|
389 |
minimum=30,
|
|
|
408 |
is_adaface_enabled = gr.Checkbox(label="Enable AdaFace",
|
409 |
info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).",
|
410 |
value=True)
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
adaface_power_scale = gr.Slider(
|
413 |
label="AdaFace Embedding Power Scale",
|
414 |
info="Increase this scale slightly only if the face is defocused or the face details are not clear",
|
415 |
minimum=0.8,
|
416 |
maximum=1.2,
|
417 |
+
step=0.05,
|
418 |
+
value=1.1,
|
419 |
+
visible=True,
|
420 |
+
)
|
421 |
+
|
422 |
+
attn_scale = gr.Slider(
|
423 |
+
label="Attention Processor Scale",
|
424 |
+
info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" ,
|
425 |
+
minimum=0.5,
|
426 |
+
maximum=2,
|
427 |
step=0.1,
|
428 |
value=1,
|
429 |
+
visible=True
|
430 |
)
|
431 |
|
432 |
image_embed_cfg_begin_scale = gr.Slider(
|
433 |
label="ID-Animator Image Embedding Initial Scale",
|
434 |
info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
|
435 |
+
minimum=0,
|
436 |
+
maximum=1,
|
437 |
step=0.1,
|
438 |
+
value=0.5,
|
439 |
)
|
440 |
image_embed_cfg_end_scale = gr.Slider(
|
441 |
label="ID-Animator Image Embedding Final Scale",
|
442 |
info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
|
443 |
+
minimum=0,
|
444 |
+
maximum=1,
|
445 |
step=0.1,
|
446 |
+
value=0.1,
|
447 |
)
|
448 |
|
449 |
id_animator_anneal_steps = gr.Slider(
|
|
|
451 |
minimum=0,
|
452 |
maximum=40,
|
453 |
step=1,
|
454 |
+
value=40,
|
455 |
visible=True,
|
456 |
)
|
457 |
|
458 |
+
negative_prompt = gr.Textbox(
|
459 |
+
label="Negative Prompt",
|
460 |
+
placeholder="low quality",
|
461 |
+
value="deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream, nude, naked, nsfw, topless, bare breasts",
|
462 |
+
)
|
|
|
|
|
|
|
463 |
|
464 |
with gr.Column():
|
465 |
result_video = gr.Video(label="Generated Animation", interactive=False)
|
|
|
479 |
outputs=seed,
|
480 |
queue=False,
|
481 |
api_name=False,
|
482 |
+
).then(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, highlight_face,
|
483 |
+
guidance_scale],
|
484 |
outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
|
485 |
uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
|
486 |
|
|
|
495 |
fn=generate_video,
|
496 |
inputs=[image_container, files,
|
497 |
init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
498 |
+
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
499 |
seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale,
|
500 |
+
highlight_face, is_adaface_enabled,
|
501 |
+
adaface_power_scale, id_animator_anneal_steps],
|
502 |
outputs=[result_video]
|
503 |
)
|
504 |
|
faceadapter/face_adapter.py
CHANGED
@@ -315,10 +315,10 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora):
|
|
315 |
negative_prompt_embeds0 = negative_prompt_embeds_
|
316 |
adaface_prompt_embeds, negative_prompt_embeds_ = adaface_prompt_embeds
|
317 |
# self.torch_type == torch.float16. adaface_prompt_embeds is torch.float32.
|
318 |
-
prompt_embeds_
|
319 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
|
320 |
if adaface_power_scale != 1.0:
|
321 |
-
prompt_embeds_ = prompt_embeds_ * adaface_power_scale
|
322 |
|
323 |
# Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling.
|
324 |
image_prompt_embeds_begin = image_prompt_embeds * image_embed_cfg_scales[0] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[0])
|
|
|
315 |
negative_prompt_embeds0 = negative_prompt_embeds_
|
316 |
adaface_prompt_embeds, negative_prompt_embeds_ = adaface_prompt_embeds
|
317 |
# self.torch_type == torch.float16. adaface_prompt_embeds is torch.float32.
|
318 |
+
prompt_embeds_ = adaface_prompt_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
|
319 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(num_samples, 1, 1).to(dtype=self.torch_type)
|
320 |
if adaface_power_scale != 1.0:
|
321 |
+
prompt_embeds_ = prompt_embeds_ * adaface_power_scale + negative_prompt_embeds0 * (1 - adaface_power_scale)
|
322 |
|
323 |
# Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling.
|
324 |
image_prompt_embeds_begin = image_prompt_embeds * image_embed_cfg_scales[0] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[0])
|
infer.py
CHANGED
@@ -17,7 +17,7 @@ model_style_type2base_model_path = {
|
|
17 |
|
18 |
def load_model(model_style_type="realistic", device="cuda"):
|
19 |
inference_config = "inference-v2.yaml"
|
20 |
-
sd_version = "animatediff/sd"
|
21 |
id_ckpt = "models/animator.ckpt"
|
22 |
image_encoder_path = "models/image_encoder"
|
23 |
|
@@ -73,7 +73,7 @@ def load_model(model_style_type="realistic", device="cuda"):
|
|
73 |
|
74 |
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
|
75 |
# print(vae)
|
76 |
-
#vae ->to_q,to_k,to_v
|
77 |
# print(converted_vae_checkpoint)
|
78 |
convert_vae_keys = list(converted_vae_checkpoint.keys())
|
79 |
for key in convert_vae_keys:
|
@@ -93,7 +93,8 @@ def load_model(model_style_type="realistic", device="cuda"):
|
|
93 |
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
94 |
|
95 |
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
|
96 |
-
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
|
|
97 |
|
98 |
pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
|
99 |
|
|
|
17 |
|
18 |
def load_model(model_style_type="realistic", device="cuda"):
|
19 |
inference_config = "inference-v2.yaml"
|
20 |
+
sd_version = "models/animatediff/sd"
|
21 |
id_ckpt = "models/animator.ckpt"
|
22 |
image_encoder_path = "models/image_encoder"
|
23 |
|
|
|
73 |
|
74 |
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
|
75 |
# print(vae)
|
76 |
+
# vae ->to_q, to_k, to_v
|
77 |
# print(converted_vae_checkpoint)
|
78 |
convert_vae_keys = list(converted_vae_checkpoint.keys())
|
79 |
for key in convert_vae_keys:
|
|
|
93 |
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
94 |
|
95 |
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
|
96 |
+
m, u = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
97 |
+
print(f"### custom unet missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
98 |
|
99 |
pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
|
100 |
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
diffusers==0.29.2
|
2 |
-
torch
|
3 |
torchvision
|
4 |
imageio
|
5 |
imageio-ffmpeg
|
|
|
1 |
diffusers==0.29.2
|
2 |
+
torch==2.4.1
|
3 |
torchvision
|
4 |
imageio
|
5 |
imageio-ffmpeg
|