Files changed (5) hide show
  1. README (4).md +13 -0
  2. app (8).py +316 -0
  3. gitattributes (8) +36 -0
  4. pulid_pipeline_flux.py +188 -0
  5. requirements (2).txt +20 -0
README (4).md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PuLID-FLUX
3
+ emoji: 🤗
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app (8).py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import time
3
+ import os
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import Image
9
+ from transformers import pipeline
10
+
11
+ from flux.cli import SamplingOptions
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
13
+ from flux.util import load_ae, load_clip, load_flow_model, load_t5
14
+ from pulid.pipeline_flux import PuLIDPipeline
15
+ from pulid.utils import resize_numpy_image_long
16
+
17
+ NSFW_THRESHOLD = 0.85
18
+
19
+ def get_models(name: str, device: torch.device, offload: bool):
20
+ t5 = load_t5(device, max_length=128)
21
+ clip = load_clip(device)
22
+ model = load_flow_model(name, device="cpu" if offload else device)
23
+ model.eval()
24
+ ae = load_ae(name, device="cpu" if offload else device)
25
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
26
+ return model, ae, t5, clip, nsfw_classifier
27
+
28
+
29
+ class FluxGenerator:
30
+ def __init__(self):
31
+ self.device = torch.device('cuda')
32
+ self.offload = False
33
+ self.model_name = 'flux-dev'
34
+ self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
35
+ self.model_name,
36
+ device=self.device,
37
+ offload=self.offload,
38
+ )
39
+ self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16)
40
+ self.pulid_model.load_pretrain()
41
+
42
+
43
+ flux_generator = FluxGenerator()
44
+
45
+
46
+ @spaces.GPU
47
+ @torch.inference_mode()
48
+ def generate_image(
49
+ prompt,
50
+ id_image,
51
+ start_step,
52
+ guidance,
53
+ seed,
54
+ true_cfg,
55
+ width=896,
56
+ height=1152,
57
+ num_steps=20,
58
+ id_weight=1.0,
59
+ neg_prompt="bad quality, worst quality, text, signature, watermark, extra limbs",
60
+ timestep_to_start_cfg=1,
61
+ max_sequence_length=128,
62
+ ):
63
+ flux_generator.t5.max_length = max_sequence_length
64
+
65
+ seed = int(seed)
66
+ if seed == -1:
67
+ seed = None
68
+
69
+ opts = SamplingOptions(
70
+ prompt=prompt,
71
+ width=width,
72
+ height=height,
73
+ num_steps=num_steps,
74
+ guidance=guidance,
75
+ seed=seed,
76
+ )
77
+
78
+ if opts.seed is None:
79
+ opts.seed = torch.Generator(device="cpu").seed()
80
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
81
+ t0 = time.perf_counter()
82
+
83
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
84
+
85
+ if id_image is not None:
86
+ id_image = resize_numpy_image_long(id_image, 1024)
87
+ id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
88
+ else:
89
+ id_embeddings = None
90
+ uncond_id_embeddings = None
91
+
92
+
93
+ # prepare input
94
+ x = get_noise(
95
+ 1,
96
+ opts.height,
97
+ opts.width,
98
+ device=flux_generator.device,
99
+ dtype=torch.bfloat16,
100
+ seed=opts.seed,
101
+ )
102
+ timesteps = get_schedule(
103
+ opts.num_steps,
104
+ x.shape[-1] * x.shape[-2] // 4,
105
+ shift=True,
106
+ )
107
+
108
+ if flux_generator.offload:
109
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device)
110
+ inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
111
+ inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
112
+
113
+ # offload TEs to CPU, load model to gpu
114
+ if flux_generator.offload:
115
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu()
116
+ torch.cuda.empty_cache()
117
+ flux_generator.model = flux_generator.model.to(flux_generator.device)
118
+
119
+ # denoise initial noise
120
+ x = denoise(
121
+ flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
122
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
123
+ timestep_to_start_cfg=timestep_to_start_cfg,
124
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
125
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
126
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
127
+ )
128
+
129
+ # offload model, load autoencoder to gpu
130
+ if flux_generator.offload:
131
+ flux_generator.model.cpu()
132
+ torch.cuda.empty_cache()
133
+ flux_generator.ae.decoder.to(x.device)
134
+
135
+ # decode latents to pixel space
136
+ x = unpack(x.float(), opts.height, opts.width)
137
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
138
+ x = flux_generator.ae.decode(x)
139
+
140
+ if flux_generator.offload:
141
+ flux_generator.ae.decoder.cpu()
142
+ torch.cuda.empty_cache()
143
+
144
+ t1 = time.perf_counter()
145
+
146
+ print(f"Done in {t1 - t0:.1f}s.")
147
+ # bring into PIL format
148
+ x = x.clamp(-1, 1)
149
+ # x = embed_watermark(x.float())
150
+ x = rearrange(x[0], "c h w -> h w c")
151
+
152
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
153
+ nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0]
154
+ if nsfw_score < NSFW_THRESHOLD:
155
+ return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
156
+ else:
157
+ return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content",
158
+ flux_generator.pulid_model.debug_img_list)
159
+
160
+ _HEADER_ = '''
161
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
162
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
163
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
164
+ </div>
165
+
166
+ ❗️❗️❗️**Tips:**
167
+
168
+ - `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
169
+ - `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to the [doc](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#useful-tips).
170
+ - `Learn more about the model:` please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
171
+ - `Examples:` we provide some examples (we have cached them, so just click them to see what the model can do) in the bottom, you can try these example prompts first
172
+ ''' # noqa E501
173
+
174
+ _CITE_ = r"""
175
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
176
+ ---
177
+ 📧 **Contact**
178
+ If you have any questions or feedbacks, feel free to open a discussion or contact <b>wuyanze123@gmail.com</b>.
179
+ """ # noqa E501
180
+
181
+ _DEV_DES = '''
182
+ * Please refer to our repo for instructions on running gradio demo [locally](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
183
+ '''
184
+
185
+
186
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
187
+ offload: bool = False):
188
+ with gr.Blocks() as demo:
189
+ with gr.Accordion("For Developers", open=False):
190
+ gr.Markdown(_DEV_DES)
191
+
192
+ gr.Markdown(_HEADER_)
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
197
+ id_image = gr.Image(label="ID Image")
198
+ id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
199
+
200
+ width = gr.Slider(256, 1536, 896, step=16, label="Width")
201
+ height = gr.Slider(256, 1536, 1152, step=16, label="Height")
202
+ num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
203
+ start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
204
+ guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
205
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
206
+ max_sequence_length = gr.Slider(128, 512, 128, step=128,
207
+ label="max_sequence_length for prompt (T5), small will be faster")
208
+
209
+ with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
210
+ neg_prompt = gr.Textbox(
211
+ label="Negative Prompt",
212
+ value="bad quality, worst quality, text, signature, watermark, extra limbs")
213
+ true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
214
+ timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
215
+
216
+ generate_btn = gr.Button("Generate")
217
+
218
+ with gr.Column():
219
+ output_image = gr.Image(label="Generated Image", format='png')
220
+ seed_output = gr.Textbox(label="Used Seed")
221
+ intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
222
+ gr.Markdown(_CITE_)
223
+
224
+ with gr.Row(), gr.Column():
225
+ gr.Markdown("## Examples")
226
+ example_inps = [
227
+ [
228
+ 'a woman holding sign with glowing green text \"PuLID for FLUX\"',
229
+ 'example_inputs/liuyifei.png',
230
+ 4, 4, 2680261499100305976, 1
231
+ ],
232
+ [
233
+ 'portrait, side view',
234
+ 'example_inputs/liuyifei.png',
235
+ 4, 4, 180825677246321775, 1
236
+ ],
237
+ [
238
+ 'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
239
+ 'example_inputs/liuyifei.png',
240
+ 4, 4, 16942328329935464989, 1
241
+ ],
242
+ [
243
+ 'a young child is eating Icecream',
244
+ 'example_inputs/liuyifei.png',
245
+ 4, 4, 4527590969012358757, 1
246
+ ],
247
+ [
248
+ 'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
249
+ 'example_inputs/pengwei.jpg',
250
+ 4, 4, 6273700647573240909, 1
251
+ ],
252
+ [
253
+ 'portrait, candle light',
254
+ 'example_inputs/pengwei.jpg',
255
+ 4, 4, 17522759474323955700, 1
256
+ ],
257
+ [
258
+ 'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
259
+ 'example_inputs/pengwei.jpg',
260
+ 4, 4, 17733156847328193625, 1
261
+ ],
262
+ [
263
+ 'American Comics, 1boy',
264
+ 'example_inputs/pengwei.jpg',
265
+ 1, 4, 13223174453874179686, 1
266
+ ],
267
+ [
268
+ 'portrait, pixar',
269
+ 'example_inputs/pengwei.jpg',
270
+ 1, 4, 9445036702517583939, 1
271
+ ],
272
+ ]
273
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
274
+ label='fake CFG', cache_examples='lazy', outputs=[output_image, seed_output],
275
+ fn=generate_image)
276
+
277
+ example_inps = [
278
+ [
279
+ 'portrait, made of ice sculpture',
280
+ 'example_inputs/lecun.jpg',
281
+ 1, 1, 7717391560531186077, 5
282
+ ],
283
+ ]
284
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
285
+ label='true CFG', cache_examples='lazy', outputs=[output_image, seed_output],
286
+ fn=generate_image)
287
+
288
+ generate_btn.click(
289
+ fn=generate_image,
290
+ inputs=[prompt, id_image, start_step, guidance, seed, true_cfg, width, height, num_steps, id_weight,
291
+ neg_prompt, timestep_to_start_cfg, max_sequence_length],
292
+ outputs=[output_image, seed_output, intermediate_output],
293
+ )
294
+
295
+ return demo
296
+
297
+
298
+ if __name__ == "__main__":
299
+ import argparse
300
+
301
+ parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
302
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
303
+ help="currently only support flux-dev")
304
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
305
+ help="Device to use")
306
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
307
+ parser.add_argument("--port", type=int, default=8080, help="Port to use")
308
+ parser.add_argument("--dev", action='store_true', help="Development mode")
309
+ parser.add_argument("--pretrained_model", type=str, help='for development')
310
+ args = parser.parse_args()
311
+
312
+ import huggingface_hub
313
+ huggingface_hub.login(os.getenv('HF_TOKEN'))
314
+
315
+ demo = create_demo(args, args.name, args.device, args.offload)
316
+ demo.launch()
gitattributes (8) ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_inputs/pengwei.jpg filter=lfs diff=lfs merge=lfs -text
pulid_pipeline_flux.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import cv2
4
+ import insightface
5
+ import torch
6
+ import torch.nn as nn
7
+ from basicsr.utils import img2tensor, tensor2img
8
+ from facexlib.parsing import init_parsing_model
9
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ from insightface.app import FaceAnalysis
12
+ from safetensors.torch import load_file
13
+ from torchvision.transforms import InterpolationMode
14
+ from torchvision.transforms.functional import normalize, resize
15
+
16
+ from eva_clip import create_model_and_transforms
17
+ from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
18
+ from pulid.encoders_flux import IDFormer, PerceiverAttentionCA
19
+
20
+
21
+ class PuLIDPipeline(nn.Module):
22
+ def __init__(self, dit, device, weight_dtype=torch.bfloat16, *args, **kwargs):
23
+ super().__init__()
24
+ self.device = device
25
+ self.weight_dtype = weight_dtype
26
+ double_interval = 2
27
+ single_interval = 4
28
+
29
+ # init encoder
30
+ self.pulid_encoder = IDFormer().to(self.device, self.weight_dtype)
31
+
32
+ num_ca = 19 // double_interval + 38 // single_interval
33
+ if 19 % double_interval != 0:
34
+ num_ca += 1
35
+ if 38 % single_interval != 0:
36
+ num_ca += 1
37
+ self.pulid_ca = nn.ModuleList([
38
+ PerceiverAttentionCA().to(self.device, self.weight_dtype) for _ in range(num_ca)
39
+ ])
40
+
41
+ dit.pulid_ca = self.pulid_ca
42
+ dit.pulid_double_interval = double_interval
43
+ dit.pulid_single_interval = single_interval
44
+
45
+ # preprocessors
46
+ # face align and parsing
47
+ self.face_helper = FaceRestoreHelper(
48
+ upscale_factor=1,
49
+ face_size=512,
50
+ crop_ratio=(1, 1),
51
+ det_model='retinaface_resnet50',
52
+ save_ext='png',
53
+ device=self.device,
54
+ )
55
+ self.face_helper.face_parse = None
56
+ self.face_helper.face_parse = init_parsing_model(model_name='bisenet', device=self.device)
57
+ # clip-vit backbone
58
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True)
59
+ model = model.visual
60
+ self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype)
61
+ eva_transform_mean = getattr(self.clip_vision_model, 'image_mean', OPENAI_DATASET_MEAN)
62
+ eva_transform_std = getattr(self.clip_vision_model, 'image_std', OPENAI_DATASET_STD)
63
+ if not isinstance(eva_transform_mean, (list, tuple)):
64
+ eva_transform_mean = (eva_transform_mean,) * 3
65
+ if not isinstance(eva_transform_std, (list, tuple)):
66
+ eva_transform_std = (eva_transform_std,) * 3
67
+ self.eva_transform_mean = eva_transform_mean
68
+ self.eva_transform_std = eva_transform_std
69
+ # antelopev2
70
+ snapshot_download('DIAMONIK7777/antelopev2', local_dir='models/antelopev2')
71
+ self.app = FaceAnalysis(
72
+ name='antelopev2', root='.', providers=['CPUExecutionProvider']
73
+ )
74
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
75
+ self.handler_ante = insightface.model_zoo.get_model('models/antelopev2/glintr100.onnx', providers=['CPUExecutionProvider'])
76
+ self.handler_ante.prepare(ctx_id=0)
77
+
78
+ gc.collect()
79
+ torch.cuda.empty_cache()
80
+
81
+ # self.load_pretrain()
82
+
83
+ # other configs
84
+ self.debug_img_list = []
85
+
86
+ def load_pretrain(self, pretrain_path=None):
87
+ hf_hub_download('guozinan/PuLID', 'pulid_flux_v0.9.0.safetensors', local_dir='models')
88
+ ckpt_path = 'models/pulid_flux_v0.9.0.safetensors'
89
+ if pretrain_path is not None:
90
+ ckpt_path = pretrain_path
91
+ state_dict = load_file(ckpt_path)
92
+ state_dict_dict = {}
93
+ for k, v in state_dict.items():
94
+ module = k.split('.')[0]
95
+ state_dict_dict.setdefault(module, {})
96
+ new_k = k[len(module) + 1:]
97
+ state_dict_dict[module][new_k] = v
98
+
99
+ for module in state_dict_dict:
100
+ print(f'loading from {module}')
101
+ getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
102
+
103
+ del state_dict
104
+ del state_dict_dict
105
+
106
+ def to_gray(self, img):
107
+ x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
108
+ x = x.repeat(1, 3, 1, 1)
109
+ return x
110
+
111
+ def get_id_embedding(self, image, cal_uncond=False):
112
+ """
113
+ Args:
114
+ image: numpy rgb image, range [0, 255]
115
+ """
116
+ self.face_helper.clean_all()
117
+ self.debug_img_list = []
118
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
119
+ # get antelopev2 embedding
120
+ # for k in self.app.models.keys():
121
+ # self.app.models[k].session.set_providers(['CUDAExecutionProvider'])
122
+ face_info = self.app.get(image_bgr)
123
+ if len(face_info) > 0:
124
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
125
+ -1
126
+ ] # only use the maximum face
127
+ id_ante_embedding = face_info['embedding']
128
+ self.debug_img_list.append(
129
+ image[
130
+ int(face_info['bbox'][1]) : int(face_info['bbox'][3]),
131
+ int(face_info['bbox'][0]) : int(face_info['bbox'][2]),
132
+ ]
133
+ )
134
+ else:
135
+ id_ante_embedding = None
136
+
137
+ # using facexlib to detect and align face
138
+ self.face_helper.read_image(image_bgr)
139
+ self.face_helper.get_face_landmarks_5(only_center_face=True)
140
+ self.face_helper.align_warp_face()
141
+ if len(self.face_helper.cropped_faces) == 0:
142
+ raise RuntimeError('facexlib align face fail')
143
+ align_face = self.face_helper.cropped_faces[0]
144
+ # incase insightface didn't detect face
145
+ if id_ante_embedding is None:
146
+ print('fail to detect face using insightface, extract embedding on align face')
147
+ # self.handler_ante.session.set_providers(['CUDAExecutionProvider'])
148
+ id_ante_embedding = self.handler_ante.get_feat(align_face)
149
+
150
+ id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device, self.weight_dtype)
151
+ if id_ante_embedding.ndim == 1:
152
+ id_ante_embedding = id_ante_embedding.unsqueeze(0)
153
+
154
+ # parsing
155
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
156
+ input = input.to(self.device)
157
+ parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
158
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True)
159
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
160
+ bg = sum(parsing_out == i for i in bg_label).bool()
161
+ white_image = torch.ones_like(input)
162
+ # only keep the face features
163
+ face_features_image = torch.where(bg, white_image, self.to_gray(input))
164
+ self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False))
165
+
166
+ # transform img before sending to eva-clip-vit
167
+ face_features_image = resize(face_features_image, self.clip_vision_model.image_size, InterpolationMode.BICUBIC)
168
+ face_features_image = normalize(face_features_image, self.eva_transform_mean, self.eva_transform_std)
169
+ id_cond_vit, id_vit_hidden = self.clip_vision_model(
170
+ face_features_image.to(self.weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
171
+ )
172
+ id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
173
+ id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
174
+
175
+ id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1)
176
+
177
+ id_embedding = self.pulid_encoder(id_cond, id_vit_hidden)
178
+
179
+ if not cal_uncond:
180
+ return id_embedding, None
181
+
182
+ id_uncond = torch.zeros_like(id_cond)
183
+ id_vit_hidden_uncond = []
184
+ for layer_idx in range(0, len(id_vit_hidden)):
185
+ id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[layer_idx]))
186
+ uncond_id_embedding = self.pulid_encoder(id_uncond, id_vit_hidden_uncond)
187
+
188
+ return id_embedding, uncond_id_embedding
requirements (2).txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.25.0
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ transformers==4.43.3
5
+ opencv-python
6
+ einops
7
+ ftfy
8
+ basicsr
9
+ facexlib
10
+ insightface
11
+ onnx==1.13.1
12
+ onnxruntime-gpu
13
+ onnxruntime==1.14.1
14
+ accelerate
15
+ huggingface-hub
16
+ timm
17
+ SentencePiece
18
+ fire
19
+ safetensors
20
+ numpy==1.24.1