BestWishYsh commited on
Commit
dc8d70e
·
1 Parent(s): e27c7fb
app.py CHANGED
@@ -1,23 +1,23 @@
1
- import os
2
  import math
3
- import time
4
- import spaces
5
  import random
6
  import threading
7
- import gradio as gr
8
- from moviepy import VideoFileClip
9
  from datetime import datetime, timedelta
10
- from huggingface_hub import hf_hub_download, snapshot_download
11
 
 
 
12
  import torch
 
 
 
 
 
 
 
13
  from diffusers.image_processor import VaeImageProcessor
14
  from diffusers.training_utils import free_memory
15
 
16
- from util.utils import *
17
- from util.rife_model import load_rife_model, rife_inference_with_latents
18
- from models.utils import process_face_embeddings_infer, prepare_face_models
19
- from models.pipeline_consisid import ConsisIDPipeline
20
-
21
 
22
  # 0. Pre config
23
  model_path = "ckpts"
@@ -28,13 +28,13 @@ dtype = torch.bfloat16
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  if not os.path.exists(model_path) or not os.path.exists(f"{model_path}/model_real_esran") or not os.path.exists(f"{model_path}/model_rife"):
31
- print(f"Model not found, downloading from Hugging Face...")
32
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir=f"{model_path}/model_real_esran")
33
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir=f"{model_path}/model_rife")
34
  snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=f"{model_path}")
35
  else:
36
  print(f"Model already exists in {model_path}, skipping download.")
37
-
38
 
39
  # 1. Prepare all the face models
40
  face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(model_path, device, dtype)
@@ -79,18 +79,15 @@ def generate(
79
  seed = random.randint(0, 2**8 - 1)
80
 
81
  # 4. Prepare model input
82
- id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2,
83
- eva_transform_mean, eva_transform_std,
84
- face_main_model, device, dtype,
85
  image_input, is_align_face=True)
86
 
87
- is_kps = getattr(pipe.transformer.config, 'is_kps', False)
88
- kps_cond = face_kps if is_kps else None
89
-
90
  prompt = prompt.strip('"')
91
  if negative_prompt:
92
  negative_prompt = negative_prompt.strip('"')
93
-
94
  # 5. Generate Identity-Preserving Video
95
  generator = torch.Generator(device).manual_seed(seed) if seed else None
96
  video_pt = pipe(
@@ -105,12 +102,12 @@ def generate(
105
  generator=generator,
106
  id_vit_hidden=id_vit_hidden,
107
  id_cond=id_cond,
108
- kps_cond=kps_cond,
109
  output_type="pt",
110
  ).frames
111
-
112
  free_memory()
113
-
114
  if scale_status:
115
  video_pt = upscale_batch_and_concatenate(upscale_model, video_pt, device)
116
  if rife_status:
@@ -302,7 +299,7 @@ with gr.Blocks() as demo:
302
  seed=seed_value,
303
  scale_status=scale_status,
304
  rife_status=rife_status,
305
- )
306
 
307
  video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
308
  video_update = gr.update(visible=True, value=video_path)
@@ -311,14 +308,14 @@ with gr.Blocks() as demo:
311
  seed_update = gr.update(visible=True, value=seed)
312
 
313
  return video_path, video_update, gif_update, seed_update
314
-
315
  generate_button.click(
316
  fn=run,
317
  inputs=[prompt, negative_prompt, image_input, seed_param, enable_scale, enable_rife],
318
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
319
  )
320
-
321
 
322
  if __name__ == "__main__":
323
  demo.queue(max_size=15)
324
- demo.launch()
 
 
1
  import math
2
+ import os
 
3
  import random
4
  import threading
5
+ import time
 
6
  from datetime import datetime, timedelta
 
7
 
8
+ import gradio as gr
9
+ import spaces
10
  import torch
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
+ from models.consisid_utils import prepare_face_models, process_face_embeddings_infer
13
+ from models.pipeline_consisid import ConsisIDPipeline
14
+ from moviepy import VideoFileClip
15
+ from util.rife_model import load_rife_model, rife_inference_with_latents
16
+ from util.utils import load_sd_upscale, save_video, upscale_batch_and_concatenate
17
+
18
  from diffusers.image_processor import VaeImageProcessor
19
  from diffusers.training_utils import free_memory
20
 
 
 
 
 
 
21
 
22
  # 0. Pre config
23
  model_path = "ckpts"
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  if not os.path.exists(model_path) or not os.path.exists(f"{model_path}/model_real_esran") or not os.path.exists(f"{model_path}/model_rife"):
31
+ print("Model not found, downloading from Hugging Face...")
32
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir=f"{model_path}/model_real_esran")
33
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir=f"{model_path}/model_rife")
34
  snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=f"{model_path}")
35
  else:
36
  print(f"Model already exists in {model_path}, skipping download.")
37
+
38
 
39
  # 1. Prepare all the face models
40
  face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(model_path, device, dtype)
 
79
  seed = random.randint(0, 2**8 - 1)
80
 
81
  # 4. Prepare model input
82
+ id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2,
83
+ eva_transform_mean, eva_transform_std,
84
+ face_main_model, device, dtype,
85
  image_input, is_align_face=True)
86
 
 
 
 
87
  prompt = prompt.strip('"')
88
  if negative_prompt:
89
  negative_prompt = negative_prompt.strip('"')
90
+
91
  # 5. Generate Identity-Preserving Video
92
  generator = torch.Generator(device).manual_seed(seed) if seed else None
93
  video_pt = pipe(
 
102
  generator=generator,
103
  id_vit_hidden=id_vit_hidden,
104
  id_cond=id_cond,
105
+ kps_cond=face_kps,
106
  output_type="pt",
107
  ).frames
108
+
109
  free_memory()
110
+
111
  if scale_status:
112
  video_pt = upscale_batch_and_concatenate(upscale_model, video_pt, device)
113
  if rife_status:
 
299
  seed=seed_value,
300
  scale_status=scale_status,
301
  rife_status=rife_status,
302
+ )
303
 
304
  video_path = save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
305
  video_update = gr.update(visible=True, value=video_path)
 
308
  seed_update = gr.update(visible=True, value=seed)
309
 
310
  return video_path, video_update, gif_update, seed_update
311
+
312
  generate_button.click(
313
  fn=run,
314
  inputs=[prompt, negative_prompt, image_input, seed_param, enable_scale, enable_rife],
315
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
316
  )
317
+
318
 
319
  if __name__ == "__main__":
320
  demo.queue(max_size=15)
321
+ demo.launch()
models/{utils.py → consisid_utils.py} RENAMED
@@ -1,154 +1,47 @@
1
  import os
 
 
2
  import cv2
3
- import math
4
  import numpy as np
5
- from PIL import Image, ImageOps
6
-
7
  import torch
 
 
 
 
 
 
8
  from torchvision.transforms import InterpolationMode
9
  from torchvision.transforms.functional import normalize, resize
10
  from transformers import T5EncoderModel, T5Tokenizer
11
- from typing import List, Optional, Tuple, Union
12
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
13
  from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
14
  from diffusers.utils import load_image
15
 
16
- import insightface
17
- from insightface.app import FaceAnalysis
18
- from facexlib.parsing import init_parsing_model
19
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
20
-
21
- from models.eva_clip import create_model_and_transforms
22
- from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
23
- from models.eva_clip.utils_qformer import resize_numpy_image_long
24
-
25
- def tensor_to_pil(src_img_tensor):
26
- img = src_img_tensor.clone().detach()
27
- if img.dtype == torch.bfloat16:
28
- img = img.to(torch.float32)
29
- img = img.cpu().numpy()
30
- img = np.transpose(img, (1, 2, 0))
31
- img = img.astype(np.uint8)
32
- pil_image = Image.fromarray(img)
33
- return pil_image
34
-
35
-
36
- def _get_t5_prompt_embeds(
37
- tokenizer: T5Tokenizer,
38
- text_encoder: T5EncoderModel,
39
- prompt: Union[str, List[str]],
40
- num_videos_per_prompt: int = 1,
41
- max_sequence_length: int = 226,
42
- device: Optional[torch.device] = None,
43
- dtype: Optional[torch.dtype] = None,
44
- text_input_ids=None,
45
- ):
46
- prompt = [prompt] if isinstance(prompt, str) else prompt
47
- batch_size = len(prompt)
48
-
49
- if tokenizer is not None:
50
- text_inputs = tokenizer(
51
- prompt,
52
- padding="max_length",
53
- max_length=max_sequence_length,
54
- truncation=True,
55
- add_special_tokens=True,
56
- return_tensors="pt",
57
- )
58
- text_input_ids = text_inputs.input_ids
59
- else:
60
- if text_input_ids is None:
61
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
62
-
63
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
64
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
65
-
66
- # duplicate text embeddings for each generation per prompt, using mps friendly method
67
- _, seq_len, _ = prompt_embeds.shape
68
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
69
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
70
-
71
- return prompt_embeds
72
-
73
-
74
- def encode_prompt(
75
- tokenizer: T5Tokenizer,
76
- text_encoder: T5EncoderModel,
77
- prompt: Union[str, List[str]],
78
- num_videos_per_prompt: int = 1,
79
- max_sequence_length: int = 226,
80
- device: Optional[torch.device] = None,
81
- dtype: Optional[torch.dtype] = None,
82
- text_input_ids=None,
83
- ):
84
- prompt = [prompt] if isinstance(prompt, str) else prompt
85
- prompt_embeds = _get_t5_prompt_embeds(
86
- tokenizer,
87
- text_encoder,
88
- prompt=prompt,
89
- num_videos_per_prompt=num_videos_per_prompt,
90
- max_sequence_length=max_sequence_length,
91
- device=device,
92
- dtype=dtype,
93
- text_input_ids=text_input_ids,
94
- )
95
- return prompt_embeds
96
-
97
-
98
- def compute_prompt_embeddings(
99
- tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
100
- ):
101
- if requires_grad:
102
- prompt_embeds = encode_prompt(
103
- tokenizer,
104
- text_encoder,
105
- prompt,
106
- num_videos_per_prompt=1,
107
- max_sequence_length=max_sequence_length,
108
- device=device,
109
- dtype=dtype,
110
- )
111
- else:
112
- with torch.no_grad():
113
- prompt_embeds = encode_prompt(
114
- tokenizer,
115
- text_encoder,
116
- prompt,
117
- num_videos_per_prompt=1,
118
- max_sequence_length=max_sequence_length,
119
- device=device,
120
- dtype=dtype,
121
- )
122
- return prompt_embeds
123
 
 
 
 
 
124
 
125
- def prepare_rotary_positional_embeddings(
126
- height: int,
127
- width: int,
128
- num_frames: int,
129
- vae_scale_factor_spatial: int = 8,
130
- patch_size: int = 2,
131
- attention_head_dim: int = 64,
132
- device: Optional[torch.device] = None,
133
- base_height: int = 480,
134
- base_width: int = 720,
135
- ) -> Tuple[torch.Tensor, torch.Tensor]:
136
- grid_height = height // (vae_scale_factor_spatial * patch_size)
137
- grid_width = width // (vae_scale_factor_spatial * patch_size)
138
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
139
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
140
 
141
- grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
142
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
143
- embed_dim=attention_head_dim,
144
- crops_coords=grid_crops_coords,
145
- grid_size=(grid_height, grid_width),
146
- temporal_size=num_frames,
147
- )
148
 
149
- freqs_cos = freqs_cos.to(device=device)
150
- freqs_sin = freqs_sin.to(device=device)
151
- return freqs_cos, freqs_sin
 
 
 
 
 
152
 
153
 
154
  def img2tensor(imgs, bgr2rgb=True, float32=True):
@@ -166,8 +59,8 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
166
 
167
  def _totensor(img, bgr2rgb, float32):
168
  if img.shape[2] == 3 and bgr2rgb:
169
- if img.dtype == 'float64':
170
- img = img.astype('float32')
171
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
172
  img = torch.from_numpy(img.transpose(2, 0, 1))
173
  if float32:
@@ -180,55 +73,70 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
180
 
181
 
182
  def to_gray(img):
 
 
 
 
 
 
 
 
 
 
 
183
  x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
184
  x = x.repeat(1, 3, 1, 1)
185
  return x
186
 
187
 
188
- def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
189
- stickwidth = 4
190
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
191
- kps = np.array(kps)
192
-
193
- w, h = image_pil.size
194
- out_img = np.zeros([h, w, 3])
195
-
196
- for i in range(len(limbSeq)):
197
- index = limbSeq[i]
198
- color = color_list[index[0]]
199
-
200
- x = kps[index][:, 0]
201
- y = kps[index][:, 1]
202
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
203
- angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
204
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
205
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
206
- out_img = (out_img * 0.6).astype(np.uint8)
207
-
208
- for idx_kp, kp in enumerate(kps):
209
- color = color_list[idx_kp]
210
- x, y = kp
211
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
212
-
213
- out_img_pil = Image.fromarray(out_img.astype(np.uint8))
214
- return out_img_pil
215
-
216
-
217
- def process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True):
218
  """
 
 
 
219
  Args:
220
- image: numpy rgb image, range [0, 255]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  """
 
222
  face_helper_1.clean_all()
223
- image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
224
  # get antelopev2 embedding
225
  face_info = app.get(image_bgr)
226
  if len(face_info) > 0:
227
- face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
228
  -1
229
  ] # only use the maximum face
230
- id_ante_embedding = face_info['embedding'] # (512,)
231
- face_kps = face_info['kps']
232
  else:
233
  id_ante_embedding = None
234
  face_kps = None
@@ -240,12 +148,12 @@ def process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva
240
  face_kps = face_helper_1.all_landmarks_5[0]
241
  face_helper_1.align_warp_face()
242
  if len(face_helper_1.cropped_faces) == 0:
243
- raise RuntimeError('facexlib align face fail')
244
  align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
245
 
246
  # incase insightface didn't detect face
247
  if id_ante_embedding is None:
248
- print('fail to detect face using insightface, extract embedding on align face')
249
  id_ante_embedding = face_helper_2.get_feat(align_face)
250
 
251
  id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
@@ -271,33 +179,90 @@ def process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva
271
  return_face_features_image = return_face_features_image_2 = input
272
 
273
  # transform img before sending to eva-clip-vit
274
- face_features_image = resize(return_face_features_image, clip_vision_model.image_size,
275
- InterpolationMode.BICUBIC) # torch.Size([1, 3, 336, 336])
 
276
  face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
277
- id_cond_vit, id_vit_hidden = clip_vision_model(face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
 
 
278
  id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
279
  id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
280
 
281
- id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
282
-
283
- return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
284
-
285
-
286
- def process_face_embeddings_infer(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, img_file_path, is_align_face=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  """
 
 
 
288
  Args:
289
- image: numpy rgb image, range [0, 255]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  """
 
 
291
  if isinstance(img_file_path, str):
292
  image = np.array(load_image(image=img_file_path).convert("RGB"))
293
- else:
294
  image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
295
-
 
296
  image = resize_numpy_image_long(image, 1024)
297
  original_id_image = image
298
 
299
- id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image, is_align_face)
300
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  tensor = align_crop_face_image.cpu().detach()
302
  tensor = tensor.squeeze()
303
  tensor = tensor.permute(1, 2, 0)
@@ -307,6 +272,7 @@ def process_face_embeddings_infer(face_helper_1, clip_vision_model, face_helper_
307
 
308
  return id_cond, id_vit_hidden, image, face_kps
309
 
 
310
  def prepare_face_models(model_path, device, dtype):
311
  """
312
  Prepare all face models for the facial recognition task.
@@ -329,21 +295,29 @@ def prepare_face_models(model_path, device, dtype):
329
  upscale_factor=1,
330
  face_size=512,
331
  crop_ratio=(1, 1),
332
- det_model='retinaface_resnet50',
333
- save_ext='png',
334
  device=device,
335
- model_rootpath=os.path.join(model_path, "face_encoder")
336
  )
337
  face_helper_1.face_parse = None
338
- face_helper_1.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
339
- face_helper_2 = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
 
 
 
 
340
  face_helper_2.prepare(ctx_id=0)
341
 
342
  # get local facial extractor part 1
343
- model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
 
 
 
 
344
  face_clip_model = model.visual
345
- eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
346
- eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
347
  if not isinstance(eva_transform_mean, (list, tuple)):
348
  eva_transform_mean = (eva_transform_mean,) * 3
349
  if not isinstance(eva_transform_std, (list, tuple)):
@@ -352,9 +326,11 @@ def prepare_face_models(model_path, device, dtype):
352
  eva_transform_std = eva_transform_std
353
 
354
  # get local facial extractor part 2
355
- face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
 
 
356
  face_main_model.prepare(ctx_id=0, det_size=(640, 640))
357
-
358
  # move face models to device
359
  face_helper_1.face_det.eval()
360
  face_helper_1.face_parse.eval()
@@ -362,5 +338,230 @@ def prepare_face_models(model_path, device, dtype):
362
  face_helper_1.face_det.to(device)
363
  face_helper_1.face_parse.to(device)
364
  face_clip_model.to(device, dtype=dtype)
365
-
366
- return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import List, Optional, Tuple, Union
3
+
4
  import cv2
5
+ import insightface
6
  import numpy as np
 
 
7
  import torch
8
+ from consisid_eva_clip import create_model_and_transforms
9
+ from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
10
+ from facexlib.parsing import init_parsing_model
11
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
12
+ from insightface.app import FaceAnalysis
13
+ from PIL import Image, ImageOps
14
  from torchvision.transforms import InterpolationMode
15
  from torchvision.transforms.functional import normalize, resize
16
  from transformers import T5EncoderModel, T5Tokenizer
17
+
18
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
19
  from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
20
  from diffusers.utils import load_image
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ###### pipeline ###
24
+ def resize_numpy_image_long(image, resize_long_edge=768):
25
+ """
26
+ Resize the input image to a specified long edge while maintaining aspect ratio.
27
 
28
+ Args:
29
+ image (numpy.ndarray): Input image (H x W x C or H x W).
30
+ resize_long_edge (int): The target size for the long edge of the image. Default is 768.
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ Returns:
33
+ numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect
34
+ ratio.
35
+ """
 
 
 
36
 
37
+ h, w = image.shape[:2]
38
+ if max(h, w) <= resize_long_edge:
39
+ return image
40
+ k = resize_long_edge / max(h, w)
41
+ h = int(h * k)
42
+ w = int(w * k)
43
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
44
+ return image
45
 
46
 
47
  def img2tensor(imgs, bgr2rgb=True, float32=True):
 
59
 
60
  def _totensor(img, bgr2rgb, float32):
61
  if img.shape[2] == 3 and bgr2rgb:
62
+ if img.dtype == "float64":
63
+ img = img.astype("float32")
64
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
  img = torch.from_numpy(img.transpose(2, 0, 1))
66
  if float32:
 
73
 
74
 
75
  def to_gray(img):
76
+ """
77
+ Converts an RGB image to grayscale by applying the standard luminosity formula.
78
+
79
+ Args:
80
+ img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width).
81
+ The image is expected to be in RGB format (3 channels).
82
+
83
+ Returns:
84
+ torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width).
85
+ The grayscale values are replicated across all three channels.
86
+ """
87
  x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
88
  x = x.repeat(1, 3, 1, 1)
89
  return x
90
 
91
 
92
+ def process_face_embeddings(
93
+ face_helper_1,
94
+ clip_vision_model,
95
+ face_helper_2,
96
+ eva_transform_mean,
97
+ eva_transform_std,
98
+ app,
99
+ device,
100
+ weight_dtype,
101
+ image,
102
+ original_id_image=None,
103
+ is_align_face=True,
104
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """
106
+ Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed
107
+ face features using a series of face detection and alignment tools.
108
+
109
  Args:
110
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
111
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
112
+ face_helper_2: Face helper object (second helper) for embedding extraction.
113
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
114
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
115
+ app: Application instance used for face detection.
116
+ device: Device (CPU or GPU) where the computations will be performed.
117
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
118
+ image: Input image in RGB format with pixel values in the range [0, 255].
119
+ original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False.
120
+ is_align_face: Boolean flag indicating whether face alignment should be performed.
121
+
122
+ Returns:
123
+ Tuple:
124
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding
125
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
126
+ - return_face_features_image_2: Processed face features image after normalization and parsing.
127
+ - face_kps: Keypoints of the face detected in the image.
128
  """
129
+
130
  face_helper_1.clean_all()
131
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
132
  # get antelopev2 embedding
133
  face_info = app.get(image_bgr)
134
  if len(face_info) > 0:
135
+ face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[
136
  -1
137
  ] # only use the maximum face
138
+ id_ante_embedding = face_info["embedding"] # (512,)
139
+ face_kps = face_info["kps"]
140
  else:
141
  id_ante_embedding = None
142
  face_kps = None
 
148
  face_kps = face_helper_1.all_landmarks_5[0]
149
  face_helper_1.align_warp_face()
150
  if len(face_helper_1.cropped_faces) == 0:
151
+ raise RuntimeError("facexlib align face fail")
152
  align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
153
 
154
  # incase insightface didn't detect face
155
  if id_ante_embedding is None:
156
+ print("fail to detect face using insightface, extract embedding on align face")
157
  id_ante_embedding = face_helper_2.get_feat(align_face)
158
 
159
  id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
 
179
  return_face_features_image = return_face_features_image_2 = input
180
 
181
  # transform img before sending to eva-clip-vit
182
+ face_features_image = resize(
183
+ return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC
184
+ ) # torch.Size([1, 3, 336, 336])
185
  face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
186
+ id_cond_vit, id_vit_hidden = clip_vision_model(
187
+ face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
188
+ ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
189
  id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
190
  id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
191
 
192
+ id_cond = torch.cat(
193
+ [id_ante_embedding, id_cond_vit], dim=-1
194
+ ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
195
+
196
+ return (
197
+ id_cond,
198
+ id_vit_hidden,
199
+ return_face_features_image_2,
200
+ face_kps,
201
+ ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
202
+
203
+
204
+ def process_face_embeddings_infer(
205
+ face_helper_1,
206
+ clip_vision_model,
207
+ face_helper_2,
208
+ eva_transform_mean,
209
+ eva_transform_std,
210
+ app,
211
+ device,
212
+ weight_dtype,
213
+ img_file_path,
214
+ is_align_face=True,
215
+ ):
216
  """
217
+ Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding
218
+ concatenation.
219
+
220
  Args:
221
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
222
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
223
+ face_helper_2: Face helper object (second helper) for embedding extraction.
224
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
225
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
226
+ app: Application instance used for face detection.
227
+ device: Device (CPU or GPU) where the computations will be performed.
228
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
229
+ img_file_path: Path to the input image file (string) or a numpy array representing an image.
230
+ is_align_face: Boolean flag indicating whether face alignment should be performed (default: True).
231
+
232
+ Returns:
233
+ Tuple:
234
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding.
235
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
236
+ - image: Processed face image after feature extraction and alignment.
237
+ - face_kps: Keypoints of the face detected in the image.
238
  """
239
+
240
+ # Load and preprocess the input image
241
  if isinstance(img_file_path, str):
242
  image = np.array(load_image(image=img_file_path).convert("RGB"))
243
+ else:
244
  image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
245
+
246
+ # Resize image to ensure the longer side is 1024 pixels
247
  image = resize_numpy_image_long(image, 1024)
248
  original_id_image = image
249
 
250
+ # Process the image to extract face embeddings and related features
251
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(
252
+ face_helper_1,
253
+ clip_vision_model,
254
+ face_helper_2,
255
+ eva_transform_mean,
256
+ eva_transform_std,
257
+ app,
258
+ device,
259
+ weight_dtype,
260
+ image,
261
+ original_id_image,
262
+ is_align_face,
263
+ )
264
+
265
+ # Convert the aligned cropped face image (torch tensor) to a numpy array
266
  tensor = align_crop_face_image.cpu().detach()
267
  tensor = tensor.squeeze()
268
  tensor = tensor.permute(1, 2, 0)
 
272
 
273
  return id_cond, id_vit_hidden, image, face_kps
274
 
275
+
276
  def prepare_face_models(model_path, device, dtype):
277
  """
278
  Prepare all face models for the facial recognition task.
 
295
  upscale_factor=1,
296
  face_size=512,
297
  crop_ratio=(1, 1),
298
+ det_model="retinaface_resnet50",
299
+ save_ext="png",
300
  device=device,
301
+ model_rootpath=os.path.join(model_path, "face_encoder"),
302
  )
303
  face_helper_1.face_parse = None
304
+ face_helper_1.face_parse = init_parsing_model(
305
+ model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder")
306
+ )
307
+ face_helper_2 = insightface.model_zoo.get_model(
308
+ f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"]
309
+ )
310
  face_helper_2.prepare(ctx_id=0)
311
 
312
  # get local facial extractor part 1
313
+ model, _, _ = create_model_and_transforms(
314
+ "EVA02-CLIP-L-14-336",
315
+ os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"),
316
+ force_custom_clip=True,
317
+ )
318
  face_clip_model = model.visual
319
+ eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN)
320
+ eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD)
321
  if not isinstance(eva_transform_mean, (list, tuple)):
322
  eva_transform_mean = (eva_transform_mean,) * 3
323
  if not isinstance(eva_transform_std, (list, tuple)):
 
326
  eva_transform_std = eva_transform_std
327
 
328
  # get local facial extractor part 2
329
+ face_main_model = FaceAnalysis(
330
+ name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"]
331
+ )
332
  face_main_model.prepare(ctx_id=0, det_size=(640, 640))
333
+
334
  # move face models to device
335
  face_helper_1.face_det.eval()
336
  face_helper_1.face_parse.eval()
 
338
  face_helper_1.face_det.to(device)
339
  face_helper_1.face_parse.to(device)
340
  face_clip_model.to(device, dtype=dtype)
341
+
342
+ return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
343
+
344
+
345
+
346
+ ###### train ###
347
+ def _get_t5_prompt_embeds(
348
+ tokenizer: T5Tokenizer,
349
+ text_encoder: T5EncoderModel,
350
+ prompt: Union[str, List[str]],
351
+ num_videos_per_prompt: int = 1,
352
+ max_sequence_length: int = 226,
353
+ device: Optional[torch.device] = None,
354
+ dtype: Optional[torch.dtype] = None,
355
+ text_input_ids=None,
356
+ ):
357
+ """
358
+ Generate prompt embeddings using the T5 model for a given prompt or list of prompts.
359
+
360
+ Args:
361
+ tokenizer (T5Tokenizer): Tokenizer used to encode the text prompt(s).
362
+ text_encoder (T5EncoderModel): Pretrained T5 encoder model to generate embeddings.
363
+ prompt (Union[str, List[str]]): Single prompt or list of prompts to encode.
364
+ num_videos_per_prompt (int, optional): Number of video embeddings to generate per prompt. Defaults to 1.
365
+ max_sequence_length (int, optional): Maximum length for the tokenized prompt. Defaults to 226.
366
+ device (Optional[torch.device], optional): The device on which to run the model (e.g., "cuda", "cpu").
367
+ dtype (Optional[torch.dtype], optional): The data type for the embeddings (e.g., torch.float32).
368
+ text_input_ids (optional): Pre-tokenized input IDs. If not provided, tokenizer is used to encode the prompt.
369
+
370
+ Returns:
371
+ torch.Tensor: The generated prompt embeddings reshaped for the specified number of video generations per prompt.
372
+ """
373
+
374
+ prompt = [prompt] if isinstance(prompt, str) else prompt
375
+ batch_size = len(prompt)
376
+
377
+ if tokenizer is not None:
378
+ text_inputs = tokenizer(
379
+ prompt,
380
+ padding="max_length",
381
+ max_length=max_sequence_length,
382
+ truncation=True,
383
+ add_special_tokens=True,
384
+ return_tensors="pt",
385
+ )
386
+ text_input_ids = text_inputs.input_ids
387
+ else:
388
+ if text_input_ids is None:
389
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
390
+
391
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
392
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
393
+
394
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
395
+ _, seq_len, _ = prompt_embeds.shape
396
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
397
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
398
+
399
+ return prompt_embeds
400
+
401
+
402
+ def encode_prompt(
403
+ tokenizer: T5Tokenizer,
404
+ text_encoder: T5EncoderModel,
405
+ prompt: Union[str, List[str]],
406
+ num_videos_per_prompt: int = 1,
407
+ max_sequence_length: int = 226,
408
+ device: Optional[torch.device] = None,
409
+ dtype: Optional[torch.dtype] = None,
410
+ text_input_ids=None,
411
+ ):
412
+ """
413
+ Encode the given prompt(s) into embeddings using the T5 model.
414
+
415
+ This function wraps the _get_t5_prompt_embeds function to generate prompt embeddings
416
+ for a given prompt or list of prompts. It allows for generating multiple embeddings
417
+ per prompt, useful for tasks like video generation.
418
+
419
+ Args:
420
+ tokenizer (T5Tokenizer): Tokenizer used to encode the text prompt(s).
421
+ text_encoder (T5EncoderModel): Pretrained T5 encoder model to generate embeddings.
422
+ prompt (Union[str, List[str]]): Single prompt or list of prompts to encode.
423
+ num_videos_per_prompt (int, optional): Number of video embeddings to generate per prompt. Defaults to 1.
424
+ max_sequence_length (int, optional): Maximum length for the tokenized prompt. Defaults to 226.
425
+ device (Optional[torch.device], optional): The device on which to run the model (e.g., "cuda", "cpu").
426
+ dtype (Optional[torch.dtype], optional): The data type for the embeddings (e.g., torch.float32).
427
+ text_input_ids (optional): Pre-tokenized input IDs. If not provided, tokenizer is used to encode the prompt.
428
+
429
+ Returns:
430
+ torch.Tensor: The generated prompt embeddings reshaped for the specified number of video generations per prompt.
431
+ """
432
+
433
+ prompt = [prompt] if isinstance(prompt, str) else prompt
434
+ prompt_embeds = _get_t5_prompt_embeds(
435
+ tokenizer,
436
+ text_encoder,
437
+ prompt=prompt,
438
+ num_videos_per_prompt=num_videos_per_prompt,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ dtype=dtype,
442
+ text_input_ids=text_input_ids,
443
+ )
444
+ return prompt_embeds
445
+
446
+
447
+ def compute_prompt_embeddings(
448
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
449
+ ):
450
+ """
451
+ Compute the prompt embeddings based on whether gradient computation is required.
452
+
453
+ This function generates embeddings for a given prompt or list of prompts, either
454
+ with or without gradient tracking, depending on the `requires_grad` argument. It
455
+ uses the `encode_prompt` function to generate embeddings for the provided prompt(s).
456
+
457
+ Args:
458
+ tokenizer (T5Tokenizer): Tokenizer used to encode the text prompt(s).
459
+ text_encoder (T5EncoderModel): Pretrained T5 encoder model to generate embeddings.
460
+ prompt (Union[str, List[str]]): Single prompt or list of prompts to encode.
461
+ max_sequence_length (int): Maximum length for the tokenized prompt.
462
+ device (torch.device): The device on which to run the model (e.g., "cuda", "cpu").
463
+ dtype (torch.dtype): The data type for the embeddings (e.g., torch.float32).
464
+ requires_grad (bool, optional): Whether the embeddings should require gradient computation. Defaults to False.
465
+
466
+ Returns:
467
+ torch.Tensor: The generated prompt embeddings.
468
+ """
469
+
470
+ if requires_grad:
471
+ prompt_embeds = encode_prompt(
472
+ tokenizer,
473
+ text_encoder,
474
+ prompt,
475
+ num_videos_per_prompt=1,
476
+ max_sequence_length=max_sequence_length,
477
+ device=device,
478
+ dtype=dtype,
479
+ )
480
+ else:
481
+ with torch.no_grad():
482
+ prompt_embeds = encode_prompt(
483
+ tokenizer,
484
+ text_encoder,
485
+ prompt,
486
+ num_videos_per_prompt=1,
487
+ max_sequence_length=max_sequence_length,
488
+ device=device,
489
+ dtype=dtype,
490
+ )
491
+ return prompt_embeds
492
+
493
+
494
+ def prepare_rotary_positional_embeddings(
495
+ height: int,
496
+ width: int,
497
+ num_frames: int,
498
+ vae_scale_factor_spatial: int = 8,
499
+ patch_size: int = 2,
500
+ attention_head_dim: int = 64,
501
+ device: Optional[torch.device] = None,
502
+ base_height: int = 480,
503
+ base_width: int = 720,
504
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
505
+ """
506
+ Prepare rotary positional embeddings for a given input grid size and number of frames.
507
+
508
+ This function computes the rotary positional embeddings for both spatial and temporal dimensions
509
+ given the grid size (height, width) and the number of frames. It also takes into account the scaling
510
+ factors for the spatial resolution, as well as the patch size for the input.
511
+
512
+ Args:
513
+ height (int): Height of the input grid.
514
+ width (int): Width of the input grid.
515
+ num_frames (int): Number of frames in the temporal dimension.
516
+ vae_scale_factor_spatial (int, optional): Scaling factor for the spatial resolution. Defaults to 8.
517
+ patch_size (int, optional): The patch size used for the grid. Defaults to 2.
518
+ attention_head_dim (int, optional): The dimensionality of the attention head. Defaults to 64.
519
+ device (Optional[torch.device], optional): The device to which the tensors should be moved (e.g., "cuda", "cpu").
520
+ base_height (int, optional): Base height for the image, typically the full resolution height. Defaults to 480.
521
+ base_width (int, optional): Base width for the image, typically the full resolution width. Defaults to 720.
522
+
523
+ Returns:
524
+ Tuple[torch.Tensor, torch.Tensor]: Cosine and sine components of the rotary positional embeddings.
525
+ """
526
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
527
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
528
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
529
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
530
+
531
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
532
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
533
+ embed_dim=attention_head_dim,
534
+ crops_coords=grid_crops_coords,
535
+ grid_size=(grid_height, grid_width),
536
+ temporal_size=num_frames,
537
+ )
538
+
539
+ freqs_cos = freqs_cos.to(device=device)
540
+ freqs_sin = freqs_sin.to(device=device)
541
+ return freqs_cos, freqs_sin
542
+
543
+
544
+ def tensor_to_pil(src_img_tensor):
545
+ """
546
+ Converts a tensor image to a PIL image.
547
+
548
+ This function takes an input tensor with the shape (C, H, W) and converts it
549
+ into a PIL Image format. It ensures that the tensor is in the correct data
550
+ type and moves it to CPU if necessary.
551
+
552
+ Parameters:
553
+ src_img_tensor (torch.Tensor): Input image tensor with shape (C, H, W),
554
+ where C is the number of channels, H is the height, and W is the width.
555
+
556
+ Returns:
557
+ PIL.Image: The converted image in PIL format.
558
+ """
559
+
560
+ img = src_img_tensor.clone().detach()
561
+ if img.dtype == torch.bfloat16:
562
+ img = img.to(torch.float32)
563
+ img = img.cpu().numpy()
564
+ img = np.transpose(img, (1, 2, 0))
565
+ img = img.astype(np.uint8)
566
+ pil_image = Image.fromarray(img)
567
+ return pil_image
models/eva_clip/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
- from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3
- from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
- from .loss import ClipLoss
5
- from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
- convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
- from .openai import load_openai_model, list_openai_models
8
- from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
- get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
- from .tokenizer import SimpleTokenizer, tokenize
11
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
models/eva_clip/constants.py DELETED
@@ -1,2 +0,0 @@
1
- OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
- OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
 
 
 
models/eva_clip/eva_vit_model.py DELETED
@@ -1,548 +0,0 @@
1
- # --------------------------------------------------------
2
- # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
- # --------------------------------------------------------
4
- import math
5
- import os
6
- from functools import partial
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- try:
11
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
- except:
13
- from timm.layers import drop_path, to_2tuple, trunc_normal_
14
-
15
- from .transformer import PatchDropout
16
- from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
-
18
- if os.getenv('ENV_TYPE') == 'deepspeed':
19
- try:
20
- from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
- except:
22
- from torch.utils.checkpoint import checkpoint
23
- else:
24
- from torch.utils.checkpoint import checkpoint
25
-
26
- try:
27
- import xformers
28
- import xformers.ops as xops
29
- XFORMERS_IS_AVAILBLE = True
30
- except:
31
- XFORMERS_IS_AVAILBLE = False
32
-
33
- class DropPath(nn.Module):
34
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
- """
36
- def __init__(self, drop_prob=None):
37
- super(DropPath, self).__init__()
38
- self.drop_prob = drop_prob
39
-
40
- def forward(self, x):
41
- return drop_path(x, self.drop_prob, self.training)
42
-
43
- def extra_repr(self) -> str:
44
- return 'p={}'.format(self.drop_prob)
45
-
46
-
47
- class Mlp(nn.Module):
48
- def __init__(
49
- self,
50
- in_features,
51
- hidden_features=None,
52
- out_features=None,
53
- act_layer=nn.GELU,
54
- norm_layer=nn.LayerNorm,
55
- drop=0.,
56
- subln=False,
57
-
58
- ):
59
- super().__init__()
60
- out_features = out_features or in_features
61
- hidden_features = hidden_features or in_features
62
- self.fc1 = nn.Linear(in_features, hidden_features)
63
- self.act = act_layer()
64
-
65
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
66
-
67
- self.fc2 = nn.Linear(hidden_features, out_features)
68
- self.drop = nn.Dropout(drop)
69
-
70
- def forward(self, x):
71
- x = self.fc1(x)
72
- x = self.act(x)
73
- # x = self.drop(x)
74
- # commit this for the orignal BERT implement
75
- x = self.ffn_ln(x)
76
-
77
- x = self.fc2(x)
78
- x = self.drop(x)
79
- return x
80
-
81
- class SwiGLU(nn.Module):
82
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
- norm_layer=nn.LayerNorm, subln=False):
84
- super().__init__()
85
- out_features = out_features or in_features
86
- hidden_features = hidden_features or in_features
87
-
88
- self.w1 = nn.Linear(in_features, hidden_features)
89
- self.w2 = nn.Linear(in_features, hidden_features)
90
-
91
- self.act = act_layer()
92
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
93
- self.w3 = nn.Linear(hidden_features, out_features)
94
-
95
- self.drop = nn.Dropout(drop)
96
-
97
- def forward(self, x):
98
- x1 = self.w1(x)
99
- x2 = self.w2(x)
100
- hidden = self.act(x1) * x2
101
- x = self.ffn_ln(hidden)
102
- x = self.w3(x)
103
- x = self.drop(x)
104
- return x
105
-
106
- class Attention(nn.Module):
107
- def __init__(
108
- self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
109
- proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
110
- super().__init__()
111
- self.num_heads = num_heads
112
- head_dim = dim // num_heads
113
- if attn_head_dim is not None:
114
- head_dim = attn_head_dim
115
- all_head_dim = head_dim * self.num_heads
116
- self.scale = qk_scale or head_dim ** -0.5
117
-
118
- self.subln = subln
119
- if self.subln:
120
- self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
121
- self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
122
- self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
123
- else:
124
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
125
-
126
- if qkv_bias:
127
- self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
- self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
- else:
130
- self.q_bias = None
131
- self.v_bias = None
132
-
133
- if window_size:
134
- self.window_size = window_size
135
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
136
- self.relative_position_bias_table = nn.Parameter(
137
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
138
- # cls to token & token 2 cls & cls to cls
139
-
140
- # get pair-wise relative position index for each token inside the window
141
- coords_h = torch.arange(window_size[0])
142
- coords_w = torch.arange(window_size[1])
143
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
144
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
145
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
146
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
147
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
148
- relative_coords[:, :, 1] += window_size[1] - 1
149
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
150
- relative_position_index = \
151
- torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
152
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
- relative_position_index[0, 0:] = self.num_relative_distance - 3
154
- relative_position_index[0:, 0] = self.num_relative_distance - 2
155
- relative_position_index[0, 0] = self.num_relative_distance - 1
156
-
157
- self.register_buffer("relative_position_index", relative_position_index)
158
- else:
159
- self.window_size = None
160
- self.relative_position_bias_table = None
161
- self.relative_position_index = None
162
-
163
- self.attn_drop = nn.Dropout(attn_drop)
164
- self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
165
- # self.proj = nn.Linear(all_head_dim, all_head_dim)
166
- self.proj = nn.Linear(all_head_dim, dim)
167
- self.proj_drop = nn.Dropout(proj_drop)
168
- self.xattn = xattn
169
- self.xattn_drop = attn_drop
170
-
171
- self.rope = rope
172
-
173
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
174
- B, N, C = x.shape
175
- if self.subln:
176
- q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
177
- k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
178
- v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
179
-
180
- q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
181
- k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
182
- v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183
- else:
184
-
185
- qkv_bias = None
186
- if self.q_bias is not None:
187
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
188
-
189
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
191
- q, k, v = qkv[0], qkv[1], qkv[2]
192
-
193
- if self.rope:
194
- # slightly fast impl
195
- q_t = q[:, :, 1:, :]
196
- ro_q_t = self.rope(q_t)
197
- q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
198
-
199
- k_t = k[:, :, 1:, :]
200
- ro_k_t = self.rope(k_t)
201
- k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
202
-
203
- if self.xattn:
204
- q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
205
- k = k.permute(0, 2, 1, 3)
206
- v = v.permute(0, 2, 1, 3)
207
-
208
- x = xops.memory_efficient_attention(
209
- q, k, v,
210
- p=self.xattn_drop,
211
- scale=self.scale,
212
- )
213
- x = x.reshape(B, N, -1)
214
- x = self.inner_attn_ln(x)
215
- x = self.proj(x)
216
- x = self.proj_drop(x)
217
- else:
218
- q = q * self.scale
219
- attn = (q @ k.transpose(-2, -1))
220
-
221
- if self.relative_position_bias_table is not None:
222
- relative_position_bias = \
223
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
224
- self.window_size[0] * self.window_size[1] + 1,
225
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
226
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
227
- attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
228
-
229
- if rel_pos_bias is not None:
230
- attn = attn + rel_pos_bias.type_as(attn)
231
-
232
- if attn_mask is not None:
233
- attn_mask = attn_mask.bool()
234
- attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
235
-
236
- attn = attn.softmax(dim=-1)
237
- attn = self.attn_drop(attn)
238
-
239
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
240
- x = self.inner_attn_ln(x)
241
- x = self.proj(x)
242
- x = self.proj_drop(x)
243
- return x
244
-
245
-
246
- class Block(nn.Module):
247
-
248
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
249
- drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
250
- window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
251
- subln=False, naiveswiglu=False):
252
- super().__init__()
253
- self.norm1 = norm_layer(dim)
254
- self.attn = Attention(
255
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
256
- attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
257
- xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
258
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
259
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
260
- self.norm2 = norm_layer(dim)
261
- mlp_hidden_dim = int(dim * mlp_ratio)
262
-
263
- if naiveswiglu:
264
- self.mlp = SwiGLU(
265
- in_features=dim,
266
- hidden_features=mlp_hidden_dim,
267
- subln=subln,
268
- norm_layer=norm_layer,
269
- )
270
- else:
271
- self.mlp = Mlp(
272
- in_features=dim,
273
- hidden_features=mlp_hidden_dim,
274
- act_layer=act_layer,
275
- subln=subln,
276
- drop=drop
277
- )
278
-
279
- if init_values is not None and init_values > 0:
280
- self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
281
- self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
282
- else:
283
- self.gamma_1, self.gamma_2 = None, None
284
-
285
- self.postnorm = postnorm
286
-
287
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
288
- if self.gamma_1 is None:
289
- if self.postnorm:
290
- x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
291
- x = x + self.drop_path(self.norm2(self.mlp(x)))
292
- else:
293
- x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
294
- x = x + self.drop_path(self.mlp(self.norm2(x)))
295
- else:
296
- if self.postnorm:
297
- x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
298
- x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
299
- else:
300
- x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
301
- x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
302
- return x
303
-
304
-
305
- class PatchEmbed(nn.Module):
306
- """ Image to Patch Embedding
307
- """
308
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
309
- super().__init__()
310
- img_size = to_2tuple(img_size)
311
- patch_size = to_2tuple(patch_size)
312
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
313
- self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
314
- self.img_size = img_size
315
- self.patch_size = patch_size
316
- self.num_patches = num_patches
317
-
318
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
319
-
320
- def forward(self, x, **kwargs):
321
- B, C, H, W = x.shape
322
- # FIXME look at relaxing size constraints
323
- assert H == self.img_size[0] and W == self.img_size[1], \
324
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
325
- x = self.proj(x).flatten(2).transpose(1, 2)
326
- return x
327
-
328
-
329
- class RelativePositionBias(nn.Module):
330
-
331
- def __init__(self, window_size, num_heads):
332
- super().__init__()
333
- self.window_size = window_size
334
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
335
- self.relative_position_bias_table = nn.Parameter(
336
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
337
- # cls to token & token 2 cls & cls to cls
338
-
339
- # get pair-wise relative position index for each token inside the window
340
- coords_h = torch.arange(window_size[0])
341
- coords_w = torch.arange(window_size[1])
342
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
343
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
344
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
345
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
347
- relative_coords[:, :, 1] += window_size[1] - 1
348
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
349
- relative_position_index = \
350
- torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
351
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
352
- relative_position_index[0, 0:] = self.num_relative_distance - 3
353
- relative_position_index[0:, 0] = self.num_relative_distance - 2
354
- relative_position_index[0, 0] = self.num_relative_distance - 1
355
-
356
- self.register_buffer("relative_position_index", relative_position_index)
357
-
358
- def forward(self):
359
- relative_position_bias = \
360
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
361
- self.window_size[0] * self.window_size[1] + 1,
362
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
363
- return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
364
-
365
-
366
- class EVAVisionTransformer(nn.Module):
367
- """ Vision Transformer with support for patch or hybrid CNN input stage
368
- """
369
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
370
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
371
- drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
372
- use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
373
- use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
374
- pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
375
- super().__init__()
376
-
377
- if not XFORMERS_IS_AVAILBLE:
378
- xattn = False
379
-
380
- self.image_size = img_size
381
- self.num_classes = num_classes
382
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
383
-
384
- self.patch_embed = PatchEmbed(
385
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
386
- num_patches = self.patch_embed.num_patches
387
-
388
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
389
- # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
390
- if use_abs_pos_emb:
391
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
392
- else:
393
- self.pos_embed = None
394
- self.pos_drop = nn.Dropout(p=drop_rate)
395
-
396
- if use_shared_rel_pos_bias:
397
- self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
398
- else:
399
- self.rel_pos_bias = None
400
-
401
- if rope:
402
- half_head_dim = embed_dim // num_heads // 2
403
- hw_seq_len = img_size // patch_size
404
- self.rope = VisionRotaryEmbeddingFast(
405
- dim=half_head_dim,
406
- pt_seq_len=pt_hw_seq_len,
407
- ft_seq_len=hw_seq_len if intp_freq else None,
408
- # patch_dropout=patch_dropout
409
- )
410
- else:
411
- self.rope = None
412
-
413
- self.naiveswiglu = naiveswiglu
414
-
415
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
- self.use_rel_pos_bias = use_rel_pos_bias
417
- self.blocks = nn.ModuleList([
418
- Block(
419
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
420
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
421
- init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
422
- xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
423
- for i in range(depth)])
424
- self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
425
- self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
426
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
427
-
428
- if self.pos_embed is not None:
429
- trunc_normal_(self.pos_embed, std=.02)
430
-
431
- trunc_normal_(self.cls_token, std=.02)
432
- # trunc_normal_(self.mask_token, std=.02)
433
-
434
- self.apply(self._init_weights)
435
- self.fix_init_weight()
436
-
437
- if isinstance(self.head, nn.Linear):
438
- trunc_normal_(self.head.weight, std=.02)
439
- self.head.weight.data.mul_(init_scale)
440
- self.head.bias.data.mul_(init_scale)
441
-
442
- # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
443
- self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
444
-
445
- self.grad_checkpointing = grad_checkpointing
446
-
447
- def fix_init_weight(self):
448
- def rescale(param, layer_id):
449
- param.div_(math.sqrt(2.0 * layer_id))
450
-
451
- for layer_id, layer in enumerate(self.blocks):
452
- rescale(layer.attn.proj.weight.data, layer_id + 1)
453
- if self.naiveswiglu:
454
- rescale(layer.mlp.w3.weight.data, layer_id + 1)
455
- else:
456
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
457
-
458
- def get_cast_dtype(self) -> torch.dtype:
459
- return self.blocks[0].mlp.fc2.weight.dtype
460
-
461
- def _init_weights(self, m):
462
- if isinstance(m, nn.Linear):
463
- trunc_normal_(m.weight, std=.02)
464
- if m.bias is not None:
465
- nn.init.constant_(m.bias, 0)
466
- elif isinstance(m, nn.LayerNorm):
467
- nn.init.constant_(m.bias, 0)
468
- nn.init.constant_(m.weight, 1.0)
469
-
470
- def get_num_layers(self):
471
- return len(self.blocks)
472
-
473
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
474
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
475
- for param in self.parameters():
476
- param.requires_grad = False
477
-
478
- @torch.jit.ignore
479
- def set_grad_checkpointing(self, enable=True):
480
- self.grad_checkpointing = enable
481
-
482
- @torch.jit.ignore
483
- def no_weight_decay(self):
484
- return {'pos_embed', 'cls_token'}
485
-
486
- def get_classifier(self):
487
- return self.head
488
-
489
- def reset_classifier(self, num_classes, global_pool=''):
490
- self.num_classes = num_classes
491
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
492
-
493
- def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
494
-
495
- x = self.patch_embed(x)
496
- batch_size, seq_len, _ = x.size()
497
-
498
- if shuffle:
499
- idx = torch.randperm(x.shape[1]) + 1
500
- zero = torch.LongTensor([0, ])
501
- idx = torch.cat([zero, idx])
502
- pos_embed = self.pos_embed[:, idx]
503
-
504
- cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
505
- x = torch.cat((cls_tokens, x), dim=1)
506
- if shuffle:
507
- x = x + pos_embed
508
- elif self.pos_embed is not None:
509
- x = x + self.pos_embed
510
- x = self.pos_drop(x)
511
-
512
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
513
- if os.getenv('RoPE') == '1':
514
- if self.training and not isinstance(self.patch_dropout, nn.Identity):
515
- x, patch_indices_keep = self.patch_dropout(x)
516
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
517
- else:
518
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
519
- x = self.patch_dropout(x)
520
- else:
521
- x = self.patch_dropout(x)
522
-
523
- rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
524
- hidden_states = []
525
- for idx, blk in enumerate(self.blocks):
526
- if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
527
- hidden_states.append(x)
528
- if self.grad_checkpointing:
529
- x = checkpoint(blk, x, (rel_pos_bias,))
530
- else:
531
- x = blk(x, rel_pos_bias=rel_pos_bias)
532
-
533
- if not return_all_features:
534
- x = self.norm(x)
535
- if self.fc_norm is not None:
536
- return self.fc_norm(x.mean(1)), hidden_states
537
- else:
538
- return x[:, 0], hidden_states
539
- return x
540
-
541
- def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
542
- if return_all_features:
543
- return self.forward_features(x, return_all_features, return_hidden, shuffle)
544
- x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
545
- x = self.head(x)
546
- if return_hidden:
547
- return x, hidden_states
548
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/factory.py DELETED
@@ -1,517 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
- from typing import Optional, Tuple, Union, Dict, Any
9
- import torch
10
-
11
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
- from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
- get_cast_dtype
14
- from .openai import load_openai_model
15
- from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
- from .transform import image_transform
17
- from .tokenizer import HFTokenizer, tokenize
18
- from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
-
20
-
21
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
-
24
-
25
- def _natural_key(string_):
26
- return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
-
28
-
29
- def _rescan_model_configs():
30
- global _MODEL_CONFIGS
31
-
32
- config_ext = ('.json',)
33
- config_files = []
34
- for config_path in _MODEL_CONFIG_PATHS:
35
- if config_path.is_file() and config_path.suffix in config_ext:
36
- config_files.append(config_path)
37
- elif config_path.is_dir():
38
- for ext in config_ext:
39
- config_files.extend(config_path.glob(f'*{ext}'))
40
-
41
- for cf in config_files:
42
- with open(cf, "r", encoding="utf8") as f:
43
- model_cfg = json.load(f)
44
- if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
- _MODEL_CONFIGS[cf.stem] = model_cfg
46
-
47
- _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
-
49
-
50
- _rescan_model_configs() # initial populate of model config registry
51
-
52
-
53
- def list_models():
54
- """ enumerate available model architectures based on config files """
55
- return list(_MODEL_CONFIGS.keys())
56
-
57
-
58
- def add_model_config(path):
59
- """ add model config path or file and update registry """
60
- if not isinstance(path, Path):
61
- path = Path(path)
62
- _MODEL_CONFIG_PATHS.append(path)
63
- _rescan_model_configs()
64
-
65
-
66
- def get_model_config(model_name):
67
- if model_name in _MODEL_CONFIGS:
68
- return deepcopy(_MODEL_CONFIGS[model_name])
69
- else:
70
- return None
71
-
72
-
73
- def get_tokenizer(model_name):
74
- config = get_model_config(model_name)
75
- tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
- return tokenizer
77
-
78
-
79
- # loading openai CLIP weights when is_openai=True for training
80
- def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
- if is_openai:
82
- model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
- state_dict = model.state_dict()
84
- for key in ["input_resolution", "context_length", "vocab_size"]:
85
- state_dict.pop(key, None)
86
- else:
87
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
- for mk in model_key.split('|'):
89
- if isinstance(checkpoint, dict) and mk in checkpoint:
90
- state_dict = checkpoint[mk]
91
- break
92
- else:
93
- state_dict = checkpoint
94
- if next(iter(state_dict.items()))[0].startswith('module'):
95
- state_dict = {k[7:]: v for k, v in state_dict.items()}
96
-
97
- for k in skip_list:
98
- if k in list(state_dict.keys()):
99
- logging.info(f"Removing key {k} from pretrained checkpoint")
100
- del state_dict[k]
101
-
102
- if os.getenv('RoPE') == '1':
103
- for k in list(state_dict.keys()):
104
- if 'freqs_cos' in k or 'freqs_sin' in k:
105
- del state_dict[k]
106
- return state_dict
107
-
108
-
109
-
110
- def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
- state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
- # detect old format and make compatible with new format
113
- if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
- state_dict = convert_to_custom_text_state_dict(state_dict)
115
- if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
- state_dict['logit_scale'] = state_dict['text.logit_scale']
117
- del state_dict['text.logit_scale']
118
-
119
- # resize_clip_pos_embed for CLIP and open CLIP
120
- if 'visual.positional_embedding' in state_dict:
121
- resize_clip_pos_embed(state_dict, model)
122
- # specified to eva_vit_model
123
- elif 'visual.pos_embed' in state_dict:
124
- resize_evaclip_pos_embed(state_dict, model)
125
-
126
- # resize_clip_pos_embed(state_dict, model)
127
- incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
- logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
- return incompatible_keys
130
-
131
- def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
- state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
-
134
- for k in list(state_dict.keys()):
135
- if not k.startswith('visual.'):
136
- del state_dict[k]
137
- for k in list(state_dict.keys()):
138
- if k.startswith('visual.'):
139
- new_k = k[7:]
140
- state_dict[new_k] = state_dict[k]
141
- del state_dict[k]
142
- return state_dict
143
-
144
- def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
- state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
-
147
- for k in list(state_dict.keys()):
148
- if k.startswith('visual.'):
149
- del state_dict[k]
150
- return state_dict
151
-
152
- def get_pretrained_tag(pretrained_model):
153
- pretrained_model = pretrained_model.lower()
154
- if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
- return "open_clip"
156
- elif "openai" in pretrained_model:
157
- return "clip"
158
- elif "eva" in pretrained_model and "clip" in pretrained_model:
159
- return "eva_clip"
160
- else:
161
- return "other"
162
-
163
- def load_pretrained_checkpoint(
164
- model,
165
- visual_checkpoint_path,
166
- text_checkpoint_path,
167
- strict=True,
168
- visual_model=None,
169
- text_model=None,
170
- model_key="model|module|state_dict",
171
- skip_list=[]):
172
- visual_tag = get_pretrained_tag(visual_model)
173
- text_tag = get_pretrained_tag(text_model)
174
-
175
- logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
- visual_incompatible_keys, text_incompatible_keys = None, None
177
- if visual_checkpoint_path:
178
- if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
- visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
- elif visual_tag == "clip":
181
- visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
- else:
183
- visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
-
185
- # resize_clip_pos_embed for CLIP and open CLIP
186
- if 'positional_embedding' in visual_state_dict:
187
- resize_visual_pos_embed(visual_state_dict, model)
188
- # specified to EVA model
189
- elif 'pos_embed' in visual_state_dict:
190
- resize_eva_pos_embed(visual_state_dict, model)
191
-
192
- visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
- logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
- logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
-
196
- if text_checkpoint_path:
197
- if text_tag == "eva_clip" or text_tag == "open_clip":
198
- text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
- elif text_tag == "clip":
200
- text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
- else:
202
- text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
-
204
- text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
-
206
- logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
- logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
-
209
- return visual_incompatible_keys, text_incompatible_keys
210
-
211
- def create_model(
212
- model_name: str,
213
- pretrained: Optional[str] = None,
214
- precision: str = 'fp32',
215
- device: Union[str, torch.device] = 'cpu',
216
- jit: bool = False,
217
- force_quick_gelu: bool = False,
218
- force_custom_clip: bool = False,
219
- force_patch_dropout: Optional[float] = None,
220
- pretrained_image: str = '',
221
- pretrained_text: str = '',
222
- pretrained_hf: bool = True,
223
- pretrained_visual_model: str = None,
224
- pretrained_text_model: str = None,
225
- cache_dir: Optional[str] = None,
226
- skip_list: list = [],
227
- ):
228
- model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
- if isinstance(device, str):
230
- device = torch.device(device)
231
-
232
- if pretrained and pretrained.lower() == 'openai':
233
- logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
- model = load_openai_model(
235
- model_name,
236
- precision=precision,
237
- device=device,
238
- jit=jit,
239
- cache_dir=cache_dir,
240
- )
241
- else:
242
- model_cfg = get_model_config(model_name)
243
- if model_cfg is not None:
244
- logging.info(f'Loaded {model_name} model config.')
245
- else:
246
- logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
- raise RuntimeError(f'Model config for {model_name} not found.')
248
-
249
- if 'rope' in model_cfg.get('vision_cfg', {}):
250
- if model_cfg['vision_cfg']['rope']:
251
- os.environ['RoPE'] = "1"
252
- else:
253
- os.environ['RoPE'] = "0"
254
-
255
- if force_quick_gelu:
256
- # override for use of QuickGELU on non-OpenAI transformer models
257
- model_cfg["quick_gelu"] = True
258
-
259
- if force_patch_dropout is not None:
260
- # override the default patch dropout value
261
- model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
-
263
- cast_dtype = get_cast_dtype(precision)
264
- custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
-
266
-
267
- if custom_clip:
268
- if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
- model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
- model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
- else:
272
- model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
-
274
- pretrained_cfg = {}
275
- if pretrained:
276
- checkpoint_path = ''
277
- pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
- if pretrained_cfg:
279
- checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
- elif os.path.exists(pretrained):
281
- checkpoint_path = pretrained
282
-
283
- if checkpoint_path:
284
- logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
- load_checkpoint(model,
286
- checkpoint_path,
287
- model_key="model|module|state_dict",
288
- strict=False
289
- )
290
- else:
291
- error_str = (
292
- f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
- f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
- logging.warning(error_str)
295
- raise RuntimeError(error_str)
296
- else:
297
- visual_checkpoint_path = ''
298
- text_checkpoint_path = ''
299
-
300
- if pretrained_image:
301
- pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
- pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
- if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
- # pretrained weight loading for timm models set via vision_cfg
305
- model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
- elif pretrained_image_cfg:
307
- visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
- elif os.path.exists(pretrained_image):
309
- visual_checkpoint_path = pretrained_image
310
- else:
311
- logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
- raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
-
314
- if pretrained_text:
315
- pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
- pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
- if pretrained_image_cfg:
318
- text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
- elif os.path.exists(pretrained_text):
320
- text_checkpoint_path = pretrained_text
321
- else:
322
- logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
- raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
-
325
- if visual_checkpoint_path:
326
- logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
- if text_checkpoint_path:
328
- logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
-
330
- if visual_checkpoint_path or text_checkpoint_path:
331
- load_pretrained_checkpoint(
332
- model,
333
- visual_checkpoint_path,
334
- text_checkpoint_path,
335
- strict=False,
336
- visual_model=pretrained_visual_model,
337
- text_model=pretrained_text_model,
338
- model_key="model|module|state_dict",
339
- skip_list=skip_list
340
- )
341
-
342
- if "fp16" in precision or "bf16" in precision:
343
- logging.info(f'convert precision to {precision}')
344
- model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
-
346
- model.to(device=device)
347
-
348
- # set image / mean metadata from pretrained_cfg if available, or use default
349
- model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
- model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
-
352
- if jit:
353
- model = torch.jit.script(model)
354
-
355
- return model
356
-
357
-
358
- def create_model_and_transforms(
359
- model_name: str,
360
- pretrained: Optional[str] = None,
361
- precision: str = 'fp32',
362
- device: Union[str, torch.device] = 'cpu',
363
- jit: bool = False,
364
- force_quick_gelu: bool = False,
365
- force_custom_clip: bool = False,
366
- force_patch_dropout: Optional[float] = None,
367
- pretrained_image: str = '',
368
- pretrained_text: str = '',
369
- pretrained_hf: bool = True,
370
- pretrained_visual_model: str = None,
371
- pretrained_text_model: str = None,
372
- image_mean: Optional[Tuple[float, ...]] = None,
373
- image_std: Optional[Tuple[float, ...]] = None,
374
- cache_dir: Optional[str] = None,
375
- skip_list: list = [],
376
- ):
377
- model = create_model(
378
- model_name,
379
- pretrained,
380
- precision=precision,
381
- device=device,
382
- jit=jit,
383
- force_quick_gelu=force_quick_gelu,
384
- force_custom_clip=force_custom_clip,
385
- force_patch_dropout=force_patch_dropout,
386
- pretrained_image=pretrained_image,
387
- pretrained_text=pretrained_text,
388
- pretrained_hf=pretrained_hf,
389
- pretrained_visual_model=pretrained_visual_model,
390
- pretrained_text_model=pretrained_text_model,
391
- cache_dir=cache_dir,
392
- skip_list=skip_list,
393
- )
394
-
395
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
- image_std = image_std or getattr(model.visual, 'image_std', None)
397
- preprocess_train = image_transform(
398
- model.visual.image_size,
399
- is_train=True,
400
- mean=image_mean,
401
- std=image_std
402
- )
403
- preprocess_val = image_transform(
404
- model.visual.image_size,
405
- is_train=False,
406
- mean=image_mean,
407
- std=image_std
408
- )
409
-
410
- return model, preprocess_train, preprocess_val
411
-
412
-
413
- def create_transforms(
414
- model_name: str,
415
- pretrained: Optional[str] = None,
416
- precision: str = 'fp32',
417
- device: Union[str, torch.device] = 'cpu',
418
- jit: bool = False,
419
- force_quick_gelu: bool = False,
420
- force_custom_clip: bool = False,
421
- force_patch_dropout: Optional[float] = None,
422
- pretrained_image: str = '',
423
- pretrained_text: str = '',
424
- pretrained_hf: bool = True,
425
- pretrained_visual_model: str = None,
426
- pretrained_text_model: str = None,
427
- image_mean: Optional[Tuple[float, ...]] = None,
428
- image_std: Optional[Tuple[float, ...]] = None,
429
- cache_dir: Optional[str] = None,
430
- skip_list: list = [],
431
- ):
432
- model = create_model(
433
- model_name,
434
- pretrained,
435
- precision=precision,
436
- device=device,
437
- jit=jit,
438
- force_quick_gelu=force_quick_gelu,
439
- force_custom_clip=force_custom_clip,
440
- force_patch_dropout=force_patch_dropout,
441
- pretrained_image=pretrained_image,
442
- pretrained_text=pretrained_text,
443
- pretrained_hf=pretrained_hf,
444
- pretrained_visual_model=pretrained_visual_model,
445
- pretrained_text_model=pretrained_text_model,
446
- cache_dir=cache_dir,
447
- skip_list=skip_list,
448
- )
449
-
450
-
451
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
- image_std = image_std or getattr(model.visual, 'image_std', None)
453
- preprocess_train = image_transform(
454
- model.visual.image_size,
455
- is_train=True,
456
- mean=image_mean,
457
- std=image_std
458
- )
459
- preprocess_val = image_transform(
460
- model.visual.image_size,
461
- is_train=False,
462
- mean=image_mean,
463
- std=image_std
464
- )
465
- del model
466
-
467
- return preprocess_train, preprocess_val
468
-
469
- def create_model_from_pretrained(
470
- model_name: str,
471
- pretrained: str,
472
- precision: str = 'fp32',
473
- device: Union[str, torch.device] = 'cpu',
474
- jit: bool = False,
475
- force_quick_gelu: bool = False,
476
- force_custom_clip: bool = False,
477
- force_patch_dropout: Optional[float] = None,
478
- return_transform: bool = True,
479
- image_mean: Optional[Tuple[float, ...]] = None,
480
- image_std: Optional[Tuple[float, ...]] = None,
481
- cache_dir: Optional[str] = None,
482
- is_frozen: bool = False,
483
- ):
484
- if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
- raise RuntimeError(
486
- f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
- f' Use open_clip.list_pretrained() to find one.')
488
-
489
- model = create_model(
490
- model_name,
491
- pretrained,
492
- precision=precision,
493
- device=device,
494
- jit=jit,
495
- force_quick_gelu=force_quick_gelu,
496
- force_custom_clip=force_custom_clip,
497
- force_patch_dropout=force_patch_dropout,
498
- cache_dir=cache_dir,
499
- )
500
-
501
- if is_frozen:
502
- for param in model.parameters():
503
- param.requires_grad = False
504
-
505
- if not return_transform:
506
- return model
507
-
508
- image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
- image_std = image_std or getattr(model.visual, 'image_std', None)
510
- preprocess = image_transform(
511
- model.visual.image_size,
512
- is_train=False,
513
- mean=image_mean,
514
- std=image_std
515
- )
516
-
517
- return model, preprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/hf_configs.py DELETED
@@ -1,57 +0,0 @@
1
- # HF architecture dict:
2
- arch_dict = {
3
- # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
- "roberta": {
5
- "config_names": {
6
- "context_length": "max_position_embeddings",
7
- "vocab_size": "vocab_size",
8
- "width": "hidden_size",
9
- "heads": "num_attention_heads",
10
- "layers": "num_hidden_layers",
11
- "layer_attr": "layer",
12
- "token_embeddings_attr": "embeddings"
13
- },
14
- "pooler": "mean_pooler",
15
- },
16
- # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
- "xlm-roberta": {
18
- "config_names": {
19
- "context_length": "max_position_embeddings",
20
- "vocab_size": "vocab_size",
21
- "width": "hidden_size",
22
- "heads": "num_attention_heads",
23
- "layers": "num_hidden_layers",
24
- "layer_attr": "layer",
25
- "token_embeddings_attr": "embeddings"
26
- },
27
- "pooler": "mean_pooler",
28
- },
29
- # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
- "mt5": {
31
- "config_names": {
32
- # unlimited seqlen
33
- # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
- # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
- "context_length": "",
36
- "vocab_size": "vocab_size",
37
- "width": "d_model",
38
- "heads": "num_heads",
39
- "layers": "num_layers",
40
- "layer_attr": "block",
41
- "token_embeddings_attr": "embed_tokens"
42
- },
43
- "pooler": "mean_pooler",
44
- },
45
- "bert": {
46
- "config_names": {
47
- "context_length": "max_position_embeddings",
48
- "vocab_size": "vocab_size",
49
- "width": "hidden_size",
50
- "heads": "num_attention_heads",
51
- "layers": "num_hidden_layers",
52
- "layer_attr": "layer",
53
- "token_embeddings_attr": "embeddings"
54
- },
55
- "pooler": "mean_pooler",
56
- }
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/hf_model.py DELETED
@@ -1,248 +0,0 @@
1
- """ huggingface model adapter
2
-
3
- Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
- """
5
-
6
- import re
7
-
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import functional as F
11
- from torch import TensorType
12
- try:
13
- import transformers
14
- from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
- BaseModelOutputWithPoolingAndCrossAttentions
17
- except ImportError as e:
18
- transformers = None
19
-
20
-
21
- class BaseModelOutput:
22
- pass
23
-
24
-
25
- class PretrainedConfig:
26
- pass
27
-
28
- from .hf_configs import arch_dict
29
-
30
- # utils
31
- def _camel2snake(s):
32
- return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
-
34
- # TODO: ?last - for gpt-like models
35
- _POOLERS = {}
36
-
37
- def register_pooler(cls):
38
- """Decorator registering pooler class"""
39
- _POOLERS[_camel2snake(cls.__name__)] = cls
40
- return cls
41
-
42
-
43
- @register_pooler
44
- class MeanPooler(nn.Module):
45
- """Mean pooling"""
46
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
- masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
- return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
-
50
- @register_pooler
51
- class MaxPooler(nn.Module):
52
- """Max pooling"""
53
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
- masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
- return masked_output.max(1).values
56
-
57
- @register_pooler
58
- class ClsPooler(nn.Module):
59
- """CLS token pooling"""
60
- def __init__(self, use_pooler_output=True):
61
- super().__init__()
62
- self.cls_token_position = 0
63
- self.use_pooler_output = use_pooler_output
64
-
65
- def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
-
67
- if (self.use_pooler_output and
68
- isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
- (x.pooler_output is not None)
70
- ):
71
- return x.pooler_output
72
-
73
- return x.last_hidden_state[:, self.cls_token_position, :]
74
-
75
- class HFTextEncoder(nn.Module):
76
- """HuggingFace model adapter"""
77
- def __init__(
78
- self,
79
- model_name_or_path: str,
80
- output_dim: int,
81
- tokenizer_name: str = None,
82
- config: PretrainedConfig = None,
83
- pooler_type: str = None,
84
- proj: str = None,
85
- pretrained: bool = True,
86
- masked_language_modeling: bool = False):
87
- super().__init__()
88
-
89
- self.output_dim = output_dim
90
-
91
- # TODO: find better way to get this information
92
- uses_transformer_pooler = (pooler_type == "cls_pooler")
93
-
94
- if transformers is None:
95
- raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
- if config is None:
97
- self.config = AutoConfig.from_pretrained(model_name_or_path)
98
- if masked_language_modeling:
99
- create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
- AutoModelForMaskedLM.from_config, self.config)
101
- else:
102
- create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
- AutoModel.from_config, self.config)
104
- # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
- if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
- self.transformer = create_func(model_args)
107
- self.transformer = self.transformer.encoder
108
- else:
109
- self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
- else:
111
- self.config = config
112
- if masked_language_modeling:
113
- self.transformer = AutoModelForMaskedLM.from_config(config)
114
- else:
115
- self.transformer = AutoModel.from_config(config)
116
-
117
- if pooler_type is None: # get default arch pooler
118
- self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
- else:
120
- self.pooler = _POOLERS[pooler_type]()
121
-
122
- d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
- if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
- self.proj = nn.Identity()
125
- elif proj == 'linear':
126
- self.proj = nn.Linear(d_model, output_dim, bias=False)
127
- elif proj == 'mlp':
128
- hidden_size = (d_model + output_dim) // 2
129
- self.proj = nn.Sequential(
130
- nn.Linear(d_model, hidden_size, bias=False),
131
- nn.GELU(),
132
- nn.Linear(hidden_size, output_dim, bias=False),
133
- )
134
-
135
- # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
- # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
-
139
- # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
- # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
- # attn_mask = (x != self.config.pad_token_id).long()
142
- # out = self.transformer(
143
- # input_ids=x,
144
- # attention_mask=attn_mask,
145
- # encoder_hidden_states = image_embeds,
146
- # encoder_attention_mask = image_atts,
147
- # )
148
- # pooled_out = self.pooler(out, attn_mask)
149
-
150
- # return self.itm_proj(pooled_out)
151
-
152
- def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
- if masked_indices is None:
154
- masked_indices = torch.bernoulli(probability_matrix).bool()
155
-
156
- masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
- masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
-
159
- if targets is not None:
160
- targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
-
162
- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
- indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
- input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
-
166
- # 10% of the time, we replace masked input tokens with random word
167
- indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
- random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
- input_ids[indices_random] = random_words[indices_random]
170
- # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
-
172
- if targets is not None:
173
- return input_ids, targets
174
- else:
175
- return input_ids
176
-
177
- def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
- labels = input_ids.clone()
179
- attn_mask = (input_ids != self.config.pad_token_id).long()
180
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
- vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
- probability_matrix = torch.full(labels.shape, mlm_probability)
183
- input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
- probability_matrix = probability_matrix)
185
- mlm_output = self.transformer(input_ids,
186
- attention_mask = attn_mask,
187
- encoder_hidden_states = image_embeds,
188
- encoder_attention_mask = image_atts,
189
- return_dict = True,
190
- labels = labels,
191
- )
192
- return mlm_output.loss
193
- # mlm_output = self.transformer(input_ids,
194
- # attention_mask = attn_mask,
195
- # encoder_hidden_states = image_embeds,
196
- # encoder_attention_mask = image_atts,
197
- # return_dict = True,
198
- # ).last_hidden_state
199
- # logits = self.mlm_proj(mlm_output)
200
-
201
- # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
- # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
- # labels = labels[:, 1:].contiguous().view(-1)
204
-
205
- # mlm_loss = F.cross_entropy(
206
- # logits,
207
- # labels,
208
- # # label_smoothing=0.1,
209
- # )
210
- # return mlm_loss
211
-
212
-
213
- def forward(self, x:TensorType) -> TensorType:
214
- attn_mask = (x != self.config.pad_token_id).long()
215
- out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
- pooled_out = self.pooler(out, attn_mask)
217
-
218
- return self.proj(pooled_out)
219
-
220
- def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
- if not unlocked_layers: # full freezing
222
- for n, p in self.transformer.named_parameters():
223
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
- return
225
-
226
- encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
- layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
- print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
- embeddings = getattr(
230
- self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
- modules = [embeddings, *layer_list][:-unlocked_layers]
232
- # freeze layers
233
- for module in modules:
234
- for n, p in module.named_parameters():
235
- p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
-
237
-
238
- @torch.jit.ignore
239
- def set_grad_checkpointing(self, enable=True):
240
- self.transformer.gradient_checkpointing_enable()
241
-
242
- def get_num_layers(self):
243
- encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
- layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
- return len(layer_list)
246
-
247
- def init_parameters(self):
248
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/loss.py DELETED
@@ -1,138 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
-
6
- try:
7
- import torch.distributed.nn
8
- from torch import distributed as dist
9
- has_distributed = True
10
- except ImportError:
11
- has_distributed = False
12
-
13
- try:
14
- import horovod.torch as hvd
15
- except ImportError:
16
- hvd = None
17
-
18
- from timm.loss import LabelSmoothingCrossEntropy
19
-
20
-
21
- def gather_features(
22
- image_features,
23
- text_features,
24
- local_loss=False,
25
- gather_with_grad=False,
26
- rank=0,
27
- world_size=1,
28
- use_horovod=False
29
- ):
30
- assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
- if use_horovod:
32
- assert hvd is not None, 'Please install horovod'
33
- if gather_with_grad:
34
- all_image_features = hvd.allgather(image_features)
35
- all_text_features = hvd.allgather(text_features)
36
- else:
37
- with torch.no_grad():
38
- all_image_features = hvd.allgather(image_features)
39
- all_text_features = hvd.allgather(text_features)
40
- if not local_loss:
41
- # ensure grads for local rank when all_* features don't have a gradient
42
- gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
- gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
- gathered_image_features[rank] = image_features
45
- gathered_text_features[rank] = text_features
46
- all_image_features = torch.cat(gathered_image_features, dim=0)
47
- all_text_features = torch.cat(gathered_text_features, dim=0)
48
- else:
49
- # We gather tensors from all gpus
50
- if gather_with_grad:
51
- all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
- all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
- # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
- # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
- else:
56
- gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
- gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
- dist.all_gather(gathered_image_features, image_features)
59
- dist.all_gather(gathered_text_features, text_features)
60
- if not local_loss:
61
- # ensure grads for local rank when all_* features don't have a gradient
62
- gathered_image_features[rank] = image_features
63
- gathered_text_features[rank] = text_features
64
- all_image_features = torch.cat(gathered_image_features, dim=0)
65
- all_text_features = torch.cat(gathered_text_features, dim=0)
66
-
67
- return all_image_features, all_text_features
68
-
69
-
70
- class ClipLoss(nn.Module):
71
-
72
- def __init__(
73
- self,
74
- local_loss=False,
75
- gather_with_grad=False,
76
- cache_labels=False,
77
- rank=0,
78
- world_size=1,
79
- use_horovod=False,
80
- smoothing=0.,
81
- ):
82
- super().__init__()
83
- self.local_loss = local_loss
84
- self.gather_with_grad = gather_with_grad
85
- self.cache_labels = cache_labels
86
- self.rank = rank
87
- self.world_size = world_size
88
- self.use_horovod = use_horovod
89
- self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
-
91
- # cache state
92
- self.prev_num_logits = 0
93
- self.labels = {}
94
-
95
- def forward(self, image_features, text_features, logit_scale=1.):
96
- device = image_features.device
97
- if self.world_size > 1:
98
- all_image_features, all_text_features = gather_features(
99
- image_features, text_features,
100
- self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
-
102
- if self.local_loss:
103
- logits_per_image = logit_scale * image_features @ all_text_features.T
104
- logits_per_text = logit_scale * text_features @ all_image_features.T
105
- else:
106
- logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
- logits_per_text = logits_per_image.T
108
- else:
109
- logits_per_image = logit_scale * image_features @ text_features.T
110
- logits_per_text = logit_scale * text_features @ image_features.T
111
- # calculated ground-truth and cache if enabled
112
- num_logits = logits_per_image.shape[0]
113
- if self.prev_num_logits != num_logits or device not in self.labels:
114
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
- if self.world_size > 1 and self.local_loss:
116
- labels = labels + num_logits * self.rank
117
- if self.cache_labels:
118
- self.labels[device] = labels
119
- self.prev_num_logits = num_logits
120
- else:
121
- labels = self.labels[device]
122
-
123
- if self.label_smoothing_cross_entropy:
124
- total_loss = (
125
- self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
- self.label_smoothing_cross_entropy(logits_per_text, labels)
127
- ) / 2
128
- else:
129
- total_loss = (
130
- F.cross_entropy(logits_per_image, labels) +
131
- F.cross_entropy(logits_per_text, labels)
132
- ) / 2
133
-
134
- acc = None
135
- i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
- t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
- acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
- return total_loss, acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model.py DELETED
@@ -1,439 +0,0 @@
1
- """ CLIP Model
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import os
6
- from dataclasses import dataclass
7
- from typing import Optional, Tuple, Union
8
- from functools import partial
9
-
10
- import numpy as np
11
- import torch
12
- import torch.nn.functional as F
13
- from torch import nn
14
-
15
- try:
16
- from .hf_model import HFTextEncoder
17
- except:
18
- HFTextEncoder = None
19
- from .modified_resnet import ModifiedResNet
20
- from .timm_model import TimmModel
21
- from .eva_vit_model import EVAVisionTransformer
22
- from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
-
24
- try:
25
- from apex.normalization import FusedLayerNorm
26
- except:
27
- FusedLayerNorm = LayerNorm
28
- print("Please 'pip install apex'")
29
-
30
- try:
31
- import xformers.ops as xops
32
- except ImportError:
33
- xops = None
34
- print("Please 'pip install xformers'")
35
-
36
- @dataclass
37
- class CLIPVisionCfg:
38
- layers: Union[Tuple[int, int, int, int], int] = 12
39
- width: int = 768
40
- head_width: int = 64
41
- mlp_ratio: float = 4.0
42
- patch_size: int = 16
43
- image_size: Union[Tuple[int, int], int] = 224
44
- ls_init_value: Optional[float] = None # layer scale initial value
45
- patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
- global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
- drop_path_rate: Optional[float] = None # drop path rate
48
- timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
- timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
- timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
- timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
- timm_proj_bias: bool = False # enable bias final projection
53
- eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
- qkv_bias: bool = True
55
- fusedLN: bool = False
56
- xattn: bool = False
57
- postnorm: bool = False
58
- rope: bool = False
59
- pt_hw_seq_len: int = 16 # 224/14
60
- intp_freq: bool = False
61
- naiveswiglu: bool = False
62
- subln: bool = False
63
-
64
-
65
- @dataclass
66
- class CLIPTextCfg:
67
- context_length: int = 77
68
- vocab_size: int = 49408
69
- width: int = 512
70
- heads: int = 8
71
- layers: int = 12
72
- ls_init_value: Optional[float] = None # layer scale initial value
73
- hf_model_name: str = None
74
- hf_tokenizer_name: str = None
75
- hf_model_pretrained: bool = True
76
- proj: str = 'mlp'
77
- pooler_type: str = 'mean_pooler'
78
- masked_language_modeling: bool = False
79
- fusedLN: bool = False
80
- xattn: bool = False
81
- attn_mask: bool = True
82
-
83
- def get_cast_dtype(precision: str):
84
- cast_dtype = None
85
- if precision == 'bf16':
86
- cast_dtype = torch.bfloat16
87
- elif precision == 'fp16':
88
- cast_dtype = torch.float16
89
- return cast_dtype
90
-
91
-
92
- def _build_vision_tower(
93
- embed_dim: int,
94
- vision_cfg: CLIPVisionCfg,
95
- quick_gelu: bool = False,
96
- cast_dtype: Optional[torch.dtype] = None
97
- ):
98
- if isinstance(vision_cfg, dict):
99
- vision_cfg = CLIPVisionCfg(**vision_cfg)
100
-
101
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
- # memory efficient in recent PyTorch releases (>= 1.10).
103
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
- act_layer = QuickGELU if quick_gelu else nn.GELU
105
-
106
- if vision_cfg.eva_model_name:
107
- vision_heads = vision_cfg.width // vision_cfg.head_width
108
- norm_layer = LayerNorm
109
-
110
- visual = EVAVisionTransformer(
111
- img_size=vision_cfg.image_size,
112
- patch_size=vision_cfg.patch_size,
113
- num_classes=embed_dim,
114
- use_mean_pooling=vision_cfg.global_average_pool, #False
115
- init_values=vision_cfg.ls_init_value,
116
- patch_dropout=vision_cfg.patch_dropout,
117
- embed_dim=vision_cfg.width,
118
- depth=vision_cfg.layers,
119
- num_heads=vision_heads,
120
- mlp_ratio=vision_cfg.mlp_ratio,
121
- qkv_bias=vision_cfg.qkv_bias,
122
- drop_path_rate=vision_cfg.drop_path_rate,
123
- norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
- xattn=vision_cfg.xattn,
125
- rope=vision_cfg.rope,
126
- postnorm=vision_cfg.postnorm,
127
- pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
- intp_freq= vision_cfg.intp_freq,
129
- naiveswiglu= vision_cfg.naiveswiglu,
130
- subln= vision_cfg.subln
131
- )
132
- elif vision_cfg.timm_model_name:
133
- visual = TimmModel(
134
- vision_cfg.timm_model_name,
135
- pretrained=vision_cfg.timm_model_pretrained,
136
- pool=vision_cfg.timm_pool,
137
- proj=vision_cfg.timm_proj,
138
- proj_bias=vision_cfg.timm_proj_bias,
139
- embed_dim=embed_dim,
140
- image_size=vision_cfg.image_size
141
- )
142
- act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
- elif isinstance(vision_cfg.layers, (tuple, list)):
144
- vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
- visual = ModifiedResNet(
146
- layers=vision_cfg.layers,
147
- output_dim=embed_dim,
148
- heads=vision_heads,
149
- image_size=vision_cfg.image_size,
150
- width=vision_cfg.width
151
- )
152
- else:
153
- vision_heads = vision_cfg.width // vision_cfg.head_width
154
- norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
155
- visual = VisionTransformer(
156
- image_size=vision_cfg.image_size,
157
- patch_size=vision_cfg.patch_size,
158
- width=vision_cfg.width,
159
- layers=vision_cfg.layers,
160
- heads=vision_heads,
161
- mlp_ratio=vision_cfg.mlp_ratio,
162
- ls_init_value=vision_cfg.ls_init_value,
163
- patch_dropout=vision_cfg.patch_dropout,
164
- global_average_pool=vision_cfg.global_average_pool,
165
- output_dim=embed_dim,
166
- act_layer=act_layer,
167
- norm_layer=norm_layer,
168
- )
169
-
170
- return visual
171
-
172
-
173
- def _build_text_tower(
174
- embed_dim: int,
175
- text_cfg: CLIPTextCfg,
176
- quick_gelu: bool = False,
177
- cast_dtype: Optional[torch.dtype] = None,
178
- ):
179
- if isinstance(text_cfg, dict):
180
- text_cfg = CLIPTextCfg(**text_cfg)
181
-
182
- if text_cfg.hf_model_name:
183
- text = HFTextEncoder(
184
- text_cfg.hf_model_name,
185
- output_dim=embed_dim,
186
- tokenizer_name=text_cfg.hf_tokenizer_name,
187
- proj=text_cfg.proj,
188
- pooler_type=text_cfg.pooler_type,
189
- masked_language_modeling=text_cfg.masked_language_modeling
190
- )
191
- else:
192
- act_layer = QuickGELU if quick_gelu else nn.GELU
193
- norm_layer = LayerNorm
194
-
195
- text = TextTransformer(
196
- context_length=text_cfg.context_length,
197
- vocab_size=text_cfg.vocab_size,
198
- width=text_cfg.width,
199
- heads=text_cfg.heads,
200
- layers=text_cfg.layers,
201
- ls_init_value=text_cfg.ls_init_value,
202
- output_dim=embed_dim,
203
- act_layer=act_layer,
204
- norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
205
- xattn=text_cfg.xattn,
206
- attn_mask=text_cfg.attn_mask,
207
- )
208
- return text
209
-
210
- class CLIP(nn.Module):
211
- def __init__(
212
- self,
213
- embed_dim: int,
214
- vision_cfg: CLIPVisionCfg,
215
- text_cfg: CLIPTextCfg,
216
- quick_gelu: bool = False,
217
- cast_dtype: Optional[torch.dtype] = None,
218
- ):
219
- super().__init__()
220
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
221
-
222
- text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
223
- self.transformer = text.transformer
224
- self.vocab_size = text.vocab_size
225
- self.token_embedding = text.token_embedding
226
- self.positional_embedding = text.positional_embedding
227
- self.ln_final = text.ln_final
228
- self.text_projection = text.text_projection
229
- self.register_buffer('attn_mask', text.attn_mask, persistent=False)
230
-
231
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
232
-
233
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
234
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
235
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
236
-
237
- @torch.jit.ignore
238
- def set_grad_checkpointing(self, enable=True):
239
- self.visual.set_grad_checkpointing(enable)
240
- self.transformer.grad_checkpointing = enable
241
-
242
- @torch.jit.ignore
243
- def no_weight_decay(self):
244
- return {'logit_scale'}
245
-
246
- def encode_image(self, image, normalize: bool = False):
247
- features = self.visual(image)
248
- return F.normalize(features, dim=-1) if normalize else features
249
-
250
- def encode_text(self, text, normalize: bool = False):
251
- cast_dtype = self.transformer.get_cast_dtype()
252
-
253
- x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
254
-
255
- x = x + self.positional_embedding.to(cast_dtype)
256
- x = x.permute(1, 0, 2) # NLD -> LND
257
- x = self.transformer(x, attn_mask=self.attn_mask)
258
- x = x.permute(1, 0, 2) # LND -> NLD
259
- x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
260
- # take features from the eot embedding (eot_token is the highest number in each sequence)
261
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
262
- return F.normalize(x, dim=-1) if normalize else x
263
-
264
- def forward(self, image, text):
265
- image_features = self.encode_image(image, normalize=True)
266
- text_features = self.encode_text(text, normalize=True)
267
- return image_features, text_features, self.logit_scale.exp()
268
-
269
-
270
- class CustomCLIP(nn.Module):
271
- def __init__(
272
- self,
273
- embed_dim: int,
274
- vision_cfg: CLIPVisionCfg,
275
- text_cfg: CLIPTextCfg,
276
- quick_gelu: bool = False,
277
- cast_dtype: Optional[torch.dtype] = None,
278
- itm_task: bool = False,
279
- ):
280
- super().__init__()
281
- self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
- self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
283
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
284
-
285
- def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
286
- # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
287
- self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
288
-
289
- def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
290
- self.text.lock(unlocked_layers, freeze_layer_norm)
291
-
292
- @torch.jit.ignore
293
- def set_grad_checkpointing(self, enable=True):
294
- self.visual.set_grad_checkpointing(enable)
295
- self.text.set_grad_checkpointing(enable)
296
-
297
- @torch.jit.ignore
298
- def no_weight_decay(self):
299
- return {'logit_scale'}
300
-
301
- def encode_image(self, image, normalize: bool = False):
302
- features = self.visual(image)
303
- return F.normalize(features, dim=-1) if normalize else features
304
-
305
- def encode_text(self, text, normalize: bool = False):
306
- features = self.text(text)
307
- return F.normalize(features, dim=-1) if normalize else features
308
-
309
- def forward(self, image, text):
310
- image_features = self.encode_image(image, normalize=True)
311
- text_features = self.encode_text(text, normalize=True)
312
- return image_features, text_features, self.logit_scale.exp()
313
-
314
-
315
- def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
316
- """Convert applicable model parameters to low-precision (bf16 or fp16)"""
317
-
318
- def _convert_weights(l):
319
-
320
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
321
- l.weight.data = l.weight.data.to(dtype)
322
- if l.bias is not None:
323
- l.bias.data = l.bias.data.to(dtype)
324
-
325
- if isinstance(l, (nn.MultiheadAttention, Attention)):
326
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
327
- tensor = getattr(l, attr, None)
328
- if tensor is not None:
329
- tensor.data = tensor.data.to(dtype)
330
-
331
- if isinstance(l, nn.Parameter):
332
- l.data = l.data.to(dtype)
333
-
334
- for name in ["text_projection", "proj"]:
335
- if hasattr(l, name) and isinstance(l, nn.Parameter):
336
- attr = getattr(l, name, None)
337
- if attr is not None:
338
- attr.data = attr.data.to(dtype)
339
-
340
- model.apply(_convert_weights)
341
-
342
-
343
- convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
344
-
345
-
346
- # used to maintain checkpoint compatibility
347
- def convert_to_custom_text_state_dict(state_dict: dict):
348
- if 'text_projection' in state_dict:
349
- # old format state_dict, move text tower -> .text
350
- new_state_dict = {}
351
- for k, v in state_dict.items():
352
- if any(k.startswith(p) for p in (
353
- 'text_projection',
354
- 'positional_embedding',
355
- 'token_embedding',
356
- 'transformer',
357
- 'ln_final',
358
- 'logit_scale'
359
- )):
360
- k = 'text.' + k
361
- new_state_dict[k] = v
362
- return new_state_dict
363
- return state_dict
364
-
365
-
366
- def build_model_from_openai_state_dict(
367
- state_dict: dict,
368
- quick_gelu=True,
369
- cast_dtype=torch.float16,
370
- ):
371
- vit = "visual.proj" in state_dict
372
-
373
- if vit:
374
- vision_width = state_dict["visual.conv1.weight"].shape[0]
375
- vision_layers = len(
376
- [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
- image_size = vision_patch_size * grid_size
380
- else:
381
- counts: list = [
382
- len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
- vision_layers = tuple(counts)
384
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
- vision_patch_size = None
387
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
- image_size = output_width * 32
389
-
390
- embed_dim = state_dict["text_projection"].shape[1]
391
- context_length = state_dict["positional_embedding"].shape[0]
392
- vocab_size = state_dict["token_embedding.weight"].shape[0]
393
- transformer_width = state_dict["ln_final.weight"].shape[0]
394
- transformer_heads = transformer_width // 64
395
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
-
397
- vision_cfg = CLIPVisionCfg(
398
- layers=vision_layers,
399
- width=vision_width,
400
- patch_size=vision_patch_size,
401
- image_size=image_size,
402
- )
403
- text_cfg = CLIPTextCfg(
404
- context_length=context_length,
405
- vocab_size=vocab_size,
406
- width=transformer_width,
407
- heads=transformer_heads,
408
- layers=transformer_layers
409
- )
410
- model = CLIP(
411
- embed_dim,
412
- vision_cfg=vision_cfg,
413
- text_cfg=text_cfg,
414
- quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
- cast_dtype=cast_dtype,
416
- )
417
-
418
- for key in ["input_resolution", "context_length", "vocab_size"]:
419
- state_dict.pop(key, None)
420
-
421
- convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
- model.load_state_dict(state_dict)
423
- return model.eval()
424
-
425
-
426
- def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
- model.eval()
428
- image_size = model.visual.image_size
429
- example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
- example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
- model = torch.jit.trace_module(
432
- model,
433
- inputs=dict(
434
- forward=(example_images, example_text),
435
- encode_text=(example_text,),
436
- encode_image=(example_images,)
437
- ))
438
- model.visual.image_size = image_size
439
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA01-CLIP-B-16.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 16,
8
- "eva_model_name": "eva-clip-b-16",
9
- "ls_init_value": 0.1,
10
- "drop_path_rate": 0.0
11
- },
12
- "text_cfg": {
13
- "context_length": 77,
14
- "vocab_size": 49408,
15
- "width": 512,
16
- "heads": 8,
17
- "layers": 12
18
- }
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 40,
6
- "width": 1408,
7
- "head_width": 88,
8
- "mlp_ratio": 4.3637,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-g-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "fusedLN": true
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 1024,
19
- "heads": 16,
20
- "layers": 24,
21
- "xattn": false,
22
- "fusedLN": true
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA01-CLIP-g-14.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 40,
6
- "width": 1408,
7
- "head_width": 88,
8
- "mlp_ratio": 4.3637,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-g-14-x",
11
- "drop_path_rate": 0.4,
12
- "xattn": true,
13
- "fusedLN": true
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 768,
19
- "heads": 12,
20
- "layers": 12,
21
- "xattn": false,
22
- "fusedLN": true
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA02-CLIP-B-16.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "head_width": 64,
8
- "patch_size": 16,
9
- "mlp_ratio": 2.6667,
10
- "eva_model_name": "eva-clip-b-16-X",
11
- "drop_path_rate": 0.0,
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 512,
24
- "heads": 8,
25
- "layers": 12,
26
- "xattn": true,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA02-CLIP-L-14-336.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 336,
5
- "layers": 24,
6
- "width": 1024,
7
- "drop_path_rate": 0,
8
- "head_width": 64,
9
- "mlp_ratio": 2.6667,
10
- "patch_size": 14,
11
- "eva_model_name": "eva-clip-l-14-336",
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 768,
24
- "heads": 12,
25
- "layers": 12,
26
- "xattn": false,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA02-CLIP-L-14.json DELETED
@@ -1,29 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 24,
6
- "width": 1024,
7
- "drop_path_rate": 0,
8
- "head_width": 64,
9
- "mlp_ratio": 2.6667,
10
- "patch_size": 14,
11
- "eva_model_name": "eva-clip-l-14",
12
- "xattn": true,
13
- "fusedLN": true,
14
- "rope": true,
15
- "pt_hw_seq_len": 16,
16
- "intp_freq": true,
17
- "naiveswiglu": true,
18
- "subln": true
19
- },
20
- "text_cfg": {
21
- "context_length": 77,
22
- "vocab_size": 49408,
23
- "width": 768,
24
- "heads": 12,
25
- "layers": 12,
26
- "xattn": false,
27
- "fusedLN": true
28
- }
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 64,
6
- "width": 1792,
7
- "head_width": 112,
8
- "mlp_ratio": 8.571428571428571,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-4b-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "postnorm": true,
14
- "fusedLN": true
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 1280,
20
- "heads": 20,
21
- "layers": 32,
22
- "xattn": false,
23
- "fusedLN": true
24
- }
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/model_configs/EVA02-CLIP-bigE-14.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 64,
6
- "width": 1792,
7
- "head_width": 112,
8
- "mlp_ratio": 8.571428571428571,
9
- "patch_size": 14,
10
- "eva_model_name": "eva-clip-4b-14-x",
11
- "drop_path_rate": 0,
12
- "xattn": true,
13
- "postnorm": true,
14
- "fusedLN": true
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 1024,
20
- "heads": 16,
21
- "layers": 24,
22
- "xattn": false,
23
- "fusedLN": true
24
- }
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/modified_resnet.py DELETED
@@ -1,188 +0,0 @@
1
- import os
2
- import sys
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
- from collections import OrderedDict
8
-
9
- current_file_path = os.path.abspath(__file__)
10
- project_roots = [os.path.dirname(current_file_path)]
11
- for project_root in project_roots:
12
- sys.path.insert(0, project_root) if project_root not in sys.path else None
13
-
14
- from utils import freeze_batch_norm_2d
15
-
16
-
17
- class Bottleneck(nn.Module):
18
- expansion = 4
19
-
20
- def __init__(self, inplanes, planes, stride=1):
21
- super().__init__()
22
-
23
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
24
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
25
- self.bn1 = nn.BatchNorm2d(planes)
26
- self.act1 = nn.ReLU(inplace=True)
27
-
28
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
29
- self.bn2 = nn.BatchNorm2d(planes)
30
- self.act2 = nn.ReLU(inplace=True)
31
-
32
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
33
-
34
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
35
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
36
- self.act3 = nn.ReLU(inplace=True)
37
-
38
- self.downsample = None
39
- self.stride = stride
40
-
41
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
42
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
43
- self.downsample = nn.Sequential(OrderedDict([
44
- ("-1", nn.AvgPool2d(stride)),
45
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
46
- ("1", nn.BatchNorm2d(planes * self.expansion))
47
- ]))
48
-
49
- def forward(self, x: torch.Tensor):
50
- identity = x
51
-
52
- out = self.act1(self.bn1(self.conv1(x)))
53
- out = self.act2(self.bn2(self.conv2(out)))
54
- out = self.avgpool(out)
55
- out = self.bn3(self.conv3(out))
56
-
57
- if self.downsample is not None:
58
- identity = self.downsample(x)
59
-
60
- out += identity
61
- out = self.act3(out)
62
- return out
63
-
64
-
65
- class AttentionPool2d(nn.Module):
66
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
67
- super().__init__()
68
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
69
- self.k_proj = nn.Linear(embed_dim, embed_dim)
70
- self.q_proj = nn.Linear(embed_dim, embed_dim)
71
- self.v_proj = nn.Linear(embed_dim, embed_dim)
72
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
73
- self.num_heads = num_heads
74
-
75
- def forward(self, x):
76
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
77
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
78
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
79
- x, _ = F.multi_head_attention_forward(
80
- query=x, key=x, value=x,
81
- embed_dim_to_check=x.shape[-1],
82
- num_heads=self.num_heads,
83
- q_proj_weight=self.q_proj.weight,
84
- k_proj_weight=self.k_proj.weight,
85
- v_proj_weight=self.v_proj.weight,
86
- in_proj_weight=None,
87
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
88
- bias_k=None,
89
- bias_v=None,
90
- add_zero_attn=False,
91
- dropout_p=0.,
92
- out_proj_weight=self.c_proj.weight,
93
- out_proj_bias=self.c_proj.bias,
94
- use_separate_proj_weight=True,
95
- training=self.training,
96
- need_weights=False
97
- )
98
-
99
- return x[0]
100
-
101
-
102
- class ModifiedResNet(nn.Module):
103
- """
104
- A ResNet class that is similar to torchvision's but contains the following changes:
105
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
106
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
107
- - The final pooling layer is a QKV attention instead of an average pool
108
- """
109
-
110
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
111
- super().__init__()
112
- self.output_dim = output_dim
113
- self.image_size = image_size
114
-
115
- # the 3-layer stem
116
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
117
- self.bn1 = nn.BatchNorm2d(width // 2)
118
- self.act1 = nn.ReLU(inplace=True)
119
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
120
- self.bn2 = nn.BatchNorm2d(width // 2)
121
- self.act2 = nn.ReLU(inplace=True)
122
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
123
- self.bn3 = nn.BatchNorm2d(width)
124
- self.act3 = nn.ReLU(inplace=True)
125
- self.avgpool = nn.AvgPool2d(2)
126
-
127
- # residual layers
128
- self._inplanes = width # this is a *mutable* variable used during construction
129
- self.layer1 = self._make_layer(width, layers[0])
130
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
131
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
132
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
133
-
134
- embed_dim = width * 32 # the ResNet feature dimension
135
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
136
-
137
- self.init_parameters()
138
-
139
- def _make_layer(self, planes, blocks, stride=1):
140
- layers = [Bottleneck(self._inplanes, planes, stride)]
141
-
142
- self._inplanes = planes * Bottleneck.expansion
143
- for _ in range(1, blocks):
144
- layers.append(Bottleneck(self._inplanes, planes))
145
-
146
- return nn.Sequential(*layers)
147
-
148
- def init_parameters(self):
149
- if self.attnpool is not None:
150
- std = self.attnpool.c_proj.in_features ** -0.5
151
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
152
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
153
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
154
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
155
-
156
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
157
- for name, param in resnet_block.named_parameters():
158
- if name.endswith("bn3.weight"):
159
- nn.init.zeros_(param)
160
-
161
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
162
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
163
- for param in self.parameters():
164
- param.requires_grad = False
165
- if freeze_bn_stats:
166
- freeze_batch_norm_2d(self)
167
-
168
- @torch.jit.ignore
169
- def set_grad_checkpointing(self, enable=True):
170
- # FIXME support for non-transformer
171
- pass
172
-
173
- def stem(self, x):
174
- x = self.act1(self.bn1(self.conv1(x)))
175
- x = self.act2(self.bn2(self.conv2(x)))
176
- x = self.act3(self.bn3(self.conv3(x)))
177
- x = self.avgpool(x)
178
- return x
179
-
180
- def forward(self, x):
181
- x = self.stem(x)
182
- x = self.layer1(x)
183
- x = self.layer2(x)
184
- x = self.layer3(x)
185
- x = self.layer4(x)
186
- x = self.attnpool(x)
187
-
188
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/openai.py DELETED
@@ -1,144 +0,0 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import List, Optional, Union
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
- from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
-
15
- __all__ = ["list_openai_models", "load_openai_model"]
16
-
17
-
18
- def list_openai_models() -> List[str]:
19
- """Returns the names of available CLIP models"""
20
- return list_pretrained_models_by_tag('openai')
21
-
22
-
23
- def load_openai_model(
24
- name: str,
25
- precision: Optional[str] = None,
26
- device: Optional[Union[str, torch.device]] = None,
27
- jit: bool = True,
28
- cache_dir: Optional[str] = None,
29
- ):
30
- """Load a CLIP model
31
-
32
- Parameters
33
- ----------
34
- name : str
35
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
- precision: str
37
- Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
- device : Union[str, torch.device]
39
- The device to put the loaded model
40
- jit : bool
41
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
- cache_dir : Optional[str]
43
- The directory to cache the downloaded model weights
44
-
45
- Returns
46
- -------
47
- model : torch.nn.Module
48
- The CLIP model
49
- preprocess : Callable[[PIL.Image], torch.Tensor]
50
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
- """
52
- if device is None:
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- if precision is None:
55
- precision = 'fp32' if device == 'cpu' else 'fp16'
56
-
57
- if get_pretrained_url(name, 'openai'):
58
- model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
- elif os.path.isfile(name):
60
- model_path = name
61
- else:
62
- raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
-
64
- try:
65
- # loading JIT archive
66
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
- state_dict = None
68
- except RuntimeError:
69
- # loading saved state dict
70
- if jit:
71
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
- jit = False
73
- state_dict = torch.load(model_path, map_location="cpu")
74
-
75
- if not jit:
76
- # Build a non-jit model from the OpenAI jitted model state dict
77
- cast_dtype = get_cast_dtype(precision)
78
- try:
79
- model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
- except KeyError:
81
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
- model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
-
84
- # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
- model = model.to(device)
86
- if precision.startswith('amp') or precision == 'fp32':
87
- model.float()
88
- elif precision == 'bf16':
89
- convert_weights_to_lp(model, dtype=torch.bfloat16)
90
-
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
-
97
- def patch_device(module):
98
- try:
99
- graphs = [module.graph] if hasattr(module, "graph") else []
100
- except RuntimeError:
101
- graphs = []
102
-
103
- if hasattr(module, "forward1"):
104
- graphs.append(module.forward1.graph)
105
-
106
- for graph in graphs:
107
- for node in graph.findAllNodes("prim::Constant"):
108
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
- node.copyAttributes(device_node)
110
-
111
- model.apply(patch_device)
112
- patch_device(model.encode_image)
113
- patch_device(model.encode_text)
114
-
115
- # patch dtype to float32 (typically for CPU)
116
- if precision == 'fp32':
117
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
- float_node = float_input.node()
120
-
121
- def patch_float(module):
122
- try:
123
- graphs = [module.graph] if hasattr(module, "graph") else []
124
- except RuntimeError:
125
- graphs = []
126
-
127
- if hasattr(module, "forward1"):
128
- graphs.append(module.forward1.graph)
129
-
130
- for graph in graphs:
131
- for node in graph.findAllNodes("aten::to"):
132
- inputs = list(node.inputs())
133
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
- if inputs[i].node()["value"] == 5:
135
- inputs[i].node().copyAttributes(float_node)
136
-
137
- model.apply(patch_float)
138
- patch_float(model.encode_image)
139
- patch_float(model.encode_text)
140
- model.float()
141
-
142
- # ensure image_size attr available at consistent location for both jit and non-jit
143
- model.visual.image_size = model.input_resolution.item()
144
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/pretrained.py DELETED
@@ -1,332 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from functools import partial
6
- from typing import Dict, Union
7
-
8
- from tqdm import tqdm
9
-
10
- try:
11
- from huggingface_hub import hf_hub_download
12
- _has_hf_hub = True
13
- except ImportError:
14
- hf_hub_download = None
15
- _has_hf_hub = False
16
-
17
-
18
- def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
- return dict(
20
- url=url,
21
- hf_hub=hf_hub,
22
- mean=mean,
23
- std=std,
24
- )
25
-
26
- _VITB32 = dict(
27
- openai=_pcfg(
28
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
- laion400m_e31=_pcfg(
30
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
- laion400m_e32=_pcfg(
32
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
- laion2b_e16=_pcfg(
34
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
- laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
- )
37
-
38
- _VITB32_quickgelu = dict(
39
- openai=_pcfg(
40
- "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
- laion400m_e31=_pcfg(
42
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
- laion400m_e32=_pcfg(
44
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
- )
46
-
47
- _VITB16 = dict(
48
- openai=_pcfg(
49
- "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
- laion400m_e31=_pcfg(
51
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
- laion400m_e32=_pcfg(
53
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
- )
56
-
57
- _EVAB16 = dict(
58
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
- )
63
-
64
- _VITB16_PLUS_240 = dict(
65
- laion400m_e31=_pcfg(
66
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
- laion400m_e32=_pcfg(
68
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
- )
70
-
71
- _VITL14 = dict(
72
- openai=_pcfg(
73
- "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
- laion400m_e31=_pcfg(
75
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
- laion400m_e32=_pcfg(
77
- "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
- laion2b_s32b_b82k=_pcfg(
79
- hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
- )
82
-
83
- _EVAL14 = dict(
84
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
- )
89
-
90
- _VITL14_336 = dict(
91
- openai=_pcfg(
92
- "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
- )
94
-
95
- _EVAL14_336 = dict(
96
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
- eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
- eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
- )
101
-
102
- _VITH14 = dict(
103
- laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
- )
105
-
106
- _VITg14 = dict(
107
- laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
- laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
- )
110
-
111
- _EVAg14 = dict(
112
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
- eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
- eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
- )
117
-
118
- _EVAg14_PLUS = dict(
119
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
- eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
- eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
- )
124
-
125
- _VITbigG14 = dict(
126
- laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
- )
128
-
129
- _EVAbigE14 = dict(
130
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
- )
135
-
136
- _EVAbigE14_PLUS = dict(
137
- eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
- eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
- eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
- eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
- )
142
-
143
-
144
- _PRETRAINED = {
145
- # "ViT-B-32": _VITB32,
146
- "OpenaiCLIP-B-32": _VITB32,
147
- "OpenCLIP-B-32": _VITB32,
148
-
149
- # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
- "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
- "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
-
153
- # "ViT-B-16": _VITB16,
154
- "OpenaiCLIP-B-16": _VITB16,
155
- "OpenCLIP-B-16": _VITB16,
156
-
157
- "EVA02-B-16": _EVAB16,
158
- "EVA02-CLIP-B-16": _EVAB16,
159
-
160
- # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
- "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
-
163
- # "ViT-L-14": _VITL14,
164
- "OpenaiCLIP-L-14": _VITL14,
165
- "OpenCLIP-L-14": _VITL14,
166
-
167
- "EVA02-L-14": _EVAL14,
168
- "EVA02-CLIP-L-14": _EVAL14,
169
-
170
- # "ViT-L-14-336": _VITL14_336,
171
- "OpenaiCLIP-L-14-336": _VITL14_336,
172
-
173
- "EVA02-CLIP-L-14-336": _EVAL14_336,
174
-
175
- # "ViT-H-14": _VITH14,
176
- # "ViT-g-14": _VITg14,
177
- "OpenCLIP-H-14": _VITH14,
178
- "OpenCLIP-g-14": _VITg14,
179
-
180
- "EVA01-CLIP-g-14": _EVAg14,
181
- "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
-
183
- # "ViT-bigG-14": _VITbigG14,
184
- "OpenCLIP-bigG-14": _VITbigG14,
185
-
186
- "EVA02-CLIP-bigE-14": _EVAbigE14,
187
- "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
- }
189
-
190
-
191
- def _clean_tag(tag: str):
192
- # normalize pretrained tags
193
- return tag.lower().replace('-', '_')
194
-
195
-
196
- def list_pretrained(as_str: bool = False):
197
- """ returns list of pretrained models
198
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
- """
200
- return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
-
202
-
203
- def list_pretrained_models_by_tag(tag: str):
204
- """ return all models having the specified pretrain tag """
205
- models = []
206
- tag = _clean_tag(tag)
207
- for k in _PRETRAINED.keys():
208
- if tag in _PRETRAINED[k]:
209
- models.append(k)
210
- return models
211
-
212
-
213
- def list_pretrained_tags_by_model(model: str):
214
- """ return all pretrain tags for the specified model architecture """
215
- tags = []
216
- if model in _PRETRAINED:
217
- tags.extend(_PRETRAINED[model].keys())
218
- return tags
219
-
220
-
221
- def is_pretrained_cfg(model: str, tag: str):
222
- if model not in _PRETRAINED:
223
- return False
224
- return _clean_tag(tag) in _PRETRAINED[model]
225
-
226
-
227
- def get_pretrained_cfg(model: str, tag: str):
228
- if model not in _PRETRAINED:
229
- return {}
230
- model_pretrained = _PRETRAINED[model]
231
- return model_pretrained.get(_clean_tag(tag), {})
232
-
233
-
234
- def get_pretrained_url(model: str, tag: str):
235
- cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
- return cfg.get('url', '')
237
-
238
-
239
- def download_pretrained_from_url(
240
- url: str,
241
- cache_dir: Union[str, None] = None,
242
- ):
243
- if not cache_dir:
244
- cache_dir = os.path.expanduser("~/.cache/clip")
245
- os.makedirs(cache_dir, exist_ok=True)
246
- filename = os.path.basename(url)
247
-
248
- if 'openaipublic' in url:
249
- expected_sha256 = url.split("/")[-2]
250
- elif 'mlfoundations' in url:
251
- expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
- else:
253
- expected_sha256 = ''
254
-
255
- download_target = os.path.join(cache_dir, filename)
256
-
257
- if os.path.exists(download_target) and not os.path.isfile(download_target):
258
- raise RuntimeError(f"{download_target} exists and is not a regular file")
259
-
260
- if os.path.isfile(download_target):
261
- if expected_sha256:
262
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
- return download_target
264
- else:
265
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
- else:
267
- return download_target
268
-
269
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
- with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
- while True:
272
- buffer = source.read(8192)
273
- if not buffer:
274
- break
275
-
276
- output.write(buffer)
277
- loop.update(len(buffer))
278
-
279
- if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
- raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
-
282
- return download_target
283
-
284
-
285
- def has_hf_hub(necessary=False):
286
- if not _has_hf_hub and necessary:
287
- # if no HF Hub module installed, and it is necessary to continue, raise error
288
- raise RuntimeError(
289
- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
- return _has_hf_hub
291
-
292
-
293
- def download_pretrained_from_hf(
294
- model_id: str,
295
- filename: str = 'open_clip_pytorch_model.bin',
296
- revision=None,
297
- cache_dir: Union[str, None] = None,
298
- ):
299
- has_hf_hub(True)
300
- cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
- return cached_file
302
-
303
-
304
- def download_pretrained(
305
- cfg: Dict,
306
- force_hf_hub: bool = False,
307
- cache_dir: Union[str, None] = None,
308
- ):
309
- target = ''
310
- if not cfg:
311
- return target
312
-
313
- download_url = cfg.get('url', '')
314
- download_hf_hub = cfg.get('hf_hub', '')
315
- if download_hf_hub and force_hf_hub:
316
- # use HF hub even if url exists
317
- download_url = ''
318
-
319
- if download_url:
320
- target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
- elif download_hf_hub:
322
- has_hf_hub(True)
323
- # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
- # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
- # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
- model_id, filename = os.path.split(download_hf_hub)
327
- if filename:
328
- target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
- else:
330
- target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
-
332
- return target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/rope.py DELETED
@@ -1,137 +0,0 @@
1
- from math import pi
2
- import torch
3
- from torch import nn
4
- from einops import rearrange, repeat
5
- import logging
6
-
7
- def broadcat(tensors, dim = -1):
8
- num_tensors = len(tensors)
9
- shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
- assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
- shape_len = list(shape_lens)[0]
12
- dim = (dim + shape_len) if dim < 0 else dim
13
- dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
- assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
- max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
- expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
- expanded_dims.insert(dim, (dim, dims[dim]))
19
- expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
- tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
- return torch.cat(tensors, dim = dim)
22
-
23
- def rotate_half(x):
24
- x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
- x1, x2 = x.unbind(dim = -1)
26
- x = torch.stack((-x2, x1), dim = -1)
27
- return rearrange(x, '... d r -> ... (d r)')
28
-
29
-
30
- class VisionRotaryEmbedding(nn.Module):
31
- def __init__(
32
- self,
33
- dim,
34
- pt_seq_len,
35
- ft_seq_len=None,
36
- custom_freqs = None,
37
- freqs_for = 'lang',
38
- theta = 10000,
39
- max_freq = 10,
40
- num_freqs = 1,
41
- ):
42
- super().__init__()
43
- if custom_freqs:
44
- freqs = custom_freqs
45
- elif freqs_for == 'lang':
46
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
- elif freqs_for == 'pixel':
48
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
- elif freqs_for == 'constant':
50
- freqs = torch.ones(num_freqs).float()
51
- else:
52
- raise ValueError(f'unknown modality {freqs_for}')
53
-
54
- if ft_seq_len is None: ft_seq_len = pt_seq_len
55
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
-
57
- freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
- freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
-
60
- freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
- freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
-
63
- freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
-
65
- self.register_buffer("freqs_cos", freqs.cos())
66
- self.register_buffer("freqs_sin", freqs.sin())
67
-
68
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
-
70
- def forward(self, t, start_index = 0):
71
- rot_dim = self.freqs_cos.shape[-1]
72
- end_index = start_index + rot_dim
73
- assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
- t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
- t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
-
77
- return torch.cat((t_left, t, t_right), dim = -1)
78
-
79
- class VisionRotaryEmbeddingFast(nn.Module):
80
- def __init__(
81
- self,
82
- dim,
83
- pt_seq_len,
84
- ft_seq_len=None,
85
- custom_freqs = None,
86
- freqs_for = 'lang',
87
- theta = 10000,
88
- max_freq = 10,
89
- num_freqs = 1,
90
- patch_dropout = 0.
91
- ):
92
- super().__init__()
93
- if custom_freqs:
94
- freqs = custom_freqs
95
- elif freqs_for == 'lang':
96
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
- elif freqs_for == 'pixel':
98
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
- elif freqs_for == 'constant':
100
- freqs = torch.ones(num_freqs).float()
101
- else:
102
- raise ValueError(f'unknown modality {freqs_for}')
103
-
104
- if ft_seq_len is None: ft_seq_len = pt_seq_len
105
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
-
107
- freqs = torch.einsum('..., f -> ... f', t, freqs)
108
- freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
-
111
- freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
- freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
-
114
- self.patch_dropout = patch_dropout
115
-
116
- self.register_buffer("freqs_cos", freqs_cos)
117
- self.register_buffer("freqs_sin", freqs_sin)
118
-
119
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
-
121
- def forward(self, t, patch_indices_keep=None):
122
- if patch_indices_keep is not None:
123
- batch = t.size()[0]
124
- batch_indices = torch.arange(batch)
125
- batch_indices = batch_indices[..., None]
126
-
127
- freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
- freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
-
130
- freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
- freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
- freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
- freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
-
135
- return t * freqs_cos + rotate_half(t) * freqs_sin
136
-
137
- return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/timm_model.py DELETED
@@ -1,122 +0,0 @@
1
- """ timm model adapter
2
-
3
- Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
- """
5
- import logging
6
- from collections import OrderedDict
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- try:
12
- import timm
13
- from timm.models.layers import Mlp, to_2tuple
14
- try:
15
- # old timm imports < 0.8.1
16
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
- from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
- except ImportError:
19
- # new timm imports >= 0.8.1
20
- from timm.layers import RotAttentionPool2d
21
- from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
- except ImportError:
23
- timm = None
24
-
25
- from .utils import freeze_batch_norm_2d
26
-
27
-
28
- class TimmModel(nn.Module):
29
- """ timm model adapter
30
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
- """
32
-
33
- def __init__(
34
- self,
35
- model_name,
36
- embed_dim,
37
- image_size=224,
38
- pool='avg',
39
- proj='linear',
40
- proj_bias=False,
41
- drop=0.,
42
- pretrained=False):
43
- super().__init__()
44
- if timm is None:
45
- raise RuntimeError("Please `pip install timm` to use timm models.")
46
-
47
- self.image_size = to_2tuple(image_size)
48
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
49
- feat_size = self.trunk.default_cfg.get('pool_size', None)
50
- feature_ndim = 1 if not feat_size else 2
51
- if pool in ('abs_attn', 'rot_attn'):
52
- assert feature_ndim == 2
53
- # if attn pooling used, remove both classifier and default pool
54
- self.trunk.reset_classifier(0, global_pool='')
55
- else:
56
- # reset global pool if pool config set, otherwise leave as network default
57
- reset_kwargs = dict(global_pool=pool) if pool else {}
58
- self.trunk.reset_classifier(0, **reset_kwargs)
59
- prev_chs = self.trunk.num_features
60
-
61
- head_layers = OrderedDict()
62
- if pool == 'abs_attn':
63
- head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64
- prev_chs = embed_dim
65
- elif pool == 'rot_attn':
66
- head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67
- prev_chs = embed_dim
68
- else:
69
- assert proj, 'projection layer needed if non-attention pooling is used.'
70
-
71
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72
- if proj == 'linear':
73
- head_layers['drop'] = nn.Dropout(drop)
74
- head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75
- elif proj == 'mlp':
76
- head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77
-
78
- self.head = nn.Sequential(head_layers)
79
-
80
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81
- """ lock modules
82
- Args:
83
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84
- """
85
- if not unlocked_groups:
86
- # lock full model
87
- for param in self.trunk.parameters():
88
- param.requires_grad = False
89
- if freeze_bn_stats:
90
- freeze_batch_norm_2d(self.trunk)
91
- else:
92
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93
- try:
94
- # FIXME import here until API stable and in an official release
95
- from timm.models.helpers import group_parameters, group_modules
96
- except ImportError:
97
- raise RuntimeError(
98
- 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99
- matcher = self.trunk.group_matcher()
100
- gparams = group_parameters(self.trunk, matcher)
101
- max_layer_id = max(gparams.keys())
102
- max_layer_id = max_layer_id - unlocked_groups
103
- for group_idx in range(max_layer_id + 1):
104
- group = gparams[group_idx]
105
- for param in group:
106
- self.trunk.get_parameter(param).requires_grad = False
107
- if freeze_bn_stats:
108
- gmodules = group_modules(self.trunk, matcher, reverse=True)
109
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110
- freeze_batch_norm_2d(self.trunk, gmodules)
111
-
112
- @torch.jit.ignore
113
- def set_grad_checkpointing(self, enable=True):
114
- try:
115
- self.trunk.set_grad_checkpointing(enable)
116
- except Exception as e:
117
- logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118
-
119
- def forward(self, x):
120
- x = self.trunk(x)
121
- x = self.head(x)
122
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/tokenizer.py DELETED
@@ -1,201 +0,0 @@
1
- """ CLIP tokenizer
2
-
3
- Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import gzip
6
- import html
7
- import os
8
- from functools import lru_cache
9
- from typing import Union, List
10
-
11
- import ftfy
12
- import regex as re
13
- import torch
14
-
15
- # https://stackoverflow.com/q/62691279
16
- import os
17
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
-
19
-
20
- @lru_cache()
21
- def default_bpe():
22
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
-
24
-
25
- @lru_cache()
26
- def bytes_to_unicode():
27
- """
28
- Returns list of utf-8 byte and a corresponding list of unicode strings.
29
- The reversible bpe codes work on unicode strings.
30
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
- This is a signficant percentage of your normal, say, 32K bpe vocab.
33
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
- And avoids mapping to whitespace/control characters the bpe code barfs on.
35
- """
36
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
- cs = bs[:]
38
- n = 0
39
- for b in range(2**8):
40
- if b not in bs:
41
- bs.append(b)
42
- cs.append(2**8+n)
43
- n += 1
44
- cs = [chr(n) for n in cs]
45
- return dict(zip(bs, cs))
46
-
47
-
48
- def get_pairs(word):
49
- """Return set of symbol pairs in a word.
50
- Word is represented as tuple of symbols (symbols being variable-length strings).
51
- """
52
- pairs = set()
53
- prev_char = word[0]
54
- for char in word[1:]:
55
- pairs.add((prev_char, char))
56
- prev_char = char
57
- return pairs
58
-
59
-
60
- def basic_clean(text):
61
- text = ftfy.fix_text(text)
62
- text = html.unescape(html.unescape(text))
63
- return text.strip()
64
-
65
-
66
- def whitespace_clean(text):
67
- text = re.sub(r'\s+', ' ', text)
68
- text = text.strip()
69
- return text
70
-
71
-
72
- class SimpleTokenizer(object):
73
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
- self.byte_encoder = bytes_to_unicode()
75
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
- merges = merges[1:49152-256-2+1]
78
- merges = [tuple(merge.split()) for merge in merges]
79
- vocab = list(bytes_to_unicode().values())
80
- vocab = vocab + [v+'</w>' for v in vocab]
81
- for merge in merges:
82
- vocab.append(''.join(merge))
83
- if not special_tokens:
84
- special_tokens = ['<start_of_text>', '<end_of_text>']
85
- else:
86
- special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
- vocab.extend(special_tokens)
88
- self.encoder = dict(zip(vocab, range(len(vocab))))
89
- self.decoder = {v: k for k, v in self.encoder.items()}
90
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
- self.cache = {t:t for t in special_tokens}
92
- special = "|".join(special_tokens)
93
- self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
-
95
- self.vocab_size = len(self.encoder)
96
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
-
98
- def bpe(self, token):
99
- if token in self.cache:
100
- return self.cache[token]
101
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
- pairs = get_pairs(word)
103
-
104
- if not pairs:
105
- return token+'</w>'
106
-
107
- while True:
108
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
- if bigram not in self.bpe_ranks:
110
- break
111
- first, second = bigram
112
- new_word = []
113
- i = 0
114
- while i < len(word):
115
- try:
116
- j = word.index(first, i)
117
- new_word.extend(word[i:j])
118
- i = j
119
- except:
120
- new_word.extend(word[i:])
121
- break
122
-
123
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
- new_word.append(first+second)
125
- i += 2
126
- else:
127
- new_word.append(word[i])
128
- i += 1
129
- new_word = tuple(new_word)
130
- word = new_word
131
- if len(word) == 1:
132
- break
133
- else:
134
- pairs = get_pairs(word)
135
- word = ' '.join(word)
136
- self.cache[token] = word
137
- return word
138
-
139
- def encode(self, text):
140
- bpe_tokens = []
141
- text = whitespace_clean(basic_clean(text)).lower()
142
- for token in re.findall(self.pat, text):
143
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
- return bpe_tokens
146
-
147
- def decode(self, tokens):
148
- text = ''.join([self.decoder[token] for token in tokens])
149
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
- return text
151
-
152
-
153
- _tokenizer = SimpleTokenizer()
154
-
155
-
156
- def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
- """
158
- Returns the tokenized representation of given input string(s)
159
-
160
- Parameters
161
- ----------
162
- texts : Union[str, List[str]]
163
- An input string or a list of input strings to tokenize
164
- context_length : int
165
- The context length to use; all CLIP models use 77 as the context length
166
-
167
- Returns
168
- -------
169
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
- """
171
- if isinstance(texts, str):
172
- texts = [texts]
173
-
174
- sot_token = _tokenizer.encoder["<start_of_text>"]
175
- eot_token = _tokenizer.encoder["<end_of_text>"]
176
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
-
179
- for i, tokens in enumerate(all_tokens):
180
- if len(tokens) > context_length:
181
- tokens = tokens[:context_length] # Truncate
182
- tokens[-1] = eot_token
183
- result[i, :len(tokens)] = torch.tensor(tokens)
184
-
185
- return result
186
-
187
-
188
- class HFTokenizer:
189
- "HuggingFace tokenizer wrapper"
190
- def __init__(self, tokenizer_name:str):
191
- from transformers import AutoTokenizer
192
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
-
194
- def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
- # same cleaning as for default tokenizer, except lowercasing
196
- # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
- if isinstance(texts, str):
198
- texts = [texts]
199
- texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
- input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
- return input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/transform.py DELETED
@@ -1,103 +0,0 @@
1
- from typing import Optional, Sequence, Tuple
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torchvision.transforms.functional as F
6
-
7
- from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
- CenterCrop
9
-
10
- from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
-
12
-
13
- class ResizeMaxSize(nn.Module):
14
-
15
- def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
- super().__init__()
17
- if not isinstance(max_size, int):
18
- raise TypeError(f"Size should be int. Got {type(max_size)}")
19
- self.max_size = max_size
20
- self.interpolation = interpolation
21
- self.fn = min if fn == 'min' else min
22
- self.fill = fill
23
-
24
- def forward(self, img):
25
- if isinstance(img, torch.Tensor):
26
- height, width = img.shape[:2]
27
- else:
28
- width, height = img.size
29
- scale = self.max_size / float(max(height, width))
30
- if scale != 1.0:
31
- new_size = tuple(round(dim * scale) for dim in (height, width))
32
- img = F.resize(img, new_size, self.interpolation)
33
- pad_h = self.max_size - new_size[0]
34
- pad_w = self.max_size - new_size[1]
35
- img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
- return img
37
-
38
-
39
- def _convert_to_rgb(image):
40
- return image.convert('RGB')
41
-
42
-
43
- # class CatGen(nn.Module):
44
- # def __init__(self, num=4):
45
- # self.num = num
46
- # def mixgen_batch(image, text):
47
- # batch_size = image.shape[0]
48
- # index = np.random.permutation(batch_size)
49
-
50
- # cat_images = []
51
- # for i in range(batch_size):
52
- # # image mixup
53
- # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
- # # text concat
55
- # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
- # text = torch.stack(text)
57
- # return image, text
58
-
59
-
60
- def image_transform(
61
- image_size: int,
62
- is_train: bool,
63
- mean: Optional[Tuple[float, ...]] = None,
64
- std: Optional[Tuple[float, ...]] = None,
65
- resize_longest_max: bool = False,
66
- fill_color: int = 0,
67
- ):
68
- mean = mean or OPENAI_DATASET_MEAN
69
- if not isinstance(mean, (list, tuple)):
70
- mean = (mean,) * 3
71
-
72
- std = std or OPENAI_DATASET_STD
73
- if not isinstance(std, (list, tuple)):
74
- std = (std,) * 3
75
-
76
- if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
- # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
- image_size = image_size[0]
79
-
80
- normalize = Normalize(mean=mean, std=std)
81
- if is_train:
82
- return Compose([
83
- RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
- _convert_to_rgb,
85
- ToTensor(),
86
- normalize,
87
- ])
88
- else:
89
- if resize_longest_max:
90
- transforms = [
91
- ResizeMaxSize(image_size, fill=fill_color)
92
- ]
93
- else:
94
- transforms = [
95
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
- CenterCrop(image_size),
97
- ]
98
- transforms.extend([
99
- _convert_to_rgb,
100
- ToTensor(),
101
- normalize,
102
- ])
103
- return Compose(transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/transformer.py DELETED
@@ -1,737 +0,0 @@
1
- import os
2
- import logging
3
- from collections import OrderedDict
4
- import math
5
- from typing import Callable, Optional, Sequence
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
- from torch.nn import functional as F
10
-
11
- try:
12
- from timm.models.layers import trunc_normal_
13
- except:
14
- from timm.layers import trunc_normal_
15
-
16
- from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
- from .utils import to_2tuple
18
-
19
- if os.getenv('ENV_TYPE') == 'deepspeed':
20
- try:
21
- import deepspeed
22
- from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
- except:
24
- print("Please 'pip install deepspeed'")
25
- deepspeed = None
26
- from torch.utils.checkpoint import checkpoint
27
- else:
28
- from torch.utils.checkpoint import checkpoint
29
-
30
- try:
31
- import xformers.ops as xops
32
- except ImportError:
33
- xops = None
34
- print("Please 'pip install xformers'")
35
-
36
- class LayerNormFp32(nn.LayerNorm):
37
- """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
38
- def __init__(self, *args, **kwargs):
39
- super().__init__(*args, **kwargs)
40
-
41
- def forward(self, x: torch.Tensor):
42
- output = F.layer_norm(
43
- x.float(),
44
- self.normalized_shape,
45
- self.weight.float() if self.weight is not None else None,
46
- self.bias.float() if self.bias is not None else None,
47
- self.eps,
48
- )
49
- return output.type_as(x)
50
-
51
-
52
- class LayerNorm(nn.LayerNorm):
53
- """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
-
55
- def forward(self, x: torch.Tensor):
56
- orig_type = x.dtype
57
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
- return x.to(orig_type)
59
-
60
- class QuickGELU(nn.Module):
61
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
62
- def forward(self, x: torch.Tensor):
63
- return x * torch.sigmoid(1.702 * x)
64
-
65
-
66
- class LayerScale(nn.Module):
67
- def __init__(self, dim, init_values=1e-5, inplace=False):
68
- super().__init__()
69
- self.inplace = inplace
70
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
71
-
72
- def forward(self, x):
73
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
-
75
- class PatchDropout(nn.Module):
76
- """
77
- https://arxiv.org/abs/2212.00794
78
- """
79
-
80
- def __init__(self, prob, exclude_first_token=True):
81
- super().__init__()
82
- assert 0 <= prob < 1.
83
- self.prob = prob
84
- self.exclude_first_token = exclude_first_token # exclude CLS token
85
- logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
86
-
87
- def forward(self, x):
88
- if not self.training or self.prob == 0.:
89
- return x
90
-
91
- if self.exclude_first_token:
92
- cls_tokens, x = x[:, :1], x[:, 1:]
93
- else:
94
- cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
95
-
96
- batch = x.size()[0]
97
- num_tokens = x.size()[1]
98
-
99
- batch_indices = torch.arange(batch)
100
- batch_indices = batch_indices[..., None]
101
-
102
- keep_prob = 1 - self.prob
103
- num_patches_keep = max(1, int(num_tokens * keep_prob))
104
-
105
- rand = torch.randn(batch, num_tokens)
106
- patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
107
-
108
- x = x[batch_indices, patch_indices_keep]
109
-
110
- if self.exclude_first_token:
111
- x = torch.cat((cls_tokens, x), dim=1)
112
-
113
- if self.training and os.getenv('RoPE') == '1':
114
- return x, patch_indices_keep
115
-
116
- return x
117
-
118
-
119
- def _in_projection_packed(
120
- q: torch.Tensor,
121
- k: torch.Tensor,
122
- v: torch.Tensor,
123
- w: torch.Tensor,
124
- b: Optional[torch.Tensor] = None,
125
- ):
126
- """
127
- https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
128
- """
129
- E = q.size(-1)
130
- if k is v:
131
- if q is k:
132
- # self-attention
133
- return F.linear(q, w, b).chunk(3, dim=-1)
134
- else:
135
- # encoder-decoder attention
136
- w_q, w_kv = w.split([E, E * 2])
137
- if b is None:
138
- b_q = b_kv = None
139
- else:
140
- b_q, b_kv = b.split([E, E * 2])
141
- return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
142
- else:
143
- w_q, w_k, w_v = w.chunk(3)
144
- if b is None:
145
- b_q = b_k = b_v = None
146
- else:
147
- b_q, b_k, b_v = b.chunk(3)
148
- return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
149
-
150
- class Attention(nn.Module):
151
- def __init__(
152
- self,
153
- dim,
154
- num_heads=8,
155
- qkv_bias=True,
156
- scaled_cosine=False,
157
- scale_heads=False,
158
- logit_scale_max=math.log(1. / 0.01),
159
- attn_drop=0.,
160
- proj_drop=0.,
161
- xattn=False,
162
- rope=False
163
- ):
164
- super().__init__()
165
- self.scaled_cosine = scaled_cosine
166
- self.scale_heads = scale_heads
167
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168
- self.num_heads = num_heads
169
- self.head_dim = dim // num_heads
170
- self.scale = self.head_dim ** -0.5
171
- self.logit_scale_max = logit_scale_max
172
-
173
- # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
174
- self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
175
- if qkv_bias:
176
- self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
177
- else:
178
- self.in_proj_bias = None
179
-
180
- if self.scaled_cosine:
181
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
182
- else:
183
- self.logit_scale = None
184
- self.attn_drop = nn.Dropout(attn_drop)
185
- if self.scale_heads:
186
- self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
187
- else:
188
- self.head_scale = None
189
- self.out_proj = nn.Linear(dim, dim)
190
- self.out_drop = nn.Dropout(proj_drop)
191
- self.xattn = xattn
192
- self.xattn_drop = attn_drop
193
- self.rope = rope
194
-
195
- def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
196
- L, N, C = x.shape
197
- q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
198
- if self.xattn:
199
- q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
200
- k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
201
- v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
202
-
203
- x = xops.memory_efficient_attention(
204
- q, k, v,
205
- p=self.xattn_drop,
206
- scale=self.scale if self.logit_scale is None else None,
207
- attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
208
- )
209
- else:
210
- q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
211
- k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
212
- v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
213
-
214
- if self.logit_scale is not None:
215
- attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
216
- logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
217
- attn = attn.view(N, self.num_heads, L, L) * logit_scale
218
- attn = attn.view(-1, L, L)
219
- else:
220
- q = q * self.scale
221
- attn = torch.bmm(q, k.transpose(-1, -2))
222
-
223
- if attn_mask is not None:
224
- if attn_mask.dtype == torch.bool:
225
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
226
- new_attn_mask.masked_fill_(attn_mask, float("-inf"))
227
- attn_mask = new_attn_mask
228
- attn += attn_mask
229
-
230
- attn = attn.softmax(dim=-1)
231
- attn = self.attn_drop(attn)
232
-
233
- x = torch.bmm(attn, v)
234
-
235
- if self.head_scale is not None:
236
- x = x.view(N, self.num_heads, L, C) * self.head_scale
237
- x = x.view(-1, L, C)
238
- x = x.transpose(0, 1).reshape(L, N, C)
239
- x = self.out_proj(x)
240
- x = self.out_drop(x)
241
- return x
242
-
243
- class CustomAttention(nn.Module):
244
- def __init__(
245
- self,
246
- dim,
247
- num_heads=8,
248
- qkv_bias=True,
249
- scaled_cosine=True,
250
- scale_heads=False,
251
- logit_scale_max=math.log(1. / 0.01),
252
- attn_drop=0.,
253
- proj_drop=0.,
254
- xattn=False
255
- ):
256
- super().__init__()
257
- self.scaled_cosine = scaled_cosine
258
- self.scale_heads = scale_heads
259
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
260
- self.num_heads = num_heads
261
- self.head_dim = dim // num_heads
262
- self.scale = self.head_dim ** -0.5
263
- self.logit_scale_max = logit_scale_max
264
-
265
- # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
266
- self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
267
- if qkv_bias:
268
- self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
269
- else:
270
- self.in_proj_bias = None
271
-
272
- if self.scaled_cosine:
273
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
274
- else:
275
- self.logit_scale = None
276
- self.attn_drop = nn.Dropout(attn_drop)
277
- if self.scale_heads:
278
- self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
279
- else:
280
- self.head_scale = None
281
- self.out_proj = nn.Linear(dim, dim)
282
- self.out_drop = nn.Dropout(proj_drop)
283
- self.xattn = xattn
284
- self.xattn_drop = attn_drop
285
-
286
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
287
- q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
288
- N_q, B_q, C_q = q.shape
289
- N_k, B_k, C_k = k.shape
290
- N_v, B_v, C_v = v.shape
291
- if self.xattn:
292
- # B, N, C -> B, N, num_heads, C
293
- q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
294
- k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
295
- v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
296
-
297
- x = xops.memory_efficient_attention(
298
- q, k, v,
299
- p=self.xattn_drop,
300
- scale=self.scale if self.logit_scale is None else None,
301
- attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
302
- )
303
- else:
304
- # B*H, L, C
305
- q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
306
- k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
307
- v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
308
-
309
- if self.logit_scale is not None:
310
- # B*H, N_q, N_k
311
- attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
312
- logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
313
- attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
314
- attn = attn.view(-1, N_q, N_k)
315
- else:
316
- q = q * self.scale
317
- attn = torch.bmm(q, k.transpose(-1, -2))
318
-
319
- if attn_mask is not None:
320
- if attn_mask.dtype == torch.bool:
321
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
322
- new_attn_mask.masked_fill_(attn_mask, float("-inf"))
323
- attn_mask = new_attn_mask
324
- attn += attn_mask
325
-
326
- attn = attn.softmax(dim=-1)
327
- attn = self.attn_drop(attn)
328
-
329
- x = torch.bmm(attn, v)
330
-
331
- if self.head_scale is not None:
332
- x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
333
- x = x.view(-1, N_q, C_q)
334
- x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
335
- x = self.out_proj(x)
336
- x = self.out_drop(x)
337
- return x
338
-
339
- class CustomResidualAttentionBlock(nn.Module):
340
- def __init__(
341
- self,
342
- d_model: int,
343
- n_head: int,
344
- mlp_ratio: float = 4.0,
345
- ls_init_value: float = None,
346
- act_layer: Callable = nn.GELU,
347
- norm_layer: Callable = LayerNorm,
348
- scale_cosine_attn: bool = False,
349
- scale_heads: bool = False,
350
- scale_attn: bool = False,
351
- scale_fc: bool = False,
352
- cross_attn: bool = False,
353
- xattn: bool = False,
354
- ):
355
- super().__init__()
356
-
357
- self.ln_1 = norm_layer(d_model)
358
- self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
359
- self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
360
- self.attn = CustomAttention(
361
- d_model, n_head,
362
- qkv_bias=True,
363
- attn_drop=0.,
364
- proj_drop=0.,
365
- scaled_cosine=scale_cosine_attn,
366
- scale_heads=scale_heads,
367
- xattn=xattn
368
- )
369
-
370
- self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
371
- self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
372
-
373
- self.ln_2 = norm_layer(d_model)
374
- mlp_width = int(d_model * mlp_ratio)
375
- self.mlp = nn.Sequential(OrderedDict([
376
- ("c_fc", nn.Linear(d_model, mlp_width)),
377
- ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
378
- ("gelu", act_layer()),
379
- ("c_proj", nn.Linear(mlp_width, d_model))
380
- ]))
381
-
382
- self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
383
-
384
- def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
385
- q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
386
- q = q + self.ls_2(self.mlp(self.ln_2(q)))
387
- return q
388
-
389
- class CustomTransformer(nn.Module):
390
- def __init__(
391
- self,
392
- width: int,
393
- layers: int,
394
- heads: int,
395
- mlp_ratio: float = 4.0,
396
- ls_init_value: float = None,
397
- act_layer: Callable = nn.GELU,
398
- norm_layer: Callable = LayerNorm,
399
- scale_cosine_attn: bool = True,
400
- scale_heads: bool = False,
401
- scale_attn: bool = False,
402
- scale_fc: bool = False,
403
- cross_attn: bool = False,
404
- xattn: bool = False,
405
- ):
406
- super().__init__()
407
- self.width = width
408
- self.layers = layers
409
- self.grad_checkpointing = False
410
- self.xattn = xattn
411
-
412
- self.resblocks = nn.ModuleList([
413
- CustomResidualAttentionBlock(
414
- width,
415
- heads,
416
- mlp_ratio,
417
- ls_init_value=ls_init_value,
418
- act_layer=act_layer,
419
- norm_layer=norm_layer,
420
- scale_cosine_attn=scale_cosine_attn,
421
- scale_heads=scale_heads,
422
- scale_attn=scale_attn,
423
- scale_fc=scale_fc,
424
- cross_attn=cross_attn,
425
- xattn=xattn)
426
- for _ in range(layers)
427
- ])
428
-
429
- def get_cast_dtype(self) -> torch.dtype:
430
- return self.resblocks[0].mlp.c_fc.weight.dtype
431
-
432
- def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
433
- if k is None and v is None:
434
- k = v = q
435
- for r in self.resblocks:
436
- if self.grad_checkpointing and not torch.jit.is_scripting():
437
- q = checkpoint(r, q, k, v, attn_mask)
438
- else:
439
- q = r(q, k, v, attn_mask=attn_mask)
440
- return q
441
-
442
-
443
- class ResidualAttentionBlock(nn.Module):
444
- def __init__(
445
- self,
446
- d_model: int,
447
- n_head: int,
448
- mlp_ratio: float = 4.0,
449
- ls_init_value: float = None,
450
- act_layer: Callable = nn.GELU,
451
- norm_layer: Callable = LayerNorm,
452
- xattn: bool = False,
453
- ):
454
- super().__init__()
455
-
456
- self.ln_1 = norm_layer(d_model)
457
- if xattn:
458
- self.attn = Attention(d_model, n_head, xattn=True)
459
- else:
460
- self.attn = nn.MultiheadAttention(d_model, n_head)
461
- self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
462
-
463
- self.ln_2 = norm_layer(d_model)
464
- mlp_width = int(d_model * mlp_ratio)
465
- self.mlp = nn.Sequential(OrderedDict([
466
- ("c_fc", nn.Linear(d_model, mlp_width)),
467
- ("gelu", act_layer()),
468
- ("c_proj", nn.Linear(mlp_width, d_model))
469
- ]))
470
-
471
- self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
472
- self.xattn = xattn
473
-
474
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
475
- attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
476
- if self.xattn:
477
- return self.attn(x, attn_mask=attn_mask)
478
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
479
-
480
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
481
- x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
482
- x = x + self.ls_2(self.mlp(self.ln_2(x)))
483
- return x
484
-
485
- class Transformer(nn.Module):
486
- def __init__(
487
- self,
488
- width: int,
489
- layers: int,
490
- heads: int,
491
- mlp_ratio: float = 4.0,
492
- ls_init_value: float = None,
493
- act_layer: Callable = nn.GELU,
494
- norm_layer: Callable = LayerNorm,
495
- xattn: bool = False,
496
- ):
497
- super().__init__()
498
- self.width = width
499
- self.layers = layers
500
- self.grad_checkpointing = False
501
-
502
- self.resblocks = nn.ModuleList([
503
- ResidualAttentionBlock(
504
- width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
505
- for _ in range(layers)
506
- ])
507
-
508
- def get_cast_dtype(self) -> torch.dtype:
509
- return self.resblocks[0].mlp.c_fc.weight.dtype
510
-
511
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
512
- for r in self.resblocks:
513
- if self.grad_checkpointing and not torch.jit.is_scripting():
514
- x = checkpoint(r, x, attn_mask)
515
- else:
516
- x = r(x, attn_mask=attn_mask)
517
- return x
518
-
519
-
520
- class VisionTransformer(nn.Module):
521
- def __init__(
522
- self,
523
- image_size: int,
524
- patch_size: int,
525
- width: int,
526
- layers: int,
527
- heads: int,
528
- mlp_ratio: float,
529
- ls_init_value: float = None,
530
- patch_dropout: float = 0.,
531
- global_average_pool: bool = False,
532
- output_dim: int = 512,
533
- act_layer: Callable = nn.GELU,
534
- norm_layer: Callable = LayerNorm,
535
- xattn: bool = False,
536
- ):
537
- super().__init__()
538
- self.image_size = to_2tuple(image_size)
539
- self.patch_size = to_2tuple(patch_size)
540
- self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
541
- self.output_dim = output_dim
542
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
543
-
544
- scale = width ** -0.5
545
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
546
- self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
547
-
548
- # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
549
- self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
550
- self.ln_pre = norm_layer(width)
551
-
552
- self.transformer = Transformer(
553
- width,
554
- layers,
555
- heads,
556
- mlp_ratio,
557
- ls_init_value=ls_init_value,
558
- act_layer=act_layer,
559
- norm_layer=norm_layer,
560
- xattn=xattn
561
- )
562
-
563
- self.global_average_pool = global_average_pool
564
- self.ln_post = norm_layer(width)
565
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
566
-
567
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
568
- for param in self.parameters():
569
- param.requires_grad = False
570
-
571
- if unlocked_groups != 0:
572
- groups = [
573
- [
574
- self.conv1,
575
- self.class_embedding,
576
- self.positional_embedding,
577
- self.ln_pre,
578
- ],
579
- *self.transformer.resblocks[:-1],
580
- [
581
- self.transformer.resblocks[-1],
582
- self.ln_post,
583
- ],
584
- self.proj,
585
- ]
586
-
587
- def _unlock(x):
588
- if isinstance(x, Sequence):
589
- for g in x:
590
- _unlock(g)
591
- else:
592
- if isinstance(x, torch.nn.Parameter):
593
- x.requires_grad = True
594
- else:
595
- for p in x.parameters():
596
- p.requires_grad = True
597
-
598
- _unlock(groups[-unlocked_groups:])
599
-
600
- def get_num_layers(self):
601
- return self.transformer.layers
602
-
603
- @torch.jit.ignore
604
- def set_grad_checkpointing(self, enable=True):
605
- self.transformer.grad_checkpointing = enable
606
-
607
- @torch.jit.ignore
608
- def no_weight_decay(self):
609
- return {'positional_embedding', 'class_embedding'}
610
-
611
- def forward(self, x: torch.Tensor, return_all_features: bool=False):
612
- x = self.conv1(x) # shape = [*, width, grid, grid]
613
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
614
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
615
- x = torch.cat(
616
- [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
617
- x], dim=1) # shape = [*, grid ** 2 + 1, width]
618
- x = x + self.positional_embedding.to(x.dtype)
619
-
620
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
621
- x = self.patch_dropout(x)
622
- x = self.ln_pre(x)
623
-
624
- x = x.permute(1, 0, 2) # NLD -> LND
625
- x = self.transformer(x)
626
- x = x.permute(1, 0, 2) # LND -> NLD
627
-
628
- if not return_all_features:
629
- if self.global_average_pool:
630
- x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
631
- else:
632
- x = x[:, 0]
633
-
634
- x = self.ln_post(x)
635
-
636
- if self.proj is not None:
637
- x = x @ self.proj
638
-
639
- return x
640
-
641
-
642
- class TextTransformer(nn.Module):
643
- def __init__(
644
- self,
645
- context_length: int = 77,
646
- vocab_size: int = 49408,
647
- width: int = 512,
648
- heads: int = 8,
649
- layers: int = 12,
650
- ls_init_value: float = None,
651
- output_dim: int = 512,
652
- act_layer: Callable = nn.GELU,
653
- norm_layer: Callable = LayerNorm,
654
- xattn: bool= False,
655
- attn_mask: bool = True
656
- ):
657
- super().__init__()
658
- self.context_length = context_length
659
- self.vocab_size = vocab_size
660
- self.width = width
661
- self.output_dim = output_dim
662
-
663
- self.token_embedding = nn.Embedding(vocab_size, width)
664
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
665
- self.transformer = Transformer(
666
- width=width,
667
- layers=layers,
668
- heads=heads,
669
- ls_init_value=ls_init_value,
670
- act_layer=act_layer,
671
- norm_layer=norm_layer,
672
- xattn=xattn
673
- )
674
-
675
- self.xattn = xattn
676
- self.ln_final = norm_layer(width)
677
- self.text_projection = nn.Parameter(torch.empty(width, output_dim))
678
-
679
- if attn_mask:
680
- self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
681
- else:
682
- self.attn_mask = None
683
-
684
- self.init_parameters()
685
-
686
- def init_parameters(self):
687
- nn.init.normal_(self.token_embedding.weight, std=0.02)
688
- nn.init.normal_(self.positional_embedding, std=0.01)
689
-
690
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
691
- attn_std = self.transformer.width ** -0.5
692
- fc_std = (2 * self.transformer.width) ** -0.5
693
- for block in self.transformer.resblocks:
694
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
695
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
696
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
697
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
698
-
699
- if self.text_projection is not None:
700
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
701
-
702
- @torch.jit.ignore
703
- def set_grad_checkpointing(self, enable=True):
704
- self.transformer.grad_checkpointing = enable
705
-
706
- @torch.jit.ignore
707
- def no_weight_decay(self):
708
- # return {'positional_embedding', 'token_embedding'}
709
- return {'positional_embedding'}
710
-
711
- def get_num_layers(self):
712
- return self.transformer.layers
713
-
714
- def build_attention_mask(self):
715
- # lazily create causal attention mask, with full attention between the vision tokens
716
- # pytorch uses additive attention mask; fill with -inf
717
- mask = torch.empty(self.context_length, self.context_length)
718
- mask.fill_(float("-inf"))
719
- mask.triu_(1) # zero out the lower diagonal
720
- return mask
721
-
722
- def forward(self, text, return_all_features: bool=False):
723
- cast_dtype = self.transformer.get_cast_dtype()
724
- x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
725
-
726
- x = x + self.positional_embedding.to(cast_dtype)
727
- x = x.permute(1, 0, 2) # NLD -> LND
728
- x = self.transformer(x, attn_mask=self.attn_mask)
729
- # x = self.transformer(x) # no attention mask is applied
730
- x = x.permute(1, 0, 2) # LND -> NLD
731
- x = self.ln_final(x)
732
-
733
- if not return_all_features:
734
- # x.shape = [batch_size, n_ctx, transformer.width]
735
- # take features from the eot embedding (eot_token is the highest number in each sequence)
736
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
737
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/utils.py DELETED
@@ -1,326 +0,0 @@
1
- from itertools import repeat
2
- import collections.abc
3
- import logging
4
- import math
5
- import numpy as np
6
-
7
- import torch
8
- from torch import nn as nn
9
- from torchvision.ops.misc import FrozenBatchNorm2d
10
- import torch.nn.functional as F
11
-
12
- # open CLIP
13
- def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
- # Rescale the grid of position embeddings when loading from state_dict
15
- old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
- if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
- return
18
- grid_size = to_2tuple(model.visual.grid_size)
19
- extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
- new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
- if new_seq_len == old_pos_embed.shape[0]:
22
- return
23
-
24
- if extra_tokens:
25
- pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
- else:
27
- pos_emb_tok, pos_emb_img = None, old_pos_embed
28
- old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
-
30
- logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
- pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
- pos_emb_img = F.interpolate(
33
- pos_emb_img,
34
- size=grid_size,
35
- mode=interpolation,
36
- align_corners=True,
37
- )
38
- pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
- if pos_emb_tok is not None:
40
- new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
- else:
42
- new_pos_embed = pos_emb_img
43
- state_dict['visual.positional_embedding'] = new_pos_embed
44
-
45
-
46
- def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
- # Rescale the grid of position embeddings when loading from state_dict
48
- old_pos_embed = state_dict.get('positional_embedding', None)
49
- if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
- return
51
- grid_size = to_2tuple(model.visual.grid_size)
52
- extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
- new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
- if new_seq_len == old_pos_embed.shape[0]:
55
- return
56
-
57
- if extra_tokens:
58
- pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
- else:
60
- pos_emb_tok, pos_emb_img = None, old_pos_embed
61
- old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
-
63
- logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
- pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
- pos_emb_img = F.interpolate(
66
- pos_emb_img,
67
- size=grid_size,
68
- mode=interpolation,
69
- align_corners=True,
70
- )
71
- pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
- if pos_emb_tok is not None:
73
- new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
- else:
75
- new_pos_embed = pos_emb_img
76
- state_dict['positional_embedding'] = new_pos_embed
77
-
78
- def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
- all_keys = list(state_dict.keys())
80
- # interpolate position embedding
81
- if 'visual.pos_embed' in state_dict:
82
- pos_embed_checkpoint = state_dict['visual.pos_embed']
83
- embedding_size = pos_embed_checkpoint.shape[-1]
84
- num_patches = model.visual.patch_embed.num_patches
85
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
- # height (== width) for the checkpoint position embedding
87
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
- # height (== width) for the new position embedding
89
- new_size = int(num_patches ** 0.5)
90
- # class_token and dist_token are kept unchanged
91
- if orig_size != new_size:
92
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
- # only the position tokens are interpolated
95
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
- pos_tokens = torch.nn.functional.interpolate(
98
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
99
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
- state_dict['visual.pos_embed'] = new_pos_embed
102
-
103
- patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
- patch_size = model.visual.patch_embed.patch_size
105
- state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
-
108
-
109
- def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
- all_keys = list(state_dict.keys())
111
- # interpolate position embedding
112
- if 'pos_embed' in state_dict:
113
- pos_embed_checkpoint = state_dict['pos_embed']
114
- embedding_size = pos_embed_checkpoint.shape[-1]
115
- num_patches = model.visual.patch_embed.num_patches
116
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
- # height (== width) for the checkpoint position embedding
118
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
- # height (== width) for the new position embedding
120
- new_size = int(num_patches ** 0.5)
121
- # class_token and dist_token are kept unchanged
122
- if orig_size != new_size:
123
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
- # only the position tokens are interpolated
126
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
- pos_tokens = torch.nn.functional.interpolate(
129
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
- state_dict['pos_embed'] = new_pos_embed
133
-
134
- patch_embed_proj = state_dict['patch_embed.proj.weight']
135
- patch_size = model.visual.patch_embed.patch_size
136
- state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
-
139
-
140
- def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
- all_keys = list(state_dict.keys())
142
- for key in all_keys:
143
- if "relative_position_index" in key:
144
- state_dict.pop(key)
145
-
146
- if "relative_position_bias_table" in key:
147
- rel_pos_bias = state_dict[key]
148
- src_num_pos, num_attn_heads = rel_pos_bias.size()
149
- dst_num_pos, _ = model.visual.state_dict()[key].size()
150
- dst_patch_shape = model.visual.patch_embed.patch_shape
151
- if dst_patch_shape[0] != dst_patch_shape[1]:
152
- raise NotImplementedError()
153
- num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
- src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
- dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
- if src_size != dst_size:
157
- print("Position interpolate for %s from %dx%d to %dx%d" % (
158
- key, src_size, src_size, dst_size, dst_size))
159
- extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
- rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
-
162
- def geometric_progression(a, r, n):
163
- return a * (1.0 - r ** n) / (1.0 - r)
164
-
165
- left, right = 1.01, 1.5
166
- while right - left > 1e-6:
167
- q = (left + right) / 2.0
168
- gp = geometric_progression(1, q, src_size // 2)
169
- if gp > dst_size // 2:
170
- right = q
171
- else:
172
- left = q
173
-
174
- # if q > 1.090307:
175
- # q = 1.090307
176
-
177
- dis = []
178
- cur = 1
179
- for i in range(src_size // 2):
180
- dis.append(cur)
181
- cur += q ** (i + 1)
182
-
183
- r_ids = [-_ for _ in reversed(dis)]
184
-
185
- x = r_ids + [0] + dis
186
- y = r_ids + [0] + dis
187
-
188
- t = dst_size // 2.0
189
- dx = np.arange(-t, t + 0.1, 1.0)
190
- dy = np.arange(-t, t + 0.1, 1.0)
191
-
192
- print("Original positions = %s" % str(x))
193
- print("Target positions = %s" % str(dx))
194
-
195
- all_rel_pos_bias = []
196
-
197
- for i in range(num_attn_heads):
198
- z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
- f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
- all_rel_pos_bias.append(
201
- torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
-
203
- rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
-
205
- new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
- state_dict[key] = new_rel_pos_bias
207
-
208
- # interpolate position embedding
209
- if 'pos_embed' in state_dict:
210
- pos_embed_checkpoint = state_dict['pos_embed']
211
- embedding_size = pos_embed_checkpoint.shape[-1]
212
- num_patches = model.visual.patch_embed.num_patches
213
- num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
- # height (== width) for the checkpoint position embedding
215
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
- # height (== width) for the new position embedding
217
- new_size = int(num_patches ** 0.5)
218
- # class_token and dist_token are kept unchanged
219
- if orig_size != new_size:
220
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
- # only the position tokens are interpolated
223
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
- pos_tokens = torch.nn.functional.interpolate(
226
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
- state_dict['pos_embed'] = new_pos_embed
230
-
231
- patch_embed_proj = state_dict['patch_embed.proj.weight']
232
- patch_size = model.visual.patch_embed.patch_size
233
- state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
- patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
-
236
-
237
- def freeze_batch_norm_2d(module, module_match={}, name=''):
238
- """
239
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
-
243
- Args:
244
- module (torch.nn.Module): Any PyTorch module.
245
- module_match (dict): Dictionary of full module names to freeze (all if empty)
246
- name (str): Full module name (prefix)
247
-
248
- Returns:
249
- torch.nn.Module: Resulting module
250
-
251
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
- """
253
- res = module
254
- is_match = True
255
- if module_match:
256
- is_match = name in module_match
257
- if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
- res = FrozenBatchNorm2d(module.num_features)
259
- res.num_features = module.num_features
260
- res.affine = module.affine
261
- if module.affine:
262
- res.weight.data = module.weight.data.clone().detach()
263
- res.bias.data = module.bias.data.clone().detach()
264
- res.running_mean.data = module.running_mean.data
265
- res.running_var.data = module.running_var.data
266
- res.eps = module.eps
267
- else:
268
- for child_name, child in module.named_children():
269
- full_child_name = '.'.join([name, child_name]) if name else child_name
270
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
- if new_child is not child:
272
- res.add_module(child_name, new_child)
273
- return res
274
-
275
-
276
- # From PyTorch internals
277
- def _ntuple(n):
278
- def parse(x):
279
- if isinstance(x, collections.abc.Iterable):
280
- return x
281
- return tuple(repeat(x, n))
282
- return parse
283
-
284
-
285
- to_1tuple = _ntuple(1)
286
- to_2tuple = _ntuple(2)
287
- to_3tuple = _ntuple(3)
288
- to_4tuple = _ntuple(4)
289
- to_ntuple = lambda n, x: _ntuple(n)(x)
290
-
291
-
292
- def is_logging(args):
293
- def is_global_master(args):
294
- return args.rank == 0
295
-
296
- def is_local_master(args):
297
- return args.local_rank == 0
298
-
299
- def is_master(args, local=False):
300
- return is_local_master(args) if local else is_global_master(args)
301
- return is_master
302
-
303
-
304
- class AllGather(torch.autograd.Function):
305
- """An autograd function that performs allgather on a tensor.
306
- Performs all_gather operation on the provided tensors.
307
- *** Warning ***: torch.distributed.all_gather has no gradient.
308
- """
309
-
310
- @staticmethod
311
- def forward(ctx, tensor, rank, world_size):
312
- tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
- torch.distributed.all_gather(tensors_gather, tensor)
314
- ctx.rank = rank
315
- ctx.batch_size = tensor.shape[0]
316
- return torch.cat(tensors_gather, 0)
317
-
318
- @staticmethod
319
- def backward(ctx, grad_output):
320
- return (
321
- grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
- None,
323
- None
324
- )
325
-
326
- allgather = AllGather.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/eva_clip/utils_qformer.py DELETED
@@ -1,166 +0,0 @@
1
- import importlib
2
- import math
3
- import os
4
- import random
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from torchvision.utils import make_grid
11
- from transformers import PretrainedConfig
12
-
13
-
14
- def seed_everything(seed):
15
- os.environ["PL_GLOBAL_SEED"] = str(seed)
16
- random.seed(seed)
17
- np.random.seed(seed)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed_all(seed)
20
-
21
-
22
- def is_torch2_available():
23
- return hasattr(F, "scaled_dot_product_attention")
24
-
25
-
26
- def instantiate_from_config(config):
27
- if "target" not in config:
28
- if config == '__is_first_stage__' or config == "__is_unconditional__":
29
- return None
30
- raise KeyError("Expected key `target` to instantiate.")
31
- return get_obj_from_str(config["target"])(**config.get("params", {}))
32
-
33
-
34
- def get_obj_from_str(string, reload=False):
35
- module, cls = string.rsplit(".", 1)
36
- if reload:
37
- module_imp = importlib.import_module(module)
38
- importlib.reload(module_imp)
39
- return getattr(importlib.import_module(module, package=None), cls)
40
-
41
-
42
- def drop_seq_token(seq, drop_rate=0.5):
43
- idx = torch.randperm(seq.size(1))
44
- num_keep_tokens = int(len(idx) * (1 - drop_rate))
45
- idx = idx[:num_keep_tokens]
46
- seq = seq[:, idx]
47
- return seq
48
-
49
-
50
- def import_model_class_from_model_name_or_path(
51
- pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
52
- ):
53
- text_encoder_config = PretrainedConfig.from_pretrained(
54
- pretrained_model_name_or_path, subfolder=subfolder, revision=revision
55
- )
56
- model_class = text_encoder_config.architectures[0]
57
-
58
- if model_class == "CLIPTextModel":
59
- from transformers import CLIPTextModel
60
-
61
- return CLIPTextModel
62
- elif model_class == "CLIPTextModelWithProjection": # noqa RET505
63
- from transformers import CLIPTextModelWithProjection
64
-
65
- return CLIPTextModelWithProjection
66
- else:
67
- raise ValueError(f"{model_class} is not supported.")
68
-
69
-
70
- def resize_numpy_image_long(image, resize_long_edge=768):
71
- h, w = image.shape[:2]
72
- if max(h, w) <= resize_long_edge:
73
- return image
74
- k = resize_long_edge / max(h, w)
75
- h = int(h * k)
76
- w = int(w * k)
77
- image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
78
- return image
79
-
80
-
81
- # from basicsr
82
- def img2tensor(imgs, bgr2rgb=True, float32=True):
83
- """Numpy array to tensor.
84
-
85
- Args:
86
- imgs (list[ndarray] | ndarray): Input images.
87
- bgr2rgb (bool): Whether to change bgr to rgb.
88
- float32 (bool): Whether to change to float32.
89
-
90
- Returns:
91
- list[tensor] | tensor: Tensor images. If returned results only have
92
- one element, just return tensor.
93
- """
94
-
95
- def _totensor(img, bgr2rgb, float32):
96
- if img.shape[2] == 3 and bgr2rgb:
97
- if img.dtype == 'float64':
98
- img = img.astype('float32')
99
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
100
- img = torch.from_numpy(img.transpose(2, 0, 1))
101
- if float32:
102
- img = img.float()
103
- return img
104
-
105
- if isinstance(imgs, list):
106
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
107
- return _totensor(imgs, bgr2rgb, float32)
108
-
109
-
110
- def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
111
- """Convert torch Tensors into image numpy arrays.
112
-
113
- After clamping to [min, max], values will be normalized to [0, 1].
114
-
115
- Args:
116
- tensor (Tensor or list[Tensor]): Accept shapes:
117
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
118
- 2) 3D Tensor of shape (3/1 x H x W);
119
- 3) 2D Tensor of shape (H x W).
120
- Tensor channel should be in RGB order.
121
- rgb2bgr (bool): Whether to change rgb to bgr.
122
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
123
- to uint8 type with range [0, 255]; otherwise, float type with
124
- range [0, 1]. Default: ``np.uint8``.
125
- min_max (tuple[int]): min and max values for clamp.
126
-
127
- Returns:
128
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
129
- shape (H x W). The channel order is BGR.
130
- """
131
- if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
132
- raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
133
-
134
- if torch.is_tensor(tensor):
135
- tensor = [tensor]
136
- result = []
137
- for _tensor in tensor:
138
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
139
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
140
-
141
- n_dim = _tensor.dim()
142
- if n_dim == 4:
143
- img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
144
- img_np = img_np.transpose(1, 2, 0)
145
- if rgb2bgr:
146
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
147
- elif n_dim == 3:
148
- img_np = _tensor.numpy()
149
- img_np = img_np.transpose(1, 2, 0)
150
- if img_np.shape[2] == 1: # gray image
151
- img_np = np.squeeze(img_np, axis=2)
152
- else:
153
- if rgb2bgr:
154
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
155
- elif n_dim == 2:
156
- img_np = _tensor.numpy()
157
- else:
158
- raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
159
- if out_type == np.uint8:
160
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
161
- img_np = (img_np * 255.0).round()
162
- img_np = img_np.astype(out_type)
163
- result.append(img_np)
164
- if len(result) == 1:
165
- result = result[0]
166
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/local_facial_extractor.py DELETED
@@ -1,309 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
-
5
-
6
- # FFN
7
- def ConsisIDFeedForward(dim, mult=4):
8
- """
9
- Creates a consistent ID feedforward block consisting of layer normalization,
10
- two linear layers, and a GELU activation.
11
-
12
- Args:
13
- dim (int): The input dimension of the tensor.
14
- mult (int, optional): Multiplier for the inner dimension. Default is 4.
15
-
16
- Returns:
17
- nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU.
18
- """
19
- inner_dim = int(dim * mult)
20
- return nn.Sequential(
21
- nn.LayerNorm(dim),
22
- nn.Linear(dim, inner_dim, bias=False),
23
- nn.GELU(),
24
- nn.Linear(inner_dim, dim, bias=False),
25
- )
26
-
27
-
28
- def reshape_tensor(x, heads):
29
- """
30
- Reshapes the input tensor for multi-head attention.
31
-
32
- Args:
33
- x (torch.Tensor): The input tensor with shape (batch_size, length, width).
34
- heads (int): The number of attention heads.
35
-
36
- Returns:
37
- torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
38
- """
39
- bs, length, width = x.shape
40
- x = x.view(bs, length, heads, -1)
41
- x = x.transpose(1, 2)
42
- x = x.reshape(bs, heads, length, -1)
43
- return x
44
-
45
-
46
- class PerceiverAttention(nn.Module):
47
- """
48
- Implements the Perceiver attention mechanism with multi-head attention.
49
-
50
- This layer takes two inputs: 'x' (image features) and 'latents' (latent features),
51
- applying multi-head attention to both and producing an output tensor with the same
52
- dimension as the input tensor 'x'.
53
-
54
- Args:
55
- dim (int): The input dimension.
56
- dim_head (int, optional): The dimension of each attention head. Default is 64.
57
- heads (int, optional): The number of attention heads. Default is 8.
58
- kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values.
59
- """
60
-
61
- def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
62
- super().__init__()
63
- self.scale = dim_head**-0.5
64
- self.dim_head = dim_head
65
- self.heads = heads
66
- inner_dim = dim_head * heads
67
-
68
- self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
69
- self.norm2 = nn.LayerNorm(dim)
70
-
71
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
72
- self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
73
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
74
-
75
- def forward(self, x, latents):
76
- """
77
- Forward pass for Perceiver attention.
78
-
79
- Args:
80
- x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D).
81
- latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D).
82
-
83
- Returns:
84
- torch.Tensor: Output tensor after applying attention and transformation.
85
- """
86
- # Apply normalization
87
- x = self.norm1(x)
88
- latents = self.norm2(latents)
89
-
90
- b, seq_len, _ = latents.shape # Get batch size and sequence length
91
-
92
- # Compute query, key, and value matrices
93
- q = self.to_q(latents)
94
- kv_input = torch.cat((x, latents), dim=-2)
95
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
96
-
97
- # Reshape the tensors for multi-head attention
98
- q = reshape_tensor(q, self.heads)
99
- k = reshape_tensor(k, self.heads)
100
- v = reshape_tensor(v, self.heads)
101
-
102
- # attention
103
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
104
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
105
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
106
- out = weight @ v
107
-
108
- # Reshape and return the final output
109
- out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
110
-
111
- return self.to_out(out)
112
-
113
-
114
- class LocalFacialExtractor(nn.Module):
115
- def __init__(
116
- self,
117
- dim=1024,
118
- depth=10,
119
- dim_head=64,
120
- heads=16,
121
- num_id_token=5,
122
- num_queries=32,
123
- output_dim=2048,
124
- ff_mult=4,
125
- ):
126
- """
127
- Initializes the LocalFacialExtractor class.
128
-
129
- Parameters:
130
- - dim (int): The dimensionality of latent features.
131
- - depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers.
132
- - dim_head (int): Dimensionality of each attention head.
133
- - heads (int): Number of attention heads.
134
- - num_id_token (int): Number of tokens used for identity features.
135
- - num_queries (int): Number of query tokens for the latent representation.
136
- - output_dim (int): Output dimension after projection.
137
- - ff_mult (int): Multiplier for the feed-forward network hidden dimension.
138
- """
139
- super().__init__()
140
-
141
- # Storing identity token and query information
142
- self.num_id_token = num_id_token
143
- self.dim = dim
144
- self.num_queries = num_queries
145
- assert depth % 5 == 0
146
- self.depth = depth // 5
147
- scale = dim**-0.5
148
-
149
- # Learnable latent query embeddings
150
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
151
- # Projection layer to map the latent output to the desired dimension
152
- self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
153
-
154
- # Attention and ConsisIDFeedForward layer stack
155
- self.layers = nn.ModuleList([])
156
- for _ in range(depth):
157
- self.layers.append(
158
- nn.ModuleList(
159
- [
160
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
161
- ConsisIDFeedForward(dim=dim, mult=ff_mult), # ConsisIDFeedForward layer
162
- ]
163
- )
164
- )
165
-
166
- # Mappings for each of the 5 different ViT features
167
- for i in range(5):
168
- setattr(
169
- self,
170
- f"mapping_{i}",
171
- nn.Sequential(
172
- nn.Linear(1024, 1024),
173
- nn.LayerNorm(1024),
174
- nn.LeakyReLU(),
175
- nn.Linear(1024, 1024),
176
- nn.LayerNorm(1024),
177
- nn.LeakyReLU(),
178
- nn.Linear(1024, dim),
179
- ),
180
- )
181
-
182
- # Mapping for identity embedding vectors
183
- self.id_embedding_mapping = nn.Sequential(
184
- nn.Linear(1280, 1024),
185
- nn.LayerNorm(1024),
186
- nn.LeakyReLU(),
187
- nn.Linear(1024, 1024),
188
- nn.LayerNorm(1024),
189
- nn.LeakyReLU(),
190
- nn.Linear(1024, dim * num_id_token),
191
- )
192
-
193
- def forward(self, x, y):
194
- """
195
- Forward pass for LocalFacialExtractor.
196
-
197
- Parameters:
198
- - x (Tensor): The input identity embedding tensor of shape (batch_size, 1280).
199
- - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, 1024).
200
-
201
- Returns:
202
- - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim).
203
- """
204
-
205
- # Repeat latent queries for the batch size
206
- latents = self.latents.repeat(x.size(0), 1, 1)
207
-
208
- # Map the identity embedding to tokens
209
- x = self.id_embedding_mapping(x)
210
- x = x.reshape(-1, self.num_id_token, self.dim)
211
-
212
- # Concatenate identity tokens with the latent queries
213
- latents = torch.cat((latents, x), dim=1)
214
-
215
- # Process each of the 5 visual feature inputs
216
- for i in range(5):
217
- vit_feature = getattr(self, f"mapping_{i}")(y[i])
218
- ctx_feature = torch.cat((x, vit_feature), dim=1)
219
-
220
- # Pass through the PerceiverAttention and ConsisIDFeedForward layers
221
- for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
222
- latents = attn(ctx_feature, latents) + latents
223
- latents = ff(latents) + latents
224
-
225
- # Retain only the query latents
226
- latents = latents[:, : self.num_queries]
227
- # Project the latents to the output dimension
228
- latents = latents @ self.proj_out
229
- return latents
230
-
231
-
232
- class PerceiverCrossAttention(nn.Module):
233
- """
234
-
235
- Args:
236
- dim (int): Dimension of the input latent and output. Default is 3072.
237
- dim_head (int): Dimension of each attention head. Default is 128.
238
- heads (int): Number of attention heads. Default is 16.
239
- kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
240
-
241
- Attributes:
242
- scale (float): Scaling factor used in dot-product attention for numerical stability.
243
- norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
244
- norm2 (nn.LayerNorm): Layer normalization applied to the latent features.
245
- to_q (nn.Linear): Linear layer for projecting the latent features into queries.
246
- to_kv (nn.Linear): Linear layer for projecting the input features into keys and values.
247
- to_out (nn.Linear): Linear layer for outputting the final result after attention.
248
-
249
- """
250
-
251
- def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
252
- super().__init__()
253
- self.scale = dim_head**-0.5
254
- self.dim_head = dim_head
255
- self.heads = heads
256
- inner_dim = dim_head * heads
257
-
258
- # Layer normalization to stabilize training
259
- self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
260
- self.norm2 = nn.LayerNorm(dim)
261
-
262
- # Linear transformations to produce queries, keys, and values
263
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
264
- self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
265
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
266
-
267
- def forward(self, x, latents):
268
- """
269
-
270
- Args:
271
- x (torch.Tensor): Input image features with shape (batch_size, n1, D), where:
272
- - batch_size (b): Number of samples in the batch.
273
- - n1: Sequence length (e.g., number of patches or tokens).
274
- - D: Feature dimension.
275
-
276
- latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
277
- - n2: Number of latent elements.
278
-
279
- Returns:
280
- torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
281
-
282
- """
283
- # Apply layer normalization to the input image and latent features
284
- x = self.norm1(x)
285
- latents = self.norm2(latents)
286
-
287
- b, seq_len, _ = latents.shape
288
-
289
- # Compute queries, keys, and values
290
- q = self.to_q(latents)
291
- k, v = self.to_kv(x).chunk(2, dim=-1)
292
-
293
- # Reshape tensors to split into attention heads
294
- q = reshape_tensor(q, self.heads)
295
- k = reshape_tensor(k, self.heads)
296
- v = reshape_tensor(v, self.heads)
297
-
298
- # Compute attention weights
299
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
300
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division
301
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
302
-
303
- # Compute the output via weighted combination of values
304
- out = weight @ v
305
-
306
- # Reshape and permute to prepare for final linear transformation
307
- out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
308
-
309
- return self.to_out(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/pipeline_cogvideox.py DELETED
@@ -1,748 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import inspect
17
- import math
18
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
-
20
- import torch
21
- from transformers import T5EncoderModel, T5Tokenizer
22
-
23
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
- from diffusers.loaders import CogVideoXLoraLoaderMixin
25
- from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
- from diffusers.models.embeddings import get_3d_rotary_pos_embed
27
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
- from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
- from diffusers.utils import logging, replace_example_docstring
30
- from diffusers.utils.torch_utils import randn_tensor
31
- from diffusers.video_processor import VideoProcessor
32
- from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
33
-
34
-
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
-
38
- EXAMPLE_DOC_STRING = """
39
- Examples:
40
- ```python
41
- >>> import torch
42
- >>> from diffusers import CogVideoXPipeline
43
- >>> from diffusers.utils import export_to_video
44
-
45
- >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
46
- >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
47
- >>> prompt = (
48
- ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
49
- ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
50
- ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
51
- ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
52
- ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
53
- ... "atmosphere of this unique musical performance."
54
- ... )
55
- >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
56
- >>> export_to_video(video, "output.mp4", fps=8)
57
- ```
58
- """
59
-
60
-
61
- # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
62
- def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
63
- tw = tgt_width
64
- th = tgt_height
65
- h, w = src
66
- r = h / w
67
- if r > (th / tw):
68
- resize_height = th
69
- resize_width = int(round(th / h * w))
70
- else:
71
- resize_width = tw
72
- resize_height = int(round(tw / w * h))
73
-
74
- crop_top = int(round((th - resize_height) / 2.0))
75
- crop_left = int(round((tw - resize_width) / 2.0))
76
-
77
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
78
-
79
-
80
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
- def retrieve_timesteps(
82
- scheduler,
83
- num_inference_steps: Optional[int] = None,
84
- device: Optional[Union[str, torch.device]] = None,
85
- timesteps: Optional[List[int]] = None,
86
- sigmas: Optional[List[float]] = None,
87
- **kwargs,
88
- ):
89
- """
90
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
-
93
- Args:
94
- scheduler (`SchedulerMixin`):
95
- The scheduler to get timesteps from.
96
- num_inference_steps (`int`):
97
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
- must be `None`.
99
- device (`str` or `torch.device`, *optional*):
100
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
- timesteps (`List[int]`, *optional*):
102
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
- `num_inference_steps` and `sigmas` must be `None`.
104
- sigmas (`List[float]`, *optional*):
105
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
- `num_inference_steps` and `timesteps` must be `None`.
107
-
108
- Returns:
109
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
- second element is the number of inference steps.
111
- """
112
- if timesteps is not None and sigmas is not None:
113
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
- if timesteps is not None:
115
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
- if not accepts_timesteps:
117
- raise ValueError(
118
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
- f" timestep schedules. Please check whether you are using the correct scheduler."
120
- )
121
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
- timesteps = scheduler.timesteps
123
- num_inference_steps = len(timesteps)
124
- elif sigmas is not None:
125
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
- if not accept_sigmas:
127
- raise ValueError(
128
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
- f" sigmas schedules. Please check whether you are using the correct scheduler."
130
- )
131
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
- timesteps = scheduler.timesteps
133
- num_inference_steps = len(timesteps)
134
- else:
135
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
- timesteps = scheduler.timesteps
137
- return timesteps, num_inference_steps
138
-
139
-
140
- class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
141
- r"""
142
- Pipeline for text-to-video generation using CogVideoX.
143
-
144
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
145
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
146
-
147
- Args:
148
- vae ([`AutoencoderKL`]):
149
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
150
- text_encoder ([`T5EncoderModel`]):
151
- Frozen text-encoder. CogVideoX uses
152
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
153
- [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
154
- tokenizer (`T5Tokenizer`):
155
- Tokenizer of class
156
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
157
- transformer ([`CogVideoXTransformer3DModel`]):
158
- A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
159
- scheduler ([`SchedulerMixin`]):
160
- A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
161
- """
162
-
163
- _optional_components = []
164
- model_cpu_offload_seq = "text_encoder->transformer->vae"
165
-
166
- _callback_tensor_inputs = [
167
- "latents",
168
- "prompt_embeds",
169
- "negative_prompt_embeds",
170
- ]
171
-
172
- def __init__(
173
- self,
174
- tokenizer: T5Tokenizer,
175
- text_encoder: T5EncoderModel,
176
- vae: AutoencoderKLCogVideoX,
177
- transformer: CogVideoXTransformer3DModel,
178
- scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
179
- ):
180
- super().__init__()
181
-
182
- self.register_modules(
183
- tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
184
- )
185
- self.vae_scale_factor_spatial = (
186
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
187
- )
188
- self.vae_scale_factor_temporal = (
189
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
190
- )
191
- self.vae_scaling_factor_image = (
192
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
193
- )
194
-
195
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
196
-
197
- def _get_t5_prompt_embeds(
198
- self,
199
- prompt: Union[str, List[str]] = None,
200
- num_videos_per_prompt: int = 1,
201
- max_sequence_length: int = 226,
202
- device: Optional[torch.device] = None,
203
- dtype: Optional[torch.dtype] = None,
204
- ):
205
- device = device or self._execution_device
206
- dtype = dtype or self.text_encoder.dtype
207
-
208
- prompt = [prompt] if isinstance(prompt, str) else prompt
209
- batch_size = len(prompt)
210
-
211
- text_inputs = self.tokenizer(
212
- prompt,
213
- padding="max_length",
214
- max_length=max_sequence_length,
215
- truncation=True,
216
- add_special_tokens=True,
217
- return_tensors="pt",
218
- )
219
- text_input_ids = text_inputs.input_ids
220
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
221
-
222
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
223
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
224
- logger.warning(
225
- "The following part of your input was truncated because `max_sequence_length` is set to "
226
- f" {max_sequence_length} tokens: {removed_text}"
227
- )
228
-
229
- prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
230
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
-
232
- # duplicate text embeddings for each generation per prompt, using mps friendly method
233
- _, seq_len, _ = prompt_embeds.shape
234
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
235
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
236
-
237
- return prompt_embeds
238
-
239
- def encode_prompt(
240
- self,
241
- prompt: Union[str, List[str]],
242
- negative_prompt: Optional[Union[str, List[str]]] = None,
243
- do_classifier_free_guidance: bool = True,
244
- num_videos_per_prompt: int = 1,
245
- prompt_embeds: Optional[torch.Tensor] = None,
246
- negative_prompt_embeds: Optional[torch.Tensor] = None,
247
- max_sequence_length: int = 226,
248
- device: Optional[torch.device] = None,
249
- dtype: Optional[torch.dtype] = None,
250
- ):
251
- r"""
252
- Encodes the prompt into text encoder hidden states.
253
-
254
- Args:
255
- prompt (`str` or `List[str]`, *optional*):
256
- prompt to be encoded
257
- negative_prompt (`str` or `List[str]`, *optional*):
258
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
259
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
260
- less than `1`).
261
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
262
- Whether to use classifier free guidance or not.
263
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
264
- Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
265
- prompt_embeds (`torch.Tensor`, *optional*):
266
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
- provided, text embeddings will be generated from `prompt` input argument.
268
- negative_prompt_embeds (`torch.Tensor`, *optional*):
269
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
270
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
271
- argument.
272
- device: (`torch.device`, *optional*):
273
- torch device
274
- dtype: (`torch.dtype`, *optional*):
275
- torch dtype
276
- """
277
- device = device or self._execution_device
278
-
279
- prompt = [prompt] if isinstance(prompt, str) else prompt
280
- if prompt is not None:
281
- batch_size = len(prompt)
282
- else:
283
- batch_size = prompt_embeds.shape[0]
284
-
285
- if prompt_embeds is None:
286
- prompt_embeds = self._get_t5_prompt_embeds(
287
- prompt=prompt,
288
- num_videos_per_prompt=num_videos_per_prompt,
289
- max_sequence_length=max_sequence_length,
290
- device=device,
291
- dtype=dtype,
292
- )
293
-
294
- if do_classifier_free_guidance and negative_prompt_embeds is None:
295
- negative_prompt = negative_prompt or ""
296
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
297
-
298
- if prompt is not None and type(prompt) is not type(negative_prompt):
299
- raise TypeError(
300
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301
- f" {type(prompt)}."
302
- )
303
- elif batch_size != len(negative_prompt):
304
- raise ValueError(
305
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
- " the batch size of `prompt`."
308
- )
309
-
310
- negative_prompt_embeds = self._get_t5_prompt_embeds(
311
- prompt=negative_prompt,
312
- num_videos_per_prompt=num_videos_per_prompt,
313
- max_sequence_length=max_sequence_length,
314
- device=device,
315
- dtype=dtype,
316
- )
317
-
318
- return prompt_embeds, negative_prompt_embeds
319
-
320
- def prepare_latents(
321
- self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
322
- ):
323
- if isinstance(generator, list) and len(generator) != batch_size:
324
- raise ValueError(
325
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
326
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
327
- )
328
-
329
- shape = (
330
- batch_size,
331
- (num_frames - 1) // self.vae_scale_factor_temporal + 1,
332
- num_channels_latents,
333
- height // self.vae_scale_factor_spatial,
334
- width // self.vae_scale_factor_spatial,
335
- )
336
-
337
- if latents is None:
338
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
339
- else:
340
- latents = latents.to(device)
341
-
342
- # scale the initial noise by the standard deviation required by the scheduler
343
- latents = latents * self.scheduler.init_noise_sigma
344
- return latents
345
-
346
- def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
347
- latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
348
- latents = 1 / self.vae_scaling_factor_image * latents
349
-
350
- frames = self.vae.decode(latents).sample
351
- return frames
352
-
353
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
354
- def prepare_extra_step_kwargs(self, generator, eta):
355
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
358
- # and should be between [0, 1]
359
-
360
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
361
- extra_step_kwargs = {}
362
- if accepts_eta:
363
- extra_step_kwargs["eta"] = eta
364
-
365
- # check if the scheduler accepts generator
366
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
367
- if accepts_generator:
368
- extra_step_kwargs["generator"] = generator
369
- return extra_step_kwargs
370
-
371
- # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
372
- def check_inputs(
373
- self,
374
- prompt,
375
- height,
376
- width,
377
- negative_prompt,
378
- callback_on_step_end_tensor_inputs,
379
- prompt_embeds=None,
380
- negative_prompt_embeds=None,
381
- ):
382
- if height % 8 != 0 or width % 8 != 0:
383
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
384
-
385
- if callback_on_step_end_tensor_inputs is not None and not all(
386
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
387
- ):
388
- raise ValueError(
389
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
390
- )
391
- if prompt is not None and prompt_embeds is not None:
392
- raise ValueError(
393
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
394
- " only forward one of the two."
395
- )
396
- elif prompt is None and prompt_embeds is None:
397
- raise ValueError(
398
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
399
- )
400
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
401
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
402
-
403
- if prompt is not None and negative_prompt_embeds is not None:
404
- raise ValueError(
405
- f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
406
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
407
- )
408
-
409
- if negative_prompt is not None and negative_prompt_embeds is not None:
410
- raise ValueError(
411
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
412
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
413
- )
414
-
415
- if prompt_embeds is not None and negative_prompt_embeds is not None:
416
- if prompt_embeds.shape != negative_prompt_embeds.shape:
417
- raise ValueError(
418
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
419
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
420
- f" {negative_prompt_embeds.shape}."
421
- )
422
-
423
- def fuse_qkv_projections(self) -> None:
424
- r"""Enables fused QKV projections."""
425
- self.fusing_transformer = True
426
- self.transformer.fuse_qkv_projections()
427
-
428
- def unfuse_qkv_projections(self) -> None:
429
- r"""Disable QKV projection fusion if enabled."""
430
- if not self.fusing_transformer:
431
- logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
432
- else:
433
- self.transformer.unfuse_qkv_projections()
434
- self.fusing_transformer = False
435
-
436
- def _prepare_rotary_positional_embeddings(
437
- self,
438
- height: int,
439
- width: int,
440
- num_frames: int,
441
- device: torch.device,
442
- ) -> Tuple[torch.Tensor, torch.Tensor]:
443
- grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
444
- grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
446
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
447
-
448
- grid_crops_coords = get_resize_crop_region_for_grid(
449
- (grid_height, grid_width), base_size_width, base_size_height
450
- )
451
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
452
- embed_dim=self.transformer.config.attention_head_dim,
453
- crops_coords=grid_crops_coords,
454
- grid_size=(grid_height, grid_width),
455
- temporal_size=num_frames,
456
- )
457
-
458
- freqs_cos = freqs_cos.to(device=device)
459
- freqs_sin = freqs_sin.to(device=device)
460
- return freqs_cos, freqs_sin
461
-
462
- @property
463
- def guidance_scale(self):
464
- return self._guidance_scale
465
-
466
- @property
467
- def num_timesteps(self):
468
- return self._num_timesteps
469
-
470
- @property
471
- def attention_kwargs(self):
472
- return self._attention_kwargs
473
-
474
- @property
475
- def interrupt(self):
476
- return self._interrupt
477
-
478
- @torch.no_grad()
479
- @replace_example_docstring(EXAMPLE_DOC_STRING)
480
- def __call__(
481
- self,
482
- prompt: Optional[Union[str, List[str]]] = None,
483
- negative_prompt: Optional[Union[str, List[str]]] = None,
484
- height: int = 480,
485
- width: int = 720,
486
- num_frames: int = 49,
487
- num_inference_steps: int = 50,
488
- timesteps: Optional[List[int]] = None,
489
- guidance_scale: float = 6,
490
- use_dynamic_cfg: bool = False,
491
- num_videos_per_prompt: int = 1,
492
- eta: float = 0.0,
493
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
494
- latents: Optional[torch.FloatTensor] = None,
495
- prompt_embeds: Optional[torch.FloatTensor] = None,
496
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
497
- output_type: str = "pil",
498
- return_dict: bool = True,
499
- attention_kwargs: Optional[Dict[str, Any]] = None,
500
- callback_on_step_end: Optional[
501
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
502
- ] = None,
503
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
504
- max_sequence_length: int = 226,
505
- id_vit_hidden: Optional[torch.Tensor] = None,
506
- id_cond: Optional[torch.Tensor] = None,
507
- ) -> Union[CogVideoXPipelineOutput, Tuple]:
508
- """
509
- Function invoked when calling the pipeline for generation.
510
-
511
- Args:
512
- prompt (`str` or `List[str]`, *optional*):
513
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
514
- instead.
515
- negative_prompt (`str` or `List[str]`, *optional*):
516
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
517
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
518
- less than `1`).
519
- height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
520
- The height in pixels of the generated image. This is set to 480 by default for the best results.
521
- width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
522
- The width in pixels of the generated image. This is set to 720 by default for the best results.
523
- num_frames (`int`, defaults to `48`):
524
- Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
525
- contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
526
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
527
- needs to be satisfied is that of divisibility mentioned above.
528
- num_inference_steps (`int`, *optional*, defaults to 50):
529
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
530
- expense of slower inference.
531
- timesteps (`List[int]`, *optional*):
532
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
533
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
534
- passed will be used. Must be in descending order.
535
- guidance_scale (`float`, *optional*, defaults to 7.0):
536
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
537
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
538
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
539
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
540
- usually at the expense of lower image quality.
541
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
542
- The number of videos to generate per prompt.
543
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
544
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
545
- to make generation deterministic.
546
- latents (`torch.FloatTensor`, *optional*):
547
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
548
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
549
- tensor will ge generated by sampling using the supplied random `generator`.
550
- prompt_embeds (`torch.FloatTensor`, *optional*):
551
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
552
- provided, text embeddings will be generated from `prompt` input argument.
553
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
554
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
555
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
556
- argument.
557
- output_type (`str`, *optional*, defaults to `"pil"`):
558
- The output format of the generate image. Choose between
559
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
560
- return_dict (`bool`, *optional*, defaults to `True`):
561
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
562
- of a plain tuple.
563
- attention_kwargs (`dict`, *optional*):
564
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
565
- `self.processor` in
566
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
567
- callback_on_step_end (`Callable`, *optional*):
568
- A function that calls at the end of each denoising steps during the inference. The function is called
569
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
570
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
571
- `callback_on_step_end_tensor_inputs`.
572
- callback_on_step_end_tensor_inputs (`List`, *optional*):
573
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
574
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
575
- `._callback_tensor_inputs` attribute of your pipeline class.
576
- max_sequence_length (`int`, defaults to `226`):
577
- Maximum sequence length in encoded prompt. Must be consistent with
578
- `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
579
-
580
- Examples:
581
-
582
- Returns:
583
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
584
- [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
585
- `tuple`. When returning a tuple, the first element is a list with the generated images.
586
- """
587
-
588
- if num_frames > 49:
589
- raise ValueError(
590
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
591
- )
592
-
593
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
594
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
595
-
596
- num_videos_per_prompt = 1
597
-
598
- # 1. Check inputs. Raise error if not correct
599
- self.check_inputs(
600
- prompt,
601
- height,
602
- width,
603
- negative_prompt,
604
- callback_on_step_end_tensor_inputs,
605
- prompt_embeds,
606
- negative_prompt_embeds,
607
- )
608
- self._guidance_scale = guidance_scale
609
- self._attention_kwargs = attention_kwargs
610
- self._interrupt = False
611
-
612
- # 2. Default call parameters
613
- if prompt is not None and isinstance(prompt, str):
614
- batch_size = 1
615
- elif prompt is not None and isinstance(prompt, list):
616
- batch_size = len(prompt)
617
- else:
618
- batch_size = prompt_embeds.shape[0]
619
-
620
- device = self._execution_device
621
-
622
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
623
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
624
- # corresponds to doing no classifier free guidance.
625
- do_classifier_free_guidance = guidance_scale > 1.0
626
-
627
- # 3. Encode input prompt
628
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
629
- prompt,
630
- negative_prompt,
631
- do_classifier_free_guidance,
632
- num_videos_per_prompt=num_videos_per_prompt,
633
- prompt_embeds=prompt_embeds,
634
- negative_prompt_embeds=negative_prompt_embeds,
635
- max_sequence_length=max_sequence_length,
636
- device=device,
637
- )
638
- if do_classifier_free_guidance:
639
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
640
-
641
- # 4. Prepare timesteps
642
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
643
- self._num_timesteps = len(timesteps)
644
-
645
- # 5. Prepare latents.
646
- latent_channels = self.transformer.config.in_channels
647
- latents = self.prepare_latents(
648
- batch_size * num_videos_per_prompt,
649
- latent_channels,
650
- num_frames,
651
- height,
652
- width,
653
- prompt_embeds.dtype,
654
- device,
655
- generator,
656
- latents,
657
- )
658
-
659
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
660
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
661
-
662
- # 7. Create rotary embeds if required
663
- image_rotary_emb = (
664
- self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
665
- if self.transformer.config.use_rotary_positional_embeddings
666
- else None
667
- )
668
-
669
- # 8. Denoising loop
670
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
671
-
672
- with self.progress_bar(total=num_inference_steps) as progress_bar:
673
- # for DPM-solver++
674
- old_pred_original_sample = None
675
- for i, t in enumerate(timesteps):
676
- if self.interrupt:
677
- continue
678
-
679
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
680
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
681
-
682
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
683
- timestep = t.expand(latent_model_input.shape[0])
684
-
685
- # predict noise model_output
686
- noise_pred = self.transformer(
687
- hidden_states=latent_model_input,
688
- encoder_hidden_states=prompt_embeds,
689
- timestep=timestep,
690
- image_rotary_emb=image_rotary_emb,
691
- attention_kwargs=attention_kwargs,
692
- return_dict=False,
693
- id_vit_hidden = id_vit_hidden,
694
- id_cond = id_cond,
695
- )[0]
696
- noise_pred = noise_pred.float()
697
-
698
- # perform guidance
699
- if use_dynamic_cfg:
700
- self._guidance_scale = 1 + guidance_scale * (
701
- (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
702
- )
703
- if do_classifier_free_guidance:
704
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
705
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
706
-
707
- # compute the previous noisy sample x_t -> x_t-1
708
- if not isinstance(self.scheduler, CogVideoXDPMScheduler):
709
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
710
- else:
711
- latents, old_pred_original_sample = self.scheduler.step(
712
- noise_pred,
713
- old_pred_original_sample,
714
- t,
715
- timesteps[i - 1] if i > 0 else None,
716
- latents,
717
- **extra_step_kwargs,
718
- return_dict=False,
719
- )
720
- latents = latents.to(prompt_embeds.dtype)
721
-
722
- # call the callback, if provided
723
- if callback_on_step_end is not None:
724
- callback_kwargs = {}
725
- for k in callback_on_step_end_tensor_inputs:
726
- callback_kwargs[k] = locals()[k]
727
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
728
-
729
- latents = callback_outputs.pop("latents", latents)
730
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
731
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
732
-
733
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
734
- progress_bar.update()
735
-
736
- if not output_type == "latent":
737
- video = self.decode_latents(latents)
738
- video = self.video_processor.postprocess_video(video=video, output_type=output_type)
739
- else:
740
- video = latents
741
-
742
- # Offload all models
743
- self.maybe_free_model_hooks()
744
-
745
- if not return_dict:
746
- return (video,)
747
-
748
- return CogVideoXPipelineOutput(frames=video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/pipeline_consisid.py CHANGED
@@ -16,31 +16,26 @@ import inspect
16
  import math
17
  from typing import Callable, Dict, List, Optional, Tuple, Union
18
 
19
- import os
20
- import sys
21
- import PIL
22
- import numpy as np
23
  import cv2
 
 
24
  import torch
25
- from dataclasses import dataclass
26
  from transformers import T5EncoderModel, T5Tokenizer
27
 
28
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
29
  from diffusers.image_processor import PipelineImageInput
30
- from diffusers.models import AutoencoderKLCogVideoX
31
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
 
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
  from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
34
- from diffusers.utils import logging, replace_example_docstring, BaseOutput
 
 
 
35
  from diffusers.utils.torch_utils import randn_tensor
36
  from diffusers.video_processor import VideoProcessor
37
 
38
- from models.transformer_consisid import ConsisIDTransformer3DModel
39
-
40
- current_file_path = os.path.abspath(__file__)
41
- project_roots = [os.path.dirname(os.path.dirname(current_file_path))]
42
- for project_root in project_roots:
43
- sys.path.insert(0, project_root) if project_root not in sys.path else None
44
 
45
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
 
@@ -50,22 +45,64 @@ EXAMPLE_DOC_STRING = """
50
  ```py
51
  >>> import torch
52
  >>> from diffusers import ConsisIDPipeline
53
- >>> from diffusers.utils import export_to_video, load_image
 
 
 
 
54
 
55
- >>> pipe = ConsisIDPipeline.from_pretrained("https://huggingface.co/BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
 
 
 
56
  >>> pipe.to("cuda")
57
 
58
  >>> prompt = "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."
59
- >>> image = load_image(
60
- ... "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/1.png?raw=true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ... )
62
- >>> video = pipe(image, prompt, use_dynamic_cfg=True)
63
  >>> export_to_video(video.frames[0], "output.mp4", fps=8)
64
  ```
65
  """
66
 
67
 
68
  def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  stickwidth = 4
70
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
71
  kps = np.array(kps)
@@ -96,17 +133,23 @@ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255),
96
  return out_img_pil
97
 
98
 
99
- def process_image(image, vae):
100
- image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
101
- image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
102
- noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
103
- input_image = image + noisy_image
104
- image_latent_dist = vae.encode(input_image).latent_dist
105
- return image_latent_dist
106
-
107
-
108
  # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
109
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  tw = tgt_width
111
  th = tgt_height
112
  h, w = src
@@ -133,7 +176,7 @@ def retrieve_timesteps(
133
  sigmas: Optional[List[float]] = None,
134
  **kwargs,
135
  ):
136
- """
137
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
138
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
139
 
@@ -198,21 +241,6 @@ def retrieve_latents(
198
  raise AttributeError("Could not access latents of provided encoder_output")
199
 
200
 
201
- @dataclass
202
- class ConsisIDPipelineOutput(BaseOutput):
203
- r"""
204
- Output class for ConsisID pipelines.
205
-
206
- Args:
207
- frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
208
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
209
- denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
210
- `(batch_size, num_frames, channels, height, width)`.
211
- """
212
-
213
- frames: torch.Tensor
214
-
215
-
216
  class ConsisIDPipeline(DiffusionPipeline):
217
  r"""
218
  Pipeline for image-to-video generation using ConsisID.
@@ -274,7 +302,7 @@ class ConsisIDPipeline(DiffusionPipeline):
274
 
275
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
276
 
277
- # Copied from diffusers.pipelines.consisid.pipeline_consisID.ConsisIDPipeline._get_t5_prompt_embeds
278
  def _get_t5_prompt_embeds(
279
  self,
280
  prompt: Union[str, List[str]] = None,
@@ -317,7 +345,7 @@ class ConsisIDPipeline(DiffusionPipeline):
317
 
318
  return prompt_embeds
319
 
320
- # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.encode_prompt
321
  def encode_prompt(
322
  self,
323
  prompt: Union[str, List[str]],
@@ -484,7 +512,7 @@ class ConsisIDPipeline(DiffusionPipeline):
484
  latents = latents * self.scheduler.init_noise_sigma
485
  return latents, image_latents
486
 
487
- # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.decode_latents
488
  def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
489
  latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
490
  latents = 1 / self.vae_scaling_factor_image * latents
@@ -583,13 +611,13 @@ class ConsisIDPipeline(DiffusionPipeline):
583
  f" {negative_prompt_embeds.shape}."
584
  )
585
 
586
- # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.fuse_qkv_projections
587
  def fuse_qkv_projections(self) -> None:
588
  r"""Enables fused QKV projections."""
589
  self.fusing_transformer = True
590
  self.transformer.fuse_qkv_projections()
591
 
592
- # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.unfuse_qkv_projections
593
  def unfuse_qkv_projections(self) -> None:
594
  r"""Disable QKV projection fusion if enabled."""
595
  if not self.fusing_transformer:
@@ -598,7 +626,6 @@ class ConsisIDPipeline(DiffusionPipeline):
598
  self.transformer.unfuse_qkv_projections()
599
  self.fusing_transformer = False
600
 
601
- # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline._prepare_rotary_positional_embeddings
602
  def _prepare_rotary_positional_embeddings(
603
  self,
604
  height: int,
@@ -685,7 +712,7 @@ class ConsisIDPipeline(DiffusionPipeline):
685
  The height in pixels of the generated image. This is set to 480 by default for the best results.
686
  width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
687
  The width in pixels of the generated image. This is set to 720 by default for the best results.
688
- num_frames (`int`, defaults to `48`):
689
  Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
690
  contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
691
  num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
@@ -697,7 +724,7 @@ class ConsisIDPipeline(DiffusionPipeline):
697
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
698
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
699
  passed will be used. Must be in descending order.
700
- guidance_scale (`float`, *optional*, defaults to 7.0):
701
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
702
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
703
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -804,6 +831,8 @@ class ConsisIDPipeline(DiffusionPipeline):
804
  self._num_timesteps = len(timesteps)
805
 
806
  # 5. Prepare latents
 
 
807
  if kps_cond is not None:
808
  kps_cond = draw_kps(image, kps_cond)
809
  kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
@@ -920,4 +949,4 @@ class ConsisIDPipeline(DiffusionPipeline):
920
  if not return_dict:
921
  return (video,)
922
 
923
- return ConsisIDPipelineOutput(frames=video)
 
16
  import math
17
  from typing import Callable, Dict, List, Optional, Tuple, Union
18
 
 
 
 
 
19
  import cv2
20
+ import numpy as np
21
+ import PIL
22
  import torch
 
23
  from transformers import T5EncoderModel, T5Tokenizer
24
 
25
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
26
  from diffusers.image_processor import PipelineImageInput
27
+ from diffusers.models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
28
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.consisid.pipeline_output import ConsisIDPipelineOutput
30
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
  from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
32
+ from diffusers.utils import (
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
  from diffusers.utils.torch_utils import randn_tensor
37
  from diffusers.video_processor import VideoProcessor
38
 
 
 
 
 
 
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
 
45
  ```py
46
  >>> import torch
47
  >>> from diffusers import ConsisIDPipeline
48
+ >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
49
+ >>> from diffusers.utils import export_to_video
50
+ >>> from huggingface_hub import snapshot_download
51
+
52
+ >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
53
 
54
+ >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
55
+ ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
56
+ ... )
57
+ >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
58
  >>> pipe.to("cuda")
59
 
60
  >>> prompt = "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."
61
+ >>> image = "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/1.png?raw=true"
62
+
63
+ >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
64
+ ... face_helper_1,
65
+ ... face_clip_model,
66
+ ... face_helper_2,
67
+ ... eva_transform_mean,
68
+ ... eva_transform_std,
69
+ ... face_main_model,
70
+ ... "cuda",
71
+ ... torch.bfloat16,
72
+ ... image,
73
+ ... is_align_face=True,
74
+ ... )
75
+
76
+ >>> video = pipe(
77
+ ... image=image,
78
+ ... prompt=prompt,
79
+ ... num_inference_steps=50,
80
+ ... guidance_scale=6.0,
81
+ ... use_dynamic_cfg=False,
82
+ ... id_vit_hidden=id_vit_hidden,
83
+ ... id_cond=id_cond,
84
+ ... kps_cond=face_kps,
85
+ ... generator=torch.Generator("cuda").manual_seed(42),
86
  ... )
 
87
  >>> export_to_video(video.frames[0], "output.mp4", fps=8)
88
  ```
89
  """
90
 
91
 
92
  def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
93
+ """
94
+ This function draws keypoints and the limbs connecting them on an image.
95
+
96
+ Parameters:
97
+ - image_pil (PIL.Image): Input image as a PIL object.
98
+ - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates.
99
+ - color_list (list of tuples, optional): List of colors (in RGB format) for each keypoint. Default is a set of five
100
+ colors.
101
+
102
+ Returns:
103
+ - PIL.Image: Image with the keypoints and limbs drawn.
104
+ """
105
+
106
  stickwidth = 4
107
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
108
  kps = np.array(kps)
 
133
  return out_img_pil
134
 
135
 
 
 
 
 
 
 
 
 
 
136
  # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
137
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
138
+ """
139
+ This function calculates the resize and crop region for an image to fit a target width and height while preserving
140
+ the aspect ratio.
141
+
142
+ Parameters:
143
+ - src (tuple): A tuple containing the source image's height (h) and width (w).
144
+ - tgt_width (int): The target width to resize the image.
145
+ - tgt_height (int): The target height to resize the image.
146
+
147
+ Returns:
148
+ - tuple: Two tuples representing the crop region:
149
+ 1. The top-left coordinates of the crop region.
150
+ 2. The bottom-right coordinates of the crop region.
151
+ """
152
+
153
  tw = tgt_width
154
  th = tgt_height
155
  h, w = src
 
176
  sigmas: Optional[List[float]] = None,
177
  **kwargs,
178
  ):
179
+ r"""
180
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
181
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
182
 
 
241
  raise AttributeError("Could not access latents of provided encoder_output")
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  class ConsisIDPipeline(DiffusionPipeline):
245
  r"""
246
  Pipeline for image-to-video generation using ConsisID.
 
302
 
303
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
304
 
305
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
306
  def _get_t5_prompt_embeds(
307
  self,
308
  prompt: Union[str, List[str]] = None,
 
345
 
346
  return prompt_embeds
347
 
348
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
349
  def encode_prompt(
350
  self,
351
  prompt: Union[str, List[str]],
 
512
  latents = latents * self.scheduler.init_noise_sigma
513
  return latents, image_latents
514
 
515
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
516
  def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
517
  latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
518
  latents = 1 / self.vae_scaling_factor_image * latents
 
611
  f" {negative_prompt_embeds.shape}."
612
  )
613
 
614
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
615
  def fuse_qkv_projections(self) -> None:
616
  r"""Enables fused QKV projections."""
617
  self.fusing_transformer = True
618
  self.transformer.fuse_qkv_projections()
619
 
620
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
621
  def unfuse_qkv_projections(self) -> None:
622
  r"""Disable QKV projection fusion if enabled."""
623
  if not self.fusing_transformer:
 
626
  self.transformer.unfuse_qkv_projections()
627
  self.fusing_transformer = False
628
 
 
629
  def _prepare_rotary_positional_embeddings(
630
  self,
631
  height: int,
 
712
  The height in pixels of the generated image. This is set to 480 by default for the best results.
713
  width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
714
  The width in pixels of the generated image. This is set to 720 by default for the best results.
715
+ num_frames (`int`, defaults to `49`):
716
  Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
717
  contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
718
  num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
 
724
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
725
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
726
  passed will be used. Must be in descending order.
727
+ guidance_scale (`float`, *optional*, defaults to 6):
728
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
729
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
730
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
 
831
  self._num_timesteps = len(timesteps)
832
 
833
  # 5. Prepare latents
834
+ is_kps = getattr(self.transformer.config, "is_kps", False)
835
+ kps_cond = kps_cond if is_kps else None
836
  if kps_cond is not None:
837
  kps_cond = draw_kps(image, kps_cond)
838
  kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
 
949
  if not return_dict:
950
  return (video,)
951
 
952
+ return ConsisIDPipelineOutput(frames=video)
models/transformer_consisid.py CHANGED
@@ -12,39 +12,340 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- from typing import Any, Dict, Optional, Tuple, Union
16
- import os
17
- import sys
18
- import json
19
  import glob
 
 
 
 
20
 
21
  import torch
22
  from torch import nn
23
- from einops import rearrange, reduce
24
 
25
  from diffusers.configuration_utils import ConfigMixin, register_to_config
26
  from diffusers.loaders import PeftAdapterMixin
27
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
- from diffusers.utils.torch_utils import maybe_allow_in_graph
29
  from diffusers.models.attention import Attention, FeedForward
30
- from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
 
 
 
 
31
  from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
32
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
  from diffusers.models.modeling_utils import ModelMixin
34
  from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
 
 
35
 
36
- import os
37
- import sys
38
- current_file_path = os.path.abspath(__file__)
39
- project_roots = [os.path.dirname(current_file_path)]
40
- for project_root in project_roots:
41
- sys.path.insert(0, project_root) if project_root not in sys.path else None
42
-
43
- from local_facial_extractor import LocalFacialExtractor, PerceiverCrossAttention
44
 
45
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @maybe_allow_in_graph
49
  class ConsisIDBlock(nn.Module):
50
  r"""
@@ -189,7 +490,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
189
  dropout (`float`, defaults to `0.0`):
190
  The dropout probability to use.
191
  attention_bias (`bool`, defaults to `True`):
192
- Whether or not to use bias in the attention projection layers.
193
  sample_width (`int`, defaults to `90`):
194
  The width of the input latents.
195
  sample_height (`int`, defaults to `60`):
@@ -210,7 +511,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
210
  timestep_activation_fn (`str`, defaults to `"silu"`):
211
  Activation function to use when generating the timestep embeddings.
212
  norm_elementwise_affine (`bool`, defaults to `True`):
213
- Whether or not to use elementwise affine in normalization layers.
214
  norm_eps (`float`, defaults to `1e-5`):
215
  The epsilon value to use in normalization layers.
216
  spatial_interpolation_scale (`float`, defaults to `1.875`):
@@ -218,31 +519,57 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
218
  temporal_interpolation_scale (`float`, defaults to `1.0`):
219
  Scaling factor to apply in 3D positional embeddings across temporal dimensions.
220
  is_train_face (`bool`, defaults to `False`):
221
- Whether to use enable the identity-preserving module during the training process.
222
- When set to `True`, the model will focus on identity-preserving tasks.
223
  is_kps (`bool`, defaults to `False`):
224
- Whether to enable keypoint for global facial extractor.
225
- If `True`, keypoints will be in the model.
226
- cross_attn_interval (`int`, defaults to `1`):
227
- The interval between cross-attention layers in the Transformer architecture.
228
- A larger value may reduce the frequency of cross-attention computations,
229
- which can help reduce computational overhead.
230
- LFE_num_tokens (`int`, defaults to `32`):
231
- The number of tokens to use in the Local Facial Extractor (LFE).
232
- This module is responsible for capturing high frequency representations
233
- of the face.
234
- LFE_output_dim (`int`, defaults to `768`):
235
- The output dimension of the Local Facial Extractor (LFE) module.
236
- This dimension determines the size of the feature vectors produced
237
- by the LFE module.
238
- LFE_heads (`int`, defaults to `12`):
239
- The number of attention heads used in the Local Facial Extractor (LFE) module.
240
- More heads may improve the ability to capture diverse features, but
241
- can also increase computational complexity.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  local_face_scale (`float`, defaults to `1.0`):
243
- A scaling factor used to adjust the importance of local facial features
244
- in the model. This can influence how strongly the model focuses on
245
- high frequency face-related content.
246
  """
247
 
248
  _supports_gradient_checkpointing = True
@@ -277,10 +604,18 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
277
  use_learned_positional_embeddings: bool = False,
278
  is_train_face: bool = False,
279
  is_kps: bool = False,
280
- cross_attn_interval: int = 1,
281
- LFE_num_tokens: int = 32,
282
- LFE_output_dim: int = 768,
283
- LFE_heads: int = 12,
 
 
 
 
 
 
 
 
284
  local_face_scale: float = 1.0,
285
  ):
286
  super().__init__()
@@ -352,14 +687,25 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
352
 
353
  # 5. Define identity-preserving config
354
  if is_train_face:
 
 
 
 
 
 
 
 
 
 
 
355
  self.inner_dim = inner_dim
356
  self.cross_attn_interval = cross_attn_interval
357
- self.num_ca = num_layers // cross_attn_interval
358
- self.LFE_num_tokens = LFE_num_tokens
359
- self.LFE_output_dim = LFE_output_dim
360
- self.LFE_heads = LFE_heads
361
- self.LFE_final_output_dim = int(self.inner_dim / 3 * 2)
362
  self.local_face_scale = local_face_scale
 
363
  self._init_face_inputs()
364
 
365
  def _set_gradient_checkpointing(self, module, value=False):
@@ -367,15 +713,28 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
367
 
368
  def _init_face_inputs(self):
369
  device = self.device
370
- weight_dtype = next(self.transformer_blocks.parameters()).dtype
371
- self.local_facial_extractor = LocalFacialExtractor()
 
 
 
 
 
 
 
 
 
 
372
  self.local_facial_extractor.to(device, dtype=weight_dtype)
373
  self.perceiver_cross_attention = nn.ModuleList(
374
  [
375
  PerceiverCrossAttention(
376
- dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim
 
 
 
377
  ).to(device, dtype=weight_dtype)
378
- for _ in range(self.num_ca)
379
  ]
380
  )
381
 
@@ -604,7 +963,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
604
  if not return_dict:
605
  return (output,)
606
  return Transformer2DModelOutput(sample=output)
607
-
608
  @classmethod
609
  def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}):
610
  if subfolder:
@@ -656,7 +1015,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
656
  model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
657
  state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
658
 
659
- tmp_state_dict = {}
660
  for key in state_dict:
661
  if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
662
  tmp_state_dict[key] = state_dict[key]
@@ -667,20 +1026,20 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
667
  m, u = model.load_state_dict(state_dict, strict=False)
668
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
669
  print(m)
670
-
671
  params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
672
  print(f"### Mamba Parameters: {sum(params) / 1e6} M")
673
 
674
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
675
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
676
-
677
  return model
678
-
679
  if __name__ == '__main__':
680
  device = "cuda:0"
681
  weight_dtype = torch.bfloat16
682
  pretrained_model_name_or_path = "BestWishYsh/ConsisID-preview"
683
-
684
  transformer_additional_kwargs={
685
  'torch_dtype': weight_dtype,
686
  'revision': None,
@@ -690,7 +1049,7 @@ if __name__ == '__main__':
690
  'LFE_num_tokens': 32,
691
  'LFE_output_dim': 768,
692
  'LFE_heads': 12,
693
- 'cross_attn_interval': 2,
694
  }
695
 
696
  transformer = ConsisIDTransformer3DModel.from_pretrained_cus(
@@ -723,10 +1082,8 @@ if __name__ == '__main__':
723
  timestep=timesteps,
724
  image_rotary_emb=image_rotary_emb,
725
  return_dict=False,
726
- id_vit_hidden=id_vit_hidden if id_vit_hidden is not None else None,
727
  id_cond=id_cond if id_cond is not None else None,
728
  )[0]
729
-
730
  print(model_output)
731
- # transformer.save_pretrained(os.path.join("./test_ckpt", "transformer"))
732
-
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
 
 
15
  import glob
16
+ import json
17
+ import math
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
 
23
 
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
  from diffusers.loaders import PeftAdapterMixin
 
 
26
  from diffusers.models.attention import Attention, FeedForward
27
+ from diffusers.models.attention_processor import (
28
+ AttentionProcessor,
29
+ CogVideoXAttnProcessor2_0,
30
+ FusedCogVideoXAttnProcessor2_0,
31
+ )
32
  from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
33
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
34
  from diffusers.models.modeling_utils import ModelMixin
35
  from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
36
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
37
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
38
 
 
 
 
 
 
 
 
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
42
 
43
+ def ConsisIDFeedForward(dim, mult=4):
44
+ """
45
+ Creates a consistent ID feedforward block consisting of layer normalization, two linear layers, and a GELU
46
+ activation.
47
+
48
+ Args:
49
+ dim (int): The input dimension of the tensor.
50
+ mult (int, optional): Multiplier for the inner dimension. Default is 4.
51
+
52
+ Returns:
53
+ nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU.
54
+ """
55
+ inner_dim = int(dim * mult)
56
+ return nn.Sequential(
57
+ nn.LayerNorm(dim),
58
+ nn.Linear(dim, inner_dim, bias=False),
59
+ nn.GELU(),
60
+ nn.Linear(inner_dim, dim, bias=False),
61
+ )
62
+
63
+
64
+ def reshape_tensor(x, heads):
65
+ """
66
+ Reshapes the input tensor for multi-head attention.
67
+
68
+ Args:
69
+ x (torch.Tensor): The input tensor with shape (batch_size, length, width).
70
+ heads (int): The number of attention heads.
71
+
72
+ Returns:
73
+ torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
74
+ """
75
+ bs, length, width = x.shape
76
+ x = x.view(bs, length, heads, -1)
77
+ x = x.transpose(1, 2)
78
+ x = x.reshape(bs, heads, length, -1)
79
+ return x
80
+
81
+
82
+ class PerceiverAttention(nn.Module):
83
+ """
84
+ Implements the Perceiver attention mechanism with multi-head attention.
85
+
86
+ This layer takes two inputs: 'x' (image features) and 'latents' (latent features), applying multi-head attention to
87
+ both and producing an output tensor with the same dimension as the input tensor 'x'.
88
+
89
+ Args:
90
+ dim (int): The input dimension.
91
+ dim_head (int, optional): The dimension of each attention head. Default is 64.
92
+ heads (int, optional): The number of attention heads. Default is 8.
93
+ kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values.
94
+ """
95
+
96
+ def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
97
+ super().__init__()
98
+ self.scale = dim_head**-0.5
99
+ self.dim_head = dim_head
100
+ self.heads = heads
101
+ inner_dim = dim_head * heads
102
+
103
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
104
+ self.norm2 = nn.LayerNorm(dim)
105
+
106
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
107
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
108
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
109
+
110
+ def forward(self, x, latents):
111
+ """
112
+ Forward pass for Perceiver attention.
113
+
114
+ Args:
115
+ x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D).
116
+ latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D).
117
+
118
+ Returns:
119
+ torch.Tensor: Output tensor after applying attention and transformation.
120
+ """
121
+ # Apply normalization
122
+ x = self.norm1(x)
123
+ latents = self.norm2(latents)
124
+
125
+ b, seq_len, _ = latents.shape # Get batch size and sequence length
126
+
127
+ # Compute query, key, and value matrices
128
+ q = self.to_q(latents)
129
+ kv_input = torch.cat((x, latents), dim=-2)
130
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
131
+
132
+ # Reshape the tensors for multi-head attention
133
+ q = reshape_tensor(q, self.heads)
134
+ k = reshape_tensor(k, self.heads)
135
+ v = reshape_tensor(v, self.heads)
136
+
137
+ # attention
138
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
139
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
140
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
141
+ out = weight @ v
142
+
143
+ # Reshape and return the final output
144
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
145
+
146
+ return self.to_out(out)
147
+
148
+
149
+ class LocalFacialExtractor(nn.Module):
150
+ def __init__(
151
+ self,
152
+ id_dim=1280,
153
+ vit_dim=1024,
154
+ depth=10,
155
+ dim_head=64,
156
+ heads=16,
157
+ num_id_token=5,
158
+ num_queries=32,
159
+ output_dim=2048,
160
+ ff_mult=4,
161
+ ):
162
+ """
163
+ Initializes the LocalFacialExtractor class.
164
+
165
+ Parameters:
166
+ - id_dim (int): The dimensionality of id features.
167
+ - vit_dim (int): The dimensionality of vit features.
168
+ - depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers.
169
+ - dim_head (int): Dimensionality of each attention head.
170
+ - heads (int): Number of attention heads.
171
+ - num_id_token (int): Number of tokens used for identity features.
172
+ - num_queries (int): Number of query tokens for the latent representation.
173
+ - output_dim (int): Output dimension after projection.
174
+ - ff_mult (int): Multiplier for the feed-forward network hidden dimension.
175
+ """
176
+ super().__init__()
177
+
178
+ # Storing identity token and query information
179
+ self.num_id_token = num_id_token
180
+ self.vit_dim = vit_dim
181
+ self.num_queries = num_queries
182
+ assert depth % 5 == 0
183
+ self.depth = depth // 5
184
+ scale = vit_dim**-0.5
185
+
186
+ # Learnable latent query embeddings
187
+ self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale)
188
+ # Projection layer to map the latent output to the desired dimension
189
+ self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim))
190
+
191
+ # Attention and ConsisIDFeedForward layer stack
192
+ self.layers = nn.ModuleList([])
193
+ for _ in range(depth):
194
+ self.layers.append(
195
+ nn.ModuleList(
196
+ [
197
+ PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
198
+ ConsisIDFeedForward(dim=vit_dim, mult=ff_mult), # ConsisIDFeedForward layer
199
+ ]
200
+ )
201
+ )
202
+
203
+ # Mappings for each of the 5 different ViT features
204
+ for i in range(5):
205
+ setattr(
206
+ self,
207
+ f"mapping_{i}",
208
+ nn.Sequential(
209
+ nn.Linear(vit_dim, vit_dim),
210
+ nn.LayerNorm(vit_dim),
211
+ nn.LeakyReLU(),
212
+ nn.Linear(vit_dim, vit_dim),
213
+ nn.LayerNorm(vit_dim),
214
+ nn.LeakyReLU(),
215
+ nn.Linear(vit_dim, vit_dim),
216
+ ),
217
+ )
218
+
219
+ # Mapping for identity embedding vectors
220
+ self.id_embedding_mapping = nn.Sequential(
221
+ nn.Linear(id_dim, vit_dim),
222
+ nn.LayerNorm(vit_dim),
223
+ nn.LeakyReLU(),
224
+ nn.Linear(vit_dim, vit_dim),
225
+ nn.LayerNorm(vit_dim),
226
+ nn.LeakyReLU(),
227
+ nn.Linear(vit_dim, vit_dim * num_id_token),
228
+ )
229
+
230
+ def forward(self, x, y):
231
+ """
232
+ Forward pass for LocalFacialExtractor.
233
+
234
+ Parameters:
235
+ - x (Tensor): The input identity embedding tensor of shape (batch_size, id_dim).
236
+ - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, vit_dim).
237
+
238
+ Returns:
239
+ - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim).
240
+ """
241
+
242
+ # Repeat latent queries for the batch size
243
+ latents = self.latents.repeat(x.size(0), 1, 1)
244
+
245
+ # Map the identity embedding to tokens
246
+ x = self.id_embedding_mapping(x)
247
+ x = x.reshape(-1, self.num_id_token, self.vit_dim)
248
+
249
+ # Concatenate identity tokens with the latent queries
250
+ latents = torch.cat((latents, x), dim=1)
251
+
252
+ # Process each of the 5 visual feature inputs
253
+ for i in range(5):
254
+ vit_feature = getattr(self, f"mapping_{i}")(y[i])
255
+ ctx_feature = torch.cat((x, vit_feature), dim=1)
256
+
257
+ # Pass through the PerceiverAttention and ConsisIDFeedForward layers
258
+ for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
259
+ latents = attn(ctx_feature, latents) + latents
260
+ latents = ff(latents) + latents
261
+
262
+ # Retain only the query latents
263
+ latents = latents[:, : self.num_queries]
264
+ # Project the latents to the output dimension
265
+ latents = latents @ self.proj_out
266
+ return latents
267
+
268
+
269
+ class PerceiverCrossAttention(nn.Module):
270
+ """
271
+
272
+ Args:
273
+ dim (int): Dimension of the input latent and output. Default is 3072.
274
+ dim_head (int): Dimension of each attention head. Default is 128.
275
+ heads (int): Number of attention heads. Default is 16.
276
+ kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
277
+
278
+ Attributes:
279
+ scale (float): Scaling factor used in dot-product attention for numerical stability.
280
+ norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
281
+ norm2 (nn.LayerNorm): Layer normalization applied to the latent features.
282
+ to_q (nn.Linear): Linear layer for projecting the latent features into queries.
283
+ to_kv (nn.Linear): Linear layer for projecting the input features into keys and values.
284
+ to_out (nn.Linear): Linear layer for outputting the final result after attention.
285
+
286
+ """
287
+
288
+ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
289
+ super().__init__()
290
+ self.scale = dim_head**-0.5
291
+ self.dim_head = dim_head
292
+ self.heads = heads
293
+ inner_dim = dim_head * heads
294
+
295
+ # Layer normalization to stabilize training
296
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
297
+ self.norm2 = nn.LayerNorm(dim)
298
+
299
+ # Linear transformations to produce queries, keys, and values
300
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
301
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
302
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
303
+
304
+ def forward(self, x, latents):
305
+ """
306
+
307
+ Args:
308
+ x (torch.Tensor): Input image features with shape (batch_size, n1, D), where:
309
+ - batch_size (b): Number of samples in the batch.
310
+ - n1: Sequence length (e.g., number of patches or tokens).
311
+ - D: Feature dimension.
312
+
313
+ latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
314
+ - n2: Number of latent elements.
315
+
316
+ Returns:
317
+ torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
318
+
319
+ """
320
+ # Apply layer normalization to the input image and latent features
321
+ x = self.norm1(x)
322
+ latents = self.norm2(latents)
323
+
324
+ b, seq_len, _ = latents.shape
325
+
326
+ # Compute queries, keys, and values
327
+ q = self.to_q(latents)
328
+ k, v = self.to_kv(x).chunk(2, dim=-1)
329
+
330
+ # Reshape tensors to split into attention heads
331
+ q = reshape_tensor(q, self.heads)
332
+ k = reshape_tensor(k, self.heads)
333
+ v = reshape_tensor(v, self.heads)
334
+
335
+ # Compute attention weights
336
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
337
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division
338
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
339
+
340
+ # Compute the output via weighted combination of values
341
+ out = weight @ v
342
+
343
+ # Reshape and permute to prepare for final linear transformation
344
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
345
+
346
+ return self.to_out(out)
347
+
348
+
349
  @maybe_allow_in_graph
350
  class ConsisIDBlock(nn.Module):
351
  r"""
 
490
  dropout (`float`, defaults to `0.0`):
491
  The dropout probability to use.
492
  attention_bias (`bool`, defaults to `True`):
493
+ Whether to use bias in the attention projection layers.
494
  sample_width (`int`, defaults to `90`):
495
  The width of the input latents.
496
  sample_height (`int`, defaults to `60`):
 
511
  timestep_activation_fn (`str`, defaults to `"silu"`):
512
  Activation function to use when generating the timestep embeddings.
513
  norm_elementwise_affine (`bool`, defaults to `True`):
514
+ Whether to use elementwise affine in normalization layers.
515
  norm_eps (`float`, defaults to `1e-5`):
516
  The epsilon value to use in normalization layers.
517
  spatial_interpolation_scale (`float`, defaults to `1.875`):
 
519
  temporal_interpolation_scale (`float`, defaults to `1.0`):
520
  Scaling factor to apply in 3D positional embeddings across temporal dimensions.
521
  is_train_face (`bool`, defaults to `False`):
522
+ Whether to use enable the identity-preserving module during the training process. When set to `True`, the
523
+ model will focus on identity-preserving tasks.
524
  is_kps (`bool`, defaults to `False`):
525
+ Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model.
526
+ cross_attn_interval (`int`, defaults to `2`):
527
+ The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the
528
+ frequency of cross-attention computations, which can help reduce computational overhead.
529
+ cross_attn_dim_head (`int`, optional, defaults to `128`):
530
+ The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A
531
+ larger value increases the capacity to attend to more complex patterns, but also increases memory and
532
+ computation costs.
533
+ cross_attn_num_heads (`int`, optional, defaults to `16`):
534
+ The number of attention heads in the cross-attention layers. More heads allow for more parallel attention
535
+ mechanisms, capturing diverse relationships between different components of the input, but can also
536
+ increase computational requirements.
537
+ LFE_id_dim (`int`, optional, defaults to `1280`):
538
+ The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents
539
+ the identity features of a face, which are important for tasks like face recognition and identity
540
+ preservation across different frames.
541
+ LFE_vit_dim (`int`, optional, defaults to `1024`):
542
+ The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value
543
+ dictates the size of the transformer-generated feature vectors that will be processed for facial feature
544
+ extraction.
545
+ LFE_depth (`int`, optional, defaults to `10`):
546
+ The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture
547
+ more complex representations of facial features, but also increases the computational load.
548
+ LFE_dim_head (`int`, optional, defaults to `64`):
549
+ The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how
550
+ finely the model can process and focus on different parts of the facial features during the extraction
551
+ process.
552
+ LFE_num_heads (`int`, optional, defaults to `16`):
553
+ The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's
554
+ ability to capture diverse facial features, but at the cost of increased computational complexity.
555
+ LFE_num_id_token (`int`, optional, defaults to `5`):
556
+ The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many
557
+ identity-related tokens the model will process to ensure face identity preservation during feature
558
+ extraction.
559
+ LFE_num_querie (`int`, optional, defaults to `32`):
560
+ The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture
561
+ high-frequency face-related information that aids in accurate facial feature extraction.
562
+ LFE_output_dim (`int`, optional, defaults to `2048`):
563
+ The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature
564
+ vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or
565
+ tracking.
566
+ LFE_ff_mult (`int`, optional, defaults to `4`):
567
+ The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
568
+ Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
569
+ transformations, but also increases the computation and memory requirements.
570
  local_face_scale (`float`, defaults to `1.0`):
571
+ A scaling factor used to adjust the importance of local facial features in the model. This can influence
572
+ how strongly the model focuses on high frequency face-related content.
 
573
  """
574
 
575
  _supports_gradient_checkpointing = True
 
604
  use_learned_positional_embeddings: bool = False,
605
  is_train_face: bool = False,
606
  is_kps: bool = False,
607
+ cross_attn_interval: int = 2,
608
+ cross_attn_dim_head: int = 128,
609
+ cross_attn_num_heads: int = 16,
610
+ LFE_id_dim: int = 1280,
611
+ LFE_vit_dim: int = 1024,
612
+ LFE_depth: int = 10,
613
+ LFE_dim_head: int = 64,
614
+ LFE_num_heads: int = 16,
615
+ LFE_num_id_token: int = 5,
616
+ LFE_num_querie: int = 32,
617
+ LFE_output_dim: int = 2048,
618
+ LFE_ff_mult: int = 4,
619
  local_face_scale: float = 1.0,
620
  ):
621
  super().__init__()
 
687
 
688
  # 5. Define identity-preserving config
689
  if is_train_face:
690
+ # LFE configs
691
+ self.LFE_id_dim = LFE_id_dim
692
+ self.LFE_vit_dim = LFE_vit_dim
693
+ self.LFE_depth = LFE_depth
694
+ self.LFE_dim_head = LFE_dim_head
695
+ self.LFE_num_heads = LFE_num_heads
696
+ self.LFE_num_id_token = LFE_num_id_token
697
+ self.LFE_num_querie = LFE_num_querie
698
+ self.LFE_output_dim = LFE_output_dim
699
+ self.LFE_ff_mult = LFE_ff_mult
700
+ # cross configs
701
  self.inner_dim = inner_dim
702
  self.cross_attn_interval = cross_attn_interval
703
+ self.num_cross_attn = num_layers // cross_attn_interval
704
+ self.cross_attn_dim_head = cross_attn_dim_head
705
+ self.cross_attn_num_heads = cross_attn_num_heads
706
+ self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2)
 
707
  self.local_face_scale = local_face_scale
708
+ # face modules
709
  self._init_face_inputs()
710
 
711
  def _set_gradient_checkpointing(self, module, value=False):
 
713
 
714
  def _init_face_inputs(self):
715
  device = self.device
716
+ weight_dtype = self.dtype
717
+ self.local_facial_extractor = LocalFacialExtractor(
718
+ id_dim=self.LFE_id_dim,
719
+ vit_dim=self.LFE_vit_dim,
720
+ depth=self.LFE_depth,
721
+ dim_head=self.LFE_dim_head,
722
+ heads=self.LFE_num_heads,
723
+ num_id_token=self.LFE_num_id_token,
724
+ num_queries=self.LFE_num_querie,
725
+ output_dim=self.LFE_output_dim,
726
+ ff_mult=self.LFE_ff_mult,
727
+ )
728
  self.local_facial_extractor.to(device, dtype=weight_dtype)
729
  self.perceiver_cross_attention = nn.ModuleList(
730
  [
731
  PerceiverCrossAttention(
732
+ dim=self.inner_dim,
733
+ dim_head=self.cross_attn_dim_head,
734
+ heads=self.cross_attn_num_heads,
735
+ kv_dim=self.cross_attn_kv_dim,
736
  ).to(device, dtype=weight_dtype)
737
+ for _ in range(self.num_cross_attn)
738
  ]
739
  )
740
 
 
963
  if not return_dict:
964
  return (output,)
965
  return Transformer2DModelOutput(sample=output)
966
+
967
  @classmethod
968
  def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}):
969
  if subfolder:
 
1015
  model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
1016
  state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
1017
 
1018
+ tmp_state_dict = {}
1019
  for key in state_dict:
1020
  if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1021
  tmp_state_dict[key] = state_dict[key]
 
1026
  m, u = model.load_state_dict(state_dict, strict=False)
1027
  print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1028
  print(m)
1029
+
1030
  params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
1031
  print(f"### Mamba Parameters: {sum(params) / 1e6} M")
1032
 
1033
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1034
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1035
+
1036
  return model
1037
+
1038
  if __name__ == '__main__':
1039
  device = "cuda:0"
1040
  weight_dtype = torch.bfloat16
1041
  pretrained_model_name_or_path = "BestWishYsh/ConsisID-preview"
1042
+
1043
  transformer_additional_kwargs={
1044
  'torch_dtype': weight_dtype,
1045
  'revision': None,
 
1049
  'LFE_num_tokens': 32,
1050
  'LFE_output_dim': 768,
1051
  'LFE_heads': 12,
1052
+ 'cross_attn_interval': 2,
1053
  }
1054
 
1055
  transformer = ConsisIDTransformer3DModel.from_pretrained_cus(
 
1082
  timestep=timesteps,
1083
  image_rotary_emb=image_rotary_emb,
1084
  return_dict=False,
1085
+ id_vit_hidden=id_vit_hidden if id_vit_hidden is not None else None,
1086
  id_cond=id_cond if id_cond is not None else None,
1087
  )[0]
1088
+
1089
  print(model_output)