BestWishYsh commited on
Commit
c32f190
·
verified ·
1 Parent(s): 6163efe

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. app.py +369 -0
  3. asserts/example_images/1.png +3 -0
  4. asserts/example_images/2.png +0 -0
  5. asserts/example_images/3.png +0 -0
  6. models/eva_clip/__init__.py +11 -0
  7. models/eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  8. models/eva_clip/constants.py +2 -0
  9. models/eva_clip/eva_vit_model.py +548 -0
  10. models/eva_clip/factory.py +517 -0
  11. models/eva_clip/hf_configs.py +57 -0
  12. models/eva_clip/hf_model.py +248 -0
  13. models/eva_clip/loss.py +138 -0
  14. models/eva_clip/model.py +439 -0
  15. models/eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
  16. models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
  17. models/eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
  18. models/eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
  19. models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
  20. models/eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
  21. models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
  22. models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
  23. models/eva_clip/modified_resnet.py +188 -0
  24. models/eva_clip/openai.py +144 -0
  25. models/eva_clip/pretrained.py +332 -0
  26. models/eva_clip/rope.py +137 -0
  27. models/eva_clip/timm_model.py +122 -0
  28. models/eva_clip/tokenizer.py +201 -0
  29. models/eva_clip/transform.py +103 -0
  30. models/eva_clip/transformer.py +737 -0
  31. models/eva_clip/utils.py +326 -0
  32. models/eva_clip/utils_qformer.py +166 -0
  33. models/local_facial_extractor.py +269 -0
  34. models/pipeline_cogvideox.py +748 -0
  35. models/pipeline_consisid.py +894 -0
  36. models/transformer_consisid.py +697 -0
  37. models/utils.py +273 -0
  38. requirements.txt +36 -0
  39. util/dataloader.py +1010 -0
  40. util/deepspeed_configs/accelerate_config_machine_multi.yaml +18 -0
  41. util/deepspeed_configs/accelerate_config_machine_single.yaml +13 -0
  42. util/deepspeed_configs/hostfile.txt +2 -0
  43. util/deepspeed_configs/zero_stage2_config.json +17 -0
  44. util/rife/IFNet.py +123 -0
  45. util/rife/IFNet_2R.py +123 -0
  46. util/rife/IFNet_HDv3.py +138 -0
  47. util/rife/IFNet_m.py +127 -0
  48. util/rife/RIFE.py +95 -0
  49. util/rife/RIFE_HDv3.py +86 -0
  50. util/rife/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asserts/example_images/1.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import numpy
5
+ import random
6
+ import threading
7
+ import gradio as gr
8
+ from PIL import Image, ImageOps
9
+ from moviepy import VideoFileClip
10
+ from datetime import datetime, timedelta
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
+
13
+ import insightface
14
+ from insightface.app import FaceAnalysis
15
+ from facexlib.parsing import init_parsing_model
16
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
17
+
18
+ import torch
19
+ from diffusers import CogVideoXDPMScheduler
20
+ from diffusers.utils import load_image
21
+ from diffusers.image_processor import VaeImageProcessor
22
+ from diffusers.training_utils import free_memory
23
+
24
+ from util.utils import *
25
+ from util.rife_model import load_rife_model, rife_inference_with_latents
26
+ from models.utils import process_face_embeddings
27
+ from models.transformer_consisid import ConsisIDTransformer3DModel
28
+ from models.pipeline_consisid import ConsisIDPipeline
29
+ from models.eva_clip import create_model_and_transforms
30
+ from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
31
+ from models.eva_clip.utils_qformer import resize_numpy_image_long
32
+
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
36
+ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
37
+
38
+ model_path = "BestWishYsh/ConsisID-preview"
39
+ lora_path = None
40
+ lora_rank = 128
41
+ dtype = torch.bfloat16
42
+
43
+ if os.path.exists(os.path.join(model_path, "transformer_ema")):
44
+ subfolder = "transformer_ema"
45
+ else:
46
+ subfolder = "transformer"
47
+
48
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
49
+ scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
50
+
51
+ try:
52
+ is_kps = transformer.config.is_kps
53
+ except:
54
+ is_kps = False
55
+
56
+ # 1. load face helper models
57
+ face_helper = FaceRestoreHelper(
58
+ upscale_factor=1,
59
+ face_size=512,
60
+ crop_ratio=(1, 1),
61
+ det_model='retinaface_resnet50',
62
+ save_ext='png',
63
+ device=device,
64
+ model_rootpath=os.path.join(model_path, "face_encoder")
65
+ )
66
+ face_helper.face_parse = None
67
+ face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
68
+ face_helper.face_det.eval()
69
+ face_helper.face_parse.eval()
70
+
71
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
72
+ face_clip_model = model.visual
73
+ face_clip_model.eval()
74
+
75
+ eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
76
+ eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
77
+ if not isinstance(eva_transform_mean, (list, tuple)):
78
+ eva_transform_mean = (eva_transform_mean,) * 3
79
+ if not isinstance(eva_transform_std, (list, tuple)):
80
+ eva_transform_std = (eva_transform_std,) * 3
81
+ eva_transform_mean = eva_transform_mean
82
+ eva_transform_std = eva_transform_std
83
+
84
+ face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
85
+ handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
86
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
87
+ handler_ante.prepare(ctx_id=0)
88
+
89
+ face_clip_model.to(device, dtype=dtype)
90
+ face_helper.face_det.to(device)
91
+ face_helper.face_parse.to(device)
92
+ transformer.to(device, dtype=dtype)
93
+ free_memory()
94
+
95
+ pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
96
+ # If you're using with lora, add this code
97
+ if lora_path:
98
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
99
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
100
+
101
+ scheduler_args = {}
102
+ if "variance_type" in pipe.scheduler.config:
103
+ variance_type = pipe.scheduler.config.variance_type
104
+ if variance_type in ["learned", "learned_range"]:
105
+ variance_type = "fixed_small"
106
+ scheduler_args["variance_type"] = variance_type
107
+
108
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
109
+ pipe.to(device)
110
+
111
+ os.makedirs("./output", exist_ok=True)
112
+ os.makedirs("./gradio_tmp", exist_ok=True)
113
+
114
+ upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
115
+ frame_interpolation_model = load_rife_model("model_rife")
116
+
117
+
118
+ def infer(
119
+ prompt: str,
120
+ image_input: str,
121
+ num_inference_steps: int,
122
+ guidance_scale: float,
123
+ seed: int = 42,
124
+ progress=gr.Progress(track_tqdm=True),
125
+ ):
126
+ if seed == -1:
127
+ seed = random.randint(0, 2**8 - 1)
128
+
129
+ id_image = np.array(ImageOps.exif_transpose(Image.fromarray(image_input)).convert("RGB"))
130
+ id_image = resize_numpy_image_long(id_image, 1024)
131
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
132
+ eva_transform_mean, eva_transform_std,
133
+ face_main_model, device, dtype, id_image,
134
+ original_id_image=id_image, is_align_face=True,
135
+ cal_uncond=False)
136
+
137
+ if is_kps:
138
+ kps_cond = face_kps
139
+ else:
140
+ kps_cond = None
141
+
142
+ tensor = align_crop_face_image.cpu().detach()
143
+ tensor = tensor.squeeze()
144
+ tensor = tensor.permute(1, 2, 0)
145
+ tensor = tensor.numpy() * 255
146
+ tensor = tensor.astype(np.uint8)
147
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
148
+
149
+ prompt = prompt.strip('"')
150
+
151
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
152
+
153
+ video_pt = pipe(
154
+ prompt=prompt,
155
+ image=image,
156
+ num_videos_per_prompt=1,
157
+ num_inference_steps=num_inference_steps,
158
+ num_frames=49,
159
+ use_dynamic_cfg=False,
160
+ guidance_scale=guidance_scale,
161
+ generator=generator,
162
+ id_vit_hidden=id_vit_hidden,
163
+ id_cond=id_cond,
164
+ kps_cond=kps_cond,
165
+ output_type="pt",
166
+ ).frames
167
+
168
+ free_memory()
169
+ return (video_pt, seed)
170
+
171
+
172
+ def convert_to_gif(video_path):
173
+ clip = VideoFileClip(video_path)
174
+ gif_path = video_path.replace(".mp4", ".gif")
175
+ clip.write_gif(gif_path, fps=8)
176
+ return gif_path
177
+
178
+
179
+ def delete_old_files():
180
+ while True:
181
+ now = datetime.now()
182
+ cutoff = now - timedelta(minutes=10)
183
+ directories = ["./output", "./gradio_tmp"]
184
+
185
+ for directory in directories:
186
+ for filename in os.listdir(directory):
187
+ file_path = os.path.join(directory, filename)
188
+ if os.path.isfile(file_path):
189
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
190
+ if file_mtime < cutoff:
191
+ os.remove(file_path)
192
+ time.sleep(600)
193
+
194
+
195
+ threading.Thread(target=delete_old_files, daemon=True).start()
196
+ examples_images = [
197
+ ["asserts/example_images/1.png", "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."],
198
+ ["asserts/example_images/2.png", "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."],
199
+ ["asserts/example_images/3.png", "The video depicts a man sitting at an office desk, engaged in his work. He is dressed in a formal suit and appears to be focused on his computer screen. The office environment is well-organized, with shelves filled with binders and other office supplies neatly arranged. The man is holding a red cup, possibly containing a beverage, which he drinks from before setting it down on the desk. He then proceeds to type on the keyboard, indicating that he is working on something on his computer. The overall atmosphere of the video suggests a professional setting where the man is diligently working on his tasks."]
200
+ ]
201
+
202
+ with gr.Blocks() as demo:
203
+ gr.Markdown("""
204
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
205
+ ConsisID Space🤗
206
+ </div>
207
+ <div style="text-align: center;">
208
+ <a href="https://huggingface.co/BestWishYsh/ConsisID">🤗 Model Hub</a> |
209
+ <a href="https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data">📚 Dataset</a> |
210
+ <a href="https://github.com/PKU-YuanGroup/ConsisID">🌐 Github</a> |
211
+ <a href="https://pku-yuangroup.github.io/ConsisID">📝 Page</a> |
212
+ <a href="https://arxiv.org/pdf/2408.06072">📜 arxiv </a>
213
+ </div>
214
+ <div style="text-align: center;display: flex;justify-content: center;align-items: center;margin-top: 1em;margin-bottom: .5em;">
215
+ <span>If the Space is too busy, duplicate it to use privately</span>
216
+ <a href="https://huggingface.co/spaces/BestWishYsh/ConsisID-Space?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" width="160" style="
217
+ margin-left: .75em;
218
+ "></a>
219
+ </div>
220
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
221
+ ⚠️ This demo is for academic research and experiential use only.
222
+ </div>
223
+ """)
224
+ with gr.Row():
225
+ with gr.Column():
226
+ with gr.Accordion("IPT2V: Face Input", open=True):
227
+ image_input = gr.Image(label="Input Image (should contain clear face)")
228
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
229
+ with gr.Accordion("Examples", open=False):
230
+ examples_component_images = gr.Examples(
231
+ examples_images,
232
+ inputs=[image_input, prompt],
233
+ cache_examples=False,
234
+ )
235
+
236
+ with gr.Group():
237
+ with gr.Column():
238
+ with gr.Row():
239
+ seed_param = gr.Number(
240
+ label="Inference Seed (Enter a positive number, -1 for random)", value=42
241
+ )
242
+ with gr.Row():
243
+ enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
244
+ enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
245
+ gr.Markdown(
246
+ "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution)."
247
+ )
248
+
249
+ generate_button = gr.Button("🎬 Generate Video")
250
+
251
+ with gr.Column():
252
+ video_output = gr.Video(label="ConsisID Generate Video", width=720, height=480)
253
+ with gr.Row():
254
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
255
+ download_gif_button = gr.File(label="📥 Download GIF", visible=False)
256
+ seed_text = gr.Number(label="Seed Used for Video Generation", visible=False)
257
+
258
+ gr.Markdown("""
259
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
260
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
261
+ 🎥 Video Gallery
262
+ </div>
263
+ <tr>
264
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
265
+ <p>The video features a woman in exquisite hybrid armor adorned with iridescent gemstones, standing amidst gently falling cherry blossoms. Her piercing yet serene gaze hints at quiet determination, as a breeze catches a loose strand of her hair. She stands in a tranquil courtyard framed by moss-covered stone walls and wooden arches, with blossoms casting soft shadows on the ground. The petals swirl around her, adding a dreamlike quality, while the blurred backdrop emphasizes her poised figure. The scene conveys elegance, strength, and tranquil readiness, capturing a moment of peace before an upcoming challenge.</p>
266
+ </td>
267
+ <td style="width: 25%; vertical-align: top;">
268
+ <video src="https://github.com/user-attachments/assets/97fa0710-4f14-4a6d-b6f7-f3a2e9f7486e" width="100%" controls autoplay loop></video>
269
+ </td>
270
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
271
+ <p>The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.</p>
272
+ </td>
273
+ <td style="width: 25%; vertical-align: top;">
274
+ <video src="https://github.com/user-attachments/assets/90b547a3-247c-4bb0-abae-ba53483b7b6e" width="100%" controls autoplay loop></video>
275
+ </td>
276
+ </tr>
277
+ <tr>
278
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
279
+ <p>The video features a man standing next to an airplane, engaged in a conversation on his cell phone. he is wearing sunglasses and a black top, and he appears to be talking seriously. The airplane has a green stripe running along its side, and there is a large engine visible behind his. The man seems to be standing near the entrance of the airplane, possibly preparing to board or just having disembarked. The setting suggests that he might be at an airport or a private airfield. The overall atmosphere of the video is professional and focused, with the man's attire and the presence of the airplane indicating a business or travel context.</p>
280
+ </td>
281
+ <td style="width: 25%; vertical-align: top;">
282
+ <video src="https://github.com/user-attachments/assets/55680c58-de86-48b4-8d86-e9906a3185c3" width="100%" controls autoplay loop></video>
283
+ </td>
284
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
285
+ <p>The video features a woman with blonde hair standing on a beach near the water's edge. She is wearing a black swimsuit and appears to be enjoying her time by the sea. The sky above is clear with some clouds, and the ocean waves gently lap against the shore. The woman seems to be holding something white in her hand, possibly a piece of driftwood or a small object found on the beach. The overall atmosphere of the video is serene and relaxing, capturing the beauty of nature and the simple pleasure of being by the ocean.</p>
286
+ </td>
287
+ <td style="width: 25%; vertical-align: top;">
288
+ <video src="https://github.com/user-attachments/assets/8d06e702-f80e-4cb2-abc2-b0f519ec3f11" width="100%" controls autoplay loop></video>
289
+ </td>
290
+ </tr>
291
+ <tr>
292
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
293
+ <p>The video features a man sitting in a red armchair, enjoying a cup of coffee or tea. he is dressed in a light-colored outfit and has long dark-haired hair. The setting appears to be indoors, with large windows providing a view of a misty or foggy coastal landscape outside. The room has a modern design with geometric structures visible in the background. There is a small round table next to the armchair, also holding a cup. The overall atmosphere suggests a calm and serene moment, possibly during a cold or rainy day by the sea.</p>
294
+ </td>
295
+ <td style="width: 25%; vertical-align: top;">
296
+ <video src="https://github.com/user-attachments/assets/ab9c655e-84c2-47ed-85d9-039a7f64adfe" width="100%" controls autoplay loop></video>
297
+ </td>
298
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
299
+ <p>The video shows a young boy sitting at a table, eating a piece of food. He appears to be enjoying his meal, as he takes a bite and chews it. The boy is wearing a blue shirt and has short hair. The background is dark, with some light coming from the left side of the frame. There is a straw visible on the right side of the frame, suggesting that there may be a drink next to the boy's plate. The overall atmosphere of the video seems casual and relaxed, with the focus on the boy's enjoyment of his food.</p>
300
+ </td>
301
+ <td style="width: 25%; vertical-align: top;">
302
+ <video src="https://github.com/user-attachments/assets/8014b02e-e1c4-4df7-b7f3-cebfb01fa373" width="100%" controls autoplay loop></video>
303
+ </td>
304
+ </tr>
305
+ <tr>
306
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
307
+ <p>The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.</p>
308
+ </td>
309
+ <td style="width: 25%; vertical-align: top;">
310
+ <video src="https://github.com/user-attachments/assets/e4bc3169-d3d4-46e2-a667-8b456ead9465" width="100%" controls autoplay loop></video>
311
+ </td>
312
+ <td style="width: 25%; vertical-align: top; font-size: 0.9em;">
313
+ <p>The video features a young man standing outdoors in a snowy park. he is wearing a colorful winter jacket with a floral pattern and a white knit hat. The background shows a snowy landscape with trees, benches, and a metal fence. The ground is covered in snow, and there is a light snowfall in the air. The man appears to be enjoying the winter weather, as he smiles and gives a thumbs-up gesture towards the camera. The overall atmosphere of the video is cheerful and festive, capturing the beauty of a snowy day in a park.</p>
314
+ </td>
315
+ <td style="width: 25%; vertical-align: top;">
316
+ <video src="https://github.com/user-attachments/assets/e4e3e519-95d4-44e0-afa7-9a833f99e090" width="100%" controls autoplay loop></video>
317
+ </td>
318
+ </tr>
319
+ </table>
320
+ """)
321
+
322
+ def generate(
323
+ prompt,
324
+ image_input,
325
+ seed_value,
326
+ scale_status,
327
+ rife_status,
328
+ progress=gr.Progress(track_tqdm=True)
329
+ ):
330
+ latents, seed = infer(
331
+ prompt,
332
+ image_input,
333
+ num_inference_steps=50,
334
+ guidance_scale=7.0,
335
+ seed=seed_value,
336
+ progress=progress,
337
+ )
338
+ if scale_status:
339
+ latents = upscale_batch_and_concatenate(upscale_model, latents, device)
340
+ if rife_status:
341
+ latents = rife_inference_with_latents(frame_interpolation_model, latents)
342
+
343
+ batch_size = latents.shape[0]
344
+ batch_video_frames = []
345
+ for batch_idx in range(batch_size):
346
+ pt_image = latents[batch_idx]
347
+ pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
348
+
349
+ image_np = VaeImageProcessor.pt_to_numpy(pt_image)
350
+ image_pil = VaeImageProcessor.numpy_to_pil(image_np)
351
+ batch_video_frames.append(image_pil)
352
+
353
+ video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
354
+ video_update = gr.update(visible=True, value=video_path)
355
+ gif_path = convert_to_gif(video_path)
356
+ gif_update = gr.update(visible=True, value=gif_path)
357
+ seed_update = gr.update(visible=True, value=seed)
358
+
359
+ return video_path, video_update, gif_update, seed_update
360
+
361
+ generate_button.click(
362
+ generate,
363
+ inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
364
+ outputs=[video_output, download_video_button, download_gif_button, seed_text],
365
+ )
366
+
367
+ if __name__ == "__main__":
368
+ demo.queue(max_size=15)
369
+ demo.launch()
asserts/example_images/1.png ADDED

Git LFS Details

  • SHA256: 434856739faadf1c89bc38d8b940fcdbe027595de89645f721d8585fc2fe2459
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
asserts/example_images/2.png ADDED
asserts/example_images/3.png ADDED
models/eva_clip/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .loss import ClipLoss
5
+ from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
+ from .openai import load_openai_model, list_openai_models
8
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
+ from .tokenizer import SimpleTokenizer, tokenize
11
+ from .transform import image_transform
models/eva_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/eva_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
models/eva_clip/eva_vit_model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ try:
11
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
+ except:
13
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
14
+
15
+ from .transformer import PatchDropout
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+
18
+ if os.getenv('ENV_TYPE') == 'deepspeed':
19
+ try:
20
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
+ except:
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers
28
+ import xformers.ops as xops
29
+ XFORMERS_IS_AVAILBLE = True
30
+ except:
31
+ XFORMERS_IS_AVAILBLE = False
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
+ """
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training)
42
+
43
+ def extra_repr(self) -> str:
44
+ return 'p={}'.format(self.drop_prob)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features,
51
+ hidden_features=None,
52
+ out_features=None,
53
+ act_layer=nn.GELU,
54
+ norm_layer=nn.LayerNorm,
55
+ drop=0.,
56
+ subln=False,
57
+
58
+ ):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = act_layer()
64
+
65
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
66
+
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ # x = self.drop(x)
74
+ # commit this for the orignal BERT implement
75
+ x = self.ffn_ln(x)
76
+
77
+ x = self.fc2(x)
78
+ x = self.drop(x)
79
+ return x
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
+ norm_layer=nn.LayerNorm, subln=False):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+
88
+ self.w1 = nn.Linear(in_features, hidden_features)
89
+ self.w2 = nn.Linear(in_features, hidden_features)
90
+
91
+ self.act = act_layer()
92
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
93
+ self.w3 = nn.Linear(hidden_features, out_features)
94
+
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x1 = self.w1(x)
99
+ x2 = self.w2(x)
100
+ hidden = self.act(x1) * x2
101
+ x = self.ffn_ln(hidden)
102
+ x = self.w3(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(
108
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
109
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ if attn_head_dim is not None:
114
+ head_dim = attn_head_dim
115
+ all_head_dim = head_dim * self.num_heads
116
+ self.scale = qk_scale or head_dim ** -0.5
117
+
118
+ self.subln = subln
119
+ if self.subln:
120
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
121
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
123
+ else:
124
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
125
+
126
+ if qkv_bias:
127
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
+ else:
130
+ self.q_bias = None
131
+ self.v_bias = None
132
+
133
+ if window_size:
134
+ self.window_size = window_size
135
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
136
+ self.relative_position_bias_table = nn.Parameter(
137
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
138
+ # cls to token & token 2 cls & cls to cls
139
+
140
+ # get pair-wise relative position index for each token inside the window
141
+ coords_h = torch.arange(window_size[0])
142
+ coords_w = torch.arange(window_size[1])
143
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
144
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
145
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
146
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
147
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
148
+ relative_coords[:, :, 1] += window_size[1] - 1
149
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
150
+ relative_position_index = \
151
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
152
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
154
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
155
+ relative_position_index[0, 0] = self.num_relative_distance - 1
156
+
157
+ self.register_buffer("relative_position_index", relative_position_index)
158
+ else:
159
+ self.window_size = None
160
+ self.relative_position_bias_table = None
161
+ self.relative_position_index = None
162
+
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
165
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
166
+ self.proj = nn.Linear(all_head_dim, dim)
167
+ self.proj_drop = nn.Dropout(proj_drop)
168
+ self.xattn = xattn
169
+ self.xattn_drop = attn_drop
170
+
171
+ self.rope = rope
172
+
173
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
174
+ B, N, C = x.shape
175
+ if self.subln:
176
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
177
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
178
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
179
+
180
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
181
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
182
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183
+ else:
184
+
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
188
+
189
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
191
+ q, k, v = qkv[0], qkv[1], qkv[2]
192
+
193
+ if self.rope:
194
+ # slightly fast impl
195
+ q_t = q[:, :, 1:, :]
196
+ ro_q_t = self.rope(q_t)
197
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
198
+
199
+ k_t = k[:, :, 1:, :]
200
+ ro_k_t = self.rope(k_t)
201
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
202
+
203
+ if self.xattn:
204
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
205
+ k = k.permute(0, 2, 1, 3)
206
+ v = v.permute(0, 2, 1, 3)
207
+
208
+ x = xops.memory_efficient_attention(
209
+ q, k, v,
210
+ p=self.xattn_drop,
211
+ scale=self.scale,
212
+ )
213
+ x = x.reshape(B, N, -1)
214
+ x = self.inner_attn_ln(x)
215
+ x = self.proj(x)
216
+ x = self.proj_drop(x)
217
+ else:
218
+ q = q * self.scale
219
+ attn = (q @ k.transpose(-2, -1))
220
+
221
+ if self.relative_position_bias_table is not None:
222
+ relative_position_bias = \
223
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
224
+ self.window_size[0] * self.window_size[1] + 1,
225
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
226
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
227
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
228
+
229
+ if rel_pos_bias is not None:
230
+ attn = attn + rel_pos_bias.type_as(attn)
231
+
232
+ if attn_mask is not None:
233
+ attn_mask = attn_mask.bool()
234
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ attn = self.attn_drop(attn)
238
+
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
240
+ x = self.inner_attn_ln(x)
241
+ x = self.proj(x)
242
+ x = self.proj_drop(x)
243
+ return x
244
+
245
+
246
+ class Block(nn.Module):
247
+
248
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
249
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
250
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
251
+ subln=False, naiveswiglu=False):
252
+ super().__init__()
253
+ self.norm1 = norm_layer(dim)
254
+ self.attn = Attention(
255
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
256
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
257
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
258
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
259
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
260
+ self.norm2 = norm_layer(dim)
261
+ mlp_hidden_dim = int(dim * mlp_ratio)
262
+
263
+ if naiveswiglu:
264
+ self.mlp = SwiGLU(
265
+ in_features=dim,
266
+ hidden_features=mlp_hidden_dim,
267
+ subln=subln,
268
+ norm_layer=norm_layer,
269
+ )
270
+ else:
271
+ self.mlp = Mlp(
272
+ in_features=dim,
273
+ hidden_features=mlp_hidden_dim,
274
+ act_layer=act_layer,
275
+ subln=subln,
276
+ drop=drop
277
+ )
278
+
279
+ if init_values is not None and init_values > 0:
280
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
281
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
282
+ else:
283
+ self.gamma_1, self.gamma_2 = None, None
284
+
285
+ self.postnorm = postnorm
286
+
287
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
288
+ if self.gamma_1 is None:
289
+ if self.postnorm:
290
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
291
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
292
+ else:
293
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+ else:
296
+ if self.postnorm:
297
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
298
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
299
+ else:
300
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
301
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
302
+ return x
303
+
304
+
305
+ class PatchEmbed(nn.Module):
306
+ """ Image to Patch Embedding
307
+ """
308
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
309
+ super().__init__()
310
+ img_size = to_2tuple(img_size)
311
+ patch_size = to_2tuple(patch_size)
312
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
313
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
314
+ self.img_size = img_size
315
+ self.patch_size = patch_size
316
+ self.num_patches = num_patches
317
+
318
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
319
+
320
+ def forward(self, x, **kwargs):
321
+ B, C, H, W = x.shape
322
+ # FIXME look at relaxing size constraints
323
+ assert H == self.img_size[0] and W == self.img_size[1], \
324
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
325
+ x = self.proj(x).flatten(2).transpose(1, 2)
326
+ return x
327
+
328
+
329
+ class RelativePositionBias(nn.Module):
330
+
331
+ def __init__(self, window_size, num_heads):
332
+ super().__init__()
333
+ self.window_size = window_size
334
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
335
+ self.relative_position_bias_table = nn.Parameter(
336
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
337
+ # cls to token & token 2 cls & cls to cls
338
+
339
+ # get pair-wise relative position index for each token inside the window
340
+ coords_h = torch.arange(window_size[0])
341
+ coords_w = torch.arange(window_size[1])
342
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
343
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
344
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
345
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
347
+ relative_coords[:, :, 1] += window_size[1] - 1
348
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
349
+ relative_position_index = \
350
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
351
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
352
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
353
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
354
+ relative_position_index[0, 0] = self.num_relative_distance - 1
355
+
356
+ self.register_buffer("relative_position_index", relative_position_index)
357
+
358
+ def forward(self):
359
+ relative_position_bias = \
360
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
361
+ self.window_size[0] * self.window_size[1] + 1,
362
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
363
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
364
+
365
+
366
+ class EVAVisionTransformer(nn.Module):
367
+ """ Vision Transformer with support for patch or hybrid CNN input stage
368
+ """
369
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
370
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
371
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
372
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
373
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
374
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
375
+ super().__init__()
376
+
377
+ if not XFORMERS_IS_AVAILBLE:
378
+ xattn = False
379
+
380
+ self.image_size = img_size
381
+ self.num_classes = num_classes
382
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
383
+
384
+ self.patch_embed = PatchEmbed(
385
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
386
+ num_patches = self.patch_embed.num_patches
387
+
388
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
389
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
390
+ if use_abs_pos_emb:
391
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
392
+ else:
393
+ self.pos_embed = None
394
+ self.pos_drop = nn.Dropout(p=drop_rate)
395
+
396
+ if use_shared_rel_pos_bias:
397
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
398
+ else:
399
+ self.rel_pos_bias = None
400
+
401
+ if rope:
402
+ half_head_dim = embed_dim // num_heads // 2
403
+ hw_seq_len = img_size // patch_size
404
+ self.rope = VisionRotaryEmbeddingFast(
405
+ dim=half_head_dim,
406
+ pt_seq_len=pt_hw_seq_len,
407
+ ft_seq_len=hw_seq_len if intp_freq else None,
408
+ # patch_dropout=patch_dropout
409
+ )
410
+ else:
411
+ self.rope = None
412
+
413
+ self.naiveswiglu = naiveswiglu
414
+
415
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
+ self.use_rel_pos_bias = use_rel_pos_bias
417
+ self.blocks = nn.ModuleList([
418
+ Block(
419
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
420
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
421
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
422
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
423
+ for i in range(depth)])
424
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
425
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
426
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
427
+
428
+ if self.pos_embed is not None:
429
+ trunc_normal_(self.pos_embed, std=.02)
430
+
431
+ trunc_normal_(self.cls_token, std=.02)
432
+ # trunc_normal_(self.mask_token, std=.02)
433
+
434
+ self.apply(self._init_weights)
435
+ self.fix_init_weight()
436
+
437
+ if isinstance(self.head, nn.Linear):
438
+ trunc_normal_(self.head.weight, std=.02)
439
+ self.head.weight.data.mul_(init_scale)
440
+ self.head.bias.data.mul_(init_scale)
441
+
442
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
443
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
444
+
445
+ self.grad_checkpointing = grad_checkpointing
446
+
447
+ def fix_init_weight(self):
448
+ def rescale(param, layer_id):
449
+ param.div_(math.sqrt(2.0 * layer_id))
450
+
451
+ for layer_id, layer in enumerate(self.blocks):
452
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
453
+ if self.naiveswiglu:
454
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
455
+ else:
456
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
457
+
458
+ def get_cast_dtype(self) -> torch.dtype:
459
+ return self.blocks[0].mlp.fc2.weight.dtype
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, nn.Linear):
463
+ trunc_normal_(m.weight, std=.02)
464
+ if m.bias is not None:
465
+ nn.init.constant_(m.bias, 0)
466
+ elif isinstance(m, nn.LayerNorm):
467
+ nn.init.constant_(m.bias, 0)
468
+ nn.init.constant_(m.weight, 1.0)
469
+
470
+ def get_num_layers(self):
471
+ return len(self.blocks)
472
+
473
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
474
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
475
+ for param in self.parameters():
476
+ param.requires_grad = False
477
+
478
+ @torch.jit.ignore
479
+ def set_grad_checkpointing(self, enable=True):
480
+ self.grad_checkpointing = enable
481
+
482
+ @torch.jit.ignore
483
+ def no_weight_decay(self):
484
+ return {'pos_embed', 'cls_token'}
485
+
486
+ def get_classifier(self):
487
+ return self.head
488
+
489
+ def reset_classifier(self, num_classes, global_pool=''):
490
+ self.num_classes = num_classes
491
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
492
+
493
+ def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
494
+
495
+ x = self.patch_embed(x)
496
+ batch_size, seq_len, _ = x.size()
497
+
498
+ if shuffle:
499
+ idx = torch.randperm(x.shape[1]) + 1
500
+ zero = torch.LongTensor([0, ])
501
+ idx = torch.cat([zero, idx])
502
+ pos_embed = self.pos_embed[:, idx]
503
+
504
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
505
+ x = torch.cat((cls_tokens, x), dim=1)
506
+ if shuffle:
507
+ x = x + pos_embed
508
+ elif self.pos_embed is not None:
509
+ x = x + self.pos_embed
510
+ x = self.pos_drop(x)
511
+
512
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
513
+ if os.getenv('RoPE') == '1':
514
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
515
+ x, patch_indices_keep = self.patch_dropout(x)
516
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
517
+ else:
518
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
519
+ x = self.patch_dropout(x)
520
+ else:
521
+ x = self.patch_dropout(x)
522
+
523
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
524
+ hidden_states = []
525
+ for idx, blk in enumerate(self.blocks):
526
+ if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
527
+ hidden_states.append(x)
528
+ if self.grad_checkpointing:
529
+ x = checkpoint(blk, x, (rel_pos_bias,))
530
+ else:
531
+ x = blk(x, rel_pos_bias=rel_pos_bias)
532
+
533
+ if not return_all_features:
534
+ x = self.norm(x)
535
+ if self.fc_norm is not None:
536
+ return self.fc_norm(x.mean(1)), hidden_states
537
+ else:
538
+ return x[:, 0], hidden_states
539
+ return x
540
+
541
+ def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
542
+ if return_all_features:
543
+ return self.forward_features(x, return_all_features, return_hidden, shuffle)
544
+ x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
545
+ x = self.head(x)
546
+ if return_hidden:
547
+ return x, hidden_states
548
+ return x
models/eva_clip/factory.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
+ get_cast_dtype
14
+ from .openai import load_openai_model
15
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
+ from .transform import image_transform
17
+ from .tokenizer import HFTokenizer, tokenize
18
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, "r", encoding="utf8") as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ # loading openai CLIP weights when is_openai=True for training
80
+ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
+ if is_openai:
82
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
+ state_dict = model.state_dict()
84
+ for key in ["input_resolution", "context_length", "vocab_size"]:
85
+ state_dict.pop(key, None)
86
+ else:
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ for mk in model_key.split('|'):
89
+ if isinstance(checkpoint, dict) and mk in checkpoint:
90
+ state_dict = checkpoint[mk]
91
+ break
92
+ else:
93
+ state_dict = checkpoint
94
+ if next(iter(state_dict.items()))[0].startswith('module'):
95
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
96
+
97
+ for k in skip_list:
98
+ if k in list(state_dict.keys()):
99
+ logging.info(f"Removing key {k} from pretrained checkpoint")
100
+ del state_dict[k]
101
+
102
+ if os.getenv('RoPE') == '1':
103
+ for k in list(state_dict.keys()):
104
+ if 'freqs_cos' in k or 'freqs_sin' in k:
105
+ del state_dict[k]
106
+ return state_dict
107
+
108
+
109
+
110
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
+ # detect old format and make compatible with new format
113
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
+ state_dict = convert_to_custom_text_state_dict(state_dict)
115
+ if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
+ state_dict['logit_scale'] = state_dict['text.logit_scale']
117
+ del state_dict['text.logit_scale']
118
+
119
+ # resize_clip_pos_embed for CLIP and open CLIP
120
+ if 'visual.positional_embedding' in state_dict:
121
+ resize_clip_pos_embed(state_dict, model)
122
+ # specified to eva_vit_model
123
+ elif 'visual.pos_embed' in state_dict:
124
+ resize_evaclip_pos_embed(state_dict, model)
125
+
126
+ # resize_clip_pos_embed(state_dict, model)
127
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
+ return incompatible_keys
130
+
131
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
+
134
+ for k in list(state_dict.keys()):
135
+ if not k.startswith('visual.'):
136
+ del state_dict[k]
137
+ for k in list(state_dict.keys()):
138
+ if k.startswith('visual.'):
139
+ new_k = k[7:]
140
+ state_dict[new_k] = state_dict[k]
141
+ del state_dict[k]
142
+ return state_dict
143
+
144
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
+
147
+ for k in list(state_dict.keys()):
148
+ if k.startswith('visual.'):
149
+ del state_dict[k]
150
+ return state_dict
151
+
152
+ def get_pretrained_tag(pretrained_model):
153
+ pretrained_model = pretrained_model.lower()
154
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
+ return "open_clip"
156
+ elif "openai" in pretrained_model:
157
+ return "clip"
158
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
159
+ return "eva_clip"
160
+ else:
161
+ return "other"
162
+
163
+ def load_pretrained_checkpoint(
164
+ model,
165
+ visual_checkpoint_path,
166
+ text_checkpoint_path,
167
+ strict=True,
168
+ visual_model=None,
169
+ text_model=None,
170
+ model_key="model|module|state_dict",
171
+ skip_list=[]):
172
+ visual_tag = get_pretrained_tag(visual_model)
173
+ text_tag = get_pretrained_tag(text_model)
174
+
175
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
+ visual_incompatible_keys, text_incompatible_keys = None, None
177
+ if visual_checkpoint_path:
178
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
+ elif visual_tag == "clip":
181
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
+ else:
183
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
+
185
+ # resize_clip_pos_embed for CLIP and open CLIP
186
+ if 'positional_embedding' in visual_state_dict:
187
+ resize_visual_pos_embed(visual_state_dict, model)
188
+ # specified to EVA model
189
+ elif 'pos_embed' in visual_state_dict:
190
+ resize_eva_pos_embed(visual_state_dict, model)
191
+
192
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
+
196
+ if text_checkpoint_path:
197
+ if text_tag == "eva_clip" or text_tag == "open_clip":
198
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
+ elif text_tag == "clip":
200
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
+ else:
202
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
+
204
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
+
206
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
+
209
+ return visual_incompatible_keys, text_incompatible_keys
210
+
211
+ def create_model(
212
+ model_name: str,
213
+ pretrained: Optional[str] = None,
214
+ precision: str = 'fp32',
215
+ device: Union[str, torch.device] = 'cpu',
216
+ jit: bool = False,
217
+ force_quick_gelu: bool = False,
218
+ force_custom_clip: bool = False,
219
+ force_patch_dropout: Optional[float] = None,
220
+ pretrained_image: str = '',
221
+ pretrained_text: str = '',
222
+ pretrained_hf: bool = True,
223
+ pretrained_visual_model: str = None,
224
+ pretrained_text_model: str = None,
225
+ cache_dir: Optional[str] = None,
226
+ skip_list: list = [],
227
+ ):
228
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ if pretrained and pretrained.lower() == 'openai':
233
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
+ model = load_openai_model(
235
+ model_name,
236
+ precision=precision,
237
+ device=device,
238
+ jit=jit,
239
+ cache_dir=cache_dir,
240
+ )
241
+ else:
242
+ model_cfg = get_model_config(model_name)
243
+ if model_cfg is not None:
244
+ logging.info(f'Loaded {model_name} model config.')
245
+ else:
246
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
+ raise RuntimeError(f'Model config for {model_name} not found.')
248
+
249
+ if 'rope' in model_cfg.get('vision_cfg', {}):
250
+ if model_cfg['vision_cfg']['rope']:
251
+ os.environ['RoPE'] = "1"
252
+ else:
253
+ os.environ['RoPE'] = "0"
254
+
255
+ if force_quick_gelu:
256
+ # override for use of QuickGELU on non-OpenAI transformer models
257
+ model_cfg["quick_gelu"] = True
258
+
259
+ if force_patch_dropout is not None:
260
+ # override the default patch dropout value
261
+ model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
+
263
+ cast_dtype = get_cast_dtype(precision)
264
+ custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
+
266
+
267
+ if custom_clip:
268
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
+ else:
272
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
+
274
+ pretrained_cfg = {}
275
+ if pretrained:
276
+ checkpoint_path = ''
277
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
+ if pretrained_cfg:
279
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
+ elif os.path.exists(pretrained):
281
+ checkpoint_path = pretrained
282
+
283
+ if checkpoint_path:
284
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
+ load_checkpoint(model,
286
+ checkpoint_path,
287
+ model_key="model|module|state_dict",
288
+ strict=False
289
+ )
290
+ else:
291
+ error_str = (
292
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
+ logging.warning(error_str)
295
+ raise RuntimeError(error_str)
296
+ else:
297
+ visual_checkpoint_path = ''
298
+ text_checkpoint_path = ''
299
+
300
+ if pretrained_image:
301
+ pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
+ # pretrained weight loading for timm models set via vision_cfg
305
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
+ elif pretrained_image_cfg:
307
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
+ elif os.path.exists(pretrained_image):
309
+ visual_checkpoint_path = pretrained_image
310
+ else:
311
+ logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
+ raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
+
314
+ if pretrained_text:
315
+ pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
+ if pretrained_image_cfg:
318
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
+ elif os.path.exists(pretrained_text):
320
+ text_checkpoint_path = pretrained_text
321
+ else:
322
+ logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
+ raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
+
325
+ if visual_checkpoint_path:
326
+ logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
+ if text_checkpoint_path:
328
+ logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
+
330
+ if visual_checkpoint_path or text_checkpoint_path:
331
+ load_pretrained_checkpoint(
332
+ model,
333
+ visual_checkpoint_path,
334
+ text_checkpoint_path,
335
+ strict=False,
336
+ visual_model=pretrained_visual_model,
337
+ text_model=pretrained_text_model,
338
+ model_key="model|module|state_dict",
339
+ skip_list=skip_list
340
+ )
341
+
342
+ if "fp16" in precision or "bf16" in precision:
343
+ logging.info(f'convert precision to {precision}')
344
+ model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
+
346
+ model.to(device=device)
347
+
348
+ # set image / mean metadata from pretrained_cfg if available, or use default
349
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
+
352
+ if jit:
353
+ model = torch.jit.script(model)
354
+
355
+ return model
356
+
357
+
358
+ def create_model_and_transforms(
359
+ model_name: str,
360
+ pretrained: Optional[str] = None,
361
+ precision: str = 'fp32',
362
+ device: Union[str, torch.device] = 'cpu',
363
+ jit: bool = False,
364
+ force_quick_gelu: bool = False,
365
+ force_custom_clip: bool = False,
366
+ force_patch_dropout: Optional[float] = None,
367
+ pretrained_image: str = '',
368
+ pretrained_text: str = '',
369
+ pretrained_hf: bool = True,
370
+ pretrained_visual_model: str = None,
371
+ pretrained_text_model: str = None,
372
+ image_mean: Optional[Tuple[float, ...]] = None,
373
+ image_std: Optional[Tuple[float, ...]] = None,
374
+ cache_dir: Optional[str] = None,
375
+ skip_list: list = [],
376
+ ):
377
+ model = create_model(
378
+ model_name,
379
+ pretrained,
380
+ precision=precision,
381
+ device=device,
382
+ jit=jit,
383
+ force_quick_gelu=force_quick_gelu,
384
+ force_custom_clip=force_custom_clip,
385
+ force_patch_dropout=force_patch_dropout,
386
+ pretrained_image=pretrained_image,
387
+ pretrained_text=pretrained_text,
388
+ pretrained_hf=pretrained_hf,
389
+ pretrained_visual_model=pretrained_visual_model,
390
+ pretrained_text_model=pretrained_text_model,
391
+ cache_dir=cache_dir,
392
+ skip_list=skip_list,
393
+ )
394
+
395
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
+ image_std = image_std or getattr(model.visual, 'image_std', None)
397
+ preprocess_train = image_transform(
398
+ model.visual.image_size,
399
+ is_train=True,
400
+ mean=image_mean,
401
+ std=image_std
402
+ )
403
+ preprocess_val = image_transform(
404
+ model.visual.image_size,
405
+ is_train=False,
406
+ mean=image_mean,
407
+ std=image_std
408
+ )
409
+
410
+ return model, preprocess_train, preprocess_val
411
+
412
+
413
+ def create_transforms(
414
+ model_name: str,
415
+ pretrained: Optional[str] = None,
416
+ precision: str = 'fp32',
417
+ device: Union[str, torch.device] = 'cpu',
418
+ jit: bool = False,
419
+ force_quick_gelu: bool = False,
420
+ force_custom_clip: bool = False,
421
+ force_patch_dropout: Optional[float] = None,
422
+ pretrained_image: str = '',
423
+ pretrained_text: str = '',
424
+ pretrained_hf: bool = True,
425
+ pretrained_visual_model: str = None,
426
+ pretrained_text_model: str = None,
427
+ image_mean: Optional[Tuple[float, ...]] = None,
428
+ image_std: Optional[Tuple[float, ...]] = None,
429
+ cache_dir: Optional[str] = None,
430
+ skip_list: list = [],
431
+ ):
432
+ model = create_model(
433
+ model_name,
434
+ pretrained,
435
+ precision=precision,
436
+ device=device,
437
+ jit=jit,
438
+ force_quick_gelu=force_quick_gelu,
439
+ force_custom_clip=force_custom_clip,
440
+ force_patch_dropout=force_patch_dropout,
441
+ pretrained_image=pretrained_image,
442
+ pretrained_text=pretrained_text,
443
+ pretrained_hf=pretrained_hf,
444
+ pretrained_visual_model=pretrained_visual_model,
445
+ pretrained_text_model=pretrained_text_model,
446
+ cache_dir=cache_dir,
447
+ skip_list=skip_list,
448
+ )
449
+
450
+
451
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
+ image_std = image_std or getattr(model.visual, 'image_std', None)
453
+ preprocess_train = image_transform(
454
+ model.visual.image_size,
455
+ is_train=True,
456
+ mean=image_mean,
457
+ std=image_std
458
+ )
459
+ preprocess_val = image_transform(
460
+ model.visual.image_size,
461
+ is_train=False,
462
+ mean=image_mean,
463
+ std=image_std
464
+ )
465
+ del model
466
+
467
+ return preprocess_train, preprocess_val
468
+
469
+ def create_model_from_pretrained(
470
+ model_name: str,
471
+ pretrained: str,
472
+ precision: str = 'fp32',
473
+ device: Union[str, torch.device] = 'cpu',
474
+ jit: bool = False,
475
+ force_quick_gelu: bool = False,
476
+ force_custom_clip: bool = False,
477
+ force_patch_dropout: Optional[float] = None,
478
+ return_transform: bool = True,
479
+ image_mean: Optional[Tuple[float, ...]] = None,
480
+ image_std: Optional[Tuple[float, ...]] = None,
481
+ cache_dir: Optional[str] = None,
482
+ is_frozen: bool = False,
483
+ ):
484
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
+ raise RuntimeError(
486
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
+ f' Use open_clip.list_pretrained() to find one.')
488
+
489
+ model = create_model(
490
+ model_name,
491
+ pretrained,
492
+ precision=precision,
493
+ device=device,
494
+ jit=jit,
495
+ force_quick_gelu=force_quick_gelu,
496
+ force_custom_clip=force_custom_clip,
497
+ force_patch_dropout=force_patch_dropout,
498
+ cache_dir=cache_dir,
499
+ )
500
+
501
+ if is_frozen:
502
+ for param in model.parameters():
503
+ param.requires_grad = False
504
+
505
+ if not return_transform:
506
+ return model
507
+
508
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
+ image_std = image_std or getattr(model.visual, 'image_std', None)
510
+ preprocess = image_transform(
511
+ model.visual.image_size,
512
+ is_train=False,
513
+ mean=image_mean,
514
+ std=image_std
515
+ )
516
+
517
+ return model, preprocess
models/eva_clip/hf_configs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "bert": {
46
+ "config_names": {
47
+ "context_length": "max_position_embeddings",
48
+ "vocab_size": "vocab_size",
49
+ "width": "hidden_size",
50
+ "heads": "num_attention_heads",
51
+ "layers": "num_hidden_layers",
52
+ "layer_attr": "layer",
53
+ "token_embeddings_attr": "embeddings"
54
+ },
55
+ "pooler": "mean_pooler",
56
+ }
57
+ }
models/eva_clip/hf_model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch import TensorType
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+ # TODO: ?last - for gpt-like models
35
+ _POOLERS = {}
36
+
37
+ def register_pooler(cls):
38
+ """Decorator registering pooler class"""
39
+ _POOLERS[_camel2snake(cls.__name__)] = cls
40
+ return cls
41
+
42
+
43
+ @register_pooler
44
+ class MeanPooler(nn.Module):
45
+ """Mean pooling"""
46
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
+
50
+ @register_pooler
51
+ class MaxPooler(nn.Module):
52
+ """Max pooling"""
53
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
+ return masked_output.max(1).values
56
+
57
+ @register_pooler
58
+ class ClsPooler(nn.Module):
59
+ """CLS token pooling"""
60
+ def __init__(self, use_pooler_output=True):
61
+ super().__init__()
62
+ self.cls_token_position = 0
63
+ self.use_pooler_output = use_pooler_output
64
+
65
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
+
67
+ if (self.use_pooler_output and
68
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
+ (x.pooler_output is not None)
70
+ ):
71
+ return x.pooler_output
72
+
73
+ return x.last_hidden_state[:, self.cls_token_position, :]
74
+
75
+ class HFTextEncoder(nn.Module):
76
+ """HuggingFace model adapter"""
77
+ def __init__(
78
+ self,
79
+ model_name_or_path: str,
80
+ output_dim: int,
81
+ tokenizer_name: str = None,
82
+ config: PretrainedConfig = None,
83
+ pooler_type: str = None,
84
+ proj: str = None,
85
+ pretrained: bool = True,
86
+ masked_language_modeling: bool = False):
87
+ super().__init__()
88
+
89
+ self.output_dim = output_dim
90
+
91
+ # TODO: find better way to get this information
92
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
93
+
94
+ if transformers is None:
95
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
+ if config is None:
97
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
98
+ if masked_language_modeling:
99
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
+ AutoModelForMaskedLM.from_config, self.config)
101
+ else:
102
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
+ AutoModel.from_config, self.config)
104
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
+ self.transformer = create_func(model_args)
107
+ self.transformer = self.transformer.encoder
108
+ else:
109
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
+ else:
111
+ self.config = config
112
+ if masked_language_modeling:
113
+ self.transformer = AutoModelForMaskedLM.from_config(config)
114
+ else:
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
+
139
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
+ # attn_mask = (x != self.config.pad_token_id).long()
142
+ # out = self.transformer(
143
+ # input_ids=x,
144
+ # attention_mask=attn_mask,
145
+ # encoder_hidden_states = image_embeds,
146
+ # encoder_attention_mask = image_atts,
147
+ # )
148
+ # pooled_out = self.pooler(out, attn_mask)
149
+
150
+ # return self.itm_proj(pooled_out)
151
+
152
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
+ if masked_indices is None:
154
+ masked_indices = torch.bernoulli(probability_matrix).bool()
155
+
156
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
+
159
+ if targets is not None:
160
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
+
162
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
+
166
+ # 10% of the time, we replace masked input tokens with random word
167
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
+ input_ids[indices_random] = random_words[indices_random]
170
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
+
172
+ if targets is not None:
173
+ return input_ids, targets
174
+ else:
175
+ return input_ids
176
+
177
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
+ labels = input_ids.clone()
179
+ attn_mask = (input_ids != self.config.pad_token_id).long()
180
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
+ probability_matrix = torch.full(labels.shape, mlm_probability)
183
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
+ probability_matrix = probability_matrix)
185
+ mlm_output = self.transformer(input_ids,
186
+ attention_mask = attn_mask,
187
+ encoder_hidden_states = image_embeds,
188
+ encoder_attention_mask = image_atts,
189
+ return_dict = True,
190
+ labels = labels,
191
+ )
192
+ return mlm_output.loss
193
+ # mlm_output = self.transformer(input_ids,
194
+ # attention_mask = attn_mask,
195
+ # encoder_hidden_states = image_embeds,
196
+ # encoder_attention_mask = image_atts,
197
+ # return_dict = True,
198
+ # ).last_hidden_state
199
+ # logits = self.mlm_proj(mlm_output)
200
+
201
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
+ # labels = labels[:, 1:].contiguous().view(-1)
204
+
205
+ # mlm_loss = F.cross_entropy(
206
+ # logits,
207
+ # labels,
208
+ # # label_smoothing=0.1,
209
+ # )
210
+ # return mlm_loss
211
+
212
+
213
+ def forward(self, x:TensorType) -> TensorType:
214
+ attn_mask = (x != self.config.pad_token_id).long()
215
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
+ pooled_out = self.pooler(out, attn_mask)
217
+
218
+ return self.proj(pooled_out)
219
+
220
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
+ if not unlocked_layers: # full freezing
222
+ for n, p in self.transformer.named_parameters():
223
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
+ return
225
+
226
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
+ embeddings = getattr(
230
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
+ modules = [embeddings, *layer_list][:-unlocked_layers]
232
+ # freeze layers
233
+ for module in modules:
234
+ for n, p in module.named_parameters():
235
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
+
237
+
238
+ @torch.jit.ignore
239
+ def set_grad_checkpointing(self, enable=True):
240
+ self.transformer.gradient_checkpointing_enable()
241
+
242
+ def get_num_layers(self):
243
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
+ return len(layer_list)
246
+
247
+ def init_parameters(self):
248
+ pass
models/eva_clip/loss.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+ from timm.loss import LabelSmoothingCrossEntropy
19
+
20
+
21
+ def gather_features(
22
+ image_features,
23
+ text_features,
24
+ local_loss=False,
25
+ gather_with_grad=False,
26
+ rank=0,
27
+ world_size=1,
28
+ use_horovod=False
29
+ ):
30
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
+ if use_horovod:
32
+ assert hvd is not None, 'Please install horovod'
33
+ if gather_with_grad:
34
+ all_image_features = hvd.allgather(image_features)
35
+ all_text_features = hvd.allgather(text_features)
36
+ else:
37
+ with torch.no_grad():
38
+ all_image_features = hvd.allgather(image_features)
39
+ all_text_features = hvd.allgather(text_features)
40
+ if not local_loss:
41
+ # ensure grads for local rank when all_* features don't have a gradient
42
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
+ gathered_image_features[rank] = image_features
45
+ gathered_text_features[rank] = text_features
46
+ all_image_features = torch.cat(gathered_image_features, dim=0)
47
+ all_text_features = torch.cat(gathered_text_features, dim=0)
48
+ else:
49
+ # We gather tensors from all gpus
50
+ if gather_with_grad:
51
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
+ else:
56
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
+ dist.all_gather(gathered_image_features, image_features)
59
+ dist.all_gather(gathered_text_features, text_features)
60
+ if not local_loss:
61
+ # ensure grads for local rank when all_* features don't have a gradient
62
+ gathered_image_features[rank] = image_features
63
+ gathered_text_features[rank] = text_features
64
+ all_image_features = torch.cat(gathered_image_features, dim=0)
65
+ all_text_features = torch.cat(gathered_text_features, dim=0)
66
+
67
+ return all_image_features, all_text_features
68
+
69
+
70
+ class ClipLoss(nn.Module):
71
+
72
+ def __init__(
73
+ self,
74
+ local_loss=False,
75
+ gather_with_grad=False,
76
+ cache_labels=False,
77
+ rank=0,
78
+ world_size=1,
79
+ use_horovod=False,
80
+ smoothing=0.,
81
+ ):
82
+ super().__init__()
83
+ self.local_loss = local_loss
84
+ self.gather_with_grad = gather_with_grad
85
+ self.cache_labels = cache_labels
86
+ self.rank = rank
87
+ self.world_size = world_size
88
+ self.use_horovod = use_horovod
89
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
+
91
+ # cache state
92
+ self.prev_num_logits = 0
93
+ self.labels = {}
94
+
95
+ def forward(self, image_features, text_features, logit_scale=1.):
96
+ device = image_features.device
97
+ if self.world_size > 1:
98
+ all_image_features, all_text_features = gather_features(
99
+ image_features, text_features,
100
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
+
102
+ if self.local_loss:
103
+ logits_per_image = logit_scale * image_features @ all_text_features.T
104
+ logits_per_text = logit_scale * text_features @ all_image_features.T
105
+ else:
106
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
+ logits_per_text = logits_per_image.T
108
+ else:
109
+ logits_per_image = logit_scale * image_features @ text_features.T
110
+ logits_per_text = logit_scale * text_features @ image_features.T
111
+ # calculated ground-truth and cache if enabled
112
+ num_logits = logits_per_image.shape[0]
113
+ if self.prev_num_logits != num_logits or device not in self.labels:
114
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
+ if self.world_size > 1 and self.local_loss:
116
+ labels = labels + num_logits * self.rank
117
+ if self.cache_labels:
118
+ self.labels[device] = labels
119
+ self.prev_num_logits = num_logits
120
+ else:
121
+ labels = self.labels[device]
122
+
123
+ if self.label_smoothing_cross_entropy:
124
+ total_loss = (
125
+ self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
+ self.label_smoothing_cross_entropy(logits_per_text, labels)
127
+ ) / 2
128
+ else:
129
+ total_loss = (
130
+ F.cross_entropy(logits_per_image, labels) +
131
+ F.cross_entropy(logits_per_text, labels)
132
+ ) / 2
133
+
134
+ acc = None
135
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
+ return total_loss, acc
models/eva_clip/model.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ try:
16
+ from .hf_model import HFTextEncoder
17
+ except:
18
+ HFTextEncoder = None
19
+ from .modified_resnet import ModifiedResNet
20
+ from .timm_model import TimmModel
21
+ from .eva_vit_model import EVAVisionTransformer
22
+ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
+
24
+ try:
25
+ from apex.normalization import FusedLayerNorm
26
+ except:
27
+ FusedLayerNorm = LayerNorm
28
+ print("Please 'pip install apex'")
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ @dataclass
37
+ class CLIPVisionCfg:
38
+ layers: Union[Tuple[int, int, int, int], int] = 12
39
+ width: int = 768
40
+ head_width: int = 64
41
+ mlp_ratio: float = 4.0
42
+ patch_size: int = 16
43
+ image_size: Union[Tuple[int, int], int] = 224
44
+ ls_init_value: Optional[float] = None # layer scale initial value
45
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
+ drop_path_rate: Optional[float] = None # drop path rate
48
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
+ timm_proj_bias: bool = False # enable bias final projection
53
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
+ qkv_bias: bool = True
55
+ fusedLN: bool = False
56
+ xattn: bool = False
57
+ postnorm: bool = False
58
+ rope: bool = False
59
+ pt_hw_seq_len: int = 16 # 224/14
60
+ intp_freq: bool = False
61
+ naiveswiglu: bool = False
62
+ subln: bool = False
63
+
64
+
65
+ @dataclass
66
+ class CLIPTextCfg:
67
+ context_length: int = 77
68
+ vocab_size: int = 49408
69
+ width: int = 512
70
+ heads: int = 8
71
+ layers: int = 12
72
+ ls_init_value: Optional[float] = None # layer scale initial value
73
+ hf_model_name: str = None
74
+ hf_tokenizer_name: str = None
75
+ hf_model_pretrained: bool = True
76
+ proj: str = 'mlp'
77
+ pooler_type: str = 'mean_pooler'
78
+ masked_language_modeling: bool = False
79
+ fusedLN: bool = False
80
+ xattn: bool = False
81
+ attn_mask: bool = True
82
+
83
+ def get_cast_dtype(precision: str):
84
+ cast_dtype = None
85
+ if precision == 'bf16':
86
+ cast_dtype = torch.bfloat16
87
+ elif precision == 'fp16':
88
+ cast_dtype = torch.float16
89
+ return cast_dtype
90
+
91
+
92
+ def _build_vision_tower(
93
+ embed_dim: int,
94
+ vision_cfg: CLIPVisionCfg,
95
+ quick_gelu: bool = False,
96
+ cast_dtype: Optional[torch.dtype] = None
97
+ ):
98
+ if isinstance(vision_cfg, dict):
99
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
100
+
101
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
+ # memory efficient in recent PyTorch releases (>= 1.10).
103
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
+ act_layer = QuickGELU if quick_gelu else nn.GELU
105
+
106
+ if vision_cfg.eva_model_name:
107
+ vision_heads = vision_cfg.width // vision_cfg.head_width
108
+ norm_layer = LayerNorm
109
+
110
+ visual = EVAVisionTransformer(
111
+ img_size=vision_cfg.image_size,
112
+ patch_size=vision_cfg.patch_size,
113
+ num_classes=embed_dim,
114
+ use_mean_pooling=vision_cfg.global_average_pool, #False
115
+ init_values=vision_cfg.ls_init_value,
116
+ patch_dropout=vision_cfg.patch_dropout,
117
+ embed_dim=vision_cfg.width,
118
+ depth=vision_cfg.layers,
119
+ num_heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ qkv_bias=vision_cfg.qkv_bias,
122
+ drop_path_rate=vision_cfg.drop_path_rate,
123
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
+ xattn=vision_cfg.xattn,
125
+ rope=vision_cfg.rope,
126
+ postnorm=vision_cfg.postnorm,
127
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
+ intp_freq= vision_cfg.intp_freq,
129
+ naiveswiglu= vision_cfg.naiveswiglu,
130
+ subln= vision_cfg.subln
131
+ )
132
+ elif vision_cfg.timm_model_name:
133
+ visual = TimmModel(
134
+ vision_cfg.timm_model_name,
135
+ pretrained=vision_cfg.timm_model_pretrained,
136
+ pool=vision_cfg.timm_pool,
137
+ proj=vision_cfg.timm_proj,
138
+ proj_bias=vision_cfg.timm_proj_bias,
139
+ embed_dim=embed_dim,
140
+ image_size=vision_cfg.image_size
141
+ )
142
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
+ elif isinstance(vision_cfg.layers, (tuple, list)):
144
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
+ visual = ModifiedResNet(
146
+ layers=vision_cfg.layers,
147
+ output_dim=embed_dim,
148
+ heads=vision_heads,
149
+ image_size=vision_cfg.image_size,
150
+ width=vision_cfg.width
151
+ )
152
+ else:
153
+ vision_heads = vision_cfg.width // vision_cfg.head_width
154
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
155
+ visual = VisionTransformer(
156
+ image_size=vision_cfg.image_size,
157
+ patch_size=vision_cfg.patch_size,
158
+ width=vision_cfg.width,
159
+ layers=vision_cfg.layers,
160
+ heads=vision_heads,
161
+ mlp_ratio=vision_cfg.mlp_ratio,
162
+ ls_init_value=vision_cfg.ls_init_value,
163
+ patch_dropout=vision_cfg.patch_dropout,
164
+ global_average_pool=vision_cfg.global_average_pool,
165
+ output_dim=embed_dim,
166
+ act_layer=act_layer,
167
+ norm_layer=norm_layer,
168
+ )
169
+
170
+ return visual
171
+
172
+
173
+ def _build_text_tower(
174
+ embed_dim: int,
175
+ text_cfg: CLIPTextCfg,
176
+ quick_gelu: bool = False,
177
+ cast_dtype: Optional[torch.dtype] = None,
178
+ ):
179
+ if isinstance(text_cfg, dict):
180
+ text_cfg = CLIPTextCfg(**text_cfg)
181
+
182
+ if text_cfg.hf_model_name:
183
+ text = HFTextEncoder(
184
+ text_cfg.hf_model_name,
185
+ output_dim=embed_dim,
186
+ tokenizer_name=text_cfg.hf_tokenizer_name,
187
+ proj=text_cfg.proj,
188
+ pooler_type=text_cfg.pooler_type,
189
+ masked_language_modeling=text_cfg.masked_language_modeling
190
+ )
191
+ else:
192
+ act_layer = QuickGELU if quick_gelu else nn.GELU
193
+ norm_layer = LayerNorm
194
+
195
+ text = TextTransformer(
196
+ context_length=text_cfg.context_length,
197
+ vocab_size=text_cfg.vocab_size,
198
+ width=text_cfg.width,
199
+ heads=text_cfg.heads,
200
+ layers=text_cfg.layers,
201
+ ls_init_value=text_cfg.ls_init_value,
202
+ output_dim=embed_dim,
203
+ act_layer=act_layer,
204
+ norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
205
+ xattn=text_cfg.xattn,
206
+ attn_mask=text_cfg.attn_mask,
207
+ )
208
+ return text
209
+
210
+ class CLIP(nn.Module):
211
+ def __init__(
212
+ self,
213
+ embed_dim: int,
214
+ vision_cfg: CLIPVisionCfg,
215
+ text_cfg: CLIPTextCfg,
216
+ quick_gelu: bool = False,
217
+ cast_dtype: Optional[torch.dtype] = None,
218
+ ):
219
+ super().__init__()
220
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
221
+
222
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
223
+ self.transformer = text.transformer
224
+ self.vocab_size = text.vocab_size
225
+ self.token_embedding = text.token_embedding
226
+ self.positional_embedding = text.positional_embedding
227
+ self.ln_final = text.ln_final
228
+ self.text_projection = text.text_projection
229
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
230
+
231
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
232
+
233
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
234
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
235
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
236
+
237
+ @torch.jit.ignore
238
+ def set_grad_checkpointing(self, enable=True):
239
+ self.visual.set_grad_checkpointing(enable)
240
+ self.transformer.grad_checkpointing = enable
241
+
242
+ @torch.jit.ignore
243
+ def no_weight_decay(self):
244
+ return {'logit_scale'}
245
+
246
+ def encode_image(self, image, normalize: bool = False):
247
+ features = self.visual(image)
248
+ return F.normalize(features, dim=-1) if normalize else features
249
+
250
+ def encode_text(self, text, normalize: bool = False):
251
+ cast_dtype = self.transformer.get_cast_dtype()
252
+
253
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
254
+
255
+ x = x + self.positional_embedding.to(cast_dtype)
256
+ x = x.permute(1, 0, 2) # NLD -> LND
257
+ x = self.transformer(x, attn_mask=self.attn_mask)
258
+ x = x.permute(1, 0, 2) # LND -> NLD
259
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
260
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
261
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
262
+ return F.normalize(x, dim=-1) if normalize else x
263
+
264
+ def forward(self, image, text):
265
+ image_features = self.encode_image(image, normalize=True)
266
+ text_features = self.encode_text(text, normalize=True)
267
+ return image_features, text_features, self.logit_scale.exp()
268
+
269
+
270
+ class CustomCLIP(nn.Module):
271
+ def __init__(
272
+ self,
273
+ embed_dim: int,
274
+ vision_cfg: CLIPVisionCfg,
275
+ text_cfg: CLIPTextCfg,
276
+ quick_gelu: bool = False,
277
+ cast_dtype: Optional[torch.dtype] = None,
278
+ itm_task: bool = False,
279
+ ):
280
+ super().__init__()
281
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
283
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
284
+
285
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
286
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
287
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
288
+
289
+ def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
290
+ self.text.lock(unlocked_layers, freeze_layer_norm)
291
+
292
+ @torch.jit.ignore
293
+ def set_grad_checkpointing(self, enable=True):
294
+ self.visual.set_grad_checkpointing(enable)
295
+ self.text.set_grad_checkpointing(enable)
296
+
297
+ @torch.jit.ignore
298
+ def no_weight_decay(self):
299
+ return {'logit_scale'}
300
+
301
+ def encode_image(self, image, normalize: bool = False):
302
+ features = self.visual(image)
303
+ return F.normalize(features, dim=-1) if normalize else features
304
+
305
+ def encode_text(self, text, normalize: bool = False):
306
+ features = self.text(text)
307
+ return F.normalize(features, dim=-1) if normalize else features
308
+
309
+ def forward(self, image, text):
310
+ image_features = self.encode_image(image, normalize=True)
311
+ text_features = self.encode_text(text, normalize=True)
312
+ return image_features, text_features, self.logit_scale.exp()
313
+
314
+
315
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
316
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
317
+
318
+ def _convert_weights(l):
319
+
320
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
321
+ l.weight.data = l.weight.data.to(dtype)
322
+ if l.bias is not None:
323
+ l.bias.data = l.bias.data.to(dtype)
324
+
325
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
326
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
327
+ tensor = getattr(l, attr, None)
328
+ if tensor is not None:
329
+ tensor.data = tensor.data.to(dtype)
330
+
331
+ if isinstance(l, nn.Parameter):
332
+ l.data = l.data.to(dtype)
333
+
334
+ for name in ["text_projection", "proj"]:
335
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
336
+ attr = getattr(l, name, None)
337
+ if attr is not None:
338
+ attr.data = attr.data.to(dtype)
339
+
340
+ model.apply(_convert_weights)
341
+
342
+
343
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
344
+
345
+
346
+ # used to maintain checkpoint compatibility
347
+ def convert_to_custom_text_state_dict(state_dict: dict):
348
+ if 'text_projection' in state_dict:
349
+ # old format state_dict, move text tower -> .text
350
+ new_state_dict = {}
351
+ for k, v in state_dict.items():
352
+ if any(k.startswith(p) for p in (
353
+ 'text_projection',
354
+ 'positional_embedding',
355
+ 'token_embedding',
356
+ 'transformer',
357
+ 'ln_final',
358
+ 'logit_scale'
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
models/eva_clip/model_configs/EVA01-CLIP-B-16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16,
8
+ "eva_model_name": "eva-clip-b-16",
9
+ "ls_init_value": 0.1,
10
+ "drop_path_rate": 0.0
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12
18
+ }
19
+ }
models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 24,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
models/eva_clip/model_configs/EVA01-CLIP-g-14.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0.4,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
models/eva_clip/model_configs/EVA02-CLIP-B-16.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "head_width": 64,
8
+ "patch_size": 16,
9
+ "mlp_ratio": 2.6667,
10
+ "eva_model_name": "eva-clip-b-16-X",
11
+ "drop_path_rate": 0.0,
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 512,
24
+ "heads": 8,
25
+ "layers": 12,
26
+ "xattn": true,
27
+ "fusedLN": true
28
+ }
29
+ }
models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14-336",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
models/eva_clip/model_configs/EVA02-CLIP-L-14.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1280,
20
+ "heads": 20,
21
+ "layers": 32,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1024,
20
+ "heads": 16,
21
+ "layers": 24,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
models/eva_clip/modified_resnet.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from collections import OrderedDict
8
+
9
+ current_file_path = os.path.abspath(__file__)
10
+ project_roots = [os.path.dirname(current_file_path)]
11
+ for project_root in project_roots:
12
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
13
+
14
+ from utils import freeze_batch_norm_2d
15
+
16
+
17
+ class Bottleneck(nn.Module):
18
+ expansion = 4
19
+
20
+ def __init__(self, inplanes, planes, stride=1):
21
+ super().__init__()
22
+
23
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
24
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+ self.act1 = nn.ReLU(inplace=True)
27
+
28
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
29
+ self.bn2 = nn.BatchNorm2d(planes)
30
+ self.act2 = nn.ReLU(inplace=True)
31
+
32
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
33
+
34
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
35
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
36
+ self.act3 = nn.ReLU(inplace=True)
37
+
38
+ self.downsample = None
39
+ self.stride = stride
40
+
41
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
42
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
43
+ self.downsample = nn.Sequential(OrderedDict([
44
+ ("-1", nn.AvgPool2d(stride)),
45
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
46
+ ("1", nn.BatchNorm2d(planes * self.expansion))
47
+ ]))
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ identity = x
51
+
52
+ out = self.act1(self.bn1(self.conv1(x)))
53
+ out = self.act2(self.bn2(self.conv2(out)))
54
+ out = self.avgpool(out)
55
+ out = self.bn3(self.conv3(out))
56
+
57
+ if self.downsample is not None:
58
+ identity = self.downsample(x)
59
+
60
+ out += identity
61
+ out = self.act3(out)
62
+ return out
63
+
64
+
65
+ class AttentionPool2d(nn.Module):
66
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
67
+ super().__init__()
68
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
69
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
71
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
72
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
73
+ self.num_heads = num_heads
74
+
75
+ def forward(self, x):
76
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
77
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
78
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
79
+ x, _ = F.multi_head_attention_forward(
80
+ query=x, key=x, value=x,
81
+ embed_dim_to_check=x.shape[-1],
82
+ num_heads=self.num_heads,
83
+ q_proj_weight=self.q_proj.weight,
84
+ k_proj_weight=self.k_proj.weight,
85
+ v_proj_weight=self.v_proj.weight,
86
+ in_proj_weight=None,
87
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
88
+ bias_k=None,
89
+ bias_v=None,
90
+ add_zero_attn=False,
91
+ dropout_p=0.,
92
+ out_proj_weight=self.c_proj.weight,
93
+ out_proj_bias=self.c_proj.bias,
94
+ use_separate_proj_weight=True,
95
+ training=self.training,
96
+ need_weights=False
97
+ )
98
+
99
+ return x[0]
100
+
101
+
102
+ class ModifiedResNet(nn.Module):
103
+ """
104
+ A ResNet class that is similar to torchvision's but contains the following changes:
105
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
106
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
107
+ - The final pooling layer is a QKV attention instead of an average pool
108
+ """
109
+
110
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
111
+ super().__init__()
112
+ self.output_dim = output_dim
113
+ self.image_size = image_size
114
+
115
+ # the 3-layer stem
116
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(width // 2)
118
+ self.act1 = nn.ReLU(inplace=True)
119
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
120
+ self.bn2 = nn.BatchNorm2d(width // 2)
121
+ self.act2 = nn.ReLU(inplace=True)
122
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
123
+ self.bn3 = nn.BatchNorm2d(width)
124
+ self.act3 = nn.ReLU(inplace=True)
125
+ self.avgpool = nn.AvgPool2d(2)
126
+
127
+ # residual layers
128
+ self._inplanes = width # this is a *mutable* variable used during construction
129
+ self.layer1 = self._make_layer(width, layers[0])
130
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
131
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
132
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
133
+
134
+ embed_dim = width * 32 # the ResNet feature dimension
135
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
136
+
137
+ self.init_parameters()
138
+
139
+ def _make_layer(self, planes, blocks, stride=1):
140
+ layers = [Bottleneck(self._inplanes, planes, stride)]
141
+
142
+ self._inplanes = planes * Bottleneck.expansion
143
+ for _ in range(1, blocks):
144
+ layers.append(Bottleneck(self._inplanes, planes))
145
+
146
+ return nn.Sequential(*layers)
147
+
148
+ def init_parameters(self):
149
+ if self.attnpool is not None:
150
+ std = self.attnpool.c_proj.in_features ** -0.5
151
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
152
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
153
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
154
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
155
+
156
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
157
+ for name, param in resnet_block.named_parameters():
158
+ if name.endswith("bn3.weight"):
159
+ nn.init.zeros_(param)
160
+
161
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
162
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
163
+ for param in self.parameters():
164
+ param.requires_grad = False
165
+ if freeze_bn_stats:
166
+ freeze_batch_norm_2d(self)
167
+
168
+ @torch.jit.ignore
169
+ def set_grad_checkpointing(self, enable=True):
170
+ # FIXME support for non-transformer
171
+ pass
172
+
173
+ def stem(self, x):
174
+ x = self.act1(self.bn1(self.conv1(x)))
175
+ x = self.act2(self.bn2(self.conv2(x)))
176
+ x = self.act3(self.bn3(self.conv3(x)))
177
+ x = self.avgpool(x)
178
+ return x
179
+
180
+ def forward(self, x):
181
+ x = self.stem(x)
182
+ x = self.layer1(x)
183
+ x = self.layer2(x)
184
+ x = self.layer3(x)
185
+ x = self.layer4(x)
186
+ x = self.attnpool(x)
187
+
188
+ return x
models/eva_clip/openai.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_models_by_tag('openai')
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ precision: Optional[str] = None,
26
+ device: Optional[Union[str, torch.device]] = None,
27
+ jit: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+ cache_dir : Optional[str]
43
+ The directory to cache the downloaded model weights
44
+
45
+ Returns
46
+ -------
47
+ model : torch.nn.Module
48
+ The CLIP model
49
+ preprocess : Callable[[PIL.Image], torch.Tensor]
50
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
+ """
52
+ if device is None:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if precision is None:
55
+ precision = 'fp32' if device == 'cpu' else 'fp16'
56
+
57
+ if get_pretrained_url(name, 'openai'):
58
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
+ elif os.path.isfile(name):
60
+ model_path = name
61
+ else:
62
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
+
64
+ try:
65
+ # loading JIT archive
66
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
+ state_dict = None
68
+ except RuntimeError:
69
+ # loading saved state dict
70
+ if jit:
71
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
+ jit = False
73
+ state_dict = torch.load(model_path, map_location="cpu")
74
+
75
+ if not jit:
76
+ # Build a non-jit model from the OpenAI jitted model state dict
77
+ cast_dtype = get_cast_dtype(precision)
78
+ try:
79
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
+ except KeyError:
81
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
+
84
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
+ model = model.to(device)
86
+ if precision.startswith('amp') or precision == 'fp32':
87
+ model.float()
88
+ elif precision == 'bf16':
89
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
90
+
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
+
97
+ def patch_device(module):
98
+ try:
99
+ graphs = [module.graph] if hasattr(module, "graph") else []
100
+ except RuntimeError:
101
+ graphs = []
102
+
103
+ if hasattr(module, "forward1"):
104
+ graphs.append(module.forward1.graph)
105
+
106
+ for graph in graphs:
107
+ for node in graph.findAllNodes("prim::Constant"):
108
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
+ node.copyAttributes(device_node)
110
+
111
+ model.apply(patch_device)
112
+ patch_device(model.encode_image)
113
+ patch_device(model.encode_text)
114
+
115
+ # patch dtype to float32 (typically for CPU)
116
+ if precision == 'fp32':
117
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
+ float_node = float_input.node()
120
+
121
+ def patch_float(module):
122
+ try:
123
+ graphs = [module.graph] if hasattr(module, "graph") else []
124
+ except RuntimeError:
125
+ graphs = []
126
+
127
+ if hasattr(module, "forward1"):
128
+ graphs.append(module.forward1.graph)
129
+
130
+ for graph in graphs:
131
+ for node in graph.findAllNodes("aten::to"):
132
+ inputs = list(node.inputs())
133
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
+ if inputs[i].node()["value"] == 5:
135
+ inputs[i].node().copyAttributes(float_node)
136
+
137
+ model.apply(patch_float)
138
+ patch_float(model.encode_image)
139
+ patch_float(model.encode_text)
140
+ model.float()
141
+
142
+ # ensure image_size attr available at consistent location for both jit and non-jit
143
+ model.visual.image_size = model.input_resolution.item()
144
+ return model
models/eva_clip/pretrained.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ try:
11
+ from huggingface_hub import hf_hub_download
12
+ _has_hf_hub = True
13
+ except ImportError:
14
+ hf_hub_download = None
15
+ _has_hf_hub = False
16
+
17
+
18
+ def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
+ return dict(
20
+ url=url,
21
+ hf_hub=hf_hub,
22
+ mean=mean,
23
+ std=std,
24
+ )
25
+
26
+ _VITB32 = dict(
27
+ openai=_pcfg(
28
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
+ laion400m_e31=_pcfg(
30
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
+ laion400m_e32=_pcfg(
32
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
+ laion2b_e16=_pcfg(
34
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
+ )
37
+
38
+ _VITB32_quickgelu = dict(
39
+ openai=_pcfg(
40
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
+ laion400m_e31=_pcfg(
42
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
+ laion400m_e32=_pcfg(
44
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
+ )
46
+
47
+ _VITB16 = dict(
48
+ openai=_pcfg(
49
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
+ laion400m_e31=_pcfg(
51
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
+ laion400m_e32=_pcfg(
53
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
+ )
56
+
57
+ _EVAB16 = dict(
58
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
+ )
63
+
64
+ _VITB16_PLUS_240 = dict(
65
+ laion400m_e31=_pcfg(
66
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
+ laion400m_e32=_pcfg(
68
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
+ )
70
+
71
+ _VITL14 = dict(
72
+ openai=_pcfg(
73
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
+ laion400m_e31=_pcfg(
75
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
+ laion400m_e32=_pcfg(
77
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
+ laion2b_s32b_b82k=_pcfg(
79
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
+ )
82
+
83
+ _EVAL14 = dict(
84
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
+ )
89
+
90
+ _VITL14_336 = dict(
91
+ openai=_pcfg(
92
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
+ )
94
+
95
+ _EVAL14_336 = dict(
96
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
+ eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
+ eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
+ )
101
+
102
+ _VITH14 = dict(
103
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
+ )
105
+
106
+ _VITg14 = dict(
107
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
+ )
110
+
111
+ _EVAg14 = dict(
112
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
+ )
117
+
118
+ _EVAg14_PLUS = dict(
119
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
+ )
124
+
125
+ _VITbigG14 = dict(
126
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
+ )
128
+
129
+ _EVAbigE14 = dict(
130
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
+ )
135
+
136
+ _EVAbigE14_PLUS = dict(
137
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
+ )
142
+
143
+
144
+ _PRETRAINED = {
145
+ # "ViT-B-32": _VITB32,
146
+ "OpenaiCLIP-B-32": _VITB32,
147
+ "OpenCLIP-B-32": _VITB32,
148
+
149
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
+
153
+ # "ViT-B-16": _VITB16,
154
+ "OpenaiCLIP-B-16": _VITB16,
155
+ "OpenCLIP-B-16": _VITB16,
156
+
157
+ "EVA02-B-16": _EVAB16,
158
+ "EVA02-CLIP-B-16": _EVAB16,
159
+
160
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
+
163
+ # "ViT-L-14": _VITL14,
164
+ "OpenaiCLIP-L-14": _VITL14,
165
+ "OpenCLIP-L-14": _VITL14,
166
+
167
+ "EVA02-L-14": _EVAL14,
168
+ "EVA02-CLIP-L-14": _EVAL14,
169
+
170
+ # "ViT-L-14-336": _VITL14_336,
171
+ "OpenaiCLIP-L-14-336": _VITL14_336,
172
+
173
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
174
+
175
+ # "ViT-H-14": _VITH14,
176
+ # "ViT-g-14": _VITg14,
177
+ "OpenCLIP-H-14": _VITH14,
178
+ "OpenCLIP-g-14": _VITg14,
179
+
180
+ "EVA01-CLIP-g-14": _EVAg14,
181
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
+
183
+ # "ViT-bigG-14": _VITbigG14,
184
+ "OpenCLIP-bigG-14": _VITbigG14,
185
+
186
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
187
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
+ }
189
+
190
+
191
+ def _clean_tag(tag: str):
192
+ # normalize pretrained tags
193
+ return tag.lower().replace('-', '_')
194
+
195
+
196
+ def list_pretrained(as_str: bool = False):
197
+ """ returns list of pretrained models
198
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
+ """
200
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
+
202
+
203
+ def list_pretrained_models_by_tag(tag: str):
204
+ """ return all models having the specified pretrain tag """
205
+ models = []
206
+ tag = _clean_tag(tag)
207
+ for k in _PRETRAINED.keys():
208
+ if tag in _PRETRAINED[k]:
209
+ models.append(k)
210
+ return models
211
+
212
+
213
+ def list_pretrained_tags_by_model(model: str):
214
+ """ return all pretrain tags for the specified model architecture """
215
+ tags = []
216
+ if model in _PRETRAINED:
217
+ tags.extend(_PRETRAINED[model].keys())
218
+ return tags
219
+
220
+
221
+ def is_pretrained_cfg(model: str, tag: str):
222
+ if model not in _PRETRAINED:
223
+ return False
224
+ return _clean_tag(tag) in _PRETRAINED[model]
225
+
226
+
227
+ def get_pretrained_cfg(model: str, tag: str):
228
+ if model not in _PRETRAINED:
229
+ return {}
230
+ model_pretrained = _PRETRAINED[model]
231
+ return model_pretrained.get(_clean_tag(tag), {})
232
+
233
+
234
+ def get_pretrained_url(model: str, tag: str):
235
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
+ return cfg.get('url', '')
237
+
238
+
239
+ def download_pretrained_from_url(
240
+ url: str,
241
+ cache_dir: Union[str, None] = None,
242
+ ):
243
+ if not cache_dir:
244
+ cache_dir = os.path.expanduser("~/.cache/clip")
245
+ os.makedirs(cache_dir, exist_ok=True)
246
+ filename = os.path.basename(url)
247
+
248
+ if 'openaipublic' in url:
249
+ expected_sha256 = url.split("/")[-2]
250
+ elif 'mlfoundations' in url:
251
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
+ else:
253
+ expected_sha256 = ''
254
+
255
+ download_target = os.path.join(cache_dir, filename)
256
+
257
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
258
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
259
+
260
+ if os.path.isfile(download_target):
261
+ if expected_sha256:
262
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
+ return download_target
264
+ else:
265
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
+ else:
267
+ return download_target
268
+
269
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
+ while True:
272
+ buffer = source.read(8192)
273
+ if not buffer:
274
+ break
275
+
276
+ output.write(buffer)
277
+ loop.update(len(buffer))
278
+
279
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
+
282
+ return download_target
283
+
284
+
285
+ def has_hf_hub(necessary=False):
286
+ if not _has_hf_hub and necessary:
287
+ # if no HF Hub module installed, and it is necessary to continue, raise error
288
+ raise RuntimeError(
289
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
+ return _has_hf_hub
291
+
292
+
293
+ def download_pretrained_from_hf(
294
+ model_id: str,
295
+ filename: str = 'open_clip_pytorch_model.bin',
296
+ revision=None,
297
+ cache_dir: Union[str, None] = None,
298
+ ):
299
+ has_hf_hub(True)
300
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
+ return cached_file
302
+
303
+
304
+ def download_pretrained(
305
+ cfg: Dict,
306
+ force_hf_hub: bool = False,
307
+ cache_dir: Union[str, None] = None,
308
+ ):
309
+ target = ''
310
+ if not cfg:
311
+ return target
312
+
313
+ download_url = cfg.get('url', '')
314
+ download_hf_hub = cfg.get('hf_hub', '')
315
+ if download_hf_hub and force_hf_hub:
316
+ # use HF hub even if url exists
317
+ download_url = ''
318
+
319
+ if download_url:
320
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
+ elif download_hf_hub:
322
+ has_hf_hub(True)
323
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
+ model_id, filename = os.path.split(download_hf_hub)
327
+ if filename:
328
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
+ else:
330
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
+
332
+ return target
models/eva_clip/rope.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+
30
+ class VisionRotaryEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ pt_seq_len,
35
+ ft_seq_len=None,
36
+ custom_freqs = None,
37
+ freqs_for = 'lang',
38
+ theta = 10000,
39
+ max_freq = 10,
40
+ num_freqs = 1,
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
+
60
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
+
63
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
+
65
+ self.register_buffer("freqs_cos", freqs.cos())
66
+ self.register_buffer("freqs_sin", freqs.sin())
67
+
68
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
+
70
+ def forward(self, t, start_index = 0):
71
+ rot_dim = self.freqs_cos.shape[-1]
72
+ end_index = start_index + rot_dim
73
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
+
77
+ return torch.cat((t_left, t, t_right), dim = -1)
78
+
79
+ class VisionRotaryEmbeddingFast(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ pt_seq_len,
84
+ ft_seq_len=None,
85
+ custom_freqs = None,
86
+ freqs_for = 'lang',
87
+ theta = 10000,
88
+ max_freq = 10,
89
+ num_freqs = 1,
90
+ patch_dropout = 0.
91
+ ):
92
+ super().__init__()
93
+ if custom_freqs:
94
+ freqs = custom_freqs
95
+ elif freqs_for == 'lang':
96
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
+ elif freqs_for == 'pixel':
98
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
+ elif freqs_for == 'constant':
100
+ freqs = torch.ones(num_freqs).float()
101
+ else:
102
+ raise ValueError(f'unknown modality {freqs_for}')
103
+
104
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
105
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
+
107
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
108
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
+
111
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
+
114
+ self.patch_dropout = patch_dropout
115
+
116
+ self.register_buffer("freqs_cos", freqs_cos)
117
+ self.register_buffer("freqs_sin", freqs_sin)
118
+
119
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
+
121
+ def forward(self, t, patch_indices_keep=None):
122
+ if patch_indices_keep is not None:
123
+ batch = t.size()[0]
124
+ batch_indices = torch.arange(batch)
125
+ batch_indices = batch_indices[..., None]
126
+
127
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
+
130
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
+
135
+ return t * freqs_cos + rotate_half(t) * freqs_sin
136
+
137
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
models/eva_clip/timm_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from .utils import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name,
36
+ embed_dim,
37
+ image_size=224,
38
+ pool='avg',
39
+ proj='linear',
40
+ proj_bias=False,
41
+ drop=0.,
42
+ pretrained=False):
43
+ super().__init__()
44
+ if timm is None:
45
+ raise RuntimeError("Please `pip install timm` to use timm models.")
46
+
47
+ self.image_size = to_2tuple(image_size)
48
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
49
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
50
+ feature_ndim = 1 if not feat_size else 2
51
+ if pool in ('abs_attn', 'rot_attn'):
52
+ assert feature_ndim == 2
53
+ # if attn pooling used, remove both classifier and default pool
54
+ self.trunk.reset_classifier(0, global_pool='')
55
+ else:
56
+ # reset global pool if pool config set, otherwise leave as network default
57
+ reset_kwargs = dict(global_pool=pool) if pool else {}
58
+ self.trunk.reset_classifier(0, **reset_kwargs)
59
+ prev_chs = self.trunk.num_features
60
+
61
+ head_layers = OrderedDict()
62
+ if pool == 'abs_attn':
63
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64
+ prev_chs = embed_dim
65
+ elif pool == 'rot_attn':
66
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67
+ prev_chs = embed_dim
68
+ else:
69
+ assert proj, 'projection layer needed if non-attention pooling is used.'
70
+
71
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72
+ if proj == 'linear':
73
+ head_layers['drop'] = nn.Dropout(drop)
74
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75
+ elif proj == 'mlp':
76
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77
+
78
+ self.head = nn.Sequential(head_layers)
79
+
80
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81
+ """ lock modules
82
+ Args:
83
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84
+ """
85
+ if not unlocked_groups:
86
+ # lock full model
87
+ for param in self.trunk.parameters():
88
+ param.requires_grad = False
89
+ if freeze_bn_stats:
90
+ freeze_batch_norm_2d(self.trunk)
91
+ else:
92
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93
+ try:
94
+ # FIXME import here until API stable and in an official release
95
+ from timm.models.helpers import group_parameters, group_modules
96
+ except ImportError:
97
+ raise RuntimeError(
98
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99
+ matcher = self.trunk.group_matcher()
100
+ gparams = group_parameters(self.trunk, matcher)
101
+ max_layer_id = max(gparams.keys())
102
+ max_layer_id = max_layer_id - unlocked_groups
103
+ for group_idx in range(max_layer_id + 1):
104
+ group = gparams[group_idx]
105
+ for param in group:
106
+ self.trunk.get_parameter(param).requires_grad = False
107
+ if freeze_bn_stats:
108
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
109
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110
+ freeze_batch_norm_2d(self.trunk, gmodules)
111
+
112
+ @torch.jit.ignore
113
+ def set_grad_checkpointing(self, enable=True):
114
+ try:
115
+ self.trunk.set_grad_checkpointing(enable)
116
+ except Exception as e:
117
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118
+
119
+ def forward(self, x):
120
+ x = self.trunk(x)
121
+ x = self.head(x)
122
+ return x
models/eva_clip/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+
156
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
+ """
158
+ Returns the tokenized representation of given input string(s)
159
+
160
+ Parameters
161
+ ----------
162
+ texts : Union[str, List[str]]
163
+ An input string or a list of input strings to tokenize
164
+ context_length : int
165
+ The context length to use; all CLIP models use 77 as the context length
166
+
167
+ Returns
168
+ -------
169
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
+ """
171
+ if isinstance(texts, str):
172
+ texts = [texts]
173
+
174
+ sot_token = _tokenizer.encoder["<start_of_text>"]
175
+ eot_token = _tokenizer.encoder["<end_of_text>"]
176
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
+
179
+ for i, tokens in enumerate(all_tokens):
180
+ if len(tokens) > context_length:
181
+ tokens = tokens[:context_length] # Truncate
182
+ tokens[-1] = eot_token
183
+ result[i, :len(tokens)] = torch.tensor(tokens)
184
+
185
+ return result
186
+
187
+
188
+ class HFTokenizer:
189
+ "HuggingFace tokenizer wrapper"
190
+ def __init__(self, tokenizer_name:str):
191
+ from transformers import AutoTokenizer
192
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
+
194
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
+ # same cleaning as for default tokenizer, except lowercasing
196
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
+ if isinstance(texts, str):
198
+ texts = [texts]
199
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
+ return input_ids
models/eva_clip/transform.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
+ CenterCrop
9
+
10
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
+
12
+
13
+ class ResizeMaxSize(nn.Module):
14
+
15
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
+ super().__init__()
17
+ if not isinstance(max_size, int):
18
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
19
+ self.max_size = max_size
20
+ self.interpolation = interpolation
21
+ self.fn = min if fn == 'min' else min
22
+ self.fill = fill
23
+
24
+ def forward(self, img):
25
+ if isinstance(img, torch.Tensor):
26
+ height, width = img.shape[:2]
27
+ else:
28
+ width, height = img.size
29
+ scale = self.max_size / float(max(height, width))
30
+ if scale != 1.0:
31
+ new_size = tuple(round(dim * scale) for dim in (height, width))
32
+ img = F.resize(img, new_size, self.interpolation)
33
+ pad_h = self.max_size - new_size[0]
34
+ pad_w = self.max_size - new_size[1]
35
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
+ return img
37
+
38
+
39
+ def _convert_to_rgb(image):
40
+ return image.convert('RGB')
41
+
42
+
43
+ # class CatGen(nn.Module):
44
+ # def __init__(self, num=4):
45
+ # self.num = num
46
+ # def mixgen_batch(image, text):
47
+ # batch_size = image.shape[0]
48
+ # index = np.random.permutation(batch_size)
49
+
50
+ # cat_images = []
51
+ # for i in range(batch_size):
52
+ # # image mixup
53
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
+ # # text concat
55
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
+ # text = torch.stack(text)
57
+ # return image, text
58
+
59
+
60
+ def image_transform(
61
+ image_size: int,
62
+ is_train: bool,
63
+ mean: Optional[Tuple[float, ...]] = None,
64
+ std: Optional[Tuple[float, ...]] = None,
65
+ resize_longest_max: bool = False,
66
+ fill_color: int = 0,
67
+ ):
68
+ mean = mean or OPENAI_DATASET_MEAN
69
+ if not isinstance(mean, (list, tuple)):
70
+ mean = (mean,) * 3
71
+
72
+ std = std or OPENAI_DATASET_STD
73
+ if not isinstance(std, (list, tuple)):
74
+ std = (std,) * 3
75
+
76
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
+ image_size = image_size[0]
79
+
80
+ normalize = Normalize(mean=mean, std=std)
81
+ if is_train:
82
+ return Compose([
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ])
88
+ else:
89
+ if resize_longest_max:
90
+ transforms = [
91
+ ResizeMaxSize(image_size, fill=fill_color)
92
+ ]
93
+ else:
94
+ transforms = [
95
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
+ CenterCrop(image_size),
97
+ ]
98
+ transforms.extend([
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ])
103
+ return Compose(transforms)
models/eva_clip/transformer.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Sequence
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ try:
12
+ from timm.models.layers import trunc_normal_
13
+ except:
14
+ from timm.layers import trunc_normal_
15
+
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+ from .utils import to_2tuple
18
+
19
+ if os.getenv('ENV_TYPE') == 'deepspeed':
20
+ try:
21
+ import deepspeed
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
+ except:
24
+ print("Please 'pip install deepspeed'")
25
+ deepspeed = None
26
+ from torch.utils.checkpoint import checkpoint
27
+ else:
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ class LayerNormFp32(nn.LayerNorm):
37
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ output = F.layer_norm(
43
+ x.float(),
44
+ self.normalized_shape,
45
+ self.weight.float() if self.weight is not None else None,
46
+ self.bias.float() if self.bias is not None else None,
47
+ self.eps,
48
+ )
49
+ return output.type_as(x)
50
+
51
+
52
+ class LayerNorm(nn.LayerNorm):
53
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ orig_type = x.dtype
57
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
+ return x.to(orig_type)
59
+
60
+ class QuickGELU(nn.Module):
61
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
62
+ def forward(self, x: torch.Tensor):
63
+ return x * torch.sigmoid(1.702 * x)
64
+
65
+
66
+ class LayerScale(nn.Module):
67
+ def __init__(self, dim, init_values=1e-5, inplace=False):
68
+ super().__init__()
69
+ self.inplace = inplace
70
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
71
+
72
+ def forward(self, x):
73
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
+
75
+ class PatchDropout(nn.Module):
76
+ """
77
+ https://arxiv.org/abs/2212.00794
78
+ """
79
+
80
+ def __init__(self, prob, exclude_first_token=True):
81
+ super().__init__()
82
+ assert 0 <= prob < 1.
83
+ self.prob = prob
84
+ self.exclude_first_token = exclude_first_token # exclude CLS token
85
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
86
+
87
+ def forward(self, x):
88
+ if not self.training or self.prob == 0.:
89
+ return x
90
+
91
+ if self.exclude_first_token:
92
+ cls_tokens, x = x[:, :1], x[:, 1:]
93
+ else:
94
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
95
+
96
+ batch = x.size()[0]
97
+ num_tokens = x.size()[1]
98
+
99
+ batch_indices = torch.arange(batch)
100
+ batch_indices = batch_indices[..., None]
101
+
102
+ keep_prob = 1 - self.prob
103
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
104
+
105
+ rand = torch.randn(batch, num_tokens)
106
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
107
+
108
+ x = x[batch_indices, patch_indices_keep]
109
+
110
+ if self.exclude_first_token:
111
+ x = torch.cat((cls_tokens, x), dim=1)
112
+
113
+ if self.training and os.getenv('RoPE') == '1':
114
+ return x, patch_indices_keep
115
+
116
+ return x
117
+
118
+
119
+ def _in_projection_packed(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ w: torch.Tensor,
124
+ b: Optional[torch.Tensor] = None,
125
+ ):
126
+ """
127
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
128
+ """
129
+ E = q.size(-1)
130
+ if k is v:
131
+ if q is k:
132
+ # self-attention
133
+ return F.linear(q, w, b).chunk(3, dim=-1)
134
+ else:
135
+ # encoder-decoder attention
136
+ w_q, w_kv = w.split([E, E * 2])
137
+ if b is None:
138
+ b_q = b_kv = None
139
+ else:
140
+ b_q, b_kv = b.split([E, E * 2])
141
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
142
+ else:
143
+ w_q, w_k, w_v = w.chunk(3)
144
+ if b is None:
145
+ b_q = b_k = b_v = None
146
+ else:
147
+ b_q, b_k, b_v = b.chunk(3)
148
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim,
154
+ num_heads=8,
155
+ qkv_bias=True,
156
+ scaled_cosine=False,
157
+ scale_heads=False,
158
+ logit_scale_max=math.log(1. / 0.01),
159
+ attn_drop=0.,
160
+ proj_drop=0.,
161
+ xattn=False,
162
+ rope=False
163
+ ):
164
+ super().__init__()
165
+ self.scaled_cosine = scaled_cosine
166
+ self.scale_heads = scale_heads
167
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168
+ self.num_heads = num_heads
169
+ self.head_dim = dim // num_heads
170
+ self.scale = self.head_dim ** -0.5
171
+ self.logit_scale_max = logit_scale_max
172
+
173
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
174
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
175
+ if qkv_bias:
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
177
+ else:
178
+ self.in_proj_bias = None
179
+
180
+ if self.scaled_cosine:
181
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
182
+ else:
183
+ self.logit_scale = None
184
+ self.attn_drop = nn.Dropout(attn_drop)
185
+ if self.scale_heads:
186
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
187
+ else:
188
+ self.head_scale = None
189
+ self.out_proj = nn.Linear(dim, dim)
190
+ self.out_drop = nn.Dropout(proj_drop)
191
+ self.xattn = xattn
192
+ self.xattn_drop = attn_drop
193
+ self.rope = rope
194
+
195
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
196
+ L, N, C = x.shape
197
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
198
+ if self.xattn:
199
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
200
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
201
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
202
+
203
+ x = xops.memory_efficient_attention(
204
+ q, k, v,
205
+ p=self.xattn_drop,
206
+ scale=self.scale if self.logit_scale is None else None,
207
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
208
+ )
209
+ else:
210
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
211
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
212
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
213
+
214
+ if self.logit_scale is not None:
215
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
216
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
217
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
218
+ attn = attn.view(-1, L, L)
219
+ else:
220
+ q = q * self.scale
221
+ attn = torch.bmm(q, k.transpose(-1, -2))
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
226
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
227
+ attn_mask = new_attn_mask
228
+ attn += attn_mask
229
+
230
+ attn = attn.softmax(dim=-1)
231
+ attn = self.attn_drop(attn)
232
+
233
+ x = torch.bmm(attn, v)
234
+
235
+ if self.head_scale is not None:
236
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
237
+ x = x.view(-1, L, C)
238
+ x = x.transpose(0, 1).reshape(L, N, C)
239
+ x = self.out_proj(x)
240
+ x = self.out_drop(x)
241
+ return x
242
+
243
+ class CustomAttention(nn.Module):
244
+ def __init__(
245
+ self,
246
+ dim,
247
+ num_heads=8,
248
+ qkv_bias=True,
249
+ scaled_cosine=True,
250
+ scale_heads=False,
251
+ logit_scale_max=math.log(1. / 0.01),
252
+ attn_drop=0.,
253
+ proj_drop=0.,
254
+ xattn=False
255
+ ):
256
+ super().__init__()
257
+ self.scaled_cosine = scaled_cosine
258
+ self.scale_heads = scale_heads
259
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
260
+ self.num_heads = num_heads
261
+ self.head_dim = dim // num_heads
262
+ self.scale = self.head_dim ** -0.5
263
+ self.logit_scale_max = logit_scale_max
264
+
265
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
266
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
267
+ if qkv_bias:
268
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
269
+ else:
270
+ self.in_proj_bias = None
271
+
272
+ if self.scaled_cosine:
273
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
274
+ else:
275
+ self.logit_scale = None
276
+ self.attn_drop = nn.Dropout(attn_drop)
277
+ if self.scale_heads:
278
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
279
+ else:
280
+ self.head_scale = None
281
+ self.out_proj = nn.Linear(dim, dim)
282
+ self.out_drop = nn.Dropout(proj_drop)
283
+ self.xattn = xattn
284
+ self.xattn_drop = attn_drop
285
+
286
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
287
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
288
+ N_q, B_q, C_q = q.shape
289
+ N_k, B_k, C_k = k.shape
290
+ N_v, B_v, C_v = v.shape
291
+ if self.xattn:
292
+ # B, N, C -> B, N, num_heads, C
293
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
294
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
295
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
296
+
297
+ x = xops.memory_efficient_attention(
298
+ q, k, v,
299
+ p=self.xattn_drop,
300
+ scale=self.scale if self.logit_scale is None else None,
301
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
302
+ )
303
+ else:
304
+ # B*H, L, C
305
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
306
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
307
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
308
+
309
+ if self.logit_scale is not None:
310
+ # B*H, N_q, N_k
311
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
312
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
313
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
314
+ attn = attn.view(-1, N_q, N_k)
315
+ else:
316
+ q = q * self.scale
317
+ attn = torch.bmm(q, k.transpose(-1, -2))
318
+
319
+ if attn_mask is not None:
320
+ if attn_mask.dtype == torch.bool:
321
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
322
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
323
+ attn_mask = new_attn_mask
324
+ attn += attn_mask
325
+
326
+ attn = attn.softmax(dim=-1)
327
+ attn = self.attn_drop(attn)
328
+
329
+ x = torch.bmm(attn, v)
330
+
331
+ if self.head_scale is not None:
332
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
333
+ x = x.view(-1, N_q, C_q)
334
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
335
+ x = self.out_proj(x)
336
+ x = self.out_drop(x)
337
+ return x
338
+
339
+ class CustomResidualAttentionBlock(nn.Module):
340
+ def __init__(
341
+ self,
342
+ d_model: int,
343
+ n_head: int,
344
+ mlp_ratio: float = 4.0,
345
+ ls_init_value: float = None,
346
+ act_layer: Callable = nn.GELU,
347
+ norm_layer: Callable = LayerNorm,
348
+ scale_cosine_attn: bool = False,
349
+ scale_heads: bool = False,
350
+ scale_attn: bool = False,
351
+ scale_fc: bool = False,
352
+ cross_attn: bool = False,
353
+ xattn: bool = False,
354
+ ):
355
+ super().__init__()
356
+
357
+ self.ln_1 = norm_layer(d_model)
358
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
359
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
360
+ self.attn = CustomAttention(
361
+ d_model, n_head,
362
+ qkv_bias=True,
363
+ attn_drop=0.,
364
+ proj_drop=0.,
365
+ scaled_cosine=scale_cosine_attn,
366
+ scale_heads=scale_heads,
367
+ xattn=xattn
368
+ )
369
+
370
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
371
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
372
+
373
+ self.ln_2 = norm_layer(d_model)
374
+ mlp_width = int(d_model * mlp_ratio)
375
+ self.mlp = nn.Sequential(OrderedDict([
376
+ ("c_fc", nn.Linear(d_model, mlp_width)),
377
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
378
+ ("gelu", act_layer()),
379
+ ("c_proj", nn.Linear(mlp_width, d_model))
380
+ ]))
381
+
382
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
383
+
384
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
385
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
386
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
387
+ return q
388
+
389
+ class CustomTransformer(nn.Module):
390
+ def __init__(
391
+ self,
392
+ width: int,
393
+ layers: int,
394
+ heads: int,
395
+ mlp_ratio: float = 4.0,
396
+ ls_init_value: float = None,
397
+ act_layer: Callable = nn.GELU,
398
+ norm_layer: Callable = LayerNorm,
399
+ scale_cosine_attn: bool = True,
400
+ scale_heads: bool = False,
401
+ scale_attn: bool = False,
402
+ scale_fc: bool = False,
403
+ cross_attn: bool = False,
404
+ xattn: bool = False,
405
+ ):
406
+ super().__init__()
407
+ self.width = width
408
+ self.layers = layers
409
+ self.grad_checkpointing = False
410
+ self.xattn = xattn
411
+
412
+ self.resblocks = nn.ModuleList([
413
+ CustomResidualAttentionBlock(
414
+ width,
415
+ heads,
416
+ mlp_ratio,
417
+ ls_init_value=ls_init_value,
418
+ act_layer=act_layer,
419
+ norm_layer=norm_layer,
420
+ scale_cosine_attn=scale_cosine_attn,
421
+ scale_heads=scale_heads,
422
+ scale_attn=scale_attn,
423
+ scale_fc=scale_fc,
424
+ cross_attn=cross_attn,
425
+ xattn=xattn)
426
+ for _ in range(layers)
427
+ ])
428
+
429
+ def get_cast_dtype(self) -> torch.dtype:
430
+ return self.resblocks[0].mlp.c_fc.weight.dtype
431
+
432
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
433
+ if k is None and v is None:
434
+ k = v = q
435
+ for r in self.resblocks:
436
+ if self.grad_checkpointing and not torch.jit.is_scripting():
437
+ q = checkpoint(r, q, k, v, attn_mask)
438
+ else:
439
+ q = r(q, k, v, attn_mask=attn_mask)
440
+ return q
441
+
442
+
443
+ class ResidualAttentionBlock(nn.Module):
444
+ def __init__(
445
+ self,
446
+ d_model: int,
447
+ n_head: int,
448
+ mlp_ratio: float = 4.0,
449
+ ls_init_value: float = None,
450
+ act_layer: Callable = nn.GELU,
451
+ norm_layer: Callable = LayerNorm,
452
+ xattn: bool = False,
453
+ ):
454
+ super().__init__()
455
+
456
+ self.ln_1 = norm_layer(d_model)
457
+ if xattn:
458
+ self.attn = Attention(d_model, n_head, xattn=True)
459
+ else:
460
+ self.attn = nn.MultiheadAttention(d_model, n_head)
461
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
462
+
463
+ self.ln_2 = norm_layer(d_model)
464
+ mlp_width = int(d_model * mlp_ratio)
465
+ self.mlp = nn.Sequential(OrderedDict([
466
+ ("c_fc", nn.Linear(d_model, mlp_width)),
467
+ ("gelu", act_layer()),
468
+ ("c_proj", nn.Linear(mlp_width, d_model))
469
+ ]))
470
+
471
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
472
+ self.xattn = xattn
473
+
474
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
475
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
476
+ if self.xattn:
477
+ return self.attn(x, attn_mask=attn_mask)
478
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
479
+
480
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
481
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
482
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
483
+ return x
484
+
485
+ class Transformer(nn.Module):
486
+ def __init__(
487
+ self,
488
+ width: int,
489
+ layers: int,
490
+ heads: int,
491
+ mlp_ratio: float = 4.0,
492
+ ls_init_value: float = None,
493
+ act_layer: Callable = nn.GELU,
494
+ norm_layer: Callable = LayerNorm,
495
+ xattn: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.width = width
499
+ self.layers = layers
500
+ self.grad_checkpointing = False
501
+
502
+ self.resblocks = nn.ModuleList([
503
+ ResidualAttentionBlock(
504
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
505
+ for _ in range(layers)
506
+ ])
507
+
508
+ def get_cast_dtype(self) -> torch.dtype:
509
+ return self.resblocks[0].mlp.c_fc.weight.dtype
510
+
511
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
512
+ for r in self.resblocks:
513
+ if self.grad_checkpointing and not torch.jit.is_scripting():
514
+ x = checkpoint(r, x, attn_mask)
515
+ else:
516
+ x = r(x, attn_mask=attn_mask)
517
+ return x
518
+
519
+
520
+ class VisionTransformer(nn.Module):
521
+ def __init__(
522
+ self,
523
+ image_size: int,
524
+ patch_size: int,
525
+ width: int,
526
+ layers: int,
527
+ heads: int,
528
+ mlp_ratio: float,
529
+ ls_init_value: float = None,
530
+ patch_dropout: float = 0.,
531
+ global_average_pool: bool = False,
532
+ output_dim: int = 512,
533
+ act_layer: Callable = nn.GELU,
534
+ norm_layer: Callable = LayerNorm,
535
+ xattn: bool = False,
536
+ ):
537
+ super().__init__()
538
+ self.image_size = to_2tuple(image_size)
539
+ self.patch_size = to_2tuple(patch_size)
540
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
541
+ self.output_dim = output_dim
542
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
543
+
544
+ scale = width ** -0.5
545
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
546
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
547
+
548
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
549
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
550
+ self.ln_pre = norm_layer(width)
551
+
552
+ self.transformer = Transformer(
553
+ width,
554
+ layers,
555
+ heads,
556
+ mlp_ratio,
557
+ ls_init_value=ls_init_value,
558
+ act_layer=act_layer,
559
+ norm_layer=norm_layer,
560
+ xattn=xattn
561
+ )
562
+
563
+ self.global_average_pool = global_average_pool
564
+ self.ln_post = norm_layer(width)
565
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
566
+
567
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
568
+ for param in self.parameters():
569
+ param.requires_grad = False
570
+
571
+ if unlocked_groups != 0:
572
+ groups = [
573
+ [
574
+ self.conv1,
575
+ self.class_embedding,
576
+ self.positional_embedding,
577
+ self.ln_pre,
578
+ ],
579
+ *self.transformer.resblocks[:-1],
580
+ [
581
+ self.transformer.resblocks[-1],
582
+ self.ln_post,
583
+ ],
584
+ self.proj,
585
+ ]
586
+
587
+ def _unlock(x):
588
+ if isinstance(x, Sequence):
589
+ for g in x:
590
+ _unlock(g)
591
+ else:
592
+ if isinstance(x, torch.nn.Parameter):
593
+ x.requires_grad = True
594
+ else:
595
+ for p in x.parameters():
596
+ p.requires_grad = True
597
+
598
+ _unlock(groups[-unlocked_groups:])
599
+
600
+ def get_num_layers(self):
601
+ return self.transformer.layers
602
+
603
+ @torch.jit.ignore
604
+ def set_grad_checkpointing(self, enable=True):
605
+ self.transformer.grad_checkpointing = enable
606
+
607
+ @torch.jit.ignore
608
+ def no_weight_decay(self):
609
+ return {'positional_embedding', 'class_embedding'}
610
+
611
+ def forward(self, x: torch.Tensor, return_all_features: bool=False):
612
+ x = self.conv1(x) # shape = [*, width, grid, grid]
613
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
614
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
615
+ x = torch.cat(
616
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
617
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
618
+ x = x + self.positional_embedding.to(x.dtype)
619
+
620
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
621
+ x = self.patch_dropout(x)
622
+ x = self.ln_pre(x)
623
+
624
+ x = x.permute(1, 0, 2) # NLD -> LND
625
+ x = self.transformer(x)
626
+ x = x.permute(1, 0, 2) # LND -> NLD
627
+
628
+ if not return_all_features:
629
+ if self.global_average_pool:
630
+ x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
631
+ else:
632
+ x = x[:, 0]
633
+
634
+ x = self.ln_post(x)
635
+
636
+ if self.proj is not None:
637
+ x = x @ self.proj
638
+
639
+ return x
640
+
641
+
642
+ class TextTransformer(nn.Module):
643
+ def __init__(
644
+ self,
645
+ context_length: int = 77,
646
+ vocab_size: int = 49408,
647
+ width: int = 512,
648
+ heads: int = 8,
649
+ layers: int = 12,
650
+ ls_init_value: float = None,
651
+ output_dim: int = 512,
652
+ act_layer: Callable = nn.GELU,
653
+ norm_layer: Callable = LayerNorm,
654
+ xattn: bool= False,
655
+ attn_mask: bool = True
656
+ ):
657
+ super().__init__()
658
+ self.context_length = context_length
659
+ self.vocab_size = vocab_size
660
+ self.width = width
661
+ self.output_dim = output_dim
662
+
663
+ self.token_embedding = nn.Embedding(vocab_size, width)
664
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
665
+ self.transformer = Transformer(
666
+ width=width,
667
+ layers=layers,
668
+ heads=heads,
669
+ ls_init_value=ls_init_value,
670
+ act_layer=act_layer,
671
+ norm_layer=norm_layer,
672
+ xattn=xattn
673
+ )
674
+
675
+ self.xattn = xattn
676
+ self.ln_final = norm_layer(width)
677
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
678
+
679
+ if attn_mask:
680
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
681
+ else:
682
+ self.attn_mask = None
683
+
684
+ self.init_parameters()
685
+
686
+ def init_parameters(self):
687
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
688
+ nn.init.normal_(self.positional_embedding, std=0.01)
689
+
690
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
691
+ attn_std = self.transformer.width ** -0.5
692
+ fc_std = (2 * self.transformer.width) ** -0.5
693
+ for block in self.transformer.resblocks:
694
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
695
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
696
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
697
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
698
+
699
+ if self.text_projection is not None:
700
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
701
+
702
+ @torch.jit.ignore
703
+ def set_grad_checkpointing(self, enable=True):
704
+ self.transformer.grad_checkpointing = enable
705
+
706
+ @torch.jit.ignore
707
+ def no_weight_decay(self):
708
+ # return {'positional_embedding', 'token_embedding'}
709
+ return {'positional_embedding'}
710
+
711
+ def get_num_layers(self):
712
+ return self.transformer.layers
713
+
714
+ def build_attention_mask(self):
715
+ # lazily create causal attention mask, with full attention between the vision tokens
716
+ # pytorch uses additive attention mask; fill with -inf
717
+ mask = torch.empty(self.context_length, self.context_length)
718
+ mask.fill_(float("-inf"))
719
+ mask.triu_(1) # zero out the lower diagonal
720
+ return mask
721
+
722
+ def forward(self, text, return_all_features: bool=False):
723
+ cast_dtype = self.transformer.get_cast_dtype()
724
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
725
+
726
+ x = x + self.positional_embedding.to(cast_dtype)
727
+ x = x.permute(1, 0, 2) # NLD -> LND
728
+ x = self.transformer(x, attn_mask=self.attn_mask)
729
+ # x = self.transformer(x) # no attention mask is applied
730
+ x = x.permute(1, 0, 2) # LND -> NLD
731
+ x = self.ln_final(x)
732
+
733
+ if not return_all_features:
734
+ # x.shape = [batch_size, n_ctx, transformer.width]
735
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
736
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
737
+ return x
models/eva_clip/utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn as nn
9
+ from torchvision.ops.misc import FrozenBatchNorm2d
10
+ import torch.nn.functional as F
11
+
12
+ # open CLIP
13
+ def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
+ # Rescale the grid of position embeddings when loading from state_dict
15
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
+ return
18
+ grid_size = to_2tuple(model.visual.grid_size)
19
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
+ if new_seq_len == old_pos_embed.shape[0]:
22
+ return
23
+
24
+ if extra_tokens:
25
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
+ else:
27
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
28
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
+
30
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
+ pos_emb_img = F.interpolate(
33
+ pos_emb_img,
34
+ size=grid_size,
35
+ mode=interpolation,
36
+ align_corners=True,
37
+ )
38
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
+ if pos_emb_tok is not None:
40
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
+ else:
42
+ new_pos_embed = pos_emb_img
43
+ state_dict['visual.positional_embedding'] = new_pos_embed
44
+
45
+
46
+ def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
+ # Rescale the grid of position embeddings when loading from state_dict
48
+ old_pos_embed = state_dict.get('positional_embedding', None)
49
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
+ return
51
+ grid_size = to_2tuple(model.visual.grid_size)
52
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
+ if new_seq_len == old_pos_embed.shape[0]:
55
+ return
56
+
57
+ if extra_tokens:
58
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
+ else:
60
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
61
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
+
63
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
+ pos_emb_img = F.interpolate(
66
+ pos_emb_img,
67
+ size=grid_size,
68
+ mode=interpolation,
69
+ align_corners=True,
70
+ )
71
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
+ if pos_emb_tok is not None:
73
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
+ else:
75
+ new_pos_embed = pos_emb_img
76
+ state_dict['positional_embedding'] = new_pos_embed
77
+
78
+ def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
+ all_keys = list(state_dict.keys())
80
+ # interpolate position embedding
81
+ if 'visual.pos_embed' in state_dict:
82
+ pos_embed_checkpoint = state_dict['visual.pos_embed']
83
+ embedding_size = pos_embed_checkpoint.shape[-1]
84
+ num_patches = model.visual.patch_embed.num_patches
85
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
+ # height (== width) for the checkpoint position embedding
87
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
+ # height (== width) for the new position embedding
89
+ new_size = int(num_patches ** 0.5)
90
+ # class_token and dist_token are kept unchanged
91
+ if orig_size != new_size:
92
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
+ # only the position tokens are interpolated
95
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
+ pos_tokens = torch.nn.functional.interpolate(
98
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
99
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
+ state_dict['visual.pos_embed'] = new_pos_embed
102
+
103
+ patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
+ patch_size = model.visual.patch_embed.patch_size
105
+ state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
+
108
+
109
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
+ all_keys = list(state_dict.keys())
111
+ # interpolate position embedding
112
+ if 'pos_embed' in state_dict:
113
+ pos_embed_checkpoint = state_dict['pos_embed']
114
+ embedding_size = pos_embed_checkpoint.shape[-1]
115
+ num_patches = model.visual.patch_embed.num_patches
116
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
+ # height (== width) for the checkpoint position embedding
118
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
+ # height (== width) for the new position embedding
120
+ new_size = int(num_patches ** 0.5)
121
+ # class_token and dist_token are kept unchanged
122
+ if orig_size != new_size:
123
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
+ # only the position tokens are interpolated
126
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
+ pos_tokens = torch.nn.functional.interpolate(
129
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
+ state_dict['pos_embed'] = new_pos_embed
133
+
134
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
135
+ patch_size = model.visual.patch_embed.patch_size
136
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
+
139
+
140
+ def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
+ all_keys = list(state_dict.keys())
142
+ for key in all_keys:
143
+ if "relative_position_index" in key:
144
+ state_dict.pop(key)
145
+
146
+ if "relative_position_bias_table" in key:
147
+ rel_pos_bias = state_dict[key]
148
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
149
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
150
+ dst_patch_shape = model.visual.patch_embed.patch_shape
151
+ if dst_patch_shape[0] != dst_patch_shape[1]:
152
+ raise NotImplementedError()
153
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
+ if src_size != dst_size:
157
+ print("Position interpolate for %s from %dx%d to %dx%d" % (
158
+ key, src_size, src_size, dst_size, dst_size))
159
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
+
162
+ def geometric_progression(a, r, n):
163
+ return a * (1.0 - r ** n) / (1.0 - r)
164
+
165
+ left, right = 1.01, 1.5
166
+ while right - left > 1e-6:
167
+ q = (left + right) / 2.0
168
+ gp = geometric_progression(1, q, src_size // 2)
169
+ if gp > dst_size // 2:
170
+ right = q
171
+ else:
172
+ left = q
173
+
174
+ # if q > 1.090307:
175
+ # q = 1.090307
176
+
177
+ dis = []
178
+ cur = 1
179
+ for i in range(src_size // 2):
180
+ dis.append(cur)
181
+ cur += q ** (i + 1)
182
+
183
+ r_ids = [-_ for _ in reversed(dis)]
184
+
185
+ x = r_ids + [0] + dis
186
+ y = r_ids + [0] + dis
187
+
188
+ t = dst_size // 2.0
189
+ dx = np.arange(-t, t + 0.1, 1.0)
190
+ dy = np.arange(-t, t + 0.1, 1.0)
191
+
192
+ print("Original positions = %s" % str(x))
193
+ print("Target positions = %s" % str(dx))
194
+
195
+ all_rel_pos_bias = []
196
+
197
+ for i in range(num_attn_heads):
198
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
+ f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
+ all_rel_pos_bias.append(
201
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
+
203
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
+
205
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
+ state_dict[key] = new_rel_pos_bias
207
+
208
+ # interpolate position embedding
209
+ if 'pos_embed' in state_dict:
210
+ pos_embed_checkpoint = state_dict['pos_embed']
211
+ embedding_size = pos_embed_checkpoint.shape[-1]
212
+ num_patches = model.visual.patch_embed.num_patches
213
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
+ # height (== width) for the checkpoint position embedding
215
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
+ # height (== width) for the new position embedding
217
+ new_size = int(num_patches ** 0.5)
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
+ pos_tokens = torch.nn.functional.interpolate(
226
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
+ state_dict['pos_embed'] = new_pos_embed
230
+
231
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
232
+ patch_size = model.visual.patch_embed.patch_size
233
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
+
236
+
237
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
238
+ """
239
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
+
243
+ Args:
244
+ module (torch.nn.Module): Any PyTorch module.
245
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
246
+ name (str): Full module name (prefix)
247
+
248
+ Returns:
249
+ torch.nn.Module: Resulting module
250
+
251
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
+ """
253
+ res = module
254
+ is_match = True
255
+ if module_match:
256
+ is_match = name in module_match
257
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
+ res = FrozenBatchNorm2d(module.num_features)
259
+ res.num_features = module.num_features
260
+ res.affine = module.affine
261
+ if module.affine:
262
+ res.weight.data = module.weight.data.clone().detach()
263
+ res.bias.data = module.bias.data.clone().detach()
264
+ res.running_mean.data = module.running_mean.data
265
+ res.running_var.data = module.running_var.data
266
+ res.eps = module.eps
267
+ else:
268
+ for child_name, child in module.named_children():
269
+ full_child_name = '.'.join([name, child_name]) if name else child_name
270
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
+ if new_child is not child:
272
+ res.add_module(child_name, new_child)
273
+ return res
274
+
275
+
276
+ # From PyTorch internals
277
+ def _ntuple(n):
278
+ def parse(x):
279
+ if isinstance(x, collections.abc.Iterable):
280
+ return x
281
+ return tuple(repeat(x, n))
282
+ return parse
283
+
284
+
285
+ to_1tuple = _ntuple(1)
286
+ to_2tuple = _ntuple(2)
287
+ to_3tuple = _ntuple(3)
288
+ to_4tuple = _ntuple(4)
289
+ to_ntuple = lambda n, x: _ntuple(n)(x)
290
+
291
+
292
+ def is_logging(args):
293
+ def is_global_master(args):
294
+ return args.rank == 0
295
+
296
+ def is_local_master(args):
297
+ return args.local_rank == 0
298
+
299
+ def is_master(args, local=False):
300
+ return is_local_master(args) if local else is_global_master(args)
301
+ return is_master
302
+
303
+
304
+ class AllGather(torch.autograd.Function):
305
+ """An autograd function that performs allgather on a tensor.
306
+ Performs all_gather operation on the provided tensors.
307
+ *** Warning ***: torch.distributed.all_gather has no gradient.
308
+ """
309
+
310
+ @staticmethod
311
+ def forward(ctx, tensor, rank, world_size):
312
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
+ torch.distributed.all_gather(tensors_gather, tensor)
314
+ ctx.rank = rank
315
+ ctx.batch_size = tensor.shape[0]
316
+ return torch.cat(tensors_gather, 0)
317
+
318
+ @staticmethod
319
+ def backward(ctx, grad_output):
320
+ return (
321
+ grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
+ None,
323
+ None
324
+ )
325
+
326
+ allgather = AllGather.apply
models/eva_clip/utils_qformer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision.utils import make_grid
11
+ from transformers import PretrainedConfig
12
+
13
+
14
+ def seed_everything(seed):
15
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+
21
+
22
+ def is_torch2_available():
23
+ return hasattr(F, "scaled_dot_product_attention")
24
+
25
+
26
+ def instantiate_from_config(config):
27
+ if "target" not in config:
28
+ if config == '__is_first_stage__' or config == "__is_unconditional__":
29
+ return None
30
+ raise KeyError("Expected key `target` to instantiate.")
31
+ return get_obj_from_str(config["target"])(**config.get("params", {}))
32
+
33
+
34
+ def get_obj_from_str(string, reload=False):
35
+ module, cls = string.rsplit(".", 1)
36
+ if reload:
37
+ module_imp = importlib.import_module(module)
38
+ importlib.reload(module_imp)
39
+ return getattr(importlib.import_module(module, package=None), cls)
40
+
41
+
42
+ def drop_seq_token(seq, drop_rate=0.5):
43
+ idx = torch.randperm(seq.size(1))
44
+ num_keep_tokens = int(len(idx) * (1 - drop_rate))
45
+ idx = idx[:num_keep_tokens]
46
+ seq = seq[:, idx]
47
+ return seq
48
+
49
+
50
+ def import_model_class_from_model_name_or_path(
51
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
52
+ ):
53
+ text_encoder_config = PretrainedConfig.from_pretrained(
54
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
55
+ )
56
+ model_class = text_encoder_config.architectures[0]
57
+
58
+ if model_class == "CLIPTextModel":
59
+ from transformers import CLIPTextModel
60
+
61
+ return CLIPTextModel
62
+ elif model_class == "CLIPTextModelWithProjection": # noqa RET505
63
+ from transformers import CLIPTextModelWithProjection
64
+
65
+ return CLIPTextModelWithProjection
66
+ else:
67
+ raise ValueError(f"{model_class} is not supported.")
68
+
69
+
70
+ def resize_numpy_image_long(image, resize_long_edge=768):
71
+ h, w = image.shape[:2]
72
+ if max(h, w) <= resize_long_edge:
73
+ return image
74
+ k = resize_long_edge / max(h, w)
75
+ h = int(h * k)
76
+ w = int(w * k)
77
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
78
+ return image
79
+
80
+
81
+ # from basicsr
82
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
83
+ """Numpy array to tensor.
84
+
85
+ Args:
86
+ imgs (list[ndarray] | ndarray): Input images.
87
+ bgr2rgb (bool): Whether to change bgr to rgb.
88
+ float32 (bool): Whether to change to float32.
89
+
90
+ Returns:
91
+ list[tensor] | tensor: Tensor images. If returned results only have
92
+ one element, just return tensor.
93
+ """
94
+
95
+ def _totensor(img, bgr2rgb, float32):
96
+ if img.shape[2] == 3 and bgr2rgb:
97
+ if img.dtype == 'float64':
98
+ img = img.astype('float32')
99
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
100
+ img = torch.from_numpy(img.transpose(2, 0, 1))
101
+ if float32:
102
+ img = img.float()
103
+ return img
104
+
105
+ if isinstance(imgs, list):
106
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
107
+ return _totensor(imgs, bgr2rgb, float32)
108
+
109
+
110
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
111
+ """Convert torch Tensors into image numpy arrays.
112
+
113
+ After clamping to [min, max], values will be normalized to [0, 1].
114
+
115
+ Args:
116
+ tensor (Tensor or list[Tensor]): Accept shapes:
117
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
118
+ 2) 3D Tensor of shape (3/1 x H x W);
119
+ 3) 2D Tensor of shape (H x W).
120
+ Tensor channel should be in RGB order.
121
+ rgb2bgr (bool): Whether to change rgb to bgr.
122
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
123
+ to uint8 type with range [0, 255]; otherwise, float type with
124
+ range [0, 1]. Default: ``np.uint8``.
125
+ min_max (tuple[int]): min and max values for clamp.
126
+
127
+ Returns:
128
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
129
+ shape (H x W). The channel order is BGR.
130
+ """
131
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
132
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
133
+
134
+ if torch.is_tensor(tensor):
135
+ tensor = [tensor]
136
+ result = []
137
+ for _tensor in tensor:
138
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
139
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
140
+
141
+ n_dim = _tensor.dim()
142
+ if n_dim == 4:
143
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
144
+ img_np = img_np.transpose(1, 2, 0)
145
+ if rgb2bgr:
146
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
147
+ elif n_dim == 3:
148
+ img_np = _tensor.numpy()
149
+ img_np = img_np.transpose(1, 2, 0)
150
+ if img_np.shape[2] == 1: # gray image
151
+ img_np = np.squeeze(img_np, axis=2)
152
+ else:
153
+ if rgb2bgr:
154
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
155
+ elif n_dim == 2:
156
+ img_np = _tensor.numpy()
157
+ else:
158
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
159
+ if out_type == np.uint8:
160
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
161
+ img_np = (img_np * 255.0).round()
162
+ img_np = img_np.astype(out_type)
163
+ result.append(img_np)
164
+ if len(result) == 1:
165
+ result = result[0]
166
+ return result
models/local_facial_extractor.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ # FFN
7
+ def FeedForward(dim, mult=4):
8
+ inner_dim = int(dim * mult)
9
+ return nn.Sequential(
10
+ nn.LayerNorm(dim),
11
+ nn.Linear(dim, inner_dim, bias=False),
12
+ nn.GELU(),
13
+ nn.Linear(inner_dim, dim, bias=False),
14
+ )
15
+
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+ def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
30
+ super().__init__()
31
+ self.scale = dim_head ** -0.5
32
+ self.dim_head = dim_head
33
+ self.heads = heads
34
+ inner_dim = dim_head * heads
35
+
36
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
37
+ self.norm2 = nn.LayerNorm(dim)
38
+
39
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
40
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
41
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
42
+
43
+ def forward(self, x, latents):
44
+ """
45
+ Args:
46
+ x (torch.Tensor): image features
47
+ shape (b, n1, D)
48
+ latent (torch.Tensor): latent features
49
+ shape (b, n2, D)
50
+ """
51
+ x = self.norm1(x)
52
+ latents = self.norm2(latents)
53
+
54
+ b, seq_len, _ = latents.shape
55
+
56
+ q = self.to_q(latents)
57
+ kv_input = torch.cat((x, latents), dim=-2)
58
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
59
+
60
+ q = reshape_tensor(q, self.heads)
61
+ k = reshape_tensor(k, self.heads)
62
+ v = reshape_tensor(v, self.heads)
63
+
64
+ # attention
65
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
66
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
67
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
68
+ out = weight @ v
69
+
70
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
71
+
72
+ return self.to_out(out)
73
+
74
+
75
+ class LocalFacialExtractor(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim=1024,
79
+ depth=10,
80
+ dim_head=64,
81
+ heads=16,
82
+ num_id_token=5,
83
+ num_queries=32,
84
+ output_dim=2048,
85
+ ff_mult=4,
86
+ ):
87
+ """
88
+ Initializes the LocalFacialExtractor class.
89
+
90
+ Parameters:
91
+ - dim (int): The dimensionality of latent features.
92
+ - depth (int): Total number of PerceiverAttention and FeedForward layers.
93
+ - dim_head (int): Dimensionality of each attention head.
94
+ - heads (int): Number of attention heads.
95
+ - num_id_token (int): Number of tokens used for identity features.
96
+ - num_queries (int): Number of query tokens for the latent representation.
97
+ - output_dim (int): Output dimension after projection.
98
+ - ff_mult (int): Multiplier for the feed-forward network hidden dimension.
99
+ """
100
+ super().__init__()
101
+
102
+ # Storing identity token and query information
103
+ self.num_id_token = num_id_token
104
+ self.dim = dim
105
+ self.num_queries = num_queries
106
+ assert depth % 5 == 0
107
+ self.depth = depth // 5
108
+ scale = dim ** -0.5
109
+
110
+ # Learnable latent query embeddings
111
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
112
+ # Projection layer to map the latent output to the desired dimension
113
+ self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
114
+
115
+ # Attention and FeedForward layer stack
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
122
+ FeedForward(dim=dim, mult=ff_mult), # FeedForward layer
123
+ ]
124
+ )
125
+ )
126
+
127
+ # Mappings for each of the 5 different ViT features
128
+ for i in range(5):
129
+ setattr(
130
+ self,
131
+ f'mapping_{i}',
132
+ nn.Sequential(
133
+ nn.Linear(1024, 1024),
134
+ nn.LayerNorm(1024),
135
+ nn.LeakyReLU(),
136
+ nn.Linear(1024, 1024),
137
+ nn.LayerNorm(1024),
138
+ nn.LeakyReLU(),
139
+ nn.Linear(1024, dim),
140
+ ),
141
+ )
142
+
143
+ # Mapping for identity embedding vectors
144
+ self.id_embedding_mapping = nn.Sequential(
145
+ nn.Linear(1280, 1024),
146
+ nn.LayerNorm(1024),
147
+ nn.LeakyReLU(),
148
+ nn.Linear(1024, 1024),
149
+ nn.LayerNorm(1024),
150
+ nn.LeakyReLU(),
151
+ nn.Linear(1024, dim * num_id_token),
152
+ )
153
+
154
+ def forward(self, x, y):
155
+ """
156
+ Forward pass for LocalFacialExtractor.
157
+
158
+ Parameters:
159
+ - x (Tensor): The input identity embedding tensor of shape (batch_size, 1280).
160
+ - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, 1024).
161
+
162
+ Returns:
163
+ - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim).
164
+ """
165
+
166
+ # Repeat latent queries for the batch size
167
+ latents = self.latents.repeat(x.size(0), 1, 1)
168
+
169
+ # Map the identity embedding to tokens
170
+ x = self.id_embedding_mapping(x)
171
+ x = x.reshape(-1, self.num_id_token, self.dim)
172
+
173
+ # Concatenate identity tokens with the latent queries
174
+ latents = torch.cat((latents, x), dim=1)
175
+
176
+ # Process each of the 5 visual feature inputs
177
+ for i in range(5):
178
+ vit_feature = getattr(self, f'mapping_{i}')(y[i])
179
+ ctx_feature = torch.cat((x, vit_feature), dim=1)
180
+
181
+ # Pass through the PerceiverAttention and FeedForward layers
182
+ for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
183
+ latents = attn(ctx_feature, latents) + latents
184
+ latents = ff(latents) + latents
185
+
186
+ # Retain only the query latents
187
+ latents = latents[:, :self.num_queries]
188
+ # Project the latents to the output dimension
189
+ latents = latents @ self.proj_out
190
+ return latents
191
+
192
+
193
+ class PerceiverCrossAttention(nn.Module):
194
+ """
195
+
196
+ Args:
197
+ dim (int): Dimension of the input latent and output. Default is 3072.
198
+ dim_head (int): Dimension of each attention head. Default is 128.
199
+ heads (int): Number of attention heads. Default is 16.
200
+ kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
201
+
202
+ Attributes:
203
+ scale (float): Scaling factor used in dot-product attention for numerical stability.
204
+ norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
205
+ norm2 (nn.LayerNorm): Layer normalization applied to the latent features.
206
+ to_q (nn.Linear): Linear layer for projecting the latent features into queries.
207
+ to_kv (nn.Linear): Linear layer for projecting the input features into keys and values.
208
+ to_out (nn.Linear): Linear layer for outputting the final result after attention.
209
+
210
+ """
211
+ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
212
+ super().__init__()
213
+ self.scale = dim_head ** -0.5
214
+ self.dim_head = dim_head
215
+ self.heads = heads
216
+ inner_dim = dim_head * heads
217
+
218
+ # Layer normalization to stabilize training
219
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
220
+ self.norm2 = nn.LayerNorm(dim)
221
+
222
+ # Linear transformations to produce queries, keys, and values
223
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
224
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
225
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
226
+
227
+ def forward(self, x, latents):
228
+ """
229
+
230
+ Args:
231
+ x (torch.Tensor): Input image features with shape (batch_size, n1, D), where:
232
+ - batch_size (b): Number of samples in the batch.
233
+ - n1: Sequence length (e.g., number of patches or tokens).
234
+ - D: Feature dimension.
235
+
236
+ latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
237
+ - n2: Number of latent elements.
238
+
239
+ Returns:
240
+ torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
241
+
242
+ """
243
+ # Apply layer normalization to the input image and latent features
244
+ x = self.norm1(x)
245
+ latents = self.norm2(latents)
246
+
247
+ b, seq_len, _ = latents.shape
248
+
249
+ # Compute queries, keys, and values
250
+ q = self.to_q(latents)
251
+ k, v = self.to_kv(x).chunk(2, dim=-1)
252
+
253
+ # Reshape tensors to split into attention heads
254
+ q = reshape_tensor(q, self.heads)
255
+ k = reshape_tensor(k, self.heads)
256
+ v = reshape_tensor(v, self.heads)
257
+
258
+ # Compute attention weights
259
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
260
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division
261
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
262
+
263
+ # Compute the output via weighted combination of values
264
+ out = weight @ v
265
+
266
+ # Reshape and permute to prepare for final linear transformation
267
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
268
+
269
+ return self.to_out(out)
models/pipeline_cogvideox.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from transformers import T5EncoderModel, T5Tokenizer
22
+
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.loaders import CogVideoXLoraLoaderMixin
25
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```python
41
+ >>> import torch
42
+ >>> from diffusers import CogVideoXPipeline
43
+ >>> from diffusers.utils import export_to_video
44
+
45
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
46
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
47
+ >>> prompt = (
48
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
49
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
50
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
51
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
52
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
53
+ ... "atmosphere of this unique musical performance."
54
+ ... )
55
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
56
+ >>> export_to_video(video, "output.mp4", fps=8)
57
+ ```
58
+ """
59
+
60
+
61
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
62
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
63
+ tw = tgt_width
64
+ th = tgt_height
65
+ h, w = src
66
+ r = h / w
67
+ if r > (th / tw):
68
+ resize_height = th
69
+ resize_width = int(round(th / h * w))
70
+ else:
71
+ resize_width = tw
72
+ resize_height = int(round(tw / w * h))
73
+
74
+ crop_top = int(round((th - resize_height) / 2.0))
75
+ crop_left = int(round((tw - resize_width) / 2.0))
76
+
77
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ """
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
141
+ r"""
142
+ Pipeline for text-to-video generation using CogVideoX.
143
+
144
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
145
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
146
+
147
+ Args:
148
+ vae ([`AutoencoderKL`]):
149
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
150
+ text_encoder ([`T5EncoderModel`]):
151
+ Frozen text-encoder. CogVideoX uses
152
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
153
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
154
+ tokenizer (`T5Tokenizer`):
155
+ Tokenizer of class
156
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
157
+ transformer ([`CogVideoXTransformer3DModel`]):
158
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
159
+ scheduler ([`SchedulerMixin`]):
160
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
161
+ """
162
+
163
+ _optional_components = []
164
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
165
+
166
+ _callback_tensor_inputs = [
167
+ "latents",
168
+ "prompt_embeds",
169
+ "negative_prompt_embeds",
170
+ ]
171
+
172
+ def __init__(
173
+ self,
174
+ tokenizer: T5Tokenizer,
175
+ text_encoder: T5EncoderModel,
176
+ vae: AutoencoderKLCogVideoX,
177
+ transformer: CogVideoXTransformer3DModel,
178
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
179
+ ):
180
+ super().__init__()
181
+
182
+ self.register_modules(
183
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
184
+ )
185
+ self.vae_scale_factor_spatial = (
186
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
187
+ )
188
+ self.vae_scale_factor_temporal = (
189
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
190
+ )
191
+ self.vae_scaling_factor_image = (
192
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
193
+ )
194
+
195
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
196
+
197
+ def _get_t5_prompt_embeds(
198
+ self,
199
+ prompt: Union[str, List[str]] = None,
200
+ num_videos_per_prompt: int = 1,
201
+ max_sequence_length: int = 226,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ):
205
+ device = device or self._execution_device
206
+ dtype = dtype or self.text_encoder.dtype
207
+
208
+ prompt = [prompt] if isinstance(prompt, str) else prompt
209
+ batch_size = len(prompt)
210
+
211
+ text_inputs = self.tokenizer(
212
+ prompt,
213
+ padding="max_length",
214
+ max_length=max_sequence_length,
215
+ truncation=True,
216
+ add_special_tokens=True,
217
+ return_tensors="pt",
218
+ )
219
+ text_input_ids = text_inputs.input_ids
220
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
221
+
222
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
223
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
224
+ logger.warning(
225
+ "The following part of your input was truncated because `max_sequence_length` is set to "
226
+ f" {max_sequence_length} tokens: {removed_text}"
227
+ )
228
+
229
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
230
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
+
232
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
233
+ _, seq_len, _ = prompt_embeds.shape
234
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
235
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
236
+
237
+ return prompt_embeds
238
+
239
+ def encode_prompt(
240
+ self,
241
+ prompt: Union[str, List[str]],
242
+ negative_prompt: Optional[Union[str, List[str]]] = None,
243
+ do_classifier_free_guidance: bool = True,
244
+ num_videos_per_prompt: int = 1,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
247
+ max_sequence_length: int = 226,
248
+ device: Optional[torch.device] = None,
249
+ dtype: Optional[torch.dtype] = None,
250
+ ):
251
+ r"""
252
+ Encodes the prompt into text encoder hidden states.
253
+
254
+ Args:
255
+ prompt (`str` or `List[str]`, *optional*):
256
+ prompt to be encoded
257
+ negative_prompt (`str` or `List[str]`, *optional*):
258
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
259
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
260
+ less than `1`).
261
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
262
+ Whether to use classifier free guidance or not.
263
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
264
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
265
+ prompt_embeds (`torch.Tensor`, *optional*):
266
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
+ provided, text embeddings will be generated from `prompt` input argument.
268
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
269
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
270
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
271
+ argument.
272
+ device: (`torch.device`, *optional*):
273
+ torch device
274
+ dtype: (`torch.dtype`, *optional*):
275
+ torch dtype
276
+ """
277
+ device = device or self._execution_device
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ if prompt is not None:
281
+ batch_size = len(prompt)
282
+ else:
283
+ batch_size = prompt_embeds.shape[0]
284
+
285
+ if prompt_embeds is None:
286
+ prompt_embeds = self._get_t5_prompt_embeds(
287
+ prompt=prompt,
288
+ num_videos_per_prompt=num_videos_per_prompt,
289
+ max_sequence_length=max_sequence_length,
290
+ device=device,
291
+ dtype=dtype,
292
+ )
293
+
294
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
295
+ negative_prompt = negative_prompt or ""
296
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
297
+
298
+ if prompt is not None and type(prompt) is not type(negative_prompt):
299
+ raise TypeError(
300
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301
+ f" {type(prompt)}."
302
+ )
303
+ elif batch_size != len(negative_prompt):
304
+ raise ValueError(
305
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
+ " the batch size of `prompt`."
308
+ )
309
+
310
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
311
+ prompt=negative_prompt,
312
+ num_videos_per_prompt=num_videos_per_prompt,
313
+ max_sequence_length=max_sequence_length,
314
+ device=device,
315
+ dtype=dtype,
316
+ )
317
+
318
+ return prompt_embeds, negative_prompt_embeds
319
+
320
+ def prepare_latents(
321
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
322
+ ):
323
+ if isinstance(generator, list) and len(generator) != batch_size:
324
+ raise ValueError(
325
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
326
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
327
+ )
328
+
329
+ shape = (
330
+ batch_size,
331
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
332
+ num_channels_latents,
333
+ height // self.vae_scale_factor_spatial,
334
+ width // self.vae_scale_factor_spatial,
335
+ )
336
+
337
+ if latents is None:
338
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
339
+ else:
340
+ latents = latents.to(device)
341
+
342
+ # scale the initial noise by the standard deviation required by the scheduler
343
+ latents = latents * self.scheduler.init_noise_sigma
344
+ return latents
345
+
346
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
347
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
348
+ latents = 1 / self.vae_scaling_factor_image * latents
349
+
350
+ frames = self.vae.decode(latents).sample
351
+ return frames
352
+
353
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
354
+ def prepare_extra_step_kwargs(self, generator, eta):
355
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
358
+ # and should be between [0, 1]
359
+
360
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
361
+ extra_step_kwargs = {}
362
+ if accepts_eta:
363
+ extra_step_kwargs["eta"] = eta
364
+
365
+ # check if the scheduler accepts generator
366
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
367
+ if accepts_generator:
368
+ extra_step_kwargs["generator"] = generator
369
+ return extra_step_kwargs
370
+
371
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
372
+ def check_inputs(
373
+ self,
374
+ prompt,
375
+ height,
376
+ width,
377
+ negative_prompt,
378
+ callback_on_step_end_tensor_inputs,
379
+ prompt_embeds=None,
380
+ negative_prompt_embeds=None,
381
+ ):
382
+ if height % 8 != 0 or width % 8 != 0:
383
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
384
+
385
+ if callback_on_step_end_tensor_inputs is not None and not all(
386
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
387
+ ):
388
+ raise ValueError(
389
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
390
+ )
391
+ if prompt is not None and prompt_embeds is not None:
392
+ raise ValueError(
393
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
394
+ " only forward one of the two."
395
+ )
396
+ elif prompt is None and prompt_embeds is None:
397
+ raise ValueError(
398
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
399
+ )
400
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
401
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
402
+
403
+ if prompt is not None and negative_prompt_embeds is not None:
404
+ raise ValueError(
405
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
406
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
407
+ )
408
+
409
+ if negative_prompt is not None and negative_prompt_embeds is not None:
410
+ raise ValueError(
411
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
412
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
413
+ )
414
+
415
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
416
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
417
+ raise ValueError(
418
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
419
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
420
+ f" {negative_prompt_embeds.shape}."
421
+ )
422
+
423
+ def fuse_qkv_projections(self) -> None:
424
+ r"""Enables fused QKV projections."""
425
+ self.fusing_transformer = True
426
+ self.transformer.fuse_qkv_projections()
427
+
428
+ def unfuse_qkv_projections(self) -> None:
429
+ r"""Disable QKV projection fusion if enabled."""
430
+ if not self.fusing_transformer:
431
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
432
+ else:
433
+ self.transformer.unfuse_qkv_projections()
434
+ self.fusing_transformer = False
435
+
436
+ def _prepare_rotary_positional_embeddings(
437
+ self,
438
+ height: int,
439
+ width: int,
440
+ num_frames: int,
441
+ device: torch.device,
442
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
443
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
444
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
446
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
447
+
448
+ grid_crops_coords = get_resize_crop_region_for_grid(
449
+ (grid_height, grid_width), base_size_width, base_size_height
450
+ )
451
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
452
+ embed_dim=self.transformer.config.attention_head_dim,
453
+ crops_coords=grid_crops_coords,
454
+ grid_size=(grid_height, grid_width),
455
+ temporal_size=num_frames,
456
+ )
457
+
458
+ freqs_cos = freqs_cos.to(device=device)
459
+ freqs_sin = freqs_sin.to(device=device)
460
+ return freqs_cos, freqs_sin
461
+
462
+ @property
463
+ def guidance_scale(self):
464
+ return self._guidance_scale
465
+
466
+ @property
467
+ def num_timesteps(self):
468
+ return self._num_timesteps
469
+
470
+ @property
471
+ def attention_kwargs(self):
472
+ return self._attention_kwargs
473
+
474
+ @property
475
+ def interrupt(self):
476
+ return self._interrupt
477
+
478
+ @torch.no_grad()
479
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
480
+ def __call__(
481
+ self,
482
+ prompt: Optional[Union[str, List[str]]] = None,
483
+ negative_prompt: Optional[Union[str, List[str]]] = None,
484
+ height: int = 480,
485
+ width: int = 720,
486
+ num_frames: int = 49,
487
+ num_inference_steps: int = 50,
488
+ timesteps: Optional[List[int]] = None,
489
+ guidance_scale: float = 6,
490
+ use_dynamic_cfg: bool = False,
491
+ num_videos_per_prompt: int = 1,
492
+ eta: float = 0.0,
493
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
494
+ latents: Optional[torch.FloatTensor] = None,
495
+ prompt_embeds: Optional[torch.FloatTensor] = None,
496
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
497
+ output_type: str = "pil",
498
+ return_dict: bool = True,
499
+ attention_kwargs: Optional[Dict[str, Any]] = None,
500
+ callback_on_step_end: Optional[
501
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
502
+ ] = None,
503
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
504
+ max_sequence_length: int = 226,
505
+ id_vit_hidden: Optional[torch.Tensor] = None,
506
+ id_cond: Optional[torch.Tensor] = None,
507
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
508
+ """
509
+ Function invoked when calling the pipeline for generation.
510
+
511
+ Args:
512
+ prompt (`str` or `List[str]`, *optional*):
513
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
514
+ instead.
515
+ negative_prompt (`str` or `List[str]`, *optional*):
516
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
517
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
518
+ less than `1`).
519
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
520
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
521
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
522
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
523
+ num_frames (`int`, defaults to `48`):
524
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
525
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
526
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
527
+ needs to be satisfied is that of divisibility mentioned above.
528
+ num_inference_steps (`int`, *optional*, defaults to 50):
529
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
530
+ expense of slower inference.
531
+ timesteps (`List[int]`, *optional*):
532
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
533
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
534
+ passed will be used. Must be in descending order.
535
+ guidance_scale (`float`, *optional*, defaults to 7.0):
536
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
537
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
538
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
539
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
540
+ usually at the expense of lower image quality.
541
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
542
+ The number of videos to generate per prompt.
543
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
544
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
545
+ to make generation deterministic.
546
+ latents (`torch.FloatTensor`, *optional*):
547
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
548
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
549
+ tensor will ge generated by sampling using the supplied random `generator`.
550
+ prompt_embeds (`torch.FloatTensor`, *optional*):
551
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
552
+ provided, text embeddings will be generated from `prompt` input argument.
553
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
554
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
555
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
556
+ argument.
557
+ output_type (`str`, *optional*, defaults to `"pil"`):
558
+ The output format of the generate image. Choose between
559
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
560
+ return_dict (`bool`, *optional*, defaults to `True`):
561
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
562
+ of a plain tuple.
563
+ attention_kwargs (`dict`, *optional*):
564
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
565
+ `self.processor` in
566
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
567
+ callback_on_step_end (`Callable`, *optional*):
568
+ A function that calls at the end of each denoising steps during the inference. The function is called
569
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
570
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
571
+ `callback_on_step_end_tensor_inputs`.
572
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
573
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
574
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
575
+ `._callback_tensor_inputs` attribute of your pipeline class.
576
+ max_sequence_length (`int`, defaults to `226`):
577
+ Maximum sequence length in encoded prompt. Must be consistent with
578
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
579
+
580
+ Examples:
581
+
582
+ Returns:
583
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
584
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
585
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
586
+ """
587
+
588
+ if num_frames > 49:
589
+ raise ValueError(
590
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
591
+ )
592
+
593
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
594
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
595
+
596
+ num_videos_per_prompt = 1
597
+
598
+ # 1. Check inputs. Raise error if not correct
599
+ self.check_inputs(
600
+ prompt,
601
+ height,
602
+ width,
603
+ negative_prompt,
604
+ callback_on_step_end_tensor_inputs,
605
+ prompt_embeds,
606
+ negative_prompt_embeds,
607
+ )
608
+ self._guidance_scale = guidance_scale
609
+ self._attention_kwargs = attention_kwargs
610
+ self._interrupt = False
611
+
612
+ # 2. Default call parameters
613
+ if prompt is not None and isinstance(prompt, str):
614
+ batch_size = 1
615
+ elif prompt is not None and isinstance(prompt, list):
616
+ batch_size = len(prompt)
617
+ else:
618
+ batch_size = prompt_embeds.shape[0]
619
+
620
+ device = self._execution_device
621
+
622
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
623
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
624
+ # corresponds to doing no classifier free guidance.
625
+ do_classifier_free_guidance = guidance_scale > 1.0
626
+
627
+ # 3. Encode input prompt
628
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
629
+ prompt,
630
+ negative_prompt,
631
+ do_classifier_free_guidance,
632
+ num_videos_per_prompt=num_videos_per_prompt,
633
+ prompt_embeds=prompt_embeds,
634
+ negative_prompt_embeds=negative_prompt_embeds,
635
+ max_sequence_length=max_sequence_length,
636
+ device=device,
637
+ )
638
+ if do_classifier_free_guidance:
639
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
640
+
641
+ # 4. Prepare timesteps
642
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
643
+ self._num_timesteps = len(timesteps)
644
+
645
+ # 5. Prepare latents.
646
+ latent_channels = self.transformer.config.in_channels
647
+ latents = self.prepare_latents(
648
+ batch_size * num_videos_per_prompt,
649
+ latent_channels,
650
+ num_frames,
651
+ height,
652
+ width,
653
+ prompt_embeds.dtype,
654
+ device,
655
+ generator,
656
+ latents,
657
+ )
658
+
659
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
660
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
661
+
662
+ # 7. Create rotary embeds if required
663
+ image_rotary_emb = (
664
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
665
+ if self.transformer.config.use_rotary_positional_embeddings
666
+ else None
667
+ )
668
+
669
+ # 8. Denoising loop
670
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
671
+
672
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
673
+ # for DPM-solver++
674
+ old_pred_original_sample = None
675
+ for i, t in enumerate(timesteps):
676
+ if self.interrupt:
677
+ continue
678
+
679
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
680
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
681
+
682
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
683
+ timestep = t.expand(latent_model_input.shape[0])
684
+
685
+ # predict noise model_output
686
+ noise_pred = self.transformer(
687
+ hidden_states=latent_model_input,
688
+ encoder_hidden_states=prompt_embeds,
689
+ timestep=timestep,
690
+ image_rotary_emb=image_rotary_emb,
691
+ attention_kwargs=attention_kwargs,
692
+ return_dict=False,
693
+ id_vit_hidden = id_vit_hidden,
694
+ id_cond = id_cond,
695
+ )[0]
696
+ noise_pred = noise_pred.float()
697
+
698
+ # perform guidance
699
+ if use_dynamic_cfg:
700
+ self._guidance_scale = 1 + guidance_scale * (
701
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
702
+ )
703
+ if do_classifier_free_guidance:
704
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
705
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
706
+
707
+ # compute the previous noisy sample x_t -> x_t-1
708
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
709
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
710
+ else:
711
+ latents, old_pred_original_sample = self.scheduler.step(
712
+ noise_pred,
713
+ old_pred_original_sample,
714
+ t,
715
+ timesteps[i - 1] if i > 0 else None,
716
+ latents,
717
+ **extra_step_kwargs,
718
+ return_dict=False,
719
+ )
720
+ latents = latents.to(prompt_embeds.dtype)
721
+
722
+ # call the callback, if provided
723
+ if callback_on_step_end is not None:
724
+ callback_kwargs = {}
725
+ for k in callback_on_step_end_tensor_inputs:
726
+ callback_kwargs[k] = locals()[k]
727
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
728
+
729
+ latents = callback_outputs.pop("latents", latents)
730
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
731
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
732
+
733
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
734
+ progress_bar.update()
735
+
736
+ if not output_type == "latent":
737
+ video = self.decode_latents(latents)
738
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
739
+ else:
740
+ video = latents
741
+
742
+ # Offload all models
743
+ self.maybe_free_model_hooks()
744
+
745
+ if not return_dict:
746
+ return (video,)
747
+
748
+ return CogVideoXPipelineOutput(frames=video)
models/pipeline_consisid.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ import math
9
+ from typing import Callable, Dict, List, Optional, Tuple, Union
10
+
11
+ import os
12
+ import sys
13
+ import PIL
14
+ import numpy as np
15
+ import cv2
16
+ from PIL import Image
17
+ import torch
18
+ from transformers import T5EncoderModel, T5Tokenizer
19
+
20
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
23
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
26
+ from diffusers.utils import logging, replace_example_docstring
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.video_processor import VideoProcessor
29
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
30
+
31
+ from models.transformer_consisid import ConsisIDTransformer3DModel
32
+
33
+ current_file_path = os.path.abspath(__file__)
34
+ project_roots = [os.path.dirname(os.path.dirname(current_file_path))]
35
+ for project_root in project_roots:
36
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> from diffusers import CogVideoXImageToVideoPipeline
45
+ >>> from diffusers.utils import export_to_video, load_image
46
+
47
+ >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
48
+ >>> pipe.to("cuda")
49
+
50
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
51
+ >>> image = load_image(
52
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
53
+ ... )
54
+ >>> video = pipe(image, prompt, use_dynamic_cfg=True)
55
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
56
+ ```
57
+ """
58
+
59
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
60
+ stickwidth = 4
61
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
62
+ kps = np.array(kps)
63
+
64
+ w, h = image_pil.size
65
+ out_img = np.zeros([h, w, 3])
66
+
67
+ for i in range(len(limbSeq)):
68
+ index = limbSeq[i]
69
+ color = color_list[index[0]]
70
+
71
+ x = kps[index][:, 0]
72
+ y = kps[index][:, 1]
73
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
74
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
75
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
76
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
77
+ out_img = (out_img * 0.6).astype(np.uint8)
78
+
79
+ for idx_kp, kp in enumerate(kps):
80
+ color = color_list[idx_kp]
81
+ x, y = kp
82
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
83
+
84
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
85
+ return out_img_pil
86
+
87
+ def process_image(image, vae):
88
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
89
+ image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
90
+ noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
91
+ input_image = image + noisy_image
92
+ image_latent_dist = vae.encode(input_image).latent_dist
93
+ return image_latent_dist
94
+
95
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
96
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
97
+ tw = tgt_width
98
+ th = tgt_height
99
+ h, w = src
100
+ r = h / w
101
+ if r > (th / tw):
102
+ resize_height = th
103
+ resize_width = int(round(th / h * w))
104
+ else:
105
+ resize_width = tw
106
+ resize_height = int(round(tw / w * h))
107
+
108
+ crop_top = int(round((th - resize_height) / 2.0))
109
+ crop_left = int(round((tw - resize_width) / 2.0))
110
+
111
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
115
+ def retrieve_timesteps(
116
+ scheduler,
117
+ num_inference_steps: Optional[int] = None,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ timesteps: Optional[List[int]] = None,
120
+ sigmas: Optional[List[float]] = None,
121
+ **kwargs,
122
+ ):
123
+ """
124
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
125
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
126
+
127
+ Args:
128
+ scheduler (`SchedulerMixin`):
129
+ The scheduler to get timesteps from.
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
132
+ must be `None`.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ timesteps (`List[int]`, *optional*):
136
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
137
+ `num_inference_steps` and `sigmas` must be `None`.
138
+ sigmas (`List[float]`, *optional*):
139
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
140
+ `num_inference_steps` and `timesteps` must be `None`.
141
+
142
+ Returns:
143
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
144
+ second element is the number of inference steps.
145
+ """
146
+ if timesteps is not None and sigmas is not None:
147
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
148
+ if timesteps is not None:
149
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accepts_timesteps:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" timestep schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ elif sigmas is not None:
159
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if not accept_sigmas:
161
+ raise ValueError(
162
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
163
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
164
+ )
165
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ num_inference_steps = len(timesteps)
168
+ else:
169
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ return timesteps, num_inference_steps
172
+
173
+
174
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
175
+ def retrieve_latents(
176
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
177
+ ):
178
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
179
+ return encoder_output.latent_dist.sample(generator)
180
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
181
+ return encoder_output.latent_dist.mode()
182
+ elif hasattr(encoder_output, "latents"):
183
+ return encoder_output.latents
184
+ else:
185
+ raise AttributeError("Could not access latents of provided encoder_output")
186
+
187
+
188
+ class ConsisIDPipeline(DiffusionPipeline):
189
+ r"""
190
+ Pipeline for image-to-video generation using CogVideoX.
191
+
192
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
193
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
194
+
195
+ Args:
196
+ vae ([`AutoencoderKL`]):
197
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
198
+ text_encoder ([`T5EncoderModel`]):
199
+ Frozen text-encoder. CogVideoX uses
200
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
201
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
202
+ tokenizer (`T5Tokenizer`):
203
+ Tokenizer of class
204
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
205
+ transformer ([`ConsisIDTransformer3DModel`]):
206
+ A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents.
207
+ scheduler ([`SchedulerMixin`]):
208
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
209
+ """
210
+
211
+ _optional_components = []
212
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
213
+
214
+ _callback_tensor_inputs = [
215
+ "latents",
216
+ "prompt_embeds",
217
+ "negative_prompt_embeds",
218
+ ]
219
+
220
+ def __init__(
221
+ self,
222
+ tokenizer: T5Tokenizer,
223
+ text_encoder: T5EncoderModel,
224
+ vae: AutoencoderKLCogVideoX,
225
+ transformer: Union[ConsisIDTransformer3DModel, CogVideoXTransformer3DModel],
226
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
227
+ ):
228
+ super().__init__()
229
+
230
+ self.register_modules(
231
+ tokenizer=tokenizer,
232
+ text_encoder=text_encoder,
233
+ vae=vae,
234
+ transformer=transformer,
235
+ scheduler=scheduler,
236
+ )
237
+ self.vae_scale_factor_spatial = (
238
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
239
+ )
240
+ self.vae_scale_factor_temporal = (
241
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
242
+ )
243
+ self.vae_scaling_factor_image = (
244
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
245
+ )
246
+
247
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
248
+
249
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
250
+ def _get_t5_prompt_embeds(
251
+ self,
252
+ prompt: Union[str, List[str]] = None,
253
+ num_videos_per_prompt: int = 1,
254
+ max_sequence_length: int = 226,
255
+ device: Optional[torch.device] = None,
256
+ dtype: Optional[torch.dtype] = None,
257
+ ):
258
+ device = device or self._execution_device
259
+ dtype = dtype or self.text_encoder.dtype
260
+
261
+ prompt = [prompt] if isinstance(prompt, str) else prompt
262
+ batch_size = len(prompt)
263
+
264
+ text_inputs = self.tokenizer(
265
+ prompt,
266
+ padding="max_length",
267
+ max_length=max_sequence_length,
268
+ truncation=True,
269
+ add_special_tokens=True,
270
+ return_tensors="pt",
271
+ )
272
+ text_input_ids = text_inputs.input_ids
273
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
274
+
275
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
276
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
277
+ logger.warning(
278
+ "The following part of your input was truncated because `max_sequence_length` is set to "
279
+ f" {max_sequence_length} tokens: {removed_text}"
280
+ )
281
+
282
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
283
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
284
+
285
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286
+ _, seq_len, _ = prompt_embeds.shape
287
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
288
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
289
+
290
+ return prompt_embeds
291
+
292
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
293
+ def encode_prompt(
294
+ self,
295
+ prompt: Union[str, List[str]],
296
+ negative_prompt: Optional[Union[str, List[str]]] = None,
297
+ do_classifier_free_guidance: bool = True,
298
+ num_videos_per_prompt: int = 1,
299
+ prompt_embeds: Optional[torch.Tensor] = None,
300
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
301
+ max_sequence_length: int = 226,
302
+ device: Optional[torch.device] = None,
303
+ dtype: Optional[torch.dtype] = None,
304
+ ):
305
+ r"""
306
+ Encodes the prompt into text encoder hidden states.
307
+
308
+ Args:
309
+ prompt (`str` or `List[str]`, *optional*):
310
+ prompt to be encoded
311
+ negative_prompt (`str` or `List[str]`, *optional*):
312
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
313
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
314
+ less than `1`).
315
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
316
+ Whether to use classifier free guidance or not.
317
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
318
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
319
+ prompt_embeds (`torch.Tensor`, *optional*):
320
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
321
+ provided, text embeddings will be generated from `prompt` input argument.
322
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
323
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
324
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
325
+ argument.
326
+ device: (`torch.device`, *optional*):
327
+ torch device
328
+ dtype: (`torch.dtype`, *optional*):
329
+ torch dtype
330
+ """
331
+ device = device or self._execution_device
332
+
333
+ prompt = [prompt] if isinstance(prompt, str) else prompt
334
+ if prompt is not None:
335
+ batch_size = len(prompt)
336
+ else:
337
+ batch_size = prompt_embeds.shape[0]
338
+
339
+ if prompt_embeds is None:
340
+ prompt_embeds = self._get_t5_prompt_embeds(
341
+ prompt=prompt,
342
+ num_videos_per_prompt=num_videos_per_prompt,
343
+ max_sequence_length=max_sequence_length,
344
+ device=device,
345
+ dtype=dtype,
346
+ )
347
+
348
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
349
+ negative_prompt = negative_prompt or ""
350
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
351
+
352
+ if prompt is not None and type(prompt) is not type(negative_prompt):
353
+ raise TypeError(
354
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
355
+ f" {type(prompt)}."
356
+ )
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+
364
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
365
+ prompt=negative_prompt,
366
+ num_videos_per_prompt=num_videos_per_prompt,
367
+ max_sequence_length=max_sequence_length,
368
+ device=device,
369
+ dtype=dtype,
370
+ )
371
+
372
+ return prompt_embeds, negative_prompt_embeds
373
+
374
+ def prepare_latents(
375
+ self,
376
+ image: torch.Tensor,
377
+ batch_size: int = 1,
378
+ num_channels_latents: int = 16,
379
+ num_frames: int = 13,
380
+ height: int = 60,
381
+ width: int = 90,
382
+ dtype: Optional[torch.dtype] = None,
383
+ device: Optional[torch.device] = None,
384
+ generator: Optional[torch.Generator] = None,
385
+ latents: Optional[torch.Tensor] = None,
386
+ kps_cond: Optional[torch.Tensor] = None,
387
+ ):
388
+ if isinstance(generator, list) and len(generator) != batch_size:
389
+ raise ValueError(
390
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
391
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
392
+ )
393
+
394
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
395
+ shape = (
396
+ batch_size,
397
+ num_frames,
398
+ num_channels_latents,
399
+ height // self.vae_scale_factor_spatial,
400
+ width // self.vae_scale_factor_spatial,
401
+ )
402
+
403
+ image = image.unsqueeze(2) # [B, C, F, H, W]
404
+
405
+ if isinstance(generator, list):
406
+ image_latents = [
407
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
408
+ ]
409
+ if kps_cond is not None:
410
+ kps_cond = kps_cond.unsqueeze(2)
411
+ kps_cond_latents = [
412
+ retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
413
+ ]
414
+ else:
415
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
416
+ if kps_cond is not None:
417
+ kps_cond = kps_cond.unsqueeze(2)
418
+ kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond]
419
+
420
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
421
+ image_latents = self.vae_scaling_factor_image * image_latents
422
+
423
+ if kps_cond is not None:
424
+ kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
425
+ kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents
426
+
427
+ padding_shape = (
428
+ batch_size,
429
+ num_frames - 2,
430
+ num_channels_latents,
431
+ height // self.vae_scale_factor_spatial,
432
+ width // self.vae_scale_factor_spatial,
433
+ )
434
+ else:
435
+ padding_shape = (
436
+ batch_size,
437
+ num_frames - 1,
438
+ num_channels_latents,
439
+ height // self.vae_scale_factor_spatial,
440
+ width // self.vae_scale_factor_spatial,
441
+ )
442
+
443
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
444
+ if kps_cond is not None:
445
+ image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1)
446
+ else:
447
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
448
+
449
+ if latents is None:
450
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
451
+ else:
452
+ latents = latents.to(device)
453
+
454
+ # scale the initial noise by the standard deviation required by the scheduler
455
+ latents = latents * self.scheduler.init_noise_sigma
456
+ return latents, image_latents
457
+
458
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
459
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
460
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
461
+ latents = 1 / self.vae_scaling_factor_image * latents
462
+
463
+ frames = self.vae.decode(latents).sample
464
+ return frames
465
+
466
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
467
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
468
+ # get the original timestep using init_timestep
469
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
470
+
471
+ t_start = max(num_inference_steps - init_timestep, 0)
472
+ timesteps = timesteps[t_start * self.scheduler.order :]
473
+
474
+ return timesteps, num_inference_steps - t_start
475
+
476
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
477
+ def prepare_extra_step_kwargs(self, generator, eta):
478
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
479
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
480
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
481
+ # and should be between [0, 1]
482
+
483
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
484
+ extra_step_kwargs = {}
485
+ if accepts_eta:
486
+ extra_step_kwargs["eta"] = eta
487
+
488
+ # check if the scheduler accepts generator
489
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
490
+ if accepts_generator:
491
+ extra_step_kwargs["generator"] = generator
492
+ return extra_step_kwargs
493
+
494
+ def check_inputs(
495
+ self,
496
+ image,
497
+ prompt,
498
+ height,
499
+ width,
500
+ negative_prompt,
501
+ callback_on_step_end_tensor_inputs,
502
+ latents=None,
503
+ prompt_embeds=None,
504
+ negative_prompt_embeds=None,
505
+ ):
506
+ if (
507
+ not isinstance(image, torch.Tensor)
508
+ and not isinstance(image, PIL.Image.Image)
509
+ and not isinstance(image, list)
510
+ ):
511
+ raise ValueError(
512
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
513
+ f" {type(image)}"
514
+ )
515
+
516
+ if height % 8 != 0 or width % 8 != 0:
517
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
518
+
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
521
+ ):
522
+ raise ValueError(
523
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
524
+ )
525
+ if prompt is not None and prompt_embeds is not None:
526
+ raise ValueError(
527
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
528
+ " only forward one of the two."
529
+ )
530
+ elif prompt is None and prompt_embeds is None:
531
+ raise ValueError(
532
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
533
+ )
534
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
535
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
536
+
537
+ if prompt is not None and negative_prompt_embeds is not None:
538
+ raise ValueError(
539
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
540
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
558
+ def fuse_qkv_projections(self) -> None:
559
+ r"""Enables fused QKV projections."""
560
+ self.fusing_transformer = True
561
+ self.transformer.fuse_qkv_projections()
562
+
563
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
564
+ def unfuse_qkv_projections(self) -> None:
565
+ r"""Disable QKV projection fusion if enabled."""
566
+ if not self.fusing_transformer:
567
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
568
+ else:
569
+ self.transformer.unfuse_qkv_projections()
570
+ self.fusing_transformer = False
571
+
572
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
573
+ def _prepare_rotary_positional_embeddings(
574
+ self,
575
+ height: int,
576
+ width: int,
577
+ num_frames: int,
578
+ device: torch.device,
579
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
580
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
581
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
582
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
583
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
584
+
585
+ grid_crops_coords = get_resize_crop_region_for_grid(
586
+ (grid_height, grid_width), base_size_width, base_size_height
587
+ )
588
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
589
+ embed_dim=self.transformer.config.attention_head_dim,
590
+ crops_coords=grid_crops_coords,
591
+ grid_size=(grid_height, grid_width),
592
+ temporal_size=num_frames,
593
+ )
594
+
595
+ freqs_cos = freqs_cos.to(device=device)
596
+ freqs_sin = freqs_sin.to(device=device)
597
+ return freqs_cos, freqs_sin
598
+
599
+ @property
600
+ def guidance_scale(self):
601
+ return self._guidance_scale
602
+
603
+ @property
604
+ def num_timesteps(self):
605
+ return self._num_timesteps
606
+
607
+ @property
608
+ def interrupt(self):
609
+ return self._interrupt
610
+
611
+ @torch.no_grad()
612
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
613
+ def __call__(
614
+ self,
615
+ image: PipelineImageInput,
616
+ prompt: Optional[Union[str, List[str]]] = None,
617
+ negative_prompt: Optional[Union[str, List[str]]] = None,
618
+ height: int = 480,
619
+ width: int = 720,
620
+ num_frames: int = 49,
621
+ num_inference_steps: int = 50,
622
+ timesteps: Optional[List[int]] = None,
623
+ guidance_scale: float = 6,
624
+ use_dynamic_cfg: bool = False,
625
+ num_videos_per_prompt: int = 1,
626
+ eta: float = 0.0,
627
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
628
+ latents: Optional[torch.FloatTensor] = None,
629
+ prompt_embeds: Optional[torch.FloatTensor] = None,
630
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
631
+ output_type: str = "pil",
632
+ return_dict: bool = True,
633
+ callback_on_step_end: Optional[
634
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
635
+ ] = None,
636
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
637
+ max_sequence_length: int = 226,
638
+ id_vit_hidden: Optional[torch.Tensor] = None,
639
+ id_cond: Optional[torch.Tensor] = None,
640
+ kps_cond: Optional[torch.Tensor] = None,
641
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
642
+ """
643
+ Function invoked when calling the pipeline for generation.
644
+
645
+ Args:
646
+ image (`PipelineImageInput`):
647
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
648
+ prompt (`str` or `List[str]`, *optional*):
649
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
650
+ instead.
651
+ negative_prompt (`str` or `List[str]`, *optional*):
652
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
653
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
654
+ less than `1`).
655
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
656
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
657
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
658
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
659
+ num_frames (`int`, defaults to `48`):
660
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
661
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
662
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
663
+ needs to be satisfied is that of divisibility mentioned above.
664
+ num_inference_steps (`int`, *optional*, defaults to 50):
665
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
666
+ expense of slower inference.
667
+ timesteps (`List[int]`, *optional*):
668
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
669
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
670
+ passed will be used. Must be in descending order.
671
+ guidance_scale (`float`, *optional*, defaults to 7.0):
672
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
673
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
674
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
675
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
676
+ usually at the expense of lower image quality.
677
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
678
+ The number of videos to generate per prompt.
679
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
680
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
681
+ to make generation deterministic.
682
+ latents (`torch.FloatTensor`, *optional*):
683
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
684
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
685
+ tensor will ge generated by sampling using the supplied random `generator`.
686
+ prompt_embeds (`torch.FloatTensor`, *optional*):
687
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
688
+ provided, text embeddings will be generated from `prompt` input argument.
689
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
690
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
691
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
692
+ argument.
693
+ output_type (`str`, *optional*, defaults to `"pil"`):
694
+ The output format of the generate image. Choose between
695
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
696
+ return_dict (`bool`, *optional*, defaults to `True`):
697
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
698
+ of a plain tuple.
699
+ callback_on_step_end (`Callable`, *optional*):
700
+ A function that calls at the end of each denoising steps during the inference. The function is called
701
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
702
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
703
+ `callback_on_step_end_tensor_inputs`.
704
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
705
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
706
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
707
+ `._callback_tensor_inputs` attribute of your pipeline class.
708
+ max_sequence_length (`int`, defaults to `226`):
709
+ Maximum sequence length in encoded prompt. Must be consistent with
710
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
711
+
712
+ Examples:
713
+
714
+ Returns:
715
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
716
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
717
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
718
+ """
719
+ if num_frames > 49:
720
+ raise ValueError(
721
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
722
+ )
723
+
724
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
725
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
726
+
727
+ num_videos_per_prompt = 1
728
+
729
+ # 1. Check inputs. Raise error if not correct
730
+ self.check_inputs(
731
+ image=image,
732
+ prompt=prompt,
733
+ height=height,
734
+ width=width,
735
+ negative_prompt=negative_prompt,
736
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
737
+ latents=latents,
738
+ prompt_embeds=prompt_embeds,
739
+ negative_prompt_embeds=negative_prompt_embeds,
740
+ )
741
+ self._guidance_scale = guidance_scale
742
+ self._interrupt = False
743
+
744
+ # 2. Default call parameters
745
+ if prompt is not None and isinstance(prompt, str):
746
+ batch_size = 1
747
+ elif prompt is not None and isinstance(prompt, list):
748
+ batch_size = len(prompt)
749
+ else:
750
+ batch_size = prompt_embeds.shape[0]
751
+
752
+ device = self._execution_device
753
+
754
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
755
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
756
+ # corresponds to doing no classifier free guidance.
757
+ do_classifier_free_guidance = guidance_scale > 1.0
758
+
759
+ # 3. Encode input prompt
760
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
761
+ prompt=prompt,
762
+ negative_prompt=negative_prompt,
763
+ do_classifier_free_guidance=do_classifier_free_guidance,
764
+ num_videos_per_prompt=num_videos_per_prompt,
765
+ prompt_embeds=prompt_embeds,
766
+ negative_prompt_embeds=negative_prompt_embeds,
767
+ max_sequence_length=max_sequence_length,
768
+ device=device,
769
+ )
770
+ if do_classifier_free_guidance:
771
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
772
+
773
+ # 4. Prepare timesteps
774
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
775
+ self._num_timesteps = len(timesteps)
776
+
777
+ # 5. Prepare latents
778
+ if kps_cond is not None:
779
+ kps_cond = draw_kps(image, kps_cond)
780
+ kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
781
+ device, dtype=prompt_embeds.dtype
782
+ )
783
+
784
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
785
+ device, dtype=prompt_embeds.dtype
786
+ )
787
+
788
+ latent_channels = self.transformer.config.in_channels // 2
789
+ latents, image_latents = self.prepare_latents(
790
+ image,
791
+ batch_size * num_videos_per_prompt,
792
+ latent_channels,
793
+ num_frames,
794
+ height,
795
+ width,
796
+ prompt_embeds.dtype,
797
+ device,
798
+ generator,
799
+ latents,
800
+ kps_cond
801
+ )
802
+
803
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
804
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
805
+
806
+ # 7. Create rotary embeds if required
807
+ image_rotary_emb = (
808
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
809
+ if self.transformer.config.use_rotary_positional_embeddings
810
+ else None
811
+ )
812
+
813
+ # 8. Denoising loop
814
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
815
+
816
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
817
+ # for DPM-solver++
818
+ old_pred_original_sample = None
819
+ for i, t in enumerate(timesteps):
820
+ if self.interrupt:
821
+ continue
822
+
823
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
824
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
825
+
826
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
827
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
828
+
829
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
830
+ timestep = t.expand(latent_model_input.shape[0])
831
+
832
+ # predict noise model_output
833
+ noise_pred = self.transformer(
834
+ hidden_states=latent_model_input,
835
+ encoder_hidden_states=prompt_embeds,
836
+ timestep=timestep,
837
+ image_rotary_emb=image_rotary_emb,
838
+ return_dict=False,
839
+ id_vit_hidden = id_vit_hidden,
840
+ id_cond = id_cond,
841
+ )[0]
842
+ noise_pred = noise_pred.float()
843
+
844
+ # perform guidance
845
+ if use_dynamic_cfg:
846
+ self._guidance_scale = 1 + guidance_scale * (
847
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
848
+ )
849
+ if do_classifier_free_guidance:
850
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
851
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
852
+
853
+ # compute the previous noisy sample x_t -> x_t-1
854
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
855
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
856
+ else:
857
+ latents, old_pred_original_sample = self.scheduler.step(
858
+ noise_pred,
859
+ old_pred_original_sample,
860
+ t,
861
+ timesteps[i - 1] if i > 0 else None,
862
+ latents,
863
+ **extra_step_kwargs,
864
+ return_dict=False,
865
+ )
866
+ latents = latents.to(prompt_embeds.dtype)
867
+
868
+ # call the callback, if provided
869
+ if callback_on_step_end is not None:
870
+ callback_kwargs = {}
871
+ for k in callback_on_step_end_tensor_inputs:
872
+ callback_kwargs[k] = locals()[k]
873
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
874
+
875
+ latents = callback_outputs.pop("latents", latents)
876
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
877
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
878
+
879
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
880
+ progress_bar.update()
881
+
882
+ if not output_type == "latent":
883
+ video = self.decode_latents(latents)
884
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
885
+ else:
886
+ video = latents
887
+
888
+ # Offload all models
889
+ self.maybe_free_model_hooks()
890
+
891
+ if not return_dict:
892
+ return (video,)
893
+
894
+ return CogVideoXPipelineOutput(frames=video)
models/transformer_consisid.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+ import os
9
+ import sys
10
+ import json
11
+ import glob
12
+
13
+ import torch
14
+ from torch import nn
15
+ from einops import rearrange, reduce
16
+
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+ from diffusers.loaders import PeftAdapterMixin
19
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
20
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
21
+ from diffusers.models.attention import Attention, FeedForward
22
+ from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
23
+ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
27
+
28
+ import os
29
+ import sys
30
+ current_file_path = os.path.abspath(__file__)
31
+ project_roots = [os.path.dirname(current_file_path)]
32
+ for project_root in project_roots:
33
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
34
+
35
+ from local_facial_extractor import LocalFacialExtractor, PerceiverCrossAttention
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @maybe_allow_in_graph
41
+ class CogVideoXBlock(nn.Module):
42
+ r"""
43
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
44
+
45
+ Parameters:
46
+ dim (`int`):
47
+ The number of channels in the input and output.
48
+ num_attention_heads (`int`):
49
+ The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`):
51
+ The number of channels in each head.
52
+ time_embed_dim (`int`):
53
+ The number of channels in timestep embedding.
54
+ dropout (`float`, defaults to `0.0`):
55
+ The dropout probability to use.
56
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
57
+ Activation function to be used in feed-forward.
58
+ attention_bias (`bool`, defaults to `False`):
59
+ Whether or not to use bias in attention projection layers.
60
+ qk_norm (`bool`, defaults to `True`):
61
+ Whether or not to use normalization after query and key projections in Attention.
62
+ norm_elementwise_affine (`bool`, defaults to `True`):
63
+ Whether to use learnable elementwise affine parameters for normalization.
64
+ norm_eps (`float`, defaults to `1e-5`):
65
+ Epsilon value for normalization layers.
66
+ final_dropout (`bool` defaults to `False`):
67
+ Whether to apply a final dropout after the last feed-forward layer.
68
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
69
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
70
+ ff_bias (`bool`, defaults to `True`):
71
+ Whether or not to use bias in Feed-forward layer.
72
+ attention_out_bias (`bool`, defaults to `True`):
73
+ Whether or not to use bias in Attention output projection layer.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ dim: int,
79
+ num_attention_heads: int,
80
+ attention_head_dim: int,
81
+ time_embed_dim: int,
82
+ dropout: float = 0.0,
83
+ activation_fn: str = "gelu-approximate",
84
+ attention_bias: bool = False,
85
+ qk_norm: bool = True,
86
+ norm_elementwise_affine: bool = True,
87
+ norm_eps: float = 1e-5,
88
+ final_dropout: bool = True,
89
+ ff_inner_dim: Optional[int] = None,
90
+ ff_bias: bool = True,
91
+ attention_out_bias: bool = True,
92
+ ):
93
+ super().__init__()
94
+
95
+ # 1. Self Attention
96
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
97
+
98
+ self.attn1 = Attention(
99
+ query_dim=dim,
100
+ dim_head=attention_head_dim,
101
+ heads=num_attention_heads,
102
+ qk_norm="layer_norm" if qk_norm else None,
103
+ eps=1e-6,
104
+ bias=attention_bias,
105
+ out_bias=attention_out_bias,
106
+ processor=CogVideoXAttnProcessor2_0(),
107
+ )
108
+
109
+ # 2. Feed Forward
110
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
111
+
112
+ self.ff = FeedForward(
113
+ dim,
114
+ dropout=dropout,
115
+ activation_fn=activation_fn,
116
+ final_dropout=final_dropout,
117
+ inner_dim=ff_inner_dim,
118
+ bias=ff_bias,
119
+ )
120
+
121
+ def forward(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ encoder_hidden_states: torch.Tensor,
125
+ temb: torch.Tensor,
126
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
127
+ ) -> torch.Tensor:
128
+ text_seq_length = encoder_hidden_states.size(1)
129
+
130
+ # norm & modulate
131
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
132
+ hidden_states, encoder_hidden_states, temb
133
+ )
134
+
135
+ # insert here
136
+ # pass
137
+
138
+ # attention
139
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
140
+ hidden_states=norm_hidden_states,
141
+ encoder_hidden_states=norm_encoder_hidden_states,
142
+ image_rotary_emb=image_rotary_emb,
143
+ )
144
+
145
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
146
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
147
+
148
+ # norm & modulate
149
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
150
+ hidden_states, encoder_hidden_states, temb
151
+ )
152
+
153
+ # feed-forward
154
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
155
+ ff_output = self.ff(norm_hidden_states)
156
+
157
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
158
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
159
+
160
+ return hidden_states, encoder_hidden_states
161
+
162
+
163
+ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
164
+ """
165
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
166
+
167
+ Parameters:
168
+ num_attention_heads (`int`, defaults to `30`):
169
+ The number of heads to use for multi-head attention.
170
+ attention_head_dim (`int`, defaults to `64`):
171
+ The number of channels in each head.
172
+ in_channels (`int`, defaults to `16`):
173
+ The number of channels in the input.
174
+ out_channels (`int`, *optional*, defaults to `16`):
175
+ The number of channels in the output.
176
+ flip_sin_to_cos (`bool`, defaults to `True`):
177
+ Whether to flip the sin to cos in the time embedding.
178
+ time_embed_dim (`int`, defaults to `512`):
179
+ Output dimension of timestep embeddings.
180
+ text_embed_dim (`int`, defaults to `4096`):
181
+ Input dimension of text embeddings from the text encoder.
182
+ num_layers (`int`, defaults to `30`):
183
+ The number of layers of Transformer blocks to use.
184
+ dropout (`float`, defaults to `0.0`):
185
+ The dropout probability to use.
186
+ attention_bias (`bool`, defaults to `True`):
187
+ Whether or not to use bias in the attention projection layers.
188
+ sample_width (`int`, defaults to `90`):
189
+ The width of the input latents.
190
+ sample_height (`int`, defaults to `60`):
191
+ The height of the input latents.
192
+ sample_frames (`int`, defaults to `49`):
193
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
194
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
195
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
196
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
197
+ patch_size (`int`, defaults to `2`):
198
+ The size of the patches to use in the patch embedding layer.
199
+ temporal_compression_ratio (`int`, defaults to `4`):
200
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
201
+ max_text_seq_length (`int`, defaults to `226`):
202
+ The maximum sequence length of the input text embeddings.
203
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
204
+ Activation function to use in feed-forward.
205
+ timestep_activation_fn (`str`, defaults to `"silu"`):
206
+ Activation function to use when generating the timestep embeddings.
207
+ norm_elementwise_affine (`bool`, defaults to `True`):
208
+ Whether or not to use elementwise affine in normalization layers.
209
+ norm_eps (`float`, defaults to `1e-5`):
210
+ The epsilon value to use in normalization layers.
211
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
212
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
213
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
214
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
215
+ """
216
+
217
+ _supports_gradient_checkpointing = True
218
+
219
+ @register_to_config
220
+ def __init__(
221
+ self,
222
+ num_attention_heads: int = 30,
223
+ attention_head_dim: int = 64,
224
+ in_channels: int = 16,
225
+ out_channels: Optional[int] = 16,
226
+ flip_sin_to_cos: bool = True,
227
+ freq_shift: int = 0,
228
+ time_embed_dim: int = 512,
229
+ text_embed_dim: int = 4096,
230
+ num_layers: int = 30,
231
+ dropout: float = 0.0,
232
+ attention_bias: bool = True,
233
+ sample_width: int = 90,
234
+ sample_height: int = 60,
235
+ sample_frames: int = 49,
236
+ patch_size: int = 2,
237
+ temporal_compression_ratio: int = 4,
238
+ max_text_seq_length: int = 226,
239
+ activation_fn: str = "gelu-approximate",
240
+ timestep_activation_fn: str = "silu",
241
+ norm_elementwise_affine: bool = True,
242
+ norm_eps: float = 1e-5,
243
+ spatial_interpolation_scale: float = 1.875,
244
+ temporal_interpolation_scale: float = 1.0,
245
+ use_rotary_positional_embeddings: bool = False,
246
+ use_learned_positional_embeddings: bool = False,
247
+ is_train_face: bool = False,
248
+ is_kps: bool = False,
249
+ cross_attn_interval: int = 1,
250
+ LFE_num_tokens: int = 32,
251
+ LFE_output_dim: int = 768,
252
+ LFE_heads: int = 12,
253
+ local_face_scale: float = 1.0,
254
+ ):
255
+ super().__init__()
256
+ inner_dim = num_attention_heads * attention_head_dim
257
+
258
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
259
+ raise ValueError(
260
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
261
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
262
+ "issue at https://github.com/huggingface/diffusers/issues."
263
+ )
264
+
265
+ # 1. Patch embedding
266
+ self.patch_embed = CogVideoXPatchEmbed(
267
+ patch_size=patch_size,
268
+ in_channels=in_channels,
269
+ embed_dim=inner_dim,
270
+ text_embed_dim=text_embed_dim,
271
+ bias=True,
272
+ sample_width=sample_width,
273
+ sample_height=sample_height,
274
+ sample_frames=sample_frames,
275
+ temporal_compression_ratio=temporal_compression_ratio,
276
+ max_text_seq_length=max_text_seq_length,
277
+ spatial_interpolation_scale=spatial_interpolation_scale,
278
+ temporal_interpolation_scale=temporal_interpolation_scale,
279
+ use_positional_embeddings=not use_rotary_positional_embeddings,
280
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
281
+ )
282
+ self.embedding_dropout = nn.Dropout(dropout)
283
+
284
+ # 2. Time embeddings
285
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
286
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
287
+
288
+ # 3. Define spatio-temporal transformers blocks
289
+ self.transformer_blocks = nn.ModuleList(
290
+ [
291
+ CogVideoXBlock(
292
+ dim=inner_dim,
293
+ num_attention_heads=num_attention_heads,
294
+ attention_head_dim=attention_head_dim,
295
+ time_embed_dim=time_embed_dim,
296
+ dropout=dropout,
297
+ activation_fn=activation_fn,
298
+ attention_bias=attention_bias,
299
+ norm_elementwise_affine=norm_elementwise_affine,
300
+ norm_eps=norm_eps,
301
+ )
302
+ for _ in range(num_layers)
303
+ ]
304
+ )
305
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
306
+
307
+ # 4. Output blocks
308
+ self.norm_out = AdaLayerNorm(
309
+ embedding_dim=time_embed_dim,
310
+ output_dim=2 * inner_dim,
311
+ norm_elementwise_affine=norm_elementwise_affine,
312
+ norm_eps=norm_eps,
313
+ chunk_dim=1,
314
+ )
315
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
316
+
317
+ self.gradient_checkpointing = False
318
+
319
+ self.is_train_face = is_train_face
320
+ self.is_kps = is_kps
321
+
322
+ if is_train_face:
323
+ self.inner_dim = inner_dim
324
+ self.cross_attn_interval = cross_attn_interval
325
+ self.num_ca = num_layers // cross_attn_interval
326
+ self.LFE_num_tokens = LFE_num_tokens
327
+ self.LFE_output_dim = LFE_output_dim
328
+ self.LFE_heads = LFE_heads
329
+ self.LFE_final_output_dim = int(self.inner_dim / 3 * 2)
330
+ self.local_face_scale = local_face_scale
331
+ self._init_face_inputs()
332
+
333
+ def _set_gradient_checkpointing(self, module, value=False):
334
+ self.gradient_checkpointing = value
335
+
336
+ def _init_face_inputs(self):
337
+ device = self.device
338
+ weight_dtype = next(self.transformer_blocks.parameters()).dtype
339
+ self.local_facial_extractor = LocalFacialExtractor()
340
+ self.local_facial_extractor.to(device, dtype=weight_dtype)
341
+ self.perceiver_cross_attention = nn.ModuleList([
342
+ PerceiverCrossAttention(dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim).to(device, dtype=weight_dtype) for _ in range(self.num_ca)
343
+ ])
344
+
345
+ def save_face_modules(self, path: str):
346
+ save_dict = {
347
+ 'local_facial_extractor': self.local_facial_extractor.state_dict(),
348
+ 'perceiver_cross_attention': [ca.state_dict() for ca in self.perceiver_cross_attention],
349
+ }
350
+ torch.save(save_dict, path)
351
+
352
+ def load_face_modules(self, path: str):
353
+ checkpoint = torch.load(path, map_location=self.device)
354
+ self.local_facial_extractor.load_state_dict(checkpoint['local_facial_extractor'])
355
+ for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint['perceiver_cross_attention']):
356
+ ca.load_state_dict(state_dict)
357
+
358
+ @property
359
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
360
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
361
+ r"""
362
+ Returns:
363
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
364
+ indexed by its weight name.
365
+ """
366
+ # set recursively
367
+ processors = {}
368
+
369
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
370
+ if hasattr(module, "get_processor"):
371
+ processors[f"{name}.processor"] = module.get_processor()
372
+
373
+ for sub_name, child in module.named_children():
374
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
375
+
376
+ return processors
377
+
378
+ for name, module in self.named_children():
379
+ fn_recursive_add_processors(name, module, processors)
380
+
381
+ return processors
382
+
383
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
384
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
385
+ r"""
386
+ Sets the attention processor to use to compute attention.
387
+
388
+ Parameters:
389
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
390
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
391
+ for **all** `Attention` layers.
392
+
393
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
394
+ processor. This is strongly recommended when setting trainable attention processors.
395
+
396
+ """
397
+ count = len(self.attn_processors.keys())
398
+
399
+ if isinstance(processor, dict) and len(processor) != count:
400
+ raise ValueError(
401
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
402
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
403
+ )
404
+
405
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
406
+ if hasattr(module, "set_processor"):
407
+ if not isinstance(processor, dict):
408
+ module.set_processor(processor)
409
+ else:
410
+ module.set_processor(processor.pop(f"{name}.processor"))
411
+
412
+ for sub_name, child in module.named_children():
413
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
414
+
415
+ for name, module in self.named_children():
416
+ fn_recursive_attn_processor(name, module, processor)
417
+
418
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
419
+ def fuse_qkv_projections(self):
420
+ """
421
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
422
+ are fused. For cross-attention modules, key and value projection matrices are fused.
423
+
424
+ <Tip warning={true}>
425
+
426
+ This API is 🧪 experimental.
427
+
428
+ </Tip>
429
+ """
430
+ self.original_attn_processors = None
431
+
432
+ for _, attn_processor in self.attn_processors.items():
433
+ if "Added" in str(attn_processor.__class__.__name__):
434
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
435
+
436
+ self.original_attn_processors = self.attn_processors
437
+
438
+ for module in self.modules():
439
+ if isinstance(module, Attention):
440
+ module.fuse_projections(fuse=True)
441
+
442
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
443
+
444
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
445
+ def unfuse_qkv_projections(self):
446
+ """Disables the fused QKV projection if enabled.
447
+
448
+ <Tip warning={true}>
449
+
450
+ This API is 🧪 experimental.
451
+
452
+ </Tip>
453
+
454
+ """
455
+ if self.original_attn_processors is not None:
456
+ self.set_attn_processor(self.original_attn_processors)
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ encoder_hidden_states: torch.Tensor,
462
+ timestep: Union[int, float, torch.LongTensor],
463
+ timestep_cond: Optional[torch.Tensor] = None,
464
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
465
+ attention_kwargs: Optional[Dict[str, Any]] = None,
466
+ id_cond: Optional[torch.Tensor] = None,
467
+ id_vit_hidden: Optional[torch.Tensor] = None,
468
+ return_dict: bool = True,
469
+ ):
470
+ # fuse clip and insightface
471
+ if self.is_train_face:
472
+ assert id_cond is not None and id_vit_hidden is not None
473
+ valid_face_emb = self.local_facial_extractor(id_cond, id_vit_hidden) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
474
+
475
+ if attention_kwargs is not None:
476
+ attention_kwargs = attention_kwargs.copy()
477
+ lora_scale = attention_kwargs.pop("scale", 1.0)
478
+ else:
479
+ lora_scale = 1.0
480
+
481
+ if USE_PEFT_BACKEND:
482
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
483
+ scale_lora_layers(self, lora_scale)
484
+ else:
485
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
486
+ logger.warning(
487
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
488
+ )
489
+
490
+ batch_size, num_frames, channels, height, width = hidden_states.shape
491
+
492
+ # 1. Time embedding
493
+ timesteps = timestep
494
+ t_emb = self.time_proj(timesteps)
495
+
496
+ # timesteps does not contain any weights and will always return f32 tensors
497
+ # but time_embedding might actually be running in fp16. so we need to cast here.
498
+ # there might be better ways to encapsulate this.
499
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
500
+ emb = self.time_embedding(t_emb, timestep_cond)
501
+
502
+ # 2. Patch embedding
503
+ # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90])
504
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072])
505
+ hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072])
506
+
507
+ text_seq_length = encoder_hidden_states.shape[1]
508
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
509
+ hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
510
+
511
+ # 3. Transformer blocks
512
+ ca_idx = 0
513
+ for i, block in enumerate(self.transformer_blocks):
514
+ if self.training and self.gradient_checkpointing:
515
+
516
+ def create_custom_forward(module):
517
+ def custom_forward(*inputs):
518
+ return module(*inputs)
519
+
520
+ return custom_forward
521
+
522
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
523
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(block),
525
+ hidden_states,
526
+ encoder_hidden_states,
527
+ emb,
528
+ image_rotary_emb,
529
+ **ckpt_kwargs,
530
+ )
531
+ else:
532
+ hidden_states, encoder_hidden_states = block(
533
+ hidden_states=hidden_states,
534
+ encoder_hidden_states=encoder_hidden_states,
535
+ temb=emb,
536
+ image_rotary_emb=image_rotary_emb,
537
+ )
538
+
539
+ if self.is_train_face:
540
+ if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
541
+ hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
542
+ ca_idx += 1
543
+
544
+ if not self.config.use_rotary_positional_embeddings:
545
+ # CogVideoX-2B
546
+ hidden_states = self.norm_final(hidden_states)
547
+ else:
548
+ # CogVideoX-5B
549
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
550
+ hidden_states = self.norm_final(hidden_states)
551
+ hidden_states = hidden_states[:, text_seq_length:]
552
+
553
+ # 4. Final block
554
+ hidden_states = self.norm_out(hidden_states, temb=emb)
555
+ hidden_states = self.proj_out(hidden_states)
556
+
557
+ # 5. Unpatchify
558
+ # Note: we use `-1` instead of `channels`:
559
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
560
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
561
+ p = self.config.patch_size
562
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
563
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
564
+
565
+ if USE_PEFT_BACKEND:
566
+ # remove `lora_scale` from each PEFT layer
567
+ unscale_lora_layers(self, lora_scale)
568
+
569
+ if not return_dict:
570
+ return (output,)
571
+ return Transformer2DModelOutput(sample=output)
572
+
573
+ @classmethod
574
+ def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}):
575
+ if subfolder:
576
+ config_path = config_path or pretrained_model_path
577
+ config_file = os.path.join(config_path, subfolder, 'config.json')
578
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
579
+ else:
580
+ config_file = os.path.join(config_path or pretrained_model_path, 'config.json')
581
+
582
+ print(f"Loading 3D transformer's pretrained weights from {pretrained_model_path} ...")
583
+
584
+ # Check if config file exists
585
+ if not os.path.isfile(config_file):
586
+ raise RuntimeError(f"Configuration file '{config_file}' does not exist")
587
+
588
+ # Load the configuration
589
+ with open(config_file, "r") as f:
590
+ config = json.load(f)
591
+
592
+ from diffusers.utils import WEIGHTS_NAME
593
+ model = cls.from_config(config, **transformer_additional_kwargs)
594
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
595
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
596
+ if os.path.exists(model_file):
597
+ state_dict = torch.load(model_file, map_location="cpu")
598
+ elif os.path.exists(model_file_safetensors):
599
+ from safetensors.torch import load_file
600
+ state_dict = load_file(model_file_safetensors)
601
+ else:
602
+ from safetensors.torch import load_file
603
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
604
+ state_dict = {}
605
+ for model_file_safetensors in model_files_safetensors:
606
+ _state_dict = load_file(model_file_safetensors)
607
+ for key in _state_dict:
608
+ state_dict[key] = _state_dict[key]
609
+
610
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
611
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
612
+ if len(new_shape) == 5:
613
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
614
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
615
+ else:
616
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
617
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
618
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
619
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
620
+ else:
621
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
622
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
623
+
624
+ tmp_state_dict = {}
625
+ for key in state_dict:
626
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
627
+ tmp_state_dict[key] = state_dict[key]
628
+ else:
629
+ print(key, "Size don't match, skip")
630
+ state_dict = tmp_state_dict
631
+
632
+ m, u = model.load_state_dict(state_dict, strict=False)
633
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
634
+ print(m)
635
+
636
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
637
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
638
+
639
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
640
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
641
+
642
+ return model
643
+
644
+ if __name__ == '__main__':
645
+ device = "cuda:0"
646
+ weight_dtype = torch.bfloat16
647
+ pretrained_model_name_or_path = "BestWishYsh/ConsisID-preview"
648
+
649
+ transformer_additional_kwargs={
650
+ 'torch_dtype': weight_dtype,
651
+ 'revision': None,
652
+ 'variant': None,
653
+ 'is_train_face': True,
654
+ 'is_kps': False,
655
+ 'LFE_num_tokens': 32,
656
+ 'LFE_output_dim': 768,
657
+ 'LFE_heads': 12,
658
+ 'cross_attn_interval': 2,
659
+ }
660
+
661
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(
662
+ pretrained_model_name_or_path,
663
+ subfolder="transformer",
664
+ transformer_additional_kwargs=transformer_additional_kwargs,
665
+ )
666
+
667
+ transformer.to(device, dtype=weight_dtype)
668
+ for param in transformer.parameters():
669
+ param.requires_grad = False
670
+ transformer.eval()
671
+
672
+ b = 1
673
+ dim = 32
674
+ pixel_values = torch.ones(b, 49, 3, 480, 720).to(device, dtype=weight_dtype)
675
+ noisy_latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype)
676
+ target = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype)
677
+ latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype)
678
+ prompt_embeds = torch.ones(b, 226, 4096).to(device, dtype=weight_dtype)
679
+ image_rotary_emb = (torch.ones(17550, 64).to(device, dtype=weight_dtype), torch.ones(17550, 64).to(device, dtype=weight_dtype))
680
+ timesteps = torch.tensor([311]).to(device, dtype=weight_dtype)
681
+ id_vit_hidden = [torch.ones([1, 577, 1024]).to(device, dtype=weight_dtype)] * 5
682
+ id_cond = torch.ones(b, 1280).to(device, dtype=weight_dtype)
683
+ assert len(timesteps) == b
684
+
685
+ model_output = transformer(
686
+ hidden_states=noisy_latents,
687
+ encoder_hidden_states=prompt_embeds,
688
+ timestep=timesteps,
689
+ image_rotary_emb=image_rotary_emb,
690
+ return_dict=False,
691
+ id_vit_hidden=id_vit_hidden if id_vit_hidden is not None else None,
692
+ id_cond=id_cond if id_cond is not None else None,
693
+ )[0]
694
+
695
+ print(model_output)
696
+ # transformer.save_pretrained(os.path.join("./test_ckpt", "transformer"))
697
+
models/utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torchvision.transforms import InterpolationMode
8
+ from torchvision.transforms.functional import normalize, resize
9
+ from transformers import T5EncoderModel, T5Tokenizer
10
+ from typing import List, Optional, Tuple, Union
11
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
12
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
13
+
14
+
15
+ def tensor_to_pil(src_img_tensor):
16
+ img = src_img_tensor.clone().detach()
17
+ if img.dtype == torch.bfloat16:
18
+ img = img.to(torch.float32)
19
+ img = img.cpu().numpy()
20
+ img = np.transpose(img, (1, 2, 0))
21
+ img = img.astype(np.uint8)
22
+ pil_image = Image.fromarray(img)
23
+ return pil_image
24
+
25
+
26
+ def _get_t5_prompt_embeds(
27
+ tokenizer: T5Tokenizer,
28
+ text_encoder: T5EncoderModel,
29
+ prompt: Union[str, List[str]],
30
+ num_videos_per_prompt: int = 1,
31
+ max_sequence_length: int = 226,
32
+ device: Optional[torch.device] = None,
33
+ dtype: Optional[torch.dtype] = None,
34
+ text_input_ids=None,
35
+ ):
36
+ prompt = [prompt] if isinstance(prompt, str) else prompt
37
+ batch_size = len(prompt)
38
+
39
+ if tokenizer is not None:
40
+ text_inputs = tokenizer(
41
+ prompt,
42
+ padding="max_length",
43
+ max_length=max_sequence_length,
44
+ truncation=True,
45
+ add_special_tokens=True,
46
+ return_tensors="pt",
47
+ )
48
+ text_input_ids = text_inputs.input_ids
49
+ else:
50
+ if text_input_ids is None:
51
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
52
+
53
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
54
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
55
+
56
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
57
+ _, seq_len, _ = prompt_embeds.shape
58
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
59
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
60
+
61
+ return prompt_embeds
62
+
63
+
64
+ def encode_prompt(
65
+ tokenizer: T5Tokenizer,
66
+ text_encoder: T5EncoderModel,
67
+ prompt: Union[str, List[str]],
68
+ num_videos_per_prompt: int = 1,
69
+ max_sequence_length: int = 226,
70
+ device: Optional[torch.device] = None,
71
+ dtype: Optional[torch.dtype] = None,
72
+ text_input_ids=None,
73
+ ):
74
+ prompt = [prompt] if isinstance(prompt, str) else prompt
75
+ prompt_embeds = _get_t5_prompt_embeds(
76
+ tokenizer,
77
+ text_encoder,
78
+ prompt=prompt,
79
+ num_videos_per_prompt=num_videos_per_prompt,
80
+ max_sequence_length=max_sequence_length,
81
+ device=device,
82
+ dtype=dtype,
83
+ text_input_ids=text_input_ids,
84
+ )
85
+ return prompt_embeds
86
+
87
+
88
+ def compute_prompt_embeddings(
89
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
90
+ ):
91
+ if requires_grad:
92
+ prompt_embeds = encode_prompt(
93
+ tokenizer,
94
+ text_encoder,
95
+ prompt,
96
+ num_videos_per_prompt=1,
97
+ max_sequence_length=max_sequence_length,
98
+ device=device,
99
+ dtype=dtype,
100
+ )
101
+ else:
102
+ with torch.no_grad():
103
+ prompt_embeds = encode_prompt(
104
+ tokenizer,
105
+ text_encoder,
106
+ prompt,
107
+ num_videos_per_prompt=1,
108
+ max_sequence_length=max_sequence_length,
109
+ device=device,
110
+ dtype=dtype,
111
+ )
112
+ return prompt_embeds
113
+
114
+
115
+ def prepare_rotary_positional_embeddings(
116
+ height: int,
117
+ width: int,
118
+ num_frames: int,
119
+ vae_scale_factor_spatial: int = 8,
120
+ patch_size: int = 2,
121
+ attention_head_dim: int = 64,
122
+ device: Optional[torch.device] = None,
123
+ base_height: int = 480,
124
+ base_width: int = 720,
125
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
127
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
128
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
129
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
130
+
131
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
132
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
133
+ embed_dim=attention_head_dim,
134
+ crops_coords=grid_crops_coords,
135
+ grid_size=(grid_height, grid_width),
136
+ temporal_size=num_frames,
137
+ )
138
+
139
+ freqs_cos = freqs_cos.to(device=device)
140
+ freqs_sin = freqs_sin.to(device=device)
141
+ return freqs_cos, freqs_sin
142
+
143
+
144
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
145
+ """Numpy array to tensor.
146
+
147
+ Args:
148
+ imgs (list[ndarray] | ndarray): Input images.
149
+ bgr2rgb (bool): Whether to change bgr to rgb.
150
+ float32 (bool): Whether to change to float32.
151
+
152
+ Returns:
153
+ list[tensor] | tensor: Tensor images. If returned results only have
154
+ one element, just return tensor.
155
+ """
156
+
157
+ def _totensor(img, bgr2rgb, float32):
158
+ if img.shape[2] == 3 and bgr2rgb:
159
+ if img.dtype == 'float64':
160
+ img = img.astype('float32')
161
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
162
+ img = torch.from_numpy(img.transpose(2, 0, 1))
163
+ if float32:
164
+ img = img.float()
165
+ return img
166
+
167
+ if isinstance(imgs, list):
168
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
169
+ return _totensor(imgs, bgr2rgb, float32)
170
+
171
+
172
+ def to_gray(img):
173
+ x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
174
+ x = x.repeat(1, 3, 1, 1)
175
+ return x
176
+
177
+
178
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
179
+ stickwidth = 4
180
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
181
+ kps = np.array(kps)
182
+
183
+ w, h = image_pil.size
184
+ out_img = np.zeros([h, w, 3])
185
+
186
+ for i in range(len(limbSeq)):
187
+ index = limbSeq[i]
188
+ color = color_list[index[0]]
189
+
190
+ x = kps[index][:, 0]
191
+ y = kps[index][:, 1]
192
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
193
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
194
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
195
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
196
+ out_img = (out_img * 0.6).astype(np.uint8)
197
+
198
+ for idx_kp, kp in enumerate(kps):
199
+ color = color_list[idx_kp]
200
+ x, y = kp
201
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
202
+
203
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
204
+ return out_img_pil
205
+
206
+
207
+ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True, cal_uncond=False):
208
+ """
209
+ Args:
210
+ image: numpy rgb image, range [0, 255]
211
+ """
212
+ face_helper.clean_all()
213
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
214
+ # get antelopev2 embedding
215
+ face_info = app.get(image_bgr)
216
+ if len(face_info) > 0:
217
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
218
+ -1
219
+ ] # only use the maximum face
220
+ id_ante_embedding = face_info['embedding'] # (512,)
221
+ face_kps = face_info['kps']
222
+ else:
223
+ id_ante_embedding = None
224
+ face_kps = None
225
+
226
+ # using facexlib to detect and align face
227
+ face_helper.read_image(image_bgr)
228
+ face_helper.get_face_landmarks_5(only_center_face=True)
229
+ if face_kps is None:
230
+ face_kps = face_helper.all_landmarks_5[0]
231
+ face_helper.align_warp_face()
232
+ if len(face_helper.cropped_faces) == 0:
233
+ raise RuntimeError('facexlib align face fail')
234
+ align_face = face_helper.cropped_faces[0] # (512, 512, 3) # RGB
235
+
236
+ # incase insightface didn't detect face
237
+ if id_ante_embedding is None:
238
+ print('fail to detect face using insightface, extract embedding on align face')
239
+ id_ante_embedding = handler_ante.get_feat(align_face)
240
+
241
+ id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
242
+ if id_ante_embedding.ndim == 1:
243
+ id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512])
244
+
245
+ # parsing
246
+ if is_align_face:
247
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
248
+ input = input.to(device)
249
+ parsing_out = face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
250
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
251
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
252
+ bg = sum(parsing_out == i for i in bg_label).bool()
253
+ white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512])
254
+ # only keep the face features
255
+ return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512])
256
+ return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512])
257
+ else:
258
+ original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR)
259
+ input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
260
+ input = input.to(device)
261
+ return_face_features_image = return_face_features_image_2 = input
262
+
263
+ # transform img before sending to eva-clip-vit
264
+ face_features_image = resize(return_face_features_image, clip_vision_model.image_size,
265
+ InterpolationMode.BICUBIC) # torch.Size([1, 3, 336, 336])
266
+ face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
267
+ id_cond_vit, id_vit_hidden = clip_vision_model(face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
268
+ id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
269
+ id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
270
+
271
+ id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
272
+
273
+ return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchaudio==2.5.1
3
+ torchvision==0.20.1
4
+ xformers==0.0.28.post3
5
+ onnx==1.17.0
6
+ onnxruntime-gpu==1.19.2
7
+ deepspeed==0.15.2
8
+ accelerate==1.1.1
9
+ diffusers==0.31.0
10
+ transformers==4.46.3
11
+ tokenizers==0.20.1
12
+ peft==0.12.0
13
+ decord==0.6.0
14
+ sentencepiece==0.2.0
15
+ opencv-python==4.10.0.84
16
+ pyfacer==0.0.4
17
+ numpy==1.26.4
18
+ numba==0.60.0
19
+ insightface==0.7.3
20
+ huggingface-hub==0.26.1
21
+ facexlib==0.3.0
22
+ timm==1.0.9
23
+ func_timeout==4.3.5
24
+ tensorboard==2.17.1
25
+ gradio==5.6.0
26
+ spaces==0.30.4
27
+ pillow==10.4.0
28
+ spandrel==0.4.0
29
+ scikit-video==1.1.11
30
+ moviepy
31
+ wandb
32
+ imageio-ffmpeg
33
+ ftfy
34
+ Jinja2
35
+ einops
36
+ nvitop
util/dataloader.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import cv2
4
+ import json
5
+ import math
6
+ import decord
7
+ import random
8
+ import numpy as np
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from decord import VideoReader
12
+ from contextlib import contextmanager
13
+ from func_timeout import FunctionTimedOut
14
+ from typing import Optional, Sized, Iterator
15
+
16
+ import torch
17
+ from torch.utils.data import Dataset, Sampler
18
+ import torch.nn.functional as F
19
+ from torchvision.transforms import ToPILImage
20
+ from torchvision import transforms
21
+ from accelerate.logging import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ import threading
26
+ log_lock = threading.Lock()
27
+
28
+ def log_error_to_file(error_message, video_path):
29
+ with log_lock:
30
+ with open("error_log.txt", "a") as f:
31
+ f.write(f"Error: {error_message}\n")
32
+ f.write(f"Video Path: {video_path}\n")
33
+ f.write("-" * 50 + "\n")
34
+
35
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
36
+ stickwidth = 4
37
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
38
+ kps = np.array(kps)
39
+
40
+ w, h = image_pil.size
41
+ out_img = np.zeros([h, w, 3])
42
+
43
+ for i in range(len(limbSeq)):
44
+ index = limbSeq[i]
45
+ color = color_list[index[0]]
46
+
47
+ x = kps[index][:, 0]
48
+ y = kps[index][:, 1]
49
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
50
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
51
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
52
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
53
+ out_img = (out_img * 0.6).astype(np.uint8)
54
+
55
+ for idx_kp, kp in enumerate(kps):
56
+ color = color_list[idx_kp]
57
+ x, y = kp
58
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
59
+
60
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
61
+ return out_img_pil
62
+
63
+ @contextmanager
64
+ def VideoReader_contextmanager(*args, **kwargs):
65
+ vr = VideoReader(*args, **kwargs)
66
+ try:
67
+ yield vr
68
+ finally:
69
+ del vr
70
+ gc.collect()
71
+
72
+ def get_valid_segments(valid_frame, tolerance=5):
73
+ valid_positions = sorted(set(valid_frame['face']).union(set(valid_frame['head'])))
74
+
75
+ valid_segments = []
76
+ current_segment = [valid_positions[0]]
77
+
78
+ for i in range(1, len(valid_positions)):
79
+ if valid_positions[i] - valid_positions[i - 1] <= tolerance:
80
+ current_segment.append(valid_positions[i])
81
+ else:
82
+ valid_segments.append(current_segment)
83
+ current_segment = [valid_positions[i]]
84
+
85
+ if current_segment:
86
+ valid_segments.append(current_segment)
87
+
88
+ return valid_segments
89
+
90
+
91
+ def get_frame_indices_adjusted_for_face(valid_frames, n_frames):
92
+ valid_length = len(valid_frames)
93
+ if valid_length >= n_frames:
94
+ return valid_frames[:n_frames]
95
+
96
+ additional_frames_needed = n_frames - valid_length
97
+ repeat_indices = []
98
+
99
+ for i in range(additional_frames_needed):
100
+ index_to_repeat = i % valid_length
101
+ repeat_indices.append(valid_frames[index_to_repeat])
102
+
103
+ all_indices = valid_frames + repeat_indices
104
+ all_indices.sort()
105
+
106
+ return all_indices
107
+
108
+
109
+ def generate_frame_indices_for_face(n_frames, sample_stride, valid_frame, tolerance=7, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
110
+ valid_segments = get_valid_segments(valid_frame, tolerance)
111
+ selected_segment = max(valid_segments, key=len)
112
+
113
+ valid_length = len(selected_segment)
114
+ if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
115
+ # print("use skip frame percent")
116
+ valid_start = int(valid_length * skip_frames_start_percent)
117
+ valid_end = int(valid_length * skip_frames_end_percent)
118
+ elif skip_frames_start != 0 or skip_frames_end != 0:
119
+ # print("use skip frame")
120
+ valid_start = skip_frames_start
121
+ valid_end = valid_length - skip_frames_end
122
+ else:
123
+ # print("no use skip frame")
124
+ valid_start = 0
125
+ valid_end = valid_length
126
+
127
+ if valid_length <= n_frames:
128
+ return get_frame_indices_adjusted_for_face(selected_segment, n_frames), valid_length
129
+ else:
130
+ adjusted_length = valid_end - valid_start
131
+ if adjusted_length <= 0:
132
+ print(f"video_length: {valid_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
133
+ raise ValueError("Skipping too many frames results in no frames left to sample.")
134
+
135
+ clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
136
+ start_idx_position = random.randint(valid_start, valid_end - clip_length)
137
+ start_frame = selected_segment[start_idx_position]
138
+
139
+ selected_frames = []
140
+ for i in range(n_frames):
141
+ next_frame = start_frame + i * sample_stride
142
+ if next_frame in selected_segment:
143
+ selected_frames.append(next_frame)
144
+ else:
145
+ break
146
+
147
+ if len(selected_frames) < n_frames:
148
+ return get_frame_indices_adjusted_for_face(selected_frames, n_frames), len(selected_frames)
149
+
150
+ return selected_frames, len(selected_frames)
151
+
152
+ def frame_has_required_confidence(bbox_data, frame, ID, conf_threshold=0.88):
153
+ frame_str = str(frame)
154
+ if frame_str not in bbox_data:
155
+ return False
156
+
157
+ frame_data = bbox_data[frame_str]
158
+
159
+ face_conf = any(
160
+ item['confidence'] > conf_threshold and item['new_track_id'] == ID
161
+ for item in frame_data.get('face', [])
162
+ )
163
+
164
+ head_conf = any(
165
+ item['confidence'] > conf_threshold and item['new_track_id'] == ID
166
+ for item in frame_data.get('head', [])
167
+ )
168
+
169
+ return face_conf and head_conf
170
+
171
+ def select_mask_frames_from_index(batch_frame, original_batch_frame, valid_id, corresponding_data, control_sam2_frame,
172
+ valid_frame, bbox_data, base_dir, min_distance=3, min_frames=1, max_frames=5,
173
+ mask_type='face', control_mask_type='head', dense_masks=False,
174
+ ensure_control_frame=True):
175
+ """
176
+ Selects frames with corresponding mask images while ensuring a minimum distance constraint between frames,
177
+ and that the frames exist in both batch_frame and valid_frame.
178
+
179
+ Parameters:
180
+ base_path (str): Base directory where the JSON files and mask results are located.
181
+ min_distance (int): Minimum distance between selected frames.
182
+ min_frames (int): Minimum number of frames to select.
183
+ max_frames (int): Maximum number of frames to select.
184
+ mask_type (str): Type of mask to select frames for ('face' or 'head').
185
+ control_mask_type (str): Type of mask used for control frame selection ('face' or 'head').
186
+
187
+ Returns:
188
+ dict: A dictionary where keys are IDs and values are lists of selected mask PNG paths.
189
+ """
190
+ # Helper function to randomly select frames with at least X frames apart
191
+ def select_frames_with_distance_constraint(frames, num_frames, min_distance, control_frame, bbox_data, ID,
192
+ ensure_control_frame=True, fallback=True):
193
+ """
194
+ Selects frames with a minimum distance constraint. If not enough frames can be selected, a fallback plan is applied.
195
+
196
+ Parameters:
197
+ frames (list): List of frame indices to select from.
198
+ num_frames (int): Number of frames to select.
199
+ min_distance (int): Minimum distance between selected frames.
200
+ control_frame (int): The control frame that must always be included.
201
+ fallback (bool): Whether to apply a fallback strategy if not enough frames meet the distance constraint.
202
+
203
+ Returns:
204
+ list: List of selected frames.
205
+ """
206
+ conf_thresholds = [0.95, 0.94, 0.93, 0.92, 0.91, 0.90]
207
+ if ensure_control_frame:
208
+ selected_frames = [control_frame] # Ensure control frame is always included
209
+ else:
210
+ valid_initial_frames = []
211
+ for conf_threshold in conf_thresholds:
212
+ valid_initial_frames = [
213
+ f for f in frames
214
+ if frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
215
+ ]
216
+ if valid_initial_frames:
217
+ break
218
+ if valid_initial_frames:
219
+ selected_frames = [random.choice(valid_initial_frames)]
220
+ else:
221
+ # If no frame meets the initial confidence, fall back to a random frame (or handle as per your preference)
222
+ selected_frames = [random.choice(frames)]
223
+
224
+ available_frames = [f for f in frames if f != selected_frames[0]] # Exclude control frame for random selection
225
+
226
+ random.shuffle(available_frames) # Shuffle to introduce randomness
227
+
228
+ while available_frames and len(selected_frames) < num_frames:
229
+ last_selected_frame = selected_frames[-1]
230
+
231
+ valid_choices = []
232
+ for conf_threshold in conf_thresholds:
233
+ valid_choices = [
234
+ f for f in available_frames
235
+ if abs(f - last_selected_frame) >= min_distance and
236
+ frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
237
+ ]
238
+ if valid_choices:
239
+ break
240
+
241
+ if valid_choices:
242
+ frame = random.choice(valid_choices)
243
+ available_frames.remove(frame)
244
+ selected_frames.append(frame)
245
+ else:
246
+ if fallback:
247
+ # Fallback strategy: uniformly distribute remaining frames if distance constraint cannot be met
248
+ remaining_needed = num_frames - len(selected_frames)
249
+ remaining_frames = available_frames[:remaining_needed]
250
+
251
+ # Distribute the remaining frames evenly if possible
252
+ if remaining_frames:
253
+ step = max(1, len(remaining_frames) // remaining_needed)
254
+ evenly_selected = remaining_frames[::step][:remaining_needed]
255
+ selected_frames.extend(evenly_selected)
256
+ break
257
+ else:
258
+ break # No valid choices remain and no fallback strategy is allowed
259
+
260
+ if len(selected_frames) < num_frames:
261
+ return None
262
+
263
+ return selected_frames
264
+
265
+ # Convert batch_frame list to a set to remove duplicates
266
+ batch_frame_set = set(batch_frame)
267
+
268
+ # Dictionary to store selected mask PNGs
269
+ selected_masks_dict = {}
270
+ selected_bboxs_dict = {}
271
+ dense_masks_dict = {}
272
+ selected_frames_dict = {}
273
+
274
+ # ID
275
+ try:
276
+ mask_valid_frames = valid_frame[mask_type] # Select frames based on the specified mask type
277
+ control_valid_frames = valid_frame[control_mask_type] # Control frames for control_mask_type
278
+ except KeyError:
279
+ if mask_type not in valid_frame.keys():
280
+ print(f"no valid {mask_type}")
281
+ if control_mask_type not in valid_frame.keys():
282
+ print(f"no valid {control_mask_type}")
283
+
284
+ # Get the control frame for the control mask type
285
+ control_frame = control_sam2_frame[valid_id][control_mask_type]
286
+
287
+ # Filter frames to only those which are in both valid_frame and batch_frame_set
288
+ valid_frames = []
289
+ # valid_frames = [frame for frame in mask_valid_frames if frame in control_valid_frames and frame in batch_frame_set]
290
+ for frame in mask_valid_frames:
291
+ if frame in control_valid_frames and frame in batch_frame_set:
292
+ # Check if bbox_data has 'head' or 'face' for the frame
293
+ if str(frame) in bbox_data:
294
+ frame_data = bbox_data[str(frame)]
295
+ if 'head' in frame_data or 'face' in frame_data:
296
+ valid_frames.append(frame)
297
+
298
+ # Ensure the control frame is included in the valid frames
299
+ if ensure_control_frame and (control_frame not in valid_frames):
300
+ valid_frames.append(control_frame)
301
+
302
+ # Select a random number of frames between min_frames and max_frames
303
+ num_frames_to_select = random.randint(min_frames, max_frames)
304
+ selected_frames = select_frames_with_distance_constraint(valid_frames, num_frames_to_select, min_distance,
305
+ control_frame, bbox_data, valid_id, ensure_control_frame)
306
+
307
+ # Store the selected frames as mask PNGs and bbox
308
+ selected_masks_dict[valid_id] = []
309
+ selected_bboxs_dict[valid_id] = []
310
+
311
+ # Initialize the dense_masks_dict entry for the current ID
312
+ dense_masks_dict[valid_id] = []
313
+
314
+ # Store the selected frames in the dictionary
315
+ selected_frames_dict[valid_id] = selected_frames
316
+
317
+ if dense_masks:
318
+ for frame in original_batch_frame:
319
+ mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{int(frame):05d}.png"
320
+ mask_array = np.array(Image.open(mask_data_path))
321
+ binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
322
+ dense_masks_dict[valid_id].append(binary_mask)
323
+
324
+ for frame in selected_frames:
325
+ mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{frame:05d}.png"
326
+ mask_array = np.array(Image.open(mask_data_path))
327
+ binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
328
+ selected_masks_dict[valid_id].append(binary_mask)
329
+
330
+ try:
331
+ for item in bbox_data[f"{frame}"]["head"]:
332
+ if item['new_track_id'] == int(valid_id):
333
+ temp_bbox = item['box']
334
+ break
335
+ except (KeyError, StopIteration):
336
+ try:
337
+ for item in bbox_data[f"{frame}"]["face"]:
338
+ if item['new_track_id'] == int(valid_id):
339
+ temp_bbox = item['box']
340
+ break
341
+ except (KeyError, StopIteration):
342
+ temp_bbox = None
343
+
344
+ selected_bboxs_dict[valid_id].append(temp_bbox)
345
+
346
+ return selected_frames_dict, selected_masks_dict, selected_bboxs_dict, dense_masks_dict
347
+
348
+ def pad_tensor(tensor, target_size, dim=0):
349
+ padding_size = target_size - tensor.size(dim)
350
+ if padding_size > 0:
351
+ pad_shape = list(tensor.shape)
352
+ pad_shape[dim] = padding_size
353
+ padding_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
354
+ return torch.cat([tensor, padding_tensor], dim=dim)
355
+ else:
356
+ return tensor[:target_size]
357
+
358
+ def crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=False):
359
+ """
360
+ Crop images based on given bounding boxes and frame indices from a video.
361
+
362
+ Args:
363
+ selected_frame_index (list): List of frame indices to be cropped.
364
+ selected_bboxs_dict (list of dict): List of dictionaries, each containing 'x1', 'y1', 'x2', 'y2' bounding box coordinates.
365
+ video_reader (VideoReader or list of numpy arrays): Video frames accessible by index, where each frame is a numpy array (H, W, C).
366
+
367
+ Returns:
368
+ list: A list of cropped images in PIL Image format.
369
+ """
370
+ expanded_cropped_images = []
371
+ original_cropped_images = []
372
+ for frame_idx, bbox in zip(selected_frame_index, selected_bboxs_dict):
373
+ # Get the specific frame from the video reader using the frame index
374
+ frame = video_reader[frame_idx] # torch.tensor # (H, W, C)
375
+
376
+ # Extract bounding box coordinates and convert them to integers
377
+ x1, y1, x2, y2 = int(bbox['x1']), int(bbox['y1']), int(bbox['x2']), int(bbox['y2'])
378
+ # Crop to minimize the bounding box to a square
379
+ width = x2 - x1 # Calculate the width of the bounding box
380
+ height = y2 - y1 # Calculate the height of the bounding box
381
+ side_length = max(width, height) # Determine the side length of the square (max of width or height)
382
+
383
+ # Calculate the center of the bounding box
384
+ center_x = (x1 + x2) // 2
385
+ center_y = (y1 + y2) // 2
386
+
387
+ # Calculate new coordinates for the square region centered around the original bounding box
388
+ new_x1 = max(0, center_x - side_length // 2) # Ensure x1 is within image bounds
389
+ new_y1 = max(0, center_y - side_length // 2) # Ensure y1 is within image bounds
390
+ new_x2 = min(frame.shape[1], new_x1 + side_length) # Ensure x2 does not exceed image width
391
+ new_y2 = min(frame.shape[0], new_y1 + side_length) # Ensure y2 does not exceed image height
392
+
393
+ # Adjust coordinates if the cropped area is smaller than the desired side length
394
+ # Ensure final width and height are equal, keeping it a square
395
+ actual_width = new_x2 - new_x1
396
+ actual_height = new_y2 - new_y1
397
+
398
+ if actual_width < side_length:
399
+ # Adjust x1 or x2 to ensure the correct side length, while staying in bounds
400
+ if new_x1 == 0:
401
+ new_x2 = min(frame.shape[1], new_x1 + side_length)
402
+ else:
403
+ new_x1 = max(0, new_x2 - side_length)
404
+
405
+ if actual_height < side_length:
406
+ # Adjust y1 or y2 to ensure the correct side length, while staying in bounds
407
+ if new_y1 == 0:
408
+ new_y2 = min(frame.shape[0], new_y1 + side_length)
409
+ else:
410
+ new_y1 = max(0, new_y2 - side_length)
411
+
412
+ # Expand the square by 20%
413
+ expansion_ratio = 0.2 # Define the expansion ratio
414
+ expansion_amount = int(side_length * expansion_ratio) # Calculate the number of pixels to expand by
415
+
416
+ # Calculate expanded coordinates, ensuring they stay within image bounds
417
+ expanded_x1 = max(0, new_x1 - expansion_amount) # Expand left, ensuring x1 is within bounds
418
+ expanded_y1 = max(0, new_y1 - expansion_amount) # Expand up, ensuring y1 is within bounds
419
+ expanded_x2 = min(frame.shape[1], new_x2 + expansion_amount) # Expand right, ensuring x2 does not exceed bounds
420
+ expanded_y2 = min(frame.shape[0], new_y2 + expansion_amount) # Expand down, ensuring y2 does not exceed bounds
421
+
422
+ # Ensure the expanded area is still a square
423
+ expanded_width = expanded_x2 - expanded_x1
424
+ expanded_height = expanded_y2 - expanded_y1
425
+ final_side_length = min(expanded_width, expanded_height)
426
+
427
+ # Adjust to ensure square shape if necessary
428
+ if expanded_width != expanded_height:
429
+ if expanded_width > expanded_height:
430
+ expanded_x2 = expanded_x1 + final_side_length
431
+ else:
432
+ expanded_y2 = expanded_y1 + final_side_length
433
+
434
+ expanded_cropped_rgb_tensor = frame[expanded_y1:expanded_y2, expanded_x1:expanded_x2, :]
435
+ expanded_cropped_rgb = Image.fromarray(np.array(expanded_cropped_rgb_tensor)).convert('RGB')
436
+ expanded_cropped_images.append(expanded_cropped_rgb)
437
+
438
+ if return_ori:
439
+ original_cropped_rgb_tensor = frame[new_y1:new_y2, new_x1:new_x2, :]
440
+ original_cropped_rgb = Image.fromarray(np.array(original_cropped_rgb_tensor)).convert('RGB')
441
+ original_cropped_images.append(original_cropped_rgb)
442
+ return expanded_cropped_images, original_cropped_images
443
+
444
+ return expanded_cropped_images, None
445
+
446
+ def process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480)):
447
+ """
448
+ Process a list of cropped images in PIL format.
449
+
450
+ Parameters:
451
+ expand_images_pil (list of PIL.Image): List of cropped images in PIL format.
452
+ target_size (tuple of int): The target size for resizing images, default is (480, 480).
453
+
454
+ Returns:
455
+ torch.Tensor: A tensor containing the processed images.
456
+ """
457
+ expand_face_imgs = []
458
+ original_face_imgs = []
459
+ if len(original_images_pil) != 0:
460
+ for expand_img, original_img in zip(expand_images_pil, original_images_pil):
461
+ expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
462
+ expand_src_img = np.array(expand_resized_img)
463
+ expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
464
+ expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
465
+ expand_face_imgs.append(expand_src_img)
466
+
467
+ original_resized_img = original_img.resize(target_size, Image.LANCZOS)
468
+ original_src_img = np.array(original_resized_img)
469
+ original_src_img = np.transpose(original_src_img, (2, 0, 1))
470
+ original_src_img = torch.from_numpy(original_src_img).unsqueeze(0).float()
471
+ original_face_imgs.append(original_src_img)
472
+
473
+ expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
474
+ original_face_imgs = torch.cat(original_face_imgs, dim=0)
475
+ else:
476
+ for expand_img in expand_images_pil:
477
+ expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
478
+ expand_src_img = np.array(expand_resized_img)
479
+ expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
480
+ expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
481
+ expand_face_imgs.append(expand_src_img)
482
+ expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
483
+ original_face_imgs = None
484
+
485
+ return expand_face_imgs, original_face_imgs
486
+
487
+ class RandomSampler(Sampler[int]):
488
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
489
+
490
+ If with replacement, then user can specify :attr:`num_samples` to draw.
491
+
492
+ Args:
493
+ data_source (Dataset): dataset to sample from
494
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
495
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
496
+ generator (Generator): Generator used in sampling.
497
+ """
498
+
499
+ data_source: Sized
500
+ replacement: bool
501
+
502
+ def __init__(self, data_source: Sized, replacement: bool = False,
503
+ num_samples: Optional[int] = None, generator=None) -> None:
504
+ self.data_source = data_source
505
+ self.replacement = replacement
506
+ self._num_samples = num_samples
507
+ self.generator = generator
508
+ self._pos_start = 0
509
+
510
+ if not isinstance(self.replacement, bool):
511
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
512
+
513
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
514
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
515
+
516
+ @property
517
+ def num_samples(self) -> int:
518
+ # dataset size might change at runtime
519
+ if self._num_samples is None:
520
+ return len(self.data_source)
521
+ return self._num_samples
522
+
523
+ def __iter__(self) -> Iterator[int]:
524
+ n = len(self.data_source)
525
+ if self.generator is None:
526
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
527
+ generator = torch.Generator()
528
+ generator.manual_seed(seed)
529
+ else:
530
+ generator = self.generator
531
+
532
+ if self.replacement:
533
+ for _ in range(self.num_samples // 32):
534
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
535
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
536
+ else:
537
+ for _ in range(self.num_samples // n):
538
+ xx = torch.randperm(n, generator=generator).tolist()
539
+ if self._pos_start >= n:
540
+ self._pos_start = 0
541
+ print("xx top 10", xx[:10], self._pos_start)
542
+ for idx in range(self._pos_start, n):
543
+ yield xx[idx]
544
+ self._pos_start = (self._pos_start + 1) % n
545
+ self._pos_start = 0
546
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
547
+
548
+ def __len__(self) -> int:
549
+ return self.num_samples
550
+
551
+ class SequentialSampler(Sampler[int]):
552
+ r"""Samples elements sequentially, always in the same order.
553
+
554
+ Args:
555
+ data_source (Dataset): dataset to sample from
556
+ """
557
+
558
+ data_source: Sized
559
+
560
+ def __init__(self, data_source: Sized) -> None:
561
+ self.data_source = data_source
562
+ self._pos_start = 0
563
+
564
+ def __iter__(self) -> Iterator[int]:
565
+ n = len(self.data_source)
566
+ for idx in range(self._pos_start, n):
567
+ yield idx
568
+ self._pos_start = (self._pos_start + 1) % n
569
+ self._pos_start = 0
570
+
571
+ def __len__(self) -> int:
572
+ return len(self.data_source)
573
+
574
+ class ConsisID_Dataset(Dataset):
575
+ def __init__(
576
+ self,
577
+ instance_data_root: Optional[str] = None,
578
+ id_token: Optional[str] = None,
579
+ height=480,
580
+ width=640,
581
+ max_num_frames=49,
582
+ sample_stride=3,
583
+ skip_frames_start_percent=0.0,
584
+ skip_frames_end_percent=1.0,
585
+ skip_frames_start=0,
586
+ skip_frames_end=0,
587
+ text_drop_ratio=-1,
588
+ is_train_face=False,
589
+ is_single_face=False,
590
+ miss_tolerance=6,
591
+ min_distance=3,
592
+ min_frames=1,
593
+ max_frames=5,
594
+ is_cross_face=False,
595
+ is_reserve_face=False,
596
+ ):
597
+ self.id_token = id_token or ""
598
+
599
+ # ConsisID
600
+ self.skip_frames_start_percent = skip_frames_start_percent
601
+ self.skip_frames_end_percent = skip_frames_end_percent
602
+ self.skip_frames_start = skip_frames_start
603
+ self.skip_frames_end = skip_frames_end
604
+ self.is_train_face = is_train_face
605
+ self.is_single_face = is_single_face
606
+
607
+ if is_train_face:
608
+ self.miss_tolerance = miss_tolerance
609
+ self.min_distance = min_distance
610
+ self.min_frames = min_frames
611
+ self.max_frames = max_frames
612
+ self.is_cross_face = is_cross_face
613
+ self.is_reserve_face = is_reserve_face
614
+
615
+ # Loading annotations from files
616
+ print(f"loading annotations from {instance_data_root} ...")
617
+ with open(instance_data_root, 'r') as f:
618
+ folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
619
+
620
+ self.instance_prompts = []
621
+ self.instance_video_paths = []
622
+ self.instance_annotation_base_paths = []
623
+ for sub_root, anno, anno_base in tqdm(folder_anno):
624
+ print(anno)
625
+ self.instance_annotation_base_paths.append(anno_base)
626
+ with open(anno, 'r') as f:
627
+ sub_list = json.load(f)
628
+ for i in tqdm(sub_list):
629
+ path = os.path.join(sub_root, os.path.basename(i['path']))
630
+ cap = i.get('cap', None)
631
+ fps = i.get('fps', 0)
632
+ duration = i.get('duration', 0)
633
+
634
+ if fps * duration < 49.0:
635
+ continue
636
+
637
+ self.instance_prompts.append(cap)
638
+ self.instance_video_paths.append(path)
639
+
640
+ self.num_instance_videos = len(self.instance_video_paths)
641
+
642
+ self.text_drop_ratio = text_drop_ratio
643
+
644
+ # Video params
645
+ self.sample_stride = sample_stride
646
+ self.max_num_frames = max_num_frames
647
+ self.height = height
648
+ self.width = width
649
+
650
+ def _get_frame_indices_adjusted(self, video_length, n_frames):
651
+ indices = list(range(video_length))
652
+ additional_frames_needed = n_frames - video_length
653
+
654
+ repeat_indices = []
655
+ for i in range(additional_frames_needed):
656
+ index_to_repeat = i % video_length
657
+ repeat_indices.append(indices[index_to_repeat])
658
+
659
+ all_indices = indices + repeat_indices
660
+ all_indices.sort()
661
+
662
+ return all_indices
663
+
664
+
665
+ def _generate_frame_indices(self, video_length, n_frames, sample_stride, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
666
+ if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
667
+ print("use skip frame percent")
668
+ valid_start = int(video_length * skip_frames_start_percent)
669
+ valid_end = int(video_length * skip_frames_end_percent)
670
+ elif skip_frames_start != 0 or skip_frames_end != 0:
671
+ print("use skip frame")
672
+ valid_start = skip_frames_start
673
+ valid_end = video_length - skip_frames_end
674
+ else:
675
+ print("no use skip frame")
676
+ valid_start = 0
677
+ valid_end = video_length
678
+
679
+ adjusted_length = valid_end - valid_start
680
+
681
+ if adjusted_length <= 0:
682
+ print(f"video_length: {video_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
683
+ raise ValueError("Skipping too many frames results in no frames left to sample.")
684
+
685
+ if video_length <= n_frames:
686
+ return self._get_frame_indices_adjusted(video_length, n_frames)
687
+ else:
688
+ # clip_length = min(video_length, (n_frames - 1) * sample_stride + 1)
689
+ # start_idx = random.randint(0, video_length - clip_length)
690
+ # frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
691
+
692
+ clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
693
+ start_idx = random.randint(valid_start, valid_end - clip_length)
694
+ frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
695
+ return frame_indices
696
+
697
+ def _short_resize_and_crop(self, frames, target_width, target_height):
698
+ """
699
+ Resize frames and crop to the specified size.
700
+
701
+ Args:
702
+ frames (torch.Tensor): Input frames of shape [T, H, W, C].
703
+ target_width (int): Desired width.
704
+ target_height (int): Desired height.
705
+
706
+ Returns:
707
+ torch.Tensor: Cropped frames of shape [T, target_height, target_width, C].
708
+ """
709
+ T, H, W, C = frames.shape
710
+ aspect_ratio = W / H
711
+
712
+ # Determine new dimensions ensuring they are at least target size
713
+ if aspect_ratio > target_width / target_height:
714
+ new_width = target_width
715
+ new_height = int(target_width / aspect_ratio)
716
+ if new_height < target_height:
717
+ new_height = target_height
718
+ new_width = int(target_height * aspect_ratio)
719
+ else:
720
+ new_height = target_height
721
+ new_width = int(target_height * aspect_ratio)
722
+ if new_width < target_width:
723
+ new_width = target_width
724
+ new_height = int(target_width / aspect_ratio)
725
+
726
+ resize_transform = transforms.Resize((new_height, new_width))
727
+ crop_transform = transforms.CenterCrop((target_height, target_width))
728
+
729
+ frames_tensor = frames.permute(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
730
+ resized_frames = resize_transform(frames_tensor)
731
+ cropped_frames = crop_transform(resized_frames)
732
+ sample = cropped_frames.permute(0, 2, 3, 1)
733
+
734
+ return sample
735
+
736
+ def _resize_with_aspect_ratio(self, frames, target_width, target_height):
737
+ """
738
+ Resize frames while maintaining the aspect ratio by padding or cropping.
739
+
740
+ Args:
741
+ frames (torch.Tensor): Input frames of shape [T, H, W, C].
742
+ target_width (int): Desired width.
743
+ target_height (int): Desired height.
744
+
745
+ Returns:
746
+ torch.Tensor: Resized and padded frames of shape [T, target_height, target_width, C].
747
+ """
748
+ T, frame_height, frame_width, C = frames.shape
749
+ aspect_ratio = frame_width / frame_height # 1.77, 1280 720 -> 720 406
750
+ target_aspect_ratio = target_width / target_height # 1.50, 720 480 ->
751
+
752
+ # If the frame is wider than the target, resize based on width
753
+ if aspect_ratio > target_aspect_ratio:
754
+ new_width = target_width
755
+ new_height = int(target_width / aspect_ratio)
756
+ else:
757
+ new_height = target_height
758
+ new_width = int(target_height * aspect_ratio)
759
+
760
+ # Resize using batch processing
761
+ frames = frames.permute(0, 3, 1, 2) # [T, C, H, W]
762
+ frames = F.interpolate(frames, size=(new_height, new_width), mode='bilinear', align_corners=False)
763
+
764
+ # Calculate padding
765
+ pad_top = (target_height - new_height) // 2
766
+ pad_bottom = target_height - new_height - pad_top
767
+ pad_left = (target_width - new_width) // 2
768
+ pad_right = target_width - new_width - pad_left
769
+
770
+ # Apply padding
771
+ frames = F.pad(frames, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
772
+
773
+ frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
774
+
775
+ return frames
776
+
777
+
778
+ def _save_frame(self, frame, name="1.png"):
779
+ # [H, W, C] -> [C, H, W]
780
+ img = frame
781
+ img = img.permute(2, 0, 1)
782
+ to_pil = ToPILImage()
783
+ img = to_pil(img)
784
+ img.save(name)
785
+
786
+
787
+ def _save_video(self, torch_frames, name="output.mp4"):
788
+ from moviepy.editor import ImageSequenceClip
789
+ frames_np = torch_frames.cpu().numpy()
790
+ if frames_np.dtype != 'uint8':
791
+ frames_np = frames_np.astype('uint8')
792
+ frames_list = [frame for frame in frames_np]
793
+ desired_fps = 24
794
+ clip = ImageSequenceClip(frames_list, fps=desired_fps)
795
+ clip.write_videofile(name, codec="libx264")
796
+
797
+
798
+ def get_batch(self, idx):
799
+ decord.bridge.set_bridge("torch")
800
+
801
+ video_dir = self.instance_video_paths[idx]
802
+ text = self.instance_prompts[idx]
803
+
804
+ train_transforms = transforms.Compose(
805
+ [
806
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
807
+ ]
808
+ )
809
+
810
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
811
+ video_num_frames = len(video_reader)
812
+
813
+ if self.is_train_face:
814
+ reserve_face_imgs = None
815
+ file_base_name = os.path.basename(video_dir.replace(".mp4", ""))
816
+
817
+ anno_base_path = self.instance_annotation_base_paths[idx]
818
+ valid_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "valid_frame.json")
819
+ control_sam2_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "control_sam2_frame.json")
820
+ corresponding_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "corresponding_data.json")
821
+ masks_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "tracking_mask_results")
822
+ bboxs_data_path = os.path.join(anno_base_path, "refine_bbox_jsons", f"{file_base_name}.json")
823
+
824
+ with open(corresponding_data_path, 'r') as f:
825
+ corresponding_data = json.load(f)
826
+
827
+ with open(control_sam2_frame_path, 'r') as f:
828
+ control_sam2_frame = json.load(f)
829
+
830
+ with open(valid_frame_path, 'r') as f:
831
+ valid_frame = json.load(f)
832
+
833
+ with open(bboxs_data_path, 'r') as f:
834
+ bbox_data = json.load(f)
835
+
836
+ if self.is_single_face:
837
+ if len(corresponding_data) != 1:
838
+ raise ValueError(f"Using single face, but {idx} is multi person.")
839
+
840
+ # get random valid id
841
+ valid_ids = []
842
+ backup_ids = []
843
+ for id_key, data in corresponding_data.items():
844
+ if 'face' in data and 'head' in data:
845
+ valid_ids.append(id_key)
846
+
847
+ valid_id = random.choice(valid_ids) if valid_ids else (random.choice(backup_ids) if backup_ids else None)
848
+ if valid_id is None:
849
+ raise ValueError("No valid ID found: both valid_ids and backup_ids are empty.")
850
+
851
+ # get video
852
+ total_index = list(range(video_num_frames))
853
+ batch_index, _ = generate_frame_indices_for_face(self.max_num_frames, self.sample_stride, valid_frame[valid_id],
854
+ self.miss_tolerance, self.skip_frames_start_percent, self.skip_frames_end_percent,
855
+ self.skip_frames_start, self.skip_frames_end)
856
+
857
+ if self.is_cross_face:
858
+ remaining_batch_index_index = [i for i in total_index if i not in batch_index]
859
+ try:
860
+ selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
861
+ remaining_batch_index_index,
862
+ batch_index, valid_id,
863
+ corresponding_data, control_sam2_frame,
864
+ valid_frame[valid_id], bbox_data, masks_data_path,
865
+ min_distance=self.min_distance, min_frames=self.min_frames,
866
+ max_frames=self.max_frames, dense_masks=True,
867
+ ensure_control_frame=False,
868
+ )
869
+ except:
870
+ selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
871
+ batch_index,
872
+ batch_index, valid_id,
873
+ corresponding_data, control_sam2_frame,
874
+ valid_frame[valid_id], bbox_data, masks_data_path,
875
+ min_distance=self.min_distance, min_frames=self.min_frames,
876
+ max_frames=self.max_frames, dense_masks=True,
877
+ ensure_control_frame=False,
878
+ )
879
+ else:
880
+ selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
881
+ batch_index,
882
+ batch_index, valid_id,
883
+ corresponding_data, control_sam2_frame,
884
+ valid_frame[valid_id], bbox_data, masks_data_path,
885
+ min_distance=self.min_distance, min_frames=self.min_frames,
886
+ max_frames=self.max_frames, dense_masks=True,
887
+ ensure_control_frame=True,
888
+ )
889
+ if self.is_reserve_face:
890
+ reserve_frame_index, _, reserve_bboxs_dict, _ = select_mask_frames_from_index(
891
+ batch_index,
892
+ batch_index, valid_id,
893
+ corresponding_data, control_sam2_frame,
894
+ valid_frame[valid_id], bbox_data, masks_data_path,
895
+ min_distance=3, min_frames=4,
896
+ max_frames=4, dense_masks=False,
897
+ ensure_control_frame=False,
898
+ )
899
+
900
+ # get mask and aligned_face_img
901
+ selected_frame_index = selected_frame_index[valid_id]
902
+ valid_frame = valid_frame[valid_id]
903
+ selected_masks_dict = selected_masks_dict[valid_id]
904
+ selected_bboxs_dict = selected_bboxs_dict[valid_id]
905
+ dense_masks_dict = dense_masks_dict[valid_id]
906
+
907
+ if self.is_reserve_face:
908
+ reserve_frame_index = reserve_frame_index[valid_id]
909
+ reserve_bboxs_dict = reserve_bboxs_dict[valid_id]
910
+
911
+ selected_masks_tensor = torch.stack([torch.tensor(mask) for mask in selected_masks_dict])
912
+ temp_dense_masks_tensor = torch.stack([torch.tensor(mask) for mask in dense_masks_dict])
913
+ dense_masks_tensor = self._short_resize_and_crop(temp_dense_masks_tensor.unsqueeze(-1), self.width, self.height).squeeze(-1) # [T, H, W] -> [T, H, W, 1] -> [T, H, W]
914
+
915
+ expand_images_pil, original_images_pil = crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=True)
916
+ expand_face_imgs, original_face_imgs = process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480))
917
+ if self.is_reserve_face:
918
+ reserve_images_pil, _ = crop_images(reserve_frame_index, reserve_bboxs_dict, video_reader, return_ori=False)
919
+ reserve_face_imgs, _ = process_cropped_images(reserve_images_pil, [], target_size=(480, 480))
920
+
921
+ if len(expand_face_imgs) == 0 or len(original_face_imgs) == 0:
922
+ raise ValueError(f"No face detected in input image pool")
923
+
924
+ # post process id related data
925
+ expand_face_imgs = pad_tensor(expand_face_imgs, self.max_frames, dim=0)
926
+ original_face_imgs = pad_tensor(original_face_imgs, self.max_frames, dim=0)
927
+ selected_frame_index = torch.tensor(selected_frame_index) # torch.Size(([15, 13]) [N1]
928
+ selected_frame_index = pad_tensor(selected_frame_index, self.max_frames, dim=0)
929
+ else:
930
+ batch_index = self._generate_frame_indices(video_num_frames, self.max_num_frames, self.sample_stride,
931
+ self.skip_frames_start_percent, self.skip_frames_end_percent,
932
+ self.skip_frames_start, self.skip_frames_end)
933
+
934
+ try:
935
+ frames = video_reader.get_batch(batch_index) # torch [T, H, W, C]
936
+ frames = self._short_resize_and_crop(frames, self.width, self.height) # [T, H, W, C]
937
+ except FunctionTimedOut:
938
+ raise ValueError(f"Read {idx} timeout.")
939
+ except Exception as e:
940
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
941
+
942
+ # Apply training transforms in batch
943
+ frames = frames.float()
944
+ frames = train_transforms(frames)
945
+ pixel_values = frames.permute(0, 3, 1, 2).contiguous() # [T, C, H, W]
946
+ del video_reader
947
+
948
+ # Random use no text generation
949
+ if random.random() < self.text_drop_ratio:
950
+ text = ''
951
+
952
+ if self.is_train_face:
953
+ return pixel_values, text, 'video', video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs
954
+ else:
955
+ return pixel_values, text, 'video', video_dir
956
+
957
+ def __len__(self):
958
+ return self.num_instance_videos
959
+
960
+ def __getitem__(self, idx):
961
+ sample = {}
962
+ if self.is_train_face:
963
+ pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
964
+ sample["instance_prompt"] = self.id_token + cap
965
+ sample["instance_video"] = pixel_values
966
+ sample["video_path"] = video_dir
967
+ if self.is_train_face:
968
+ sample["expand_face_imgs"] = expand_face_imgs
969
+ sample["dense_masks_tensor"] = dense_masks_tensor
970
+ sample["selected_frame_index"] = selected_frame_index
971
+ if reserve_face_imgs is not None:
972
+ sample["reserve_face_imgs"] = reserve_face_imgs
973
+ if original_face_imgs is not None:
974
+ sample["original_face_imgs"] = original_face_imgs
975
+ else:
976
+ pixel_values, cap, data_type, video_dir = self.get_batch(idx)
977
+ sample["instance_prompt"] = self.id_token + cap
978
+ sample["instance_video"] = pixel_values
979
+ sample["video_path"] = video_dir
980
+ return sample
981
+
982
+ # while True:
983
+ # sample = {}
984
+ # try:
985
+ # if self.is_train_face:
986
+ # pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
987
+ # sample["instance_prompt"] = self.id_token + cap
988
+ # sample["instance_video"] = pixel_values
989
+ # sample["video_path"] = video_dir
990
+ # if self.is_train_face:
991
+ # sample["expand_face_imgs"] = expand_face_imgs
992
+ # sample["dense_masks_tensor"] = dense_masks_tensor
993
+ # sample["selected_frame_index"] = selected_frame_index
994
+ # if reserve_face_imgs is not None:
995
+ # sample["reserve_face_imgs"] = reserve_face_imgs
996
+ # if original_face_imgs is not None:
997
+ # sample["original_face_imgs"] = original_face_imgs
998
+ # else:
999
+ # pixel_values, cap, data_type, video_dir, = self.get_batch(idx)
1000
+ # sample["instance_prompt"] = self.id_token + cap
1001
+ # sample["instance_video"] = pixel_values
1002
+ # sample["video_path"] = video_dir
1003
+ # break
1004
+ # except Exception as e:
1005
+ # error_message = str(e)
1006
+ # video_path = self.instance_video_paths[idx % len(self.instance_video_paths)]
1007
+ # print(error_message, video_path)
1008
+ # log_error_to_file(error_message, video_path)
1009
+ # idx = random.randint(0, self.num_instance_videos - 1)
1010
+ # return sample
util/deepspeed_configs/accelerate_config_machine_multi.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ distributed_type: DEEPSPEED
3
+ deepspeed_config:
4
+ deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
5
+ deepspeed_hostfile: util/deepspeed_configs/hostfile.txt
6
+ fsdp_config: {}
7
+ machine_rank: 0
8
+ main_process_ip: 100.64.24.6
9
+ main_process_port: 12343
10
+ main_training_function: main
11
+ num_machines: 2
12
+ num_processes: 16
13
+ rdzv_backend: static
14
+ same_network: true
15
+ tpu_env: []
16
+ tpu_use_cluster: false
17
+ tpu_use_sudo: false
18
+ use_cpu: false
util/deepspeed_configs/accelerate_config_machine_single.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ distributed_type: DEEPSPEED
3
+ deepspeed_config:
4
+ deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
5
+ fsdp_config: {}
6
+ machine_rank: 0
7
+ main_process_ip: null
8
+ main_process_port: 12345
9
+ main_training_function: main
10
+ num_machines: 1
11
+ num_processes: 8
12
+ gpu_ids: 0,1,2,3,4,5,6,7
13
+ use_cpu: false
util/deepspeed_configs/hostfile.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ node-user@100.64.24.6 slots=8
2
+ node-user@100.64.24.3 slots=8
util/deepspeed_configs/zero_stage2_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": "auto",
6
+ "train_batch_size": "auto",
7
+ "gradient_clipping": 1.0,
8
+ "gradient_accumulation_steps": "auto",
9
+ "dump_state": true,
10
+ "zero_optimization": {
11
+ "stage": 2,
12
+ "overlap_comm": true,
13
+ "contiguous_gradients": true,
14
+ "sub_group_size": 1e9,
15
+ "reduce_bucket_size": 5e8
16
+ }
17
+ }
util/rife/IFNet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .refine import *
2
+
3
+
4
+ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
5
+ return nn.Sequential(
6
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
7
+ nn.PReLU(out_planes),
8
+ )
9
+
10
+
11
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
12
+ return nn.Sequential(
13
+ nn.Conv2d(
14
+ in_planes,
15
+ out_planes,
16
+ kernel_size=kernel_size,
17
+ stride=stride,
18
+ padding=padding,
19
+ dilation=dilation,
20
+ bias=True,
21
+ ),
22
+ nn.PReLU(out_planes),
23
+ )
24
+
25
+
26
+ class IFBlock(nn.Module):
27
+ def __init__(self, in_planes, c=64):
28
+ super(IFBlock, self).__init__()
29
+ self.conv0 = nn.Sequential(
30
+ conv(in_planes, c // 2, 3, 2, 1),
31
+ conv(c // 2, c, 3, 2, 1),
32
+ )
33
+ self.convblock = nn.Sequential(
34
+ conv(c, c),
35
+ conv(c, c),
36
+ conv(c, c),
37
+ conv(c, c),
38
+ conv(c, c),
39
+ conv(c, c),
40
+ conv(c, c),
41
+ conv(c, c),
42
+ )
43
+ self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
44
+
45
+ def forward(self, x, flow, scale):
46
+ if scale != 1:
47
+ x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
48
+ if flow != None:
49
+ flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
50
+ x = torch.cat((x, flow), 1)
51
+ x = self.conv0(x)
52
+ x = self.convblock(x) + x
53
+ tmp = self.lastconv(x)
54
+ tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear", align_corners=False)
55
+ flow = tmp[:, :4] * scale * 2
56
+ mask = tmp[:, 4:5]
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(6, c=240)
64
+ self.block1 = IFBlock(13 + 4, c=150)
65
+ self.block2 = IFBlock(13 + 4, c=90)
66
+ self.block_tea = IFBlock(16 + 4, c=90)
67
+ self.contextnet = Contextnet()
68
+ self.unet = Unet()
69
+
70
+ def forward(self, x, scale=[4, 2, 1], timestep=0.5):
71
+ img0 = x[:, :3]
72
+ img1 = x[:, 3:6]
73
+ gt = x[:, 6:] # In inference time, gt is None
74
+ flow_list = []
75
+ merged = []
76
+ mask_list = []
77
+ warped_img0 = img0
78
+ warped_img1 = img1
79
+ flow = None
80
+ loss_distill = 0
81
+ stu = [self.block0, self.block1, self.block2]
82
+ for i in range(3):
83
+ if flow != None:
84
+ flow_d, mask_d = stu[i](
85
+ torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
86
+ )
87
+ flow = flow + flow_d
88
+ mask = mask + mask_d
89
+ else:
90
+ flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i])
91
+ mask_list.append(torch.sigmoid(mask))
92
+ flow_list.append(flow)
93
+ warped_img0 = warp(img0, flow[:, :2])
94
+ warped_img1 = warp(img1, flow[:, 2:4])
95
+ merged_student = (warped_img0, warped_img1)
96
+ merged.append(merged_student)
97
+ if gt.shape[1] == 3:
98
+ flow_d, mask_d = self.block_tea(
99
+ torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
100
+ )
101
+ flow_teacher = flow + flow_d
102
+ warped_img0_teacher = warp(img0, flow_teacher[:, :2])
103
+ warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
104
+ mask_teacher = torch.sigmoid(mask + mask_d)
105
+ merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
106
+ else:
107
+ flow_teacher = None
108
+ merged_teacher = None
109
+ for i in range(3):
110
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
111
+ if gt.shape[1] == 3:
112
+ loss_mask = (
113
+ ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
114
+ .float()
115
+ .detach()
116
+ )
117
+ loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
118
+ c0 = self.contextnet(img0, flow[:, :2])
119
+ c1 = self.contextnet(img1, flow[:, 2:4])
120
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
121
+ res = tmp[:, :3] * 2 - 1
122
+ merged[2] = torch.clamp(merged[2] + res, 0, 1)
123
+ return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill
util/rife/IFNet_2R.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .refine_2R import *
2
+
3
+
4
+ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
5
+ return nn.Sequential(
6
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
7
+ nn.PReLU(out_planes),
8
+ )
9
+
10
+
11
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
12
+ return nn.Sequential(
13
+ nn.Conv2d(
14
+ in_planes,
15
+ out_planes,
16
+ kernel_size=kernel_size,
17
+ stride=stride,
18
+ padding=padding,
19
+ dilation=dilation,
20
+ bias=True,
21
+ ),
22
+ nn.PReLU(out_planes),
23
+ )
24
+
25
+
26
+ class IFBlock(nn.Module):
27
+ def __init__(self, in_planes, c=64):
28
+ super(IFBlock, self).__init__()
29
+ self.conv0 = nn.Sequential(
30
+ conv(in_planes, c // 2, 3, 1, 1),
31
+ conv(c // 2, c, 3, 2, 1),
32
+ )
33
+ self.convblock = nn.Sequential(
34
+ conv(c, c),
35
+ conv(c, c),
36
+ conv(c, c),
37
+ conv(c, c),
38
+ conv(c, c),
39
+ conv(c, c),
40
+ conv(c, c),
41
+ conv(c, c),
42
+ )
43
+ self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
44
+
45
+ def forward(self, x, flow, scale):
46
+ if scale != 1:
47
+ x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
48
+ if flow != None:
49
+ flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
50
+ x = torch.cat((x, flow), 1)
51
+ x = self.conv0(x)
52
+ x = self.convblock(x) + x
53
+ tmp = self.lastconv(x)
54
+ tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
55
+ flow = tmp[:, :4] * scale
56
+ mask = tmp[:, 4:5]
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(6, c=240)
64
+ self.block1 = IFBlock(13 + 4, c=150)
65
+ self.block2 = IFBlock(13 + 4, c=90)
66
+ self.block_tea = IFBlock(16 + 4, c=90)
67
+ self.contextnet = Contextnet()
68
+ self.unet = Unet()
69
+
70
+ def forward(self, x, scale=[4, 2, 1], timestep=0.5):
71
+ img0 = x[:, :3]
72
+ img1 = x[:, 3:6]
73
+ gt = x[:, 6:] # In inference time, gt is None
74
+ flow_list = []
75
+ merged = []
76
+ mask_list = []
77
+ warped_img0 = img0
78
+ warped_img1 = img1
79
+ flow = None
80
+ loss_distill = 0
81
+ stu = [self.block0, self.block1, self.block2]
82
+ for i in range(3):
83
+ if flow != None:
84
+ flow_d, mask_d = stu[i](
85
+ torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
86
+ )
87
+ flow = flow + flow_d
88
+ mask = mask + mask_d
89
+ else:
90
+ flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i])
91
+ mask_list.append(torch.sigmoid(mask))
92
+ flow_list.append(flow)
93
+ warped_img0 = warp(img0, flow[:, :2])
94
+ warped_img1 = warp(img1, flow[:, 2:4])
95
+ merged_student = (warped_img0, warped_img1)
96
+ merged.append(merged_student)
97
+ if gt.shape[1] == 3:
98
+ flow_d, mask_d = self.block_tea(
99
+ torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
100
+ )
101
+ flow_teacher = flow + flow_d
102
+ warped_img0_teacher = warp(img0, flow_teacher[:, :2])
103
+ warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
104
+ mask_teacher = torch.sigmoid(mask + mask_d)
105
+ merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
106
+ else:
107
+ flow_teacher = None
108
+ merged_teacher = None
109
+ for i in range(3):
110
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
111
+ if gt.shape[1] == 3:
112
+ loss_mask = (
113
+ ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
114
+ .float()
115
+ .detach()
116
+ )
117
+ loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
118
+ c0 = self.contextnet(img0, flow[:, :2])
119
+ c1 = self.contextnet(img1, flow[:, 2:4])
120
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
121
+ res = tmp[:, :3] * 2 - 1
122
+ merged[2] = torch.clamp(merged[2] + res, 0, 1)
123
+ return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill
util/rife/IFNet_HDv3.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .warplayer import warp
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+
9
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
10
+ return nn.Sequential(
11
+ nn.Conv2d(
12
+ in_planes,
13
+ out_planes,
14
+ kernel_size=kernel_size,
15
+ stride=stride,
16
+ padding=padding,
17
+ dilation=dilation,
18
+ bias=True,
19
+ ),
20
+ nn.PReLU(out_planes),
21
+ )
22
+
23
+
24
+ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
25
+ return nn.Sequential(
26
+ nn.Conv2d(
27
+ in_planes,
28
+ out_planes,
29
+ kernel_size=kernel_size,
30
+ stride=stride,
31
+ padding=padding,
32
+ dilation=dilation,
33
+ bias=False,
34
+ ),
35
+ nn.BatchNorm2d(out_planes),
36
+ nn.PReLU(out_planes),
37
+ )
38
+
39
+
40
+ class IFBlock(nn.Module):
41
+ def __init__(self, in_planes, c=64):
42
+ super(IFBlock, self).__init__()
43
+ self.conv0 = nn.Sequential(
44
+ conv(in_planes, c // 2, 3, 2, 1),
45
+ conv(c // 2, c, 3, 2, 1),
46
+ )
47
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
48
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
49
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
50
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
51
+ self.conv1 = nn.Sequential(
52
+ nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
53
+ nn.PReLU(c // 2),
54
+ nn.ConvTranspose2d(c // 2, 4, 4, 2, 1),
55
+ )
56
+ self.conv2 = nn.Sequential(
57
+ nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
58
+ nn.PReLU(c // 2),
59
+ nn.ConvTranspose2d(c // 2, 1, 4, 2, 1),
60
+ )
61
+
62
+ def forward(self, x, flow, scale=1):
63
+ x = F.interpolate(
64
+ x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
65
+ )
66
+ flow = (
67
+ F.interpolate(
68
+ flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
69
+ )
70
+ * 1.0
71
+ / scale
72
+ )
73
+ feat = self.conv0(torch.cat((x, flow), 1))
74
+ feat = self.convblock0(feat) + feat
75
+ feat = self.convblock1(feat) + feat
76
+ feat = self.convblock2(feat) + feat
77
+ feat = self.convblock3(feat) + feat
78
+ flow = self.conv1(feat)
79
+ mask = self.conv2(feat)
80
+ flow = (
81
+ F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
82
+ * scale
83
+ )
84
+ mask = F.interpolate(
85
+ mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
86
+ )
87
+ return flow, mask
88
+
89
+
90
+ class IFNet(nn.Module):
91
+ def __init__(self):
92
+ super(IFNet, self).__init__()
93
+ self.block0 = IFBlock(7 + 4, c=90)
94
+ self.block1 = IFBlock(7 + 4, c=90)
95
+ self.block2 = IFBlock(7 + 4, c=90)
96
+ self.block_tea = IFBlock(10 + 4, c=90)
97
+ # self.contextnet = Contextnet()
98
+ # self.unet = Unet()
99
+
100
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
101
+ if training == False:
102
+ channel = x.shape[1] // 2
103
+ img0 = x[:, :channel]
104
+ img1 = x[:, channel:]
105
+ flow_list = []
106
+ merged = []
107
+ mask_list = []
108
+ warped_img0 = img0
109
+ warped_img1 = img1
110
+ flow = (x[:, :4]).detach() * 0
111
+ mask = (x[:, :1]).detach() * 0
112
+ loss_cons = 0
113
+ block = [self.block0, self.block1, self.block2]
114
+ for i in range(3):
115
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
116
+ f1, m1 = block[i](
117
+ torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
118
+ torch.cat((flow[:, 2:4], flow[:, :2]), 1),
119
+ scale=scale_list[i],
120
+ )
121
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
122
+ mask = mask + (m0 + (-m1)) / 2
123
+ mask_list.append(mask)
124
+ flow_list.append(flow)
125
+ warped_img0 = warp(img0, flow[:, :2])
126
+ warped_img1 = warp(img1, flow[:, 2:4])
127
+ merged.append((warped_img0, warped_img1))
128
+ """
129
+ c0 = self.contextnet(img0, flow[:, :2])
130
+ c1 = self.contextnet(img1, flow[:, 2:4])
131
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
132
+ res = tmp[:, 1:4] * 2 - 1
133
+ """
134
+ for i in range(3):
135
+ mask_list[i] = torch.sigmoid(mask_list[i])
136
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
137
+ # merged[i] = torch.clamp(merged[i] + res, 0, 1)
138
+ return flow_list, mask_list[2], merged
util/rife/IFNet_m.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .refine import *
2
+
3
+
4
+ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
5
+ return nn.Sequential(
6
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
7
+ nn.PReLU(out_planes),
8
+ )
9
+
10
+
11
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
12
+ return nn.Sequential(
13
+ nn.Conv2d(
14
+ in_planes,
15
+ out_planes,
16
+ kernel_size=kernel_size,
17
+ stride=stride,
18
+ padding=padding,
19
+ dilation=dilation,
20
+ bias=True,
21
+ ),
22
+ nn.PReLU(out_planes),
23
+ )
24
+
25
+
26
+ class IFBlock(nn.Module):
27
+ def __init__(self, in_planes, c=64):
28
+ super(IFBlock, self).__init__()
29
+ self.conv0 = nn.Sequential(
30
+ conv(in_planes, c // 2, 3, 2, 1),
31
+ conv(c // 2, c, 3, 2, 1),
32
+ )
33
+ self.convblock = nn.Sequential(
34
+ conv(c, c),
35
+ conv(c, c),
36
+ conv(c, c),
37
+ conv(c, c),
38
+ conv(c, c),
39
+ conv(c, c),
40
+ conv(c, c),
41
+ conv(c, c),
42
+ )
43
+ self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
44
+
45
+ def forward(self, x, flow, scale):
46
+ if scale != 1:
47
+ x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
48
+ if flow != None:
49
+ flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
50
+ x = torch.cat((x, flow), 1)
51
+ x = self.conv0(x)
52
+ x = self.convblock(x) + x
53
+ tmp = self.lastconv(x)
54
+ tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear", align_corners=False)
55
+ flow = tmp[:, :4] * scale * 2
56
+ mask = tmp[:, 4:5]
57
+ return flow, mask
58
+
59
+
60
+ class IFNet_m(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet_m, self).__init__()
63
+ self.block0 = IFBlock(6 + 1, c=240)
64
+ self.block1 = IFBlock(13 + 4 + 1, c=150)
65
+ self.block2 = IFBlock(13 + 4 + 1, c=90)
66
+ self.block_tea = IFBlock(16 + 4 + 1, c=90)
67
+ self.contextnet = Contextnet()
68
+ self.unet = Unet()
69
+
70
+ def forward(self, x, scale=[4, 2, 1], timestep=0.5, returnflow=False):
71
+ timestep = (x[:, :1].clone() * 0 + 1) * timestep
72
+ img0 = x[:, :3]
73
+ img1 = x[:, 3:6]
74
+ gt = x[:, 6:] # In inference time, gt is None
75
+ flow_list = []
76
+ merged = []
77
+ mask_list = []
78
+ warped_img0 = img0
79
+ warped_img1 = img1
80
+ flow = None
81
+ loss_distill = 0
82
+ stu = [self.block0, self.block1, self.block2]
83
+ for i in range(3):
84
+ if flow != None:
85
+ flow_d, mask_d = stu[i](
86
+ torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
87
+ )
88
+ flow = flow + flow_d
89
+ mask = mask + mask_d
90
+ else:
91
+ flow, mask = stu[i](torch.cat((img0, img1, timestep), 1), None, scale=scale[i])
92
+ mask_list.append(torch.sigmoid(mask))
93
+ flow_list.append(flow)
94
+ warped_img0 = warp(img0, flow[:, :2])
95
+ warped_img1 = warp(img1, flow[:, 2:4])
96
+ merged_student = (warped_img0, warped_img1)
97
+ merged.append(merged_student)
98
+ if gt.shape[1] == 3:
99
+ flow_d, mask_d = self.block_tea(
100
+ torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
101
+ )
102
+ flow_teacher = flow + flow_d
103
+ warped_img0_teacher = warp(img0, flow_teacher[:, :2])
104
+ warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
105
+ mask_teacher = torch.sigmoid(mask + mask_d)
106
+ merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
107
+ else:
108
+ flow_teacher = None
109
+ merged_teacher = None
110
+ for i in range(3):
111
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
112
+ if gt.shape[1] == 3:
113
+ loss_mask = (
114
+ ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
115
+ .float()
116
+ .detach()
117
+ )
118
+ loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
119
+ if returnflow:
120
+ return flow
121
+ else:
122
+ c0 = self.contextnet(img0, flow[:, :2])
123
+ c1 = self.contextnet(img1, flow[:, 2:4])
124
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
125
+ res = tmp[:, :3] * 2 - 1
126
+ merged[2] = torch.clamp(merged[2] + res, 0, 1)
127
+ return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill
util/rife/RIFE.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim import AdamW
2
+ from torch.nn.parallel import DistributedDataParallel as DDP
3
+ from .IFNet import *
4
+ from .IFNet_m import *
5
+ from .loss import *
6
+ from .laplacian import *
7
+ from .refine import *
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+
12
+ class Model:
13
+ def __init__(self, local_rank=-1, arbitrary=False):
14
+ if arbitrary == True:
15
+ self.flownet = IFNet_m()
16
+ else:
17
+ self.flownet = IFNet()
18
+ self.device()
19
+ self.optimG = AdamW(
20
+ self.flownet.parameters(), lr=1e-6, weight_decay=1e-3
21
+ ) # use large weight decay may avoid NaN loss
22
+ self.epe = EPE()
23
+ self.lap = LapLoss()
24
+ self.sobel = SOBEL()
25
+ if local_rank != -1:
26
+ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
27
+
28
+ def train(self):
29
+ self.flownet.train()
30
+
31
+ def eval(self):
32
+ self.flownet.eval()
33
+
34
+ def device(self):
35
+ self.flownet.to(device)
36
+
37
+ def load_model(self, path, rank=0):
38
+ def convert(param):
39
+ return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
40
+
41
+ if rank <= 0:
42
+ self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
43
+
44
+ def save_model(self, path, rank=0):
45
+ if rank == 0:
46
+ torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path))
47
+
48
+ def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5):
49
+ for i in range(3):
50
+ scale_list[i] = scale_list[i] * 1.0 / scale
51
+ imgs = torch.cat((img0, img1), 1)
52
+ flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(
53
+ imgs, scale_list, timestep=timestep
54
+ )
55
+ if TTA == False:
56
+ return merged[2]
57
+ else:
58
+ flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(
59
+ imgs.flip(2).flip(3), scale_list, timestep=timestep
60
+ )
61
+ return (merged[2] + merged2[2].flip(2).flip(3)) / 2
62
+
63
+ def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
64
+ for param_group in self.optimG.param_groups:
65
+ param_group["lr"] = learning_rate
66
+ img0 = imgs[:, :3]
67
+ img1 = imgs[:, 3:]
68
+ if training:
69
+ self.train()
70
+ else:
71
+ self.eval()
72
+ flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(
73
+ torch.cat((imgs, gt), 1), scale=[4, 2, 1]
74
+ )
75
+ loss_l1 = (self.lap(merged[2], gt)).mean()
76
+ loss_tea = (self.lap(merged_teacher, gt)).mean()
77
+ if training:
78
+ self.optimG.zero_grad()
79
+ loss_G = (
80
+ loss_l1 + loss_tea + loss_distill * 0.01
81
+ ) # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002
82
+ loss_G.backward()
83
+ self.optimG.step()
84
+ else:
85
+ flow_teacher = flow[2]
86
+ return merged[2], {
87
+ "merged_tea": merged_teacher,
88
+ "mask": mask,
89
+ "mask_tea": mask,
90
+ "flow": flow[2][:, :2],
91
+ "flow_tea": flow_teacher,
92
+ "loss_l1": loss_l1,
93
+ "loss_tea": loss_tea,
94
+ "loss_distill": loss_distill,
95
+ }
util/rife/RIFE_HDv3.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import AdamW
5
+ import torch.optim as optim
6
+ import itertools
7
+ from .warplayer import warp
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from .IFNet_HDv3 import *
10
+ import torch.nn.functional as F
11
+ from .loss import *
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ class Model:
17
+ def __init__(self, local_rank=-1):
18
+ self.flownet = IFNet()
19
+ self.device()
20
+ self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
21
+ self.epe = EPE()
22
+ # self.vgg = VGGPerceptualLoss().to(device)
23
+ self.sobel = SOBEL()
24
+ if local_rank != -1:
25
+ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
26
+
27
+ def train(self):
28
+ self.flownet.train()
29
+
30
+ def eval(self):
31
+ self.flownet.eval()
32
+
33
+ def device(self):
34
+ self.flownet.to(device)
35
+
36
+ def load_model(self, path, rank=0):
37
+ def convert(param):
38
+ if rank == -1:
39
+ return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
40
+ else:
41
+ return param
42
+
43
+ if rank <= 0:
44
+ if torch.cuda.is_available():
45
+ self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
46
+ else:
47
+ self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
48
+
49
+ def save_model(self, path, rank=0):
50
+ if rank == 0:
51
+ torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path))
52
+
53
+ def inference(self, img0, img1, scale=1.0):
54
+ imgs = torch.cat((img0, img1), 1)
55
+ scale_list = [4 / scale, 2 / scale, 1 / scale]
56
+ flow, mask, merged = self.flownet(imgs, scale_list)
57
+ return merged[2]
58
+
59
+ def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
60
+ for param_group in self.optimG.param_groups:
61
+ param_group["lr"] = learning_rate
62
+ img0 = imgs[:, :3]
63
+ img1 = imgs[:, 3:]
64
+ if training:
65
+ self.train()
66
+ else:
67
+ self.eval()
68
+ scale = [4, 2, 1]
69
+ flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
70
+ loss_l1 = (merged[2] - gt).abs().mean()
71
+ loss_smooth = self.sobel(flow[2], flow[2] * 0).mean()
72
+ # loss_vgg = self.vgg(merged[2], gt)
73
+ if training:
74
+ self.optimG.zero_grad()
75
+ loss_G = loss_cons + loss_smooth * 0.1
76
+ loss_G.backward()
77
+ self.optimG.step()
78
+ else:
79
+ flow_teacher = flow[2]
80
+ return merged[2], {
81
+ "mask": mask,
82
+ "flow": flow[2][:, :2],
83
+ "loss_l1": loss_l1,
84
+ "loss_cons": loss_cons,
85
+ "loss_smooth": loss_smooth,
86
+ }
util/rife/__init__.py ADDED
File without changes