HarborYuan commited on
Commit
502989e
1 Parent(s): a209a56

add rap_sam

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -4
  2. README.md +3 -5
  3. app/configs/rap_sam_r50_12e_adaptor.py +88 -0
  4. app/models/detectors/__init__.py +1 -0
  5. app/models/detectors/mask2former_vid.py +281 -0
  6. app/models/detectors/rapsam.py +66 -0
  7. app/models/heads/__init__.py +1 -0
  8. app/models/heads/mask2former_vid.py +616 -0
  9. app/models/heads/rapsam_head.py +227 -0
  10. app/models/heads/yoso_head.py +531 -0
  11. app/models/necks/__init__.py +1 -0
  12. app/models/necks/ramsam_neck.py +196 -0
  13. app/models/utils/__init__.py +3 -0
  14. app/models/utils/load_checkpoint.py +38 -0
  15. app/models/utils/mask_pool.py +27 -0
  16. app/models/utils/no_obj.py +1 -0
  17. app/models/utils/video_gt_preprocess.py +87 -0
  18. ext/meta/sam_meta.py +41 -0
  19. ext/open_clip/__init__.py +15 -0
  20. ext/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  21. ext/open_clip/coca_model.py +458 -0
  22. ext/open_clip/constants.py +2 -0
  23. ext/open_clip/factory.py +387 -0
  24. ext/open_clip/generation_utils.py +0 -0
  25. ext/open_clip/hf_configs.py +56 -0
  26. ext/open_clip/hf_model.py +193 -0
  27. ext/open_clip/loss.py +216 -0
  28. ext/open_clip/model.py +473 -0
  29. ext/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
  30. ext/open_clip/model_configs/EVA01-g-14.json +18 -0
  31. ext/open_clip/model_configs/EVA02-B-16.json +18 -0
  32. ext/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
  33. ext/open_clip/model_configs/EVA02-E-14.json +18 -0
  34. ext/open_clip/model_configs/EVA02-L-14-336.json +18 -0
  35. ext/open_clip/model_configs/EVA02-L-14.json +18 -0
  36. ext/open_clip/model_configs/RN101-quickgelu.json +22 -0
  37. ext/open_clip/model_configs/RN101.json +21 -0
  38. ext/open_clip/model_configs/RN50-quickgelu.json +22 -0
  39. ext/open_clip/model_configs/RN50.json +21 -0
  40. ext/open_clip/model_configs/RN50x16.json +21 -0
  41. ext/open_clip/model_configs/RN50x4.json +21 -0
  42. ext/open_clip/model_configs/RN50x64.json +21 -0
  43. ext/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
  44. ext/open_clip/model_configs/ViT-B-16-plus.json +16 -0
  45. ext/open_clip/model_configs/ViT-B-16.json +16 -0
  46. ext/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  47. ext/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  48. ext/open_clip/model_configs/ViT-B-32.json +16 -0
  49. ext/open_clip/model_configs/ViT-H-14.json +17 -0
  50. ext/open_clip/model_configs/ViT-H-16.json +17 -0
.gitattributes CHANGED
@@ -17,10 +17,6 @@
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
  *.safetensors filter=lfs diff=lfs merge=lfs -text
22
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pickle filter=lfs diff=lfs merge=lfs -text
33
+ *.pkl filter=lfs diff=lfs merge=lfs -text
34
+ *.pt filter=lfs diff=lfs merge=lfs -text
35
+ *.pth filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,10 +4,8 @@ emoji: 🌍
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.13.0
8
- app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
+ app_file: main.py
9
  pinned: false
10
+ python_version: 3.10
11
  ---
 
 
app/configs/rap_sam_r50_12e_adaptor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmdet.models import ResNet, MaskFormerFusionHead, CrossEntropyLoss, DiceLoss
2
+
3
+ from app.models.detectors import YOSOVideoSam
4
+ from app.models.heads import RapSAMVideoHead
5
+ from app.models.necks import YOSONeck
6
+
7
+ num_things_classes = 80
8
+ num_stuff_classes = 53
9
+ ov_model_name = 'convnext_large_d_320'
10
+ ov_datasets_name = 'CocoPanopticOVDataset'
11
+ num_classes = num_things_classes + num_stuff_classes
12
+ model = dict(
13
+ type=YOSOVideoSam,
14
+ data_preprocessor=None,
15
+ backbone=dict(
16
+ type=ResNet,
17
+ depth=50,
18
+ num_stages=4,
19
+ out_indices=(0, 1, 2, 3),
20
+ frozen_stages=-1,
21
+ norm_cfg=dict(type='BN', requires_grad=True),
22
+ norm_eval=True,
23
+ init_cfg=None,
24
+ ),
25
+ neck=dict(
26
+ type=YOSONeck,
27
+ agg_dim=128,
28
+ hidden_dim=256,
29
+ backbone_shape=[256, 512, 1024, 2048],
30
+ ),
31
+ panoptic_head=dict(
32
+ type=RapSAMVideoHead,
33
+ prompt_with_kernel_updator=False,
34
+ panoptic_with_kernel_updator=True,
35
+ use_adaptor=True,
36
+ use_kernel_updator=True,
37
+ sphere_cls=True,
38
+ ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
39
+ num_stages=3,
40
+ feat_channels=256,
41
+ num_things_classes=num_things_classes,
42
+ num_stuff_classes=num_stuff_classes,
43
+ num_queries=100,
44
+ loss_cls=dict(
45
+ type=CrossEntropyLoss,
46
+ use_sigmoid=False,
47
+ loss_weight=2.0,
48
+ reduction='mean',
49
+ class_weight=[1.0] * num_classes + [0.1]),
50
+ loss_mask=dict(
51
+ type=CrossEntropyLoss,
52
+ use_sigmoid=True,
53
+ reduction='mean',
54
+ loss_weight=5.0),
55
+ loss_dice=dict(
56
+ type=DiceLoss,
57
+ use_sigmoid=True,
58
+ activate=True,
59
+ reduction='mean',
60
+ naive_dice=True,
61
+ eps=1.0,
62
+ loss_weight=5.0)
63
+ ),
64
+ panoptic_fusion_head=dict(
65
+ type=MaskFormerFusionHead,
66
+ num_things_classes=num_things_classes,
67
+ num_stuff_classes=num_stuff_classes,
68
+ loss_panoptic=None,
69
+ init_cfg=None
70
+ ),
71
+ train_cfg=None,
72
+ test_cfg=dict(
73
+ panoptic_on=True,
74
+ # For now, the dataset does not support
75
+ # evaluating semantic segmentation metric.
76
+ semantic_on=False,
77
+ instance_on=True,
78
+ # max_per_image is for instance segmentation.
79
+ max_per_image=100,
80
+ iou_thr=0.8,
81
+ # In Mask2Former's panoptic postprocessing,
82
+ # it will filter mask area where score is less than 0.5 .
83
+ filter_low_score=True),
84
+ init_cfg=dict(
85
+ type='Pretrained',
86
+ checkpoint='models/rapsam_r50_12e.pth'
87
+ )
88
+ )
app/models/detectors/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rapsam import YOSOVideoSam
app/models/detectors/mask2former_vid.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Dict, List, Tuple
3
+
4
+ import torch
5
+ from mmengine.structures import InstanceData
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+
9
+ from mmdet.registry import MODELS
10
+ from mmdet.structures import SampleList, OptSampleList, TrackDataSample
11
+ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
12
+ from mmdet.models.detectors.single_stage import SingleStageDetector
13
+
14
+ from app.models.utils import mask_pool
15
+
16
+
17
+ @MODELS.register_module()
18
+ class Mask2formerVideo(SingleStageDetector):
19
+ r"""Implementation of `Per-Pixel Classification is
20
+ NOT All You Need for Semantic Segmentation
21
+ <https://arxiv.org/pdf/2107.06278>`_."""
22
+ OVERLAPPING = None
23
+
24
+ def __init__(self,
25
+ backbone: ConfigType,
26
+ neck: OptConfigType = None,
27
+ panoptic_head: OptConfigType = None,
28
+ panoptic_fusion_head: OptConfigType = None,
29
+ train_cfg: OptConfigType = None,
30
+ test_cfg: OptConfigType = None,
31
+ data_preprocessor: OptConfigType = None,
32
+ inference_sam: bool = False,
33
+ init_cfg: OptMultiConfig = None
34
+ ):
35
+ super(SingleStageDetector, self).__init__(
36
+ data_preprocessor=data_preprocessor, init_cfg=init_cfg)
37
+ self.backbone = MODELS.build(backbone)
38
+ if neck is not None:
39
+ self.neck = MODELS.build(neck)
40
+
41
+ panoptic_head_ = panoptic_head.deepcopy()
42
+ panoptic_head_.update(train_cfg=train_cfg)
43
+ panoptic_head_.update(test_cfg=test_cfg)
44
+ self.panoptic_head = MODELS.build(panoptic_head_)
45
+
46
+ panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
47
+ panoptic_fusion_head_.update(test_cfg=test_cfg)
48
+ self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
49
+
50
+ self.num_things_classes = self.panoptic_head.num_things_classes
51
+ self.num_stuff_classes = self.panoptic_head.num_stuff_classes
52
+ self.num_classes = self.panoptic_head.num_classes
53
+
54
+ self.train_cfg = train_cfg
55
+ self.test_cfg = test_cfg
56
+
57
+ self.alpha = 0.4
58
+ self.beta = 0.8
59
+
60
+ self.inference_sam = inference_sam
61
+
62
+ def predict(self,
63
+ batch_inputs: Tensor,
64
+ batch_data_samples: SampleList,
65
+ rescale: bool = True) -> SampleList:
66
+ """Predict results from a batch of inputs and data samples with post-
67
+ processing.
68
+
69
+ Args:
70
+ batch_inputs (Tensor): Inputs with shape (N, C, H, W).
71
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
72
+ Samples. It usually includes information such as
73
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
74
+ rescale (bool): Whether to rescale the results.
75
+ Defaults to True.
76
+
77
+ Returns:
78
+ list[:obj:`DetDataSample`]: Detection results of the
79
+ input images. Each DetDataSample usually contain
80
+ 'pred_instances' and `pred_panoptic_seg`. And the
81
+ ``pred_instances`` usually contains following keys.
82
+
83
+ - scores (Tensor): Classification scores, has a shape
84
+ (num_instance, )
85
+ - labels (Tensor): Labels of bboxes, has a shape
86
+ (num_instances, ).
87
+ - bboxes (Tensor): Has a shape (num_instances, 4),
88
+ the last dimension 4 arrange as (x1, y1, x2, y2).
89
+ - masks (Tensor): Has a shape (num_instances, H, W).
90
+
91
+ And the ``pred_panoptic_seg`` contains the following key
92
+
93
+ - sem_seg (Tensor): panoptic segmentation mask, has a
94
+ shape (1, h, w).
95
+ """
96
+ if isinstance(batch_data_samples[0], TrackDataSample):
97
+ bs, num_frames, three, h, w = batch_inputs.shape
98
+ assert three == 3, "Only supporting images with 3 channels."
99
+ x = batch_inputs.reshape((bs * num_frames, three, h, w))
100
+ feats = self.extract_feat(x)
101
+ else:
102
+ num_frames = 0
103
+ bs = batch_inputs.shape[0]
104
+ feats = self.extract_feat(batch_inputs)
105
+
106
+ mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples)
107
+
108
+ if self.inference_sam:
109
+ for i, data_sample in enumerate(batch_data_samples):
110
+ meta = data_sample.metainfo
111
+ img_height, img_width = meta['img_shape'][:2]
112
+ mask_pred_result = mask_pred_results[i][:, :img_height, :img_width]
113
+ mask_pred_result = mask_pred_result.view(-1, img_height, img_width) > 0
114
+ all_pred_instances = InstanceData(masks=mask_pred_result)
115
+ batch_data_samples[i].pred_instances = all_pred_instances
116
+
117
+ return batch_data_samples
118
+
119
+ if self.OVERLAPPING is not None:
120
+ assert len(self.OVERLAPPING) == self.num_classes
121
+ mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results)
122
+
123
+ if num_frames > 0:
124
+ for frame_id in range(num_frames):
125
+ results_list_img = self.panoptic_fusion_head.predict(
126
+ mask_cls_results,
127
+ mask_pred_results[:, :, frame_id],
128
+ [batch_data_samples[idx][frame_id] for idx in range(bs)],
129
+ rescale=rescale
130
+ )
131
+ _ = self.add_track_pred_to_datasample(
132
+ [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img
133
+ )
134
+ results = batch_data_samples
135
+ else:
136
+ results_list = self.panoptic_fusion_head.predict(
137
+ mask_cls_results,
138
+ mask_pred_results,
139
+ batch_data_samples,
140
+ iou_results=iou_results,
141
+ rescale=rescale
142
+ )
143
+ results = self.add_pred_to_datasample(batch_data_samples, results_list)
144
+
145
+ return results
146
+
147
+ def add_pred_to_datasample(self, data_samples: SampleList,
148
+ results_list: List[dict]) -> SampleList:
149
+ """Add predictions to `DetDataSample`.
150
+
151
+ Args:
152
+ data_samples (list[:obj:`DetDataSample`], optional): A batch of
153
+ data samples that contain annotations and predictions.
154
+ results_list (List[dict]): Instance segmentation, segmantic
155
+ segmentation and panoptic segmentation results.
156
+
157
+ Returns:
158
+ list[:obj:`DetDataSample`]: Detection results of the
159
+ input images. Each DetDataSample usually contain
160
+ 'pred_instances' and `pred_panoptic_seg`. And the
161
+ ``pred_instances`` usually contains following keys.
162
+
163
+ - scores (Tensor): Classification scores, has a shape
164
+ (num_instance, )
165
+ - labels (Tensor): Labels of bboxes, has a shape
166
+ (num_instances, ).
167
+ - bboxes (Tensor): Has a shape (num_instances, 4),
168
+ the last dimension 4 arrange as (x1, y1, x2, y2).
169
+ - masks (Tensor): Has a shape (num_instances, H, W).
170
+
171
+ And the ``pred_panoptic_seg`` contains the following key
172
+
173
+ - sem_seg (Tensor): panoptic segmentation mask, has a
174
+ shape (1, h, w).
175
+ """
176
+ for data_sample, pred_results in zip(data_samples, results_list):
177
+ if 'pan_results' in pred_results:
178
+ data_sample.pred_panoptic_seg = pred_results['pan_results']
179
+
180
+ if 'ins_results' in pred_results:
181
+ data_sample.pred_instances = pred_results['ins_results']
182
+
183
+ assert 'sem_results' not in pred_results
184
+
185
+ return data_samples
186
+
187
+ def add_track_pred_to_datasample(self, data_samples: SampleList, results_list: List[dict]) -> SampleList:
188
+ for data_sample, pred_results in zip(data_samples, results_list):
189
+ if 'pan_results' in pred_results:
190
+ assert self.num_stuff_classes > 0
191
+ data_sample.pred_track_panoptic_seg = pred_results['pan_results']
192
+
193
+ if 'ins_results' in pred_results:
194
+ bboxes = pred_results['ins_results']['bboxes']
195
+ labels = pred_results['ins_results']['labels']
196
+ track_ids = torch.arange(len(bboxes), dtype=labels.dtype, device=bboxes.device) + 1
197
+ pred_results['ins_results']['instances_id'] = track_ids
198
+ data_sample.pred_track_instances = pred_results['ins_results']
199
+
200
+ if 'pro_results' in pred_results:
201
+ data_sample.pred_track_proposal = pred_results['pro_results']
202
+
203
+ assert 'sem_results' not in pred_results
204
+
205
+ return data_samples
206
+
207
+ def _forward(
208
+ self,
209
+ batch_inputs: Tensor,
210
+ batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
211
+ """Network forward process. Usually includes backbone, neck and head
212
+ forward without any post-processing.
213
+
214
+ Args:
215
+ batch_inputs (Tensor): Inputs with shape (N, C, H, W).
216
+ batch_data_samples (list[:obj:`DetDataSample`]): The batch
217
+ data samples. It usually includes information such
218
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
219
+
220
+ Returns:
221
+ tuple[List[Tensor]]: A tuple of features from ``panoptic_head``
222
+ forward.
223
+ """
224
+ if isinstance(batch_data_samples[0], TrackDataSample):
225
+ bs, num_frames, three, h, w = batch_inputs.shape
226
+ assert three == 3, "Only supporting images with 3 channels."
227
+
228
+ x = batch_inputs.reshape((bs * num_frames, three, h, w))
229
+ feats = self.extract_feat(x)
230
+ else:
231
+ feats = self.extract_feat(batch_inputs)
232
+ results = self.panoptic_head.forward(feats, batch_data_samples)
233
+ return results
234
+
235
+ def open_voc_inference(self, feats, mask_cls_results, mask_pred_results):
236
+ if len(mask_pred_results.shape) == 5:
237
+ batch_size = mask_cls_results.shape[0]
238
+ num_frames = mask_pred_results.shape[2]
239
+ mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1)
240
+ else:
241
+ batch_size = mask_cls_results.shape[0]
242
+ num_frames = 0
243
+ clip_feat = self.backbone.get_clip_feature(feats[-1])
244
+ clip_feat_mask = F.interpolate(
245
+ mask_pred_results,
246
+ size=clip_feat.shape[-2:],
247
+ mode='bilinear',
248
+ align_corners=False
249
+ )
250
+ if num_frames > 0:
251
+ clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
252
+ clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
253
+ instance_feat = mask_pool(clip_feat, clip_feat_mask)
254
+ instance_feat = self.backbone.forward_feat(instance_feat)
255
+ clip_logit = self.panoptic_head.forward_logit(instance_feat)
256
+ clip_logit = clip_logit[..., :-1]
257
+ query_logit = mask_cls_results[..., :-1]
258
+
259
+ clip_logit = clip_logit.softmax(-1)
260
+ query_logit = query_logit.softmax(-1)
261
+ overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device)
262
+
263
+ valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to(
264
+ torch.float32)[..., None]
265
+ alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking
266
+ beta = torch.ones_like(clip_logit) * self.beta * valid_masking
267
+
268
+ cls_logits_seen = (
269
+ (query_logit ** (1 - alpha) * clip_logit ** alpha).log()
270
+ * overlapping_mask
271
+ )
272
+ cls_logits_unseen = (
273
+ (query_logit ** (1 - beta) * clip_logit ** beta).log()
274
+ * (1 - overlapping_mask)
275
+ )
276
+ cls_results = cls_logits_seen + cls_logits_unseen
277
+ is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:]
278
+ mask_cls_results = torch.cat([
279
+ cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1)
280
+ mask_cls_results = torch.log(mask_cls_results + 1e-8)
281
+ return mask_cls_results
app/models/detectors/rapsam.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmdet.registry import MODELS
2
+ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
3
+
4
+ from mmdet.models.detectors import SingleStageDetector
5
+
6
+ from .mask2former_vid import Mask2formerVideo
7
+
8
+
9
+ @MODELS.register_module()
10
+ class YOSOVideoSam(Mask2formerVideo):
11
+ OVERLAPPING = None
12
+
13
+ def __init__(self,
14
+ backbone: ConfigType,
15
+ neck: OptConfigType = None,
16
+ panoptic_head: OptConfigType = None,
17
+ panoptic_fusion_head: OptConfigType = None,
18
+ train_cfg: OptConfigType = None,
19
+ test_cfg: OptConfigType = None,
20
+ data_preprocessor: OptConfigType = None,
21
+ inference_sam: bool = False,
22
+ init_cfg: OptMultiConfig = None
23
+ ):
24
+ super(SingleStageDetector, self).__init__(
25
+ data_preprocessor=data_preprocessor, init_cfg=init_cfg)
26
+ self.backbone = MODELS.build(backbone)
27
+ if neck is not None:
28
+ self.neck = MODELS.build(neck)
29
+
30
+ panoptic_head_ = panoptic_head.deepcopy()
31
+ panoptic_head_.update(train_cfg=train_cfg)
32
+ panoptic_head_.update(test_cfg=test_cfg)
33
+ self.panoptic_head = MODELS.build(panoptic_head_)
34
+
35
+ panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
36
+ panoptic_fusion_head_.update(test_cfg=test_cfg)
37
+ self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
38
+
39
+ self.num_things_classes = self.panoptic_head.num_things_classes
40
+ self.num_stuff_classes = self.panoptic_head.num_stuff_classes
41
+ self.num_classes = self.panoptic_head.num_classes
42
+
43
+ self.train_cfg = train_cfg
44
+ self.test_cfg = test_cfg
45
+
46
+ self.alpha = 0.4
47
+ self.beta = 0.8
48
+
49
+ self.inference_sam = inference_sam
50
+
51
+ def predict_with_point(self, x, batch_data_samples):
52
+ feats = self.extract_feat(x)
53
+ mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples)
54
+
55
+ if 'gt_instances_collected' not in batch_data_samples[0]:
56
+ results_list = self.panoptic_fusion_head.predict(
57
+ mask_cls_results,
58
+ mask_pred_results,
59
+ batch_data_samples,
60
+ iou_results=iou_results,
61
+ rescale=False
62
+ )
63
+ mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
64
+ mask_cls_results = mask_cls_results
65
+
66
+ return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
app/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rapsam_head import RapSAMVideoHead
app/models/heads/mask2former_vid.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import os
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.distributed as dist
10
+ from mmcv.cnn import Conv2d
11
+ from mmdet.models import Mask2FormerTransformerDecoder
12
+ from mmengine.dist import get_dist_info
13
+ from mmengine.model import caffe2_xavier_init, ModuleList
14
+ from torch import Tensor
15
+ from mmdet.models.layers import MLP, inverse_sigmoid
16
+ from mmdet.models.layers import coordinate_to_encoding
17
+ from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
18
+
19
+ from mmdet.registry import MODELS, TASK_UTILS
20
+ from mmdet.structures import SampleList, TrackDataSample
21
+ from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig)
22
+ from mmdet.models.layers import SinePositionalEncoding3D
23
+ from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
24
+ from mmcv.cnn.bricks.transformer import MultiheadAttention
25
+ from app.models.utils import mask_pool
26
+
27
+
28
+ @MODELS.register_module()
29
+ class Mask2FormerVideoHead(AnchorFreeHead):
30
+ """Implements the Mask2Former head.
31
+
32
+ See `Masked-attention Mask Transformer for Universal Image
33
+ Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
34
+
35
+ Args:
36
+ in_channels (list[int]): Number of channels in the input feature map.
37
+ feat_channels (int): Number of channels for features.
38
+ out_channels (int): Number of channels for output.
39
+ num_things_classes (int): Number of things.
40
+ num_stuff_classes (int): Number of stuff.
41
+ num_queries (int): Number of query in Transformer decoder.
42
+ pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
43
+ decoder. Defaults to None.
44
+ enforce_decoder_input_project (bool, optional): Whether to add
45
+ a layer to change the embed_dim of tranformer encoder in
46
+ pixel decoder to the embed_dim of transformer decoder.
47
+ Defaults to False.
48
+ transformer_decoder (:obj:`ConfigDict` or dict): Config for
49
+ transformer decoder. Defaults to None.
50
+ positional_encoding (:obj:`ConfigDict` or dict): Config for
51
+ transformer decoder position encoding. Defaults to
52
+ dict(num_feats=128, normalize=True).
53
+ loss_cls (:obj:`ConfigDict` or dict): Config of the classification
54
+ loss. Defaults to None.
55
+ loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
56
+ Defaults to None.
57
+ loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
58
+ Defaults to None.
59
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
60
+ Mask2Former head.
61
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
62
+ Mask2Former head.
63
+ init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
64
+ dict], optional): Initialization config dict. Defaults to None.
65
+ """
66
+
67
+ def __init__(self,
68
+ in_channels: List[int],
69
+ feat_channels: int,
70
+ out_channels: int,
71
+ num_mask_tokens: int = 1,
72
+ num_things_classes: int = 80,
73
+ num_stuff_classes: int = 53,
74
+ num_queries: int = 100,
75
+ num_transformer_feat_level: int = 3,
76
+ pixel_decoder: ConfigType = ...,
77
+ enforce_decoder_input_project: bool = False,
78
+ transformer_decoder: ConfigType = ...,
79
+ positional_encoding: ConfigType = None,
80
+ loss_cls: ConfigType = None,
81
+ loss_mask: ConfigType = None,
82
+ loss_dice: ConfigType = None,
83
+ train_cfg: OptConfigType = None,
84
+ test_cfg: OptConfigType = None,
85
+ init_cfg: OptMultiConfig = None,
86
+ # ov configs
87
+ sphere_cls: bool = False,
88
+ ov_classifier_name: Optional[str] = None,
89
+ logit: Optional[int] = None,
90
+ use_adaptor = False,
91
+ **kwargs) -> None:
92
+ super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
93
+ self.use_adaptor = use_adaptor
94
+
95
+ self.num_mask_tokens = num_mask_tokens
96
+ self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels)
97
+ self.pb_embedding = nn.Embedding(2, feat_channels)
98
+ self.pos_linear = nn.Linear(2 * feat_channels, feat_channels)
99
+
100
+ self.num_things_classes = num_things_classes
101
+ self.num_stuff_classes = num_stuff_classes
102
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
103
+ self.num_queries = num_queries
104
+ self.num_transformer_feat_level = num_transformer_feat_level
105
+ self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
106
+ self.num_transformer_decoder_layers = transformer_decoder.num_layers
107
+ # assert pixel_decoder.encoder.layer_cfg. \
108
+ # self_attn_cfg.num_levels == num_transformer_feat_level
109
+ pixel_decoder_ = copy.deepcopy(pixel_decoder)
110
+ pixel_decoder_.update(
111
+ in_channels=in_channels,
112
+ feat_channels=feat_channels,
113
+ out_channels=out_channels)
114
+ self.pixel_decoder = MODELS.build(pixel_decoder_)
115
+ self.transformer_decoder = Mask2FormerTransformerDecoder(
116
+ **transformer_decoder)
117
+ self.decoder_embed_dims = self.transformer_decoder.embed_dims
118
+
119
+ self.decoder_input_projs = ModuleList()
120
+ # from low resolution to high resolution
121
+ for _ in range(num_transformer_feat_level):
122
+ if (self.decoder_embed_dims != feat_channels
123
+ or enforce_decoder_input_project):
124
+ self.decoder_input_projs.append(
125
+ Conv2d(
126
+ feat_channels, self.decoder_embed_dims, kernel_size=1))
127
+ else:
128
+ self.decoder_input_projs.append(nn.Identity())
129
+ self.decoder_positional_encoding = SinePositionalEncoding3D(
130
+ **positional_encoding)
131
+ self.query_embed = nn.Embedding(self.num_queries, feat_channels)
132
+ self.query_feat = nn.Embedding(self.num_queries, feat_channels)
133
+ # from low resolution to high resolution
134
+ self.level_embed = nn.Embedding(self.num_transformer_feat_level,
135
+ feat_channels)
136
+
137
+ if not sphere_cls:
138
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
139
+ self.mask_embed = nn.Sequential(
140
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
141
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
142
+ nn.Linear(feat_channels, out_channels))
143
+
144
+ self.iou_embed = nn.Sequential(
145
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
146
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
147
+ nn.Linear(feat_channels, 1))
148
+
149
+ self.test_cfg = test_cfg
150
+ self.train_cfg = train_cfg
151
+ if train_cfg:
152
+ self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
153
+ self.sampler = TASK_UTILS.build(
154
+ self.train_cfg['sampler'], default_args=dict(context=self))
155
+ self.num_points = self.train_cfg.get('num_points', 12544)
156
+ self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
157
+ self.importance_sample_ratio = self.train_cfg.get(
158
+ 'importance_sample_ratio', 0.75)
159
+
160
+ self.class_weight = loss_cls.class_weight
161
+ self.loss_cls = MODELS.build(loss_cls)
162
+ self.loss_mask = MODELS.build(loss_mask)
163
+ self.loss_dice = MODELS.build(loss_dice)
164
+
165
+ # prepare OV things
166
+ # OV cls embed
167
+ if sphere_cls:
168
+ rank, world_size = get_dist_info()
169
+ if ov_classifier_name is None:
170
+ _dim = 1024 # temporally hard code
171
+ cls_embed = torch.empty(self.num_classes, _dim)
172
+ torch.nn.init.orthogonal_(cls_embed)
173
+ cls_embed = cls_embed[:, None]
174
+ else:
175
+ # ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth")
176
+ ov_path = os.path.join('./models/', f"{ov_classifier_name}.pth")
177
+ cls_embed = torch.load(ov_path)
178
+ cls_embed_norm = cls_embed.norm(p=2, dim=-1)
179
+ assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm))
180
+ if self.loss_cls and self.loss_cls.use_sigmoid:
181
+ pass
182
+ else:
183
+ _dim = cls_embed.size(2)
184
+ _prototypes = cls_embed.size(1)
185
+
186
+ if rank == 0:
187
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
188
+ # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
189
+ else:
190
+ back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
191
+ if world_size > 1:
192
+ dist.broadcast(back_token, src=0)
193
+ back_token = back_token.to(device='cpu')
194
+ cls_embed = torch.cat([
195
+ cls_embed, back_token.repeat(_prototypes, 1)[None]
196
+ ], dim=0)
197
+ self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False)
198
+
199
+ # cls embd proj
200
+ cls_embed_dim = self.cls_embed.size(0)
201
+ self.cls_proj = nn.Sequential(
202
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
203
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
204
+ nn.Linear(feat_channels, cls_embed_dim)
205
+ )
206
+
207
+ # Haobo Yuan:
208
+ # For the logit_scale, I refer to this issue.
209
+ # https://github.com/openai/CLIP/issues/46#issuecomment-945062212
210
+ # https://github.com/openai/CLIP/issues/46#issuecomment-782558799
211
+ # Based on my understanding, it is a mistake of CLIP.
212
+ # Because they mention that they refer to InstDisc (Wu, 2018) paper.
213
+ # InstDisc set a non-learnable temperature to np.log(1 / 0.07).
214
+ # 4.6052 is np.log(1 / 0.01)
215
+ # np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01)
216
+ if logit is None:
217
+ logit_scale = torch.tensor(4.6052, dtype=torch.float32)
218
+ else:
219
+ logit_scale = torch.tensor(logit, dtype=torch.float32)
220
+ self.register_buffer('logit_scale', logit_scale, persistent=False)
221
+
222
+ # Mask Pooling
223
+ self.mask_pooling = mask_pool
224
+ self.mask_pooling_proj = nn.Sequential(
225
+ nn.LayerNorm(feat_channels),
226
+ nn.Linear(feat_channels, feat_channels)
227
+ )
228
+
229
+ if use_adaptor:
230
+ cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8)
231
+ self.panoptic_attn = MultiheadAttention(**cross_attn_cfg)
232
+ self.panoptic_norm = nn.LayerNorm(256)
233
+ if sphere_cls:
234
+ cls_embed_dim = self.cls_embed.size(0)
235
+ self.panoptic_cls = nn.Sequential(
236
+ nn.Linear(feat_channels, cls_embed_dim)
237
+ )
238
+ else:
239
+ raise NotImplementedError
240
+ self.prompt_attn = MultiheadAttention(**cross_attn_cfg)
241
+ self.prompt_norm = nn.LayerNorm(256)
242
+ self.prompt_iou = nn.Linear(256, 1)
243
+
244
+ def init_weights(self) -> None:
245
+ for m in self.decoder_input_projs:
246
+ if isinstance(m, Conv2d):
247
+ caffe2_xavier_init(m, bias=0)
248
+
249
+ self.pixel_decoder.init_weights()
250
+
251
+ for p in self.transformer_decoder.parameters():
252
+ if p.dim() > 1:
253
+ nn.init.xavier_normal_(p)
254
+
255
+ def forward_logit(self, cls_embd):
256
+ cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
257
+ cls_pred = cls_pred.max(-1).values
258
+ cls_pred = self.logit_scale.exp() * cls_pred
259
+ return cls_pred
260
+
261
+ def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
262
+ attn_mask_target_size: Tuple[int, int],
263
+ num_frames: int = 0) -> Tuple[Tensor]:
264
+ """Forward for head part which is called after every decoder layer.
265
+
266
+ Args:
267
+ decoder_out (Tensor): in shape (batch_size, num_queries, c).
268
+ mask_feature (Tensor): in shape (batch_size, c, h, w).
269
+ attn_mask_target_size (tuple[int, int]): target attention
270
+ mask size.
271
+
272
+ Returns:
273
+ tuple: A tuple contain three elements.
274
+
275
+ - cls_pred (Tensor): Classification scores in shape \
276
+ (batch_size, num_queries, cls_out_channels). \
277
+ Note `cls_out_channels` should includes background.
278
+ - mask_pred (Tensor): Mask scores in shape \
279
+ (batch_size, num_queries,h, w).
280
+ - attn_mask (Tensor): Attention mask in shape \
281
+ (batch_size * num_heads, num_queries, h, w).
282
+ - num_frames: How many frames are there in video.
283
+ """
284
+ decoder_out = self.transformer_decoder.post_norm(decoder_out)
285
+ # shape (num_queries, batch_size, c)
286
+ if isinstance(self.cls_embed, nn.Module):
287
+ cls_pred = self.cls_embed(decoder_out)
288
+ # shape (num_queries, batch_size, c)
289
+ mask_embed = self.mask_embed(decoder_out)
290
+ # shape (num_queries, batch_size, h, w)
291
+ mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
292
+
293
+ if not isinstance(self.cls_embed, nn.Module):
294
+ maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach())
295
+ maskpool_embd = self.mask_pooling_proj(maskpool_embd)
296
+ cls_embd = self.cls_proj(maskpool_embd + decoder_out)
297
+ cls_pred = self.forward_logit(cls_embd)
298
+
299
+ iou_pred = self.iou_embed(decoder_out)
300
+
301
+ if num_frames > 0:
302
+ assert len(mask_pred.shape) == 4
303
+ assert mask_pred.shape[2] % num_frames == 0
304
+ frame_h = mask_pred.shape[2] // num_frames
305
+ num_q = mask_pred.shape[1]
306
+ _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2)
307
+ attn_mask = F.interpolate(
308
+ _mask_pred,
309
+ attn_mask_target_size,
310
+ mode='bilinear',
311
+ align_corners=False)
312
+ attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3)
313
+ else:
314
+ attn_mask = F.interpolate(
315
+ mask_pred,
316
+ attn_mask_target_size,
317
+ mode='bilinear',
318
+ align_corners=False)
319
+ # shape (num_queries, batch_size, h, w) ->
320
+ # (batch_size * num_head, num_queries, h, w)
321
+ attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
322
+ (1, self.num_heads, 1, 1)).flatten(0, 1)
323
+ attn_mask = attn_mask.sigmoid() < 0.5
324
+ attn_mask = attn_mask.detach()
325
+
326
+ return cls_pred, mask_pred, iou_pred, attn_mask
327
+
328
+ def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
329
+ """Forward function.
330
+
331
+ Args:
332
+ x (list[Tensor]): Multi scale Features from the
333
+ upstream network, each is a 4D-tensor.
334
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
335
+ Samples. It usually includes information such as
336
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
337
+
338
+ Returns:
339
+ tuple[list[Tensor]]: A tuple contains two elements.
340
+
341
+ - cls_pred_list (list[Tensor)]: Classification logits \
342
+ for each decoder layer. Each is a 3D-tensor with shape \
343
+ (batch_size, num_queries, cls_out_channels). \
344
+ Note `cls_out_channels` should includes background.
345
+ - mask_pred_list (list[Tensor]): Mask logits for each \
346
+ decoder layer. Each with shape (batch_size, num_queries, \
347
+ h, w).
348
+ """
349
+ batch_img_metas = []
350
+ if isinstance(batch_data_samples[0], TrackDataSample):
351
+ for track_sample in batch_data_samples:
352
+ cur_list = []
353
+ for det_sample in track_sample:
354
+ cur_list.append(det_sample.metainfo)
355
+ batch_img_metas.append(cur_list)
356
+ num_frames = len(batch_img_metas[0])
357
+ else:
358
+ for data_sample in batch_data_samples:
359
+ batch_img_metas.append(data_sample.metainfo)
360
+ num_frames = 0
361
+ batch_size = len(batch_img_metas)
362
+ #(bs_nf, c, h,w)
363
+ mask_features, multi_scale_memorys = self.pixel_decoder(x)
364
+ if num_frames > 0:
365
+ mask_features = mask_features.unflatten(0, (batch_size, num_frames))
366
+ mask_features = mask_features.transpose(1, 2).flatten(2, 3) #(bs, c, nf*h,w)
367
+ decoder_inputs = []
368
+ decoder_positional_encodings = []
369
+ for i in range(self.num_transformer_feat_level):
370
+ decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) #(bs_nf, c, h,w)
371
+ decoder_input = decoder_input.flatten(2).permute(0, 2, 1) #(bs_nf,h*w, c)
372
+ if num_frames > 0:
373
+ decoder_input = decoder_input.unflatten(0, (batch_size, num_frames))
374
+ decoder_input = decoder_input.flatten(1, 2) #(bs, nf*h*w, c)
375
+ level_embed = self.level_embed.weight[i].view(1, 1, -1)
376
+ decoder_input = decoder_input + level_embed
377
+
378
+ # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
379
+ num_frames_real = 1 if num_frames == 0 else num_frames
380
+ mask = decoder_input.new_zeros(
381
+ (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:],
382
+ dtype=torch.bool)
383
+ decoder_positional_encoding = self.decoder_positional_encoding(
384
+ mask)
385
+ decoder_positional_encoding = decoder_positional_encoding.transpose(
386
+ 1, 2).flatten(2).permute(0, 2, 1)
387
+ decoder_inputs.append(decoder_input) #(bs, nf*h*w, c)
388
+ decoder_positional_encodings.append(decoder_positional_encoding) #(bs, nf*h*w, c)
389
+
390
+ if self.prompt_training:
391
+ query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo(
392
+ batch_data_samples)
393
+ query_embed = coordinate_to_encoding(input_query_bbox.sigmoid())
394
+ query_embed = self.pos_linear(query_embed)
395
+ else:
396
+ query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1))
397
+ query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1))
398
+ self_attn_mask = None
399
+
400
+ cls_pred_list = []
401
+ mask_pred_list = []
402
+ iou_pred_list = []
403
+ cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head(
404
+ query_feat, mask_features, multi_scale_memorys[0].shape[-2:],
405
+ num_frames=num_frames
406
+ )
407
+ cls_pred_list.append(cls_pred)
408
+ iou_pred_list.append(iou_pred)
409
+ if num_frames > 0: #(bs, 100, nf*h, w)-->(bs, 100, nf, h, w)
410
+ mask_pred = mask_pred.unflatten(2, (num_frames, -1))
411
+ mask_pred_list.append(mask_pred)
412
+
413
+ for i in range(self.num_transformer_decoder_layers):
414
+ level_idx = i % self.num_transformer_feat_level
415
+ # if a mask is all True(all background), then set it all False.
416
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
417
+
418
+ # cross_attn + self_attn
419
+ layer = self.transformer_decoder.layers[i]
420
+ query_feat = layer(
421
+ query=query_feat, #(bs, 100, c)
422
+ key=decoder_inputs[level_idx], #(bs, nf*h*w, c)
423
+ value=decoder_inputs[level_idx],
424
+ query_pos=query_embed,
425
+ key_pos=decoder_positional_encodings[level_idx],
426
+ cross_attn_mask=attn_mask,
427
+ self_attn_mask=self_attn_mask,
428
+ query_key_padding_mask=None,
429
+ # here we do not apply masking on padded region
430
+ key_padding_mask=None)
431
+ cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head(
432
+ query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:],
433
+ num_frames=num_frames
434
+ )
435
+
436
+ cls_pred_list.append(cls_pred)
437
+ iou_pred_list.append(iou_pred)
438
+ if num_frames > 0:
439
+ mask_pred = mask_pred.unflatten(2, (num_frames, -1))
440
+ mask_pred_list.append(mask_pred)
441
+
442
+ if self.use_adaptor:
443
+ keys = mask_features.flatten(2).transpose(1, 2).contiguous()
444
+ h, w = mask_features.shape[-2] // num_frames_real, mask_features.shape[-1]
445
+ mask = decoder_input.new_zeros((batch_size, num_frames_real, h, w), dtype=torch.bool)
446
+ key_pos = self.decoder_positional_encoding(mask)
447
+ key_pos = key_pos.transpose(1, 2).flatten(2).permute(0, 2, 1)
448
+ if not self.prompt_training:
449
+ object_kernels = self.panoptic_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed)
450
+ object_kernels = self.panoptic_norm(object_kernels)
451
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
452
+
453
+ cls_embd = self.panoptic_cls(object_kernels)
454
+ cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
455
+ cls_scores = cls_scores.max(-1).values
456
+ cls_scores = self.logit_scale.exp() * cls_scores
457
+
458
+ if num_frames > 0:
459
+ mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1)))
460
+ else:
461
+ mask_pred_list.append(mask_preds)
462
+ cls_pred_list.append(cls_scores)
463
+ iou_pred_list.append(iou_pred_list[-1])
464
+ else:
465
+ object_kernels = self.prompt_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed)
466
+ object_kernels = self.prompt_norm(object_kernels)
467
+ iou_preds = self.prompt_iou(object_kernels)
468
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
469
+
470
+ if num_frames > 0:
471
+ mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1)))
472
+ else:
473
+ mask_pred_list.append(mask_preds)
474
+ cls_pred_list.append(cls_pred_list[-1])
475
+ iou_pred_list.append(iou_preds)
476
+
477
+ return cls_pred_list, mask_pred_list, iou_pred_list, query_feat
478
+
479
+ def predict(self, x: Tuple[Tensor],
480
+ batch_data_samples: SampleList,
481
+ return_query=False,
482
+ ) -> Tuple[Tensor, ...]:
483
+ """Test without augmentaton.
484
+
485
+ Args:
486
+ return_query:
487
+ x (tuple[Tensor]): Multi-level features from the
488
+ upstream network, each is a 4D-tensor.
489
+ batch_data_samples (List[:obj:`DetDataSample`]): The Data
490
+ Samples. It usually includes information such as
491
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
492
+
493
+ Returns:
494
+ tuple[Tensor]: A tuple contains two tensors.
495
+
496
+ - mask_cls_results (Tensor): Mask classification logits,\
497
+ shape (batch_size, num_queries, cls_out_channels).
498
+ Note `cls_out_channels` should includes background.
499
+ - mask_pred_results (Tensor): Mask logits, shape \
500
+ (batch_size, num_queries, h, w).
501
+ """
502
+ self.prompt_training = False
503
+ data_sample = batch_data_samples[0]
504
+ if isinstance(data_sample, TrackDataSample):
505
+ img_shape = data_sample[0].metainfo['batch_input_shape']
506
+ num_frames = len(data_sample)
507
+ else:
508
+ if 'gt_instances_collected' in data_sample:
509
+ self.prompt_training = True
510
+ img_shape = data_sample.metainfo['batch_input_shape']
511
+ num_frames = 0
512
+ all_cls_scores, all_mask_preds, all_iou_preds, query_feat = self(x, batch_data_samples)
513
+ mask_cls_results = all_cls_scores[-1]
514
+ mask_pred_results = all_mask_preds[-1]
515
+ iou_results = all_iou_preds[-1]
516
+
517
+ if num_frames > 0:
518
+ mask_pred_results = mask_pred_results.flatten(1, 2)
519
+ mask_pred_results = F.interpolate(
520
+ mask_pred_results,
521
+ size=(img_shape[0], img_shape[1]),
522
+ mode='bilinear',
523
+ align_corners=False)
524
+ if num_frames > 0:
525
+ num_queries = mask_cls_results.shape[1]
526
+ mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames))
527
+
528
+ if return_query:
529
+ return mask_cls_results, mask_pred_results, query_feat, iou_results
530
+ else:
531
+ return mask_cls_results, mask_pred_results, iou_results
532
+
533
+ def prepare_for_dn_mo(self, batch_data_samples):
534
+ scalar, noise_scale = 100, 0.4
535
+ gt_instances = [t.gt_instances_collected for t in batch_data_samples]
536
+
537
+ point_coords = torch.stack([inst.point_coords for inst in gt_instances])
538
+ pb_labels = torch.stack([inst['pb_labels'] for inst in gt_instances])
539
+ labels = torch.zeros_like(pb_labels).long()
540
+
541
+ boxes = point_coords # + boxes
542
+
543
+ factors = []
544
+ for i, data_sample in enumerate(batch_data_samples):
545
+ h, w, = data_sample.metainfo['img_shape']
546
+ factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1)
547
+ factors.append(factor)
548
+ factors = torch.stack(factors, 0)
549
+
550
+ boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ????
551
+ # box_start = [t['box_start'] for t in targets]
552
+ box_start = [len(point) for point in point_coords]
553
+
554
+ known_labels = labels
555
+ known_pb_labels = pb_labels
556
+ known_bboxs = boxes
557
+
558
+ known_labels_expaned = known_labels.clone()
559
+ known_pb_labels_expaned = known_pb_labels.clone()
560
+ known_bbox_expand = known_bboxs.clone()
561
+
562
+ if noise_scale > 0 and self.training:
563
+ diff = torch.zeros_like(known_bbox_expand)
564
+ diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2
565
+ diff[:, :, 2:] = known_bbox_expand[:, :, 2:]
566
+ # add very small noise to input points; no box
567
+ sc = 0.01
568
+ for i, st in enumerate(box_start):
569
+ diff[i, :st] = diff[i, :st] * sc
570
+ known_bbox_expand += torch.mul(
571
+ (torch.rand_like(known_bbox_expand) * 2 - 1.0),
572
+ diff) * noise_scale
573
+
574
+ known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)
575
+
576
+ input_label_embed = self.pb_embedding(known_pb_labels_expaned)
577
+
578
+ input_bbox_embed = inverse_sigmoid(known_bbox_expand)
579
+
580
+ input_label_embed = input_label_embed.repeat_interleave(
581
+ self.num_mask_tokens,
582
+ 1) + self.mask_tokens.weight.unsqueeze(0).repeat(
583
+ input_label_embed.shape[0], input_label_embed.shape[1], 1)
584
+ input_bbox_embed = input_bbox_embed.repeat_interleave(
585
+ self.num_mask_tokens, 1)
586
+
587
+ single_pad = self.num_mask_tokens
588
+
589
+ # NOTE scalar is modified to 100, each click cannot see each other
590
+ scalar = int(input_label_embed.shape[1] / self.num_mask_tokens)
591
+
592
+ pad_size = input_label_embed.shape[1]
593
+
594
+ if input_label_embed.shape[1] > 0:
595
+ input_query_label = input_label_embed
596
+ input_query_bbox = input_bbox_embed
597
+
598
+ tgt_size = pad_size
599
+ attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
600
+ # match query cannot see the reconstruct
601
+ attn_mask[pad_size:, :pad_size] = True
602
+ # reconstruct cannot see each other
603
+ for i in range(scalar):
604
+ if i == 0:
605
+ attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True
606
+ if i == scalar - 1:
607
+ attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True
608
+ else:
609
+ attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True
610
+ attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True
611
+ mask_dict = {
612
+ 'known_lbs_bboxes': (known_labels, known_bboxs),
613
+ 'pad_size': pad_size,
614
+ 'scalar': scalar,
615
+ }
616
+ return input_query_label, input_query_bbox, attn_mask, mask_dict
app/models/heads/rapsam_head.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from mmdet.models.layers import coordinate_to_encoding
9
+
10
+ from mmdet.registry import MODELS, TASK_UTILS
11
+ from mmdet.structures import SampleList, TrackDataSample
12
+ from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig)
13
+ from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
14
+
15
+ from mmcv.cnn.bricks.transformer import MultiheadAttention
16
+
17
+ from .mask2former_vid import Mask2FormerVideoHead
18
+ from .yoso_head import CrossAttenHead, KernelUpdator
19
+
20
+ @MODELS.register_module()
21
+ class RapSAMVideoHead(Mask2FormerVideoHead):
22
+
23
+ def __init__(self,
24
+ frozen_head=False,
25
+ frozen_pred=False,
26
+ use_adaptor=False,
27
+ prompt_with_kernel_updator=False,
28
+ panoptic_with_kernel_updator=False,
29
+ num_mask_tokens = 1,
30
+ num_stages = 3,
31
+ use_kernel_updator=False,
32
+ sphere_cls = False,
33
+ ov_classifier_name = None,
34
+ temperature=0.1,
35
+ feat_channels=256,
36
+ num_things_classes: int = 80,
37
+ num_stuff_classes: int = 53,
38
+ num_queries: int = 100,
39
+ loss_cls: ConfigType = None,
40
+ loss_mask: ConfigType = None,
41
+ loss_dice: ConfigType = None,
42
+ train_cfg: OptConfigType = None,
43
+ test_cfg: OptConfigType = None,
44
+ init_cfg: OptMultiConfig = None,
45
+ matching_whole_map: bool = False,
46
+ enable_box_query: bool = False,
47
+ **kwargs) -> None:
48
+ super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
49
+ self.prompt_with_kernel_updator = prompt_with_kernel_updator
50
+ self.panoptic_with_kernel_updator = panoptic_with_kernel_updator
51
+ self.use_adaptor = use_adaptor
52
+
53
+ self.num_mask_tokens = num_mask_tokens
54
+ self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels)
55
+ self.pb_embedding = nn.Embedding(2, feat_channels)
56
+ self.pos_linear = nn.Linear(2 * feat_channels, feat_channels)
57
+
58
+ self.matching_whole_map = matching_whole_map
59
+ self.enable_box_query = enable_box_query
60
+
61
+ self.num_things_classes = num_things_classes
62
+ self.num_stuff_classes = num_stuff_classes
63
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
64
+ self.num_queries = num_queries
65
+ self.feat_channels = feat_channels
66
+ self.num_stages = num_stages
67
+ self.kernels = nn.Embedding(self.num_queries, feat_channels)
68
+ self.mask_heads = nn.ModuleList()
69
+ for _ in range(self.num_stages):
70
+ self.mask_heads.append(CrossAttenHead(
71
+ self.num_classes, self.feat_channels, self.num_queries,
72
+ use_kernel_updator=use_kernel_updator,
73
+ frozen_head=frozen_head, frozen_pred=frozen_pred,
74
+ sphere_cls=sphere_cls,
75
+ ov_classifier_name=ov_classifier_name, with_iou_pred=True))
76
+ self.temperature = temperature
77
+
78
+ if use_adaptor:
79
+ cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8)
80
+ if self.panoptic_with_kernel_updator:
81
+ self.panoptic_attn = KernelUpdator(feat_channels=256)
82
+ self.panoptic_norm = nn.Identity()
83
+ if sphere_cls:
84
+ cls_embed_dim = self.mask_heads[0].fc_cls.size(0)
85
+ self.panoptic_cls = nn.Sequential(
86
+ nn.Linear(feat_channels, cls_embed_dim)
87
+ )
88
+ else:
89
+ raise NotImplementedError
90
+ self.panoptic_cls = nn.Linear(256, self.num_classes+1)
91
+ else:
92
+ self.panoptic_attn = MultiheadAttention(**cross_attn_cfg)
93
+ self.panoptic_norm = nn.LayerNorm(256)
94
+ if sphere_cls:
95
+ cls_embed_dim = self.mask_heads[0].fc_cls.size(0)
96
+ self.panoptic_cls = nn.Sequential(
97
+ nn.Linear(feat_channels, cls_embed_dim)
98
+ )
99
+ else:
100
+ raise NotImplementedError
101
+ self.panoptic_cls = nn.Linear(256, self.num_classes+1)
102
+
103
+ if self.prompt_with_kernel_updator:
104
+ self.prompt_attn = KernelUpdator(feat_channels=256)
105
+ self.prompt_norm = nn.Identity()
106
+ self.prompt_iou = nn.Linear(256, 1)
107
+ else:
108
+ self.prompt_attn = MultiheadAttention(**cross_attn_cfg)
109
+ self.prompt_norm = nn.LayerNorm(256)
110
+ self.prompt_iou = nn.Linear(256, 1)
111
+
112
+ self.test_cfg = test_cfg
113
+ self.train_cfg = train_cfg
114
+ if train_cfg:
115
+ self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
116
+ self.sampler = TASK_UTILS.build(
117
+ self.train_cfg['sampler'], default_args=dict(context=self))
118
+ self.num_points = self.train_cfg.get('num_points', 12544)
119
+ self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
120
+ self.importance_sample_ratio = self.train_cfg.get(
121
+ 'importance_sample_ratio', 0.75)
122
+
123
+ self.class_weight = loss_cls.class_weight
124
+ self.loss_cls = MODELS.build(loss_cls)
125
+ self.loss_mask = MODELS.build(loss_mask)
126
+ self.loss_dice = MODELS.build(loss_dice)
127
+
128
+
129
+ def init_weights(self) -> None:
130
+ pass
131
+
132
+ def forward(self, x, batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
133
+ batch_img_metas = []
134
+ if isinstance(batch_data_samples[0], TrackDataSample):
135
+ for track_sample in batch_data_samples:
136
+ cur_list = []
137
+ for det_sample in track_sample:
138
+ cur_list.append(det_sample.metainfo)
139
+ batch_img_metas.append(cur_list)
140
+ num_frames = len(batch_img_metas[0])
141
+ else:
142
+ for data_sample in batch_data_samples:
143
+ batch_img_metas.append(data_sample.metainfo)
144
+ num_frames = 0
145
+ bs = len(batch_img_metas)
146
+
147
+ all_cls_scores = []
148
+ all_masks_preds = []
149
+ all_iou_preds = []
150
+ if self.prompt_training:
151
+ input_query_label, input_query_bbox, self_attn_mask, mask_dict = self.prepare_for_dn_mo(
152
+ batch_data_samples)
153
+ pos_embed = coordinate_to_encoding(input_query_bbox.sigmoid())
154
+ pos_embed = self.pos_linear(pos_embed)
155
+ object_kernels = input_query_label + pos_embed
156
+ else:
157
+ object_kernels = self.kernels.weight[None].repeat(bs, 1, 1)
158
+ self_attn_mask = None
159
+ mask_features = x
160
+ if num_frames > 0: # (bs*num_frames, c, h, w) -> (bs, c, num_frames*h, w)
161
+ mask_features = mask_features.unflatten(0, (bs, num_frames))
162
+ mask_features = mask_features.transpose(1, 2).flatten(2, 3)
163
+
164
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
165
+ for stage in range(self.num_stages):
166
+ mask_head = self.mask_heads[stage]
167
+ cls_scores, mask_preds, iou_preds, object_kernels = mask_head(
168
+ mask_features, object_kernels, mask_preds, self_attn_mask)
169
+ cls_scores = cls_scores / self.temperature
170
+ all_iou_preds.append(iou_preds)
171
+ all_cls_scores.append(cls_scores)
172
+ if num_frames > 0:
173
+ #(bs,num_query, num_frames*h, w) --> (bs,num_query,num_frames,h,w)
174
+ all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
175
+ else:
176
+ all_masks_preds.append(mask_preds)
177
+
178
+ if self.use_adaptor:
179
+ keys = mask_features.flatten(2).transpose(1, 2).contiguous()
180
+ if not self.prompt_training:
181
+ if self.panoptic_with_kernel_updator:
182
+ hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float()
183
+ f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features)
184
+ object_kernels = self.panoptic_attn(f, object_kernels)
185
+ object_kernels = self.panoptic_norm(object_kernels)
186
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
187
+ else:
188
+ object_kernels = self.panoptic_attn(object_kernels, keys)
189
+ object_kernels = self.panoptic_norm(object_kernels)
190
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
191
+ cls_embd = self.panoptic_cls(object_kernels)
192
+ cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.mask_heads[0].fc_cls)
193
+ cls_scores = cls_scores.max(-1).values
194
+ cls_scores = self.mask_heads[0].logit_scale.exp() * cls_scores
195
+
196
+ if num_frames > 0:
197
+ all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
198
+ else:
199
+ all_masks_preds.append(mask_preds)
200
+ all_cls_scores.append(cls_scores)
201
+ all_iou_preds.append(all_iou_preds[-1])
202
+ else:
203
+ if self.prompt_with_kernel_updator:
204
+ hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float()
205
+ f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features)
206
+ object_kernels = self.prompt_attn(f, object_kernels)
207
+ object_kernels = self.prompt_norm(object_kernels)
208
+ iou_preds = self.prompt_iou(object_kernels)
209
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
210
+ else:
211
+ object_kernels = self.prompt_attn(object_kernels, keys)
212
+ object_kernels = self.prompt_norm(object_kernels)
213
+ iou_preds = self.prompt_iou(object_kernels)
214
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
215
+ if num_frames > 0:
216
+ all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
217
+ else:
218
+ all_masks_preds.append(mask_preds)
219
+ all_cls_scores.append(all_cls_scores[-1])
220
+ all_iou_preds.append(iou_preds)
221
+ return all_cls_scores, all_masks_preds, all_iou_preds, object_kernels
222
+
223
+ def get_targets(self, *args, **kwargs):
224
+ raise NotImplementedError
225
+
226
+ def loss_by_feat(self, *args, **kwargs):
227
+ raise NotImplementedError
app/models/heads/yoso_head.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import os
3
+ import torch.distributed as dist
4
+ from torch import Tensor
5
+ from mmdet.registry import MODELS, TASK_UTILS
6
+ from mmdet.models.dense_heads import AnchorFreeHead
7
+ from mmdet.structures import SampleList
8
+ from mmdet.models.dense_heads import Mask2FormerHead
9
+ import math
10
+ from mmengine.model.weight_init import trunc_normal_
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ from mmcv.cnn import build_activation_layer, build_norm_layer
15
+
16
+ from mmengine.dist import get_dist_info
17
+
18
+
19
+ @MODELS.register_module()
20
+ class YOSOHead(Mask2FormerHead):
21
+ def __init__(self,
22
+ num_cls_fcs=1,
23
+ num_mask_fcs=1,
24
+ sphere_cls=False,
25
+ ov_classifier_name=None,
26
+ use_kernel_updator=False,
27
+ num_stages=3,
28
+ feat_channels=256,
29
+ out_channels=256,
30
+ num_things_classes=80,
31
+ num_stuff_classes=53,
32
+ num_classes=133,
33
+ num_queries=100,
34
+ temperature=0.1,
35
+ loss_cls=dict(
36
+ type='CrossEntropyLoss',
37
+ use_sigmoid=False,
38
+ loss_weight=2.0,
39
+ reduction='mean',
40
+ class_weight=[1.0] * 133 + [0.1]),
41
+ loss_mask=dict(
42
+ type='CrossEntropyLoss',
43
+ use_sigmoid=True,
44
+ reduction='mean',
45
+ loss_weight=5.0),
46
+ loss_dice=dict(
47
+ type='DiceLoss',
48
+ use_sigmoid=True,
49
+ activate=True,
50
+ reduction='mean',
51
+ naive_dice=True,
52
+ eps=1.0,
53
+ loss_weight=5.0),
54
+ train_cfg=None,
55
+ test_cfg=None,
56
+ init_cfg=None):
57
+ super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
58
+ self.num_stages = num_stages
59
+ self.feat_channels = feat_channels
60
+ self.out_channels = out_channels
61
+ self.num_things_classes = num_things_classes
62
+ self.num_stuff_classes = num_stuff_classes
63
+ self.num_classes = num_classes
64
+ self.num_queries = num_queries
65
+ self.temperature = temperature
66
+
67
+ self.test_cfg = test_cfg
68
+ self.train_cfg = train_cfg
69
+ if train_cfg:
70
+ self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
71
+ self.sampler = TASK_UTILS.build(
72
+ self.train_cfg['sampler'], default_args=dict(context=self))
73
+ self.num_points = self.train_cfg.get('num_points', 12544)
74
+ self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
75
+ self.importance_sample_ratio = self.train_cfg.get(
76
+ 'importance_sample_ratio', 0.75)
77
+
78
+ self.class_weight = loss_cls.class_weight
79
+ self.loss_cls = MODELS.build(loss_cls)
80
+ self.loss_mask = MODELS.build(loss_mask)
81
+ self.loss_dice = MODELS.build(loss_dice)
82
+
83
+ self.kernels = nn.Embedding(self.num_queries, self.feat_channels)
84
+
85
+ self.mask_heads = nn.ModuleList()
86
+ for _ in range(self.num_stages):
87
+ self.mask_heads.append(CrossAttenHead(
88
+ self.num_classes, self.feat_channels, self.num_queries,
89
+ use_kernel_updator=use_kernel_updator,
90
+ sphere_cls=sphere_cls, ov_classifier_name=ov_classifier_name,
91
+ num_cls_fcs=num_cls_fcs, num_mask_fcs=num_mask_fcs
92
+ ))
93
+
94
+ def init_weights(self) -> None:
95
+ super(AnchorFreeHead, self).init_weights()
96
+
97
+ def forward(self, x: List[Tensor],
98
+ batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
99
+ all_cls_scores = []
100
+ all_masks_preds = []
101
+ proposal_kernels = self.kernels.weight
102
+ object_kernels = proposal_kernels[None].repeat(x.shape[0], 1, 1)
103
+ mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, x)
104
+
105
+ for stage in range(self.num_stages):
106
+ mask_head = self.mask_heads[stage]
107
+ cls_scores, mask_preds, iou_pred, object_kernels = mask_head(x, object_kernels, mask_preds)
108
+ cls_scores = cls_scores / self.temperature
109
+
110
+ all_cls_scores.append(cls_scores)
111
+ all_masks_preds.append(mask_preds)
112
+
113
+ return all_cls_scores, all_masks_preds
114
+
115
+ def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> Tuple[Tensor]:
116
+ batch_img_metas = [
117
+ data_sample.metainfo for data_sample in batch_data_samples
118
+ ]
119
+ all_cls_scores, all_mask_preds = self(x, batch_data_samples)
120
+ mask_cls_results = all_cls_scores[-1]
121
+ mask_pred_results = all_mask_preds[-1]
122
+
123
+ # upsample masks
124
+ img_shape = batch_img_metas[0]['batch_input_shape']
125
+ mask_pred_results = F.interpolate(
126
+ mask_pred_results,
127
+ size=(img_shape[0], img_shape[1]),
128
+ mode='bilinear',
129
+ align_corners=False)
130
+
131
+ return mask_cls_results, mask_pred_results
132
+
133
+
134
+ class FFN(nn.Module):
135
+
136
+ def __init__(self,
137
+ embed_dims=256,
138
+ feedforward_channels=1024,
139
+ num_fcs=2,
140
+ add_identity=True):
141
+ super(FFN, self).__init__()
142
+ self.embed_dims = embed_dims
143
+ self.feedforward_channels = feedforward_channels
144
+ self.num_fcs = num_fcs
145
+
146
+ layers = []
147
+ in_channels = embed_dims
148
+ for _ in range(num_fcs - 1):
149
+ layers.append(nn.Sequential(
150
+ nn.Linear(in_channels, feedforward_channels),
151
+ nn.ReLU(True),
152
+ nn.Dropout(0.0)))
153
+ in_channels = feedforward_channels
154
+ layers.append(nn.Linear(feedforward_channels, embed_dims))
155
+ layers.append(nn.Dropout(0.0))
156
+ self.layers = nn.Sequential(*layers)
157
+ self.add_identity = add_identity
158
+ self.dropout_layer = nn.Dropout(0.0)
159
+
160
+ def forward(self, x, identity=None):
161
+ out = self.layers(x)
162
+ if not self.add_identity:
163
+ return self.dropout_layer(out)
164
+ if identity is None:
165
+ identity = x
166
+ return identity + self.dropout_layer(out)
167
+
168
+
169
+ class DySepConvAtten(nn.Module):
170
+ def __init__(self, hidden_dim, num_proposals, conv_kernel_size_1d):
171
+ super(DySepConvAtten, self).__init__()
172
+ self.hidden_dim = hidden_dim
173
+ self.num_proposals = num_proposals
174
+ self.kernel_size = conv_kernel_size_1d
175
+
176
+ self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size)
177
+ self.norm = nn.LayerNorm(self.hidden_dim)
178
+
179
+ def forward(self, query, value):
180
+ assert query.shape == value.shape
181
+ B, N, C = query.shape
182
+
183
+ dy_conv_weight = self.weight_linear(query)
184
+ dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B, self.num_proposals, 1, self.kernel_size)
185
+ dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B, self.num_proposals, self.num_proposals,
186
+ 1)
187
+
188
+ res = []
189
+ value = value.unsqueeze(1)
190
+ for i in range(B):
191
+ out = F.relu(F.conv1d(input=value[i], weight=dy_depth_conv_weight[i], groups=N, padding='same'))
192
+ out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same')
193
+ res.append(out)
194
+
195
+ point_out = torch.cat(res, dim=0)
196
+ point_out = self.norm(point_out)
197
+ return point_out
198
+
199
+
200
+ class KernelUpdator(nn.Module):
201
+
202
+ def __init__(self,
203
+ in_channels=256,
204
+ feat_channels=64,
205
+ out_channels=None,
206
+ input_feat_shape=3,
207
+ gate_sigmoid=True,
208
+ gate_norm_act=False,
209
+ activate_out=False,
210
+ act_cfg=dict(type='ReLU', inplace=True),
211
+ norm_cfg=dict(type='LN')):
212
+ super(KernelUpdator, self).__init__()
213
+ self.in_channels = in_channels
214
+ self.feat_channels = feat_channels
215
+ self.out_channels_raw = out_channels
216
+ self.gate_sigmoid = gate_sigmoid
217
+ self.gate_norm_act = gate_norm_act
218
+ self.activate_out = activate_out
219
+ if isinstance(input_feat_shape, int):
220
+ input_feat_shape = [input_feat_shape] * 2
221
+ self.input_feat_shape = input_feat_shape
222
+ self.act_cfg = act_cfg
223
+ self.norm_cfg = norm_cfg
224
+ self.out_channels = out_channels if out_channels else in_channels
225
+
226
+ self.num_params_in = self.feat_channels
227
+ self.num_params_out = self.feat_channels
228
+ self.dynamic_layer = nn.Linear(
229
+ self.in_channels, self.num_params_in + self.num_params_out)
230
+ self.input_layer = nn.Linear(self.in_channels,
231
+ self.num_params_in + self.num_params_out,
232
+ 1)
233
+ self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
234
+ self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
235
+ if self.gate_norm_act:
236
+ self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
237
+
238
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
239
+ self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
240
+ self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
241
+ self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
242
+
243
+ self.activation = build_activation_layer(act_cfg)
244
+
245
+ self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
246
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
247
+
248
+ def forward(self, update_feature, input_feature):
249
+ """
250
+ Args:
251
+ update_feature (torch.Tensor): [bs, num_proposals, in_channels]
252
+ input_feature (torch.Tensor): [bs, num_proposals, in_channels]
253
+ """
254
+ bs, num_proposals, _ = update_feature.shape
255
+
256
+ parameters = self.dynamic_layer(update_feature)
257
+ param_in = parameters[..., :self.num_params_in]
258
+ param_out = parameters[..., -self.num_params_out:]
259
+
260
+ input_feats = self.input_layer(input_feature)
261
+ input_in = input_feats[..., :self.num_params_in]
262
+ input_out = input_feats[..., -self.num_params_out:]
263
+
264
+ gate_feats = input_in * param_in
265
+ if self.gate_norm_act:
266
+ gate_feats = self.activation(self.gate_norm(gate_feats))
267
+
268
+ input_gate = self.input_norm_in(self.input_gate(gate_feats))
269
+ update_gate = self.norm_in(self.update_gate(gate_feats))
270
+ if self.gate_sigmoid:
271
+ input_gate = input_gate.sigmoid()
272
+ update_gate = update_gate.sigmoid()
273
+ param_out = self.norm_out(param_out)
274
+ input_out = self.input_norm_out(input_out)
275
+
276
+ if self.activate_out:
277
+ param_out = self.activation(param_out)
278
+ input_out = self.activation(input_out)
279
+
280
+ # param_out has shape (batch_size, feat_channels, out_channels)
281
+ features = update_gate * param_out + input_gate * input_out
282
+
283
+ features = self.fc_layer(features)
284
+ features = self.fc_norm(features)
285
+ features = self.activation(features)
286
+
287
+ return features
288
+
289
+
290
+ class CrossAttenHead(nn.Module):
291
+
292
+ def __init__(self,
293
+ num_classes,
294
+ in_channels,
295
+ num_proposals,
296
+ frozen_head=False,
297
+ frozen_pred=False,
298
+ with_iou_pred=False,
299
+ sphere_cls=False,
300
+ ov_classifier_name=None,
301
+ num_cls_fcs=1,
302
+ num_mask_fcs=1,
303
+ conv_kernel_size_1d=3,
304
+ conv_kernel_size_2d=1,
305
+ use_kernel_updator=False):
306
+ super(CrossAttenHead, self).__init__()
307
+ self.sphere_cls = sphere_cls
308
+ self.with_iou_pred = with_iou_pred
309
+ self.frozen_head = frozen_head
310
+ self.frozen_pred = frozen_pred
311
+ self.num_cls_fcs = num_cls_fcs
312
+ self.num_mask_fcs = num_mask_fcs
313
+ self.num_classes = num_classes
314
+ self.conv_kernel_size_2d = conv_kernel_size_2d
315
+
316
+ self.hidden_dim = in_channels
317
+ self.feat_channels = in_channels
318
+ self.num_proposals = num_proposals
319
+ self.hard_mask_thr = 0.5
320
+ self.use_kernel_updator = use_kernel_updator
321
+ # assert use_kernel_updator
322
+ if use_kernel_updator:
323
+ self.kernel_update = KernelUpdator(
324
+ in_channels=256,
325
+ feat_channels=256,
326
+ out_channels=256,
327
+ input_feat_shape=3,
328
+ act_cfg=dict(type='ReLU', inplace=True),
329
+ norm_cfg=dict(type='LN')
330
+ )
331
+ else:
332
+ self.f_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d)
333
+ self.f_dropout = nn.Dropout(0.0)
334
+ self.f_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
335
+ self.k_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d)
336
+ self.k_dropout = nn.Dropout(0.0)
337
+ self.k_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
338
+
339
+ self.s_atten = nn.MultiheadAttention(embed_dim=self.hidden_dim *
340
+ self.conv_kernel_size_2d ** 2,
341
+ num_heads=8,
342
+ dropout=0.0)
343
+ self.s_dropout = nn.Dropout(0.0)
344
+ self.s_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
345
+
346
+ self.ffn = FFN(self.hidden_dim, feedforward_channels=2048, num_fcs=2)
347
+ self.ffn_norm = nn.LayerNorm(self.hidden_dim)
348
+
349
+ self.cls_fcs = nn.ModuleList()
350
+ for _ in range(self.num_cls_fcs):
351
+ self.cls_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False))
352
+ self.cls_fcs.append(nn.LayerNorm(self.hidden_dim))
353
+ self.cls_fcs.append(nn.ReLU(True))
354
+
355
+ if sphere_cls:
356
+ rank, world_size = get_dist_info()
357
+ if ov_classifier_name is None:
358
+ _dim = 1024 # temporally hard code
359
+ cls_embed = torch.empty(self.num_classes, _dim)
360
+ torch.nn.init.orthogonal_(cls_embed)
361
+ cls_embed = cls_embed[:, None]
362
+ else:
363
+ ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth")
364
+ cls_embed = torch.load(ov_path)
365
+ cls_embed_norm = cls_embed.norm(p=2, dim=-1)
366
+ assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm))
367
+
368
+ # background class
369
+ _dim = cls_embed.size(2)
370
+ _prototypes = cls_embed.size(1)
371
+ if rank == 0:
372
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
373
+ else:
374
+ back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
375
+ if world_size > 1:
376
+ dist.broadcast(back_token, src=0)
377
+ back_token = back_token.to(device='cpu')
378
+ cls_embed = torch.cat([
379
+ cls_embed, back_token.repeat(_prototypes, 1)[None]
380
+ ], dim=0)
381
+ self.register_buffer('fc_cls', cls_embed.permute(2, 0, 1).contiguous(), persistent=False)
382
+
383
+ # cls embd proj
384
+ cls_embed_dim = self.fc_cls.size(0)
385
+ self.cls_proj = nn.Sequential(
386
+ nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True),
387
+ nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True),
388
+ nn.Linear(self.hidden_dim, cls_embed_dim)
389
+ )
390
+
391
+ logit_scale = torch.tensor(4.6052, dtype=torch.float32)
392
+ self.register_buffer('logit_scale', logit_scale, persistent=False)
393
+ else:
394
+ self.fc_cls = nn.Linear(self.hidden_dim, self.num_classes + 1)
395
+
396
+ self.mask_fcs = nn.ModuleList()
397
+ for _ in range(self.num_mask_fcs):
398
+ self.mask_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False))
399
+ self.mask_fcs.append(nn.LayerNorm(self.hidden_dim))
400
+ self.mask_fcs.append(nn.ReLU(True))
401
+ self.fc_mask = nn.Linear(self.hidden_dim, self.hidden_dim)
402
+
403
+ if self.with_iou_pred:
404
+ self.iou_embed = nn.Sequential(
405
+ nn.Linear(self.hidden_dim, self.hidden_dim),
406
+ nn.ReLU(inplace=True),
407
+ nn.Linear(self.hidden_dim, self.hidden_dim),
408
+ nn.ReLU(inplace=True),
409
+ nn.Linear(self.hidden_dim, 1),
410
+ )
411
+ prior_prob = 0.01
412
+ self.bias_value = -math.log((1 - prior_prob) / prior_prob)
413
+
414
+ self.apply(self._init_weights)
415
+ if not sphere_cls:
416
+ nn.init.constant_(self.fc_cls.bias, self.bias_value)
417
+
418
+ if self.frozen_head:
419
+ self._frozen_head()
420
+ if self.frozen_pred:
421
+ self._frozen_pred()
422
+
423
+ def _init_weights(self, m):
424
+ # print("init weights")
425
+ if isinstance(m, nn.Linear):
426
+ trunc_normal_(m.weight, std=.02)
427
+ if isinstance(m, nn.Linear) and m.bias is not None:
428
+ nn.init.constant_(m.bias, 0)
429
+ elif isinstance(m, nn.LayerNorm):
430
+ nn.init.constant_(m.bias, 0)
431
+ nn.init.constant_(m.weight, 1.0)
432
+
433
+ def _frozen_head(self):
434
+ for n, p in self.kernel_update.named_parameters():
435
+ p.requires_grad = False
436
+ for n, p in self.s_atten.named_parameters():
437
+ p.requires_grad = False
438
+ for n, p in self.s_dropout.named_parameters():
439
+ p.requires_grad = False
440
+ for n, p in self.s_atten_norm.named_parameters():
441
+ p.requires_grad = False
442
+ for n, p in self.ffn.named_parameters():
443
+ p.requires_grad = False
444
+ for n, p in self.ffn_norm.named_parameters():
445
+ p.requires_grad = False
446
+
447
+ def _frozen_pred(self):
448
+ # frozen cls_fcs, fc_cls, mask_fcs, fc_mask
449
+ for n, p in self.cls_fcs.named_parameters():
450
+ p.requires_grad = False
451
+ for n, p in self.fc_cls.named_parameters():
452
+ p.requires_grad = False
453
+ for n, p in self.mask_fcs.named_parameters():
454
+ p.requires_grad = False
455
+ for n, p in self.fc_mask.named_parameters():
456
+ p.requires_grad = False
457
+
458
+ def train(self, mode):
459
+ super().train(mode)
460
+ if self.frozen_head:
461
+ self.kernel_update.eval()
462
+ self.s_atten.eval()
463
+ self.s_dropout.eval()
464
+ self.s_atten_norm.eval()
465
+ self.ffn.eval()
466
+ self.ffn_norm.eval()
467
+ if self.frozen_pred:
468
+ self.cls_fcs.eval()
469
+ self.fc_cls.eval()
470
+ self.mask_fcs.eval()
471
+ self.fc_mask.eval()
472
+
473
+ def forward(self, features, proposal_kernels, mask_preds, self_attn_mask=None):
474
+ B, C, H, W = features.shape
475
+
476
+ soft_sigmoid_masks = mask_preds.sigmoid()
477
+ nonzero_inds = soft_sigmoid_masks > self.hard_mask_thr
478
+ hard_sigmoid_masks = nonzero_inds.float()
479
+
480
+ # [B, N, C]
481
+ f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, features)
482
+ # [B, N, C, K, K] -> [B, N, C * K * K]
483
+ num_proposals = proposal_kernels.shape[1]
484
+ k = proposal_kernels.view(B, num_proposals, -1)
485
+
486
+ # ----
487
+ if self.use_kernel_updator:
488
+ k = self.kernel_update(f, k)
489
+ else:
490
+ f_tmp = self.f_atten(k, f)
491
+ f = f + self.f_dropout(f_tmp)
492
+ f = self.f_atten_norm(f)
493
+
494
+ f_tmp = self.k_atten(k, f)
495
+ f = f + self.k_dropout(f_tmp)
496
+ k = self.k_atten_norm(f)
497
+
498
+ # [N, B, C]
499
+ k = k.permute(1, 0, 2)
500
+
501
+ k_tmp = self.s_atten(query=k, key=k, value=k, attn_mask=self_attn_mask)[0]
502
+ k = k + self.s_dropout(k_tmp)
503
+ k = self.s_atten_norm(k.permute(1, 0, 2))
504
+
505
+ obj_feat = self.ffn_norm(self.ffn(k))
506
+
507
+ cls_feat = obj_feat
508
+ mask_feat = obj_feat
509
+
510
+ for cls_layer in self.cls_fcs:
511
+ cls_feat = cls_layer(cls_feat)
512
+
513
+ if self.sphere_cls:
514
+ cls_embd = self.cls_proj(cls_feat) # FIXME Too much cls linear (cls_fcs + cls_proj)
515
+ cls_score = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.fc_cls)
516
+ cls_score = cls_score.max(-1).values
517
+ cls_score = self.logit_scale.exp() * cls_score
518
+ else:
519
+ cls_score = self.fc_cls(cls_feat)
520
+ for reg_layer in self.mask_fcs:
521
+ mask_feat = reg_layer(mask_feat)
522
+ # [B, N, K * K, C] -> [B, N, C]
523
+ mask_kernels = self.fc_mask(mask_feat)
524
+
525
+ new_mask_preds = torch.einsum("bqc,bchw->bqhw", mask_kernels, features)
526
+ if self.with_iou_pred:
527
+ iou_pred = self.iou_embed(mask_feat)
528
+ iou_pred = iou_pred
529
+ else:
530
+ iou_pred = None
531
+ return cls_score, new_mask_preds, iou_pred, obj_feat
app/models/necks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ramsam_neck import YOSONeck
app/models/necks/ramsam_neck.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from mmengine.model import kaiming_init
6
+ from mmdet.registry import MODELS
7
+ from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
8
+
9
+
10
+ class DeformLayer(nn.Module):
11
+
12
+ def __init__(self,
13
+ in_planes,
14
+ out_planes,
15
+ deconv_kernel=4,
16
+ deconv_stride=2,
17
+ deconv_pad=1,
18
+ deconv_out_pad=0,
19
+ modulate_deform=True,
20
+ num_groups=1,
21
+ deform_num_groups=1,
22
+ dilation=1):
23
+ super(DeformLayer, self).__init__()
24
+ self.deform_modulated = modulate_deform
25
+ if modulate_deform:
26
+ deform_conv_op = ModulatedDeformConv2d
27
+ offset_channels = 27
28
+ else:
29
+ deform_conv_op = DeformConv2d
30
+ offset_channels = 18
31
+
32
+ self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1 * dilation, dilation=dilation)
33
+ self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1 * dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups)
34
+ for layer in [self.dcn]:
35
+ kaiming_init(layer)
36
+
37
+ nn.init.constant_(self.dcn_offset.weight, 0)
38
+ nn.init.constant_(self.dcn_offset.bias, 0)
39
+
40
+ # nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
41
+ self.dcn_bn = nn.SyncBatchNorm(out_planes)
42
+ self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False)
43
+ self._deconv_init()
44
+ # nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
45
+ self.up_bn = nn.SyncBatchNorm(out_planes)
46
+ self.relu = nn.ReLU()
47
+
48
+ def forward(self, x):
49
+ out = x
50
+ if self.deform_modulated:
51
+ offset_mask = self.dcn_offset(out)
52
+ offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
53
+ offset = torch.cat((offset_x, offset_y), dim=1)
54
+ mask = mask.sigmoid()
55
+ out = self.dcn(out, offset, mask)
56
+ else:
57
+ offset = self.dcn_offset(out)
58
+ out = self.dcn(out, offset)
59
+ x = out
60
+
61
+ x = self.dcn_bn(x)
62
+ x = self.relu(x)
63
+ x = self.up_sample(x)
64
+ x = self.up_bn(x)
65
+ x = self.relu(x)
66
+ return x
67
+
68
+ def _deconv_init(self):
69
+ w = self.up_sample.weight.data
70
+ f = math.ceil(w.size(2) / 2)
71
+ c = (2 * f - 1 - f % 2) / (2. * f)
72
+ for i in range(w.size(2)):
73
+ for j in range(w.size(3)):
74
+ w[0, 0, i, j] = \
75
+ (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
76
+ for c in range(1, w.size(0)):
77
+ w[c, 0, :, :] = w[0, 0, :, :]
78
+
79
+ class LiteDeformConv(nn.Module):
80
+ def __init__(self, agg_dim, backbone_shape):
81
+ super(LiteDeformConv, self).__init__()
82
+ in_channels = []
83
+ out_channels = [agg_dim]
84
+ for feat in backbone_shape:
85
+ in_channels.append(feat)
86
+ out_channels.append(feat//2)
87
+
88
+ self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0)
89
+
90
+ self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2])
91
+
92
+ self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0)
93
+
94
+ self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3])
95
+
96
+ self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0)
97
+
98
+ self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4])
99
+
100
+ self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0)
101
+
102
+ # self.fuse_conv = nn.Conv2d(in_channels=sum(out_channels[1:]), out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
103
+ self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
104
+
105
+ self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True)
106
+ self.bias.data.fill_(0.0)
107
+
108
+ self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
109
+ self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
110
+ self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
111
+ self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
112
+
113
+ def forward(self, features_list):
114
+ p5 = self.lateral_conv0(features_list[-1])
115
+ x5 = p5
116
+ x = self.deform_conv1(x5)
117
+
118
+ p4 = self.lateral_conv1(features_list[-2])
119
+ x4 = p4 + x
120
+ x = self.deform_conv2(x4)
121
+
122
+ p3 = self.lateral_conv2(features_list[-3])
123
+ x3 = p3 + x
124
+ x = self.deform_conv3(x3)
125
+
126
+ p2 = self.lateral_conv3(features_list[-4])
127
+ x2 = p2 + x
128
+
129
+ # CFA
130
+ x5 = self.conv_a5(x5)
131
+ x4 = self.conv_a4(x4)
132
+ x3 = self.conv_a3(x3)
133
+
134
+ _x5 = F.interpolate(x5, scale_factor=8, align_corners=False, mode='bilinear')
135
+ _x4 = F.interpolate(x4, scale_factor=4, align_corners=False, mode='bilinear')
136
+ _x3 = F.interpolate(x3, scale_factor=2, align_corners=False, mode='bilinear')
137
+ x2 = self.conv_a2(x2)
138
+ x = _x5 + _x4 + _x3 + x2 + self.bias
139
+
140
+ x = self.output_conv(x)
141
+
142
+ return x, (x5, x4, x3)
143
+
144
+
145
+
146
+ @MODELS.register_module()
147
+ class YOSONeck(nn.Module):
148
+
149
+ def __init__(self,
150
+ agg_dim,
151
+ hidden_dim,
152
+ backbone_shape,
153
+ return_multi_scale=False,
154
+ return_single_scale=False,
155
+ #Just for compatibility with Mask2Former, not actually used
156
+ in_channels=None,
157
+ feat_channels=None,
158
+ out_channels=None
159
+ ):
160
+ super().__init__()
161
+ # in_channels == backbone_shape
162
+ # hidden_dim == feat_channels == out_channels == 256
163
+ self.return_single_scale = return_single_scale
164
+ self.return_multi_scale = return_multi_scale
165
+ self.deconv = LiteDeformConv(agg_dim=agg_dim, backbone_shape=backbone_shape)
166
+
167
+ self.loc_conv = nn.Conv2d(in_channels=agg_dim + 2, out_channels=hidden_dim, kernel_size=1, stride=1)
168
+ self.init_weights()
169
+
170
+ def init_weights(self) -> None:
171
+ for p in self.parameters():
172
+ if p.dim() > 1:
173
+ nn.init.xavier_uniform_(p)
174
+
175
+ def generate_coord(self, input_feat):
176
+ x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device)
177
+ y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device)
178
+ y, x = torch.meshgrid(y_range, x_range)
179
+ y = y.expand([input_feat.shape[0], 1, -1, -1])
180
+ x = x.expand([input_feat.shape[0], 1, -1, -1])
181
+ coord_feat = torch.cat([x, y], 1)
182
+ return coord_feat
183
+
184
+ def forward(self,
185
+ features_list,
186
+ batch_img_metas = None,
187
+ num_frames = None):
188
+ features, multi_scale = self.deconv(features_list)
189
+ coord_feat = self.generate_coord(features)
190
+ features = torch.cat([features, coord_feat], 1)
191
+ features = self.loc_conv(features)
192
+ if self.return_single_scale: # maskformer
193
+ return features, multi_scale[0]
194
+ if self.return_multi_scale: # mask2former
195
+ return features, multi_scale
196
+ return features
app/models/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .video_gt_preprocess import preprocess_video_panoptic_gt
2
+ from .mask_pool import mask_pool
3
+ from .no_obj import NO_OBJ
app/models/utils/load_checkpoint.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.runner.checkpoint import CheckpointLoader
2
+
3
+
4
+ def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
5
+ """Load partial pretrained model with specific prefix.
6
+
7
+ Args:
8
+ prefix (str): The prefix of sub-module.
9
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
10
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
11
+ details.
12
+ map_location (str | None): Same as :func:`torch.load`.
13
+ Defaults to None.
14
+ logger: logger
15
+
16
+ Returns:
17
+ dict or OrderedDict: The loaded checkpoint.
18
+ """
19
+
20
+ checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger)
21
+
22
+ if 'state_dict' in checkpoint:
23
+ state_dict = checkpoint['state_dict']
24
+ else:
25
+ state_dict = checkpoint
26
+ if not prefix:
27
+ return state_dict
28
+ if not prefix.endswith('.'):
29
+ prefix += '.'
30
+ prefix_len = len(prefix)
31
+
32
+ state_dict = {
33
+ k[prefix_len:]: v
34
+ for k, v in state_dict.items() if k.startswith(prefix)
35
+ }
36
+
37
+ assert state_dict, f'{prefix} is not in the pretrained model'
38
+ return state_dict
app/models/utils/mask_pool.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ # https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923
6
+
7
+ def mask_pool(x, mask):
8
+ """
9
+ Args:
10
+ x: [B, C, H, W]
11
+ mask: [B, Q, H, W]
12
+ """
13
+ if not x.shape[-2:] == mask.shape[-2:]:
14
+ # reshape mask to x
15
+ mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
16
+ with torch.no_grad():
17
+ mask = mask.detach()
18
+ mask = (mask > 0).to(mask.dtype)
19
+ denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
20
+
21
+ mask_pooled_x = torch.einsum(
22
+ "bchw,bqhw->bqc",
23
+ x,
24
+ mask / denorm,
25
+ )
26
+ return mask_pooled_x
27
+
app/models/utils/no_obj.py ADDED
@@ -0,0 +1 @@
 
 
1
+ NO_OBJ = 65535
app/models/utils/video_gt_preprocess.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def preprocess_video_panoptic_gt(
5
+ gt_labels,
6
+ gt_masks,
7
+ gt_semantic_seg,
8
+ gt_instance_ids,
9
+ num_things,
10
+ num_stuff,
11
+ ):
12
+ num_classes = num_things + num_stuff
13
+ num_frames = len(gt_masks)
14
+ mask_size = gt_masks[0].masks.shape[-2:]
15
+
16
+ thing_masks_list = []
17
+ for frame_id in range(num_frames):
18
+ thing_masks_list.append(gt_masks[frame_id].pad(
19
+ mask_size, pad_val=0).to_tensor(
20
+ dtype=torch.bool, device=gt_labels.device)
21
+ )
22
+ instances = torch.unique(gt_instance_ids[:, 1])
23
+ things_masks = []
24
+ labels = []
25
+ for instance in instances:
26
+ pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple
27
+ labels_instance = gt_labels[:, 1][pos_ins]
28
+ assert torch.allclose(labels_instance, labels_instance[0])
29
+ labels.append(labels_instance[0])
30
+ instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist()
31
+ instance_masks = []
32
+ for frame_id in range(num_frames):
33
+ frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1]
34
+ if frame_id not in instance_frame_ids:
35
+ empty_mask = torch.zeros(
36
+ mask_size,
37
+ dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device
38
+ )
39
+ instance_masks.append(empty_mask)
40
+ else:
41
+ pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item()
42
+ frame_mask = thing_masks_list[frame_id][pos_inner_frame]
43
+ instance_masks.append(frame_mask)
44
+ things_masks.append(torch.stack(instance_masks))
45
+
46
+ if len(instances) == 0:
47
+ things_masks = torch.stack(thing_masks_list, dim=1)
48
+ labels = torch.empty_like(instances)
49
+ else:
50
+ things_masks = torch.stack(things_masks)
51
+ labels = torch.stack(labels)
52
+ assert torch.all(torch.less(labels, num_things))
53
+
54
+ if gt_semantic_seg is not None:
55
+ things_labels = labels
56
+ gt_semantic_seg = gt_semantic_seg.squeeze(1)
57
+
58
+ semantic_labels = torch.unique(
59
+ gt_semantic_seg,
60
+ sorted=False,
61
+ return_inverse=False,
62
+ return_counts=False)
63
+ stuff_masks_list = []
64
+ stuff_labels_list = []
65
+ for label in semantic_labels:
66
+ if label < num_things or label >= num_classes:
67
+ continue
68
+ stuff_mask = gt_semantic_seg == label
69
+ stuff_masks_list.append(stuff_mask)
70
+ stuff_labels_list.append(label)
71
+
72
+ if len(stuff_masks_list) > 0:
73
+ stuff_masks = torch.stack(stuff_masks_list, dim=0)
74
+ stuff_labels = torch.stack(stuff_labels_list, dim=0)
75
+ assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes))
76
+ labels = torch.cat([things_labels, stuff_labels], dim=0)
77
+ masks = torch.cat([things_masks, stuff_masks], dim=0)
78
+ else:
79
+ labels = things_labels
80
+ masks = things_masks
81
+ assert len(labels) == len(masks)
82
+ else:
83
+ masks = things_masks
84
+
85
+ labels = labels.to(dtype=torch.long)
86
+ masks = masks.to(dtype=torch.long)
87
+ return labels, masks
ext/meta/sam_meta.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ meta_dict = {
2
+ 'vit_h': dict(
3
+ encoder_embed_dim=1280,
4
+ encoder_depth=32,
5
+ encoder_num_heads=16,
6
+ encoder_global_attn_indexes=[7, 15, 23, 31],
7
+ # common
8
+ prompt_embed_dim=256,
9
+ image_size=1024,
10
+ vit_patch_size=16,
11
+ image_embedding_size=64
12
+ ),
13
+ 'vit_l': dict(
14
+ encoder_embed_dim=1024,
15
+ encoder_depth=24,
16
+ encoder_num_heads=16,
17
+ encoder_global_attn_indexes=[5, 11, 17, 23],
18
+ # common
19
+ prompt_embed_dim=256,
20
+ image_size=1024,
21
+ vit_patch_size=16,
22
+ image_embedding_size=64
23
+ ),
24
+ 'vit_b': dict(
25
+ encoder_embed_dim=768,
26
+ encoder_depth=12,
27
+ encoder_num_heads=12,
28
+ encoder_global_attn_indexes=[2, 5, 8, 11],
29
+ # common
30
+ prompt_embed_dim=256,
31
+ image_size=1024,
32
+ vit_patch_size=16,
33
+ image_embedding_size=64
34
+ )
35
+ }
36
+
37
+ checkpoint_dict = {
38
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
39
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
40
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
41
+ }
ext/open_clip/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
14
+ from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
15
+ from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
ext/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ext/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
ext/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
ext/open_clip/factory.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from .coca_model import CoCa
16
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
17
+ from .openai import load_openai_model
18
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
19
+ list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform, AugmentationCfg
21
+ from .tokenizer import HFTokenizer, tokenize
22
+
23
+
24
+ HF_HUB_PREFIX = 'hf-hub:'
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = ('.json',)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f'*{ext}'))
44
+
45
+ for cf in config_files:
46
+ with open(cf, 'r') as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """ enumerate available model architectures based on config files """
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """ add model config path or file and update registry """
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name):
78
+ if model_name.startswith(HF_HUB_PREFIX):
79
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
+ else:
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(
83
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
84
+ return tokenizer
85
+
86
+
87
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
+ state_dict = checkpoint['state_dict']
91
+ else:
92
+ state_dict = checkpoint
93
+ if next(iter(state_dict.items()))[0].startswith('module'):
94
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def load_checkpoint(model, checkpoint_path, strict=True):
99
+ state_dict = load_state_dict(checkpoint_path)
100
+ # detect old format and make compatible with new format
101
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
+ state_dict = convert_to_custom_text_state_dict(state_dict)
103
+ resize_pos_embed(state_dict, model)
104
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
+ return incompatible_keys
106
+
107
+
108
+ def create_model(
109
+ model_name: str,
110
+ pretrained: Optional[str] = None,
111
+ precision: str = 'fp32',
112
+ device: Union[str, torch.device] = 'cpu',
113
+ jit: bool = False,
114
+ force_quick_gelu: bool = False,
115
+ force_custom_text: bool = False,
116
+ force_patch_dropout: Optional[float] = None,
117
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
+ pretrained_image: bool = False,
119
+ pretrained_hf: bool = True,
120
+ cache_dir: Optional[str] = None,
121
+ output_dict: Optional[bool] = None,
122
+ require_pretrained: bool = False,
123
+ logger: logging.Logger = logging,
124
+ ):
125
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
126
+ if has_hf_hub_prefix:
127
+ model_id = model_name[len(HF_HUB_PREFIX):]
128
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
129
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
130
+
131
+ with open(config_path, 'r', encoding='utf-8') as f:
132
+ config = json.load(f)
133
+ pretrained_cfg = config['preprocess_cfg']
134
+ model_cfg = config['model_cfg']
135
+ else:
136
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
137
+ checkpoint_path = None
138
+ pretrained_cfg = {}
139
+ model_cfg = None
140
+
141
+ if isinstance(device, str):
142
+ device = torch.device(device)
143
+
144
+ if pretrained and pretrained.lower() == 'openai':
145
+ logger.info(f'Loading pretrained {model_name} from OpenAI.')
146
+ model = load_openai_model(
147
+ model_name,
148
+ precision=precision,
149
+ device=device,
150
+ cache_dir=cache_dir,
151
+ )
152
+ else:
153
+ model_cfg = model_cfg or get_model_config(model_name)
154
+ if model_cfg is not None:
155
+ logger.info(f'Loaded {model_name} model config.')
156
+ else:
157
+ logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
158
+ raise RuntimeError(f'Model config for {model_name} not found.')
159
+
160
+ if force_quick_gelu:
161
+ # override for use of QuickGELU on non-OpenAI transformer models
162
+ model_cfg["quick_gelu"] = True
163
+
164
+ if force_patch_dropout is not None:
165
+ # override the default patch dropout value
166
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
167
+
168
+ if force_image_size is not None:
169
+ # override model config's image size
170
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
171
+
172
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
173
+ if pretrained_image:
174
+ if is_timm_model:
175
+ # pretrained weight loading for timm models set via vision_cfg
176
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
177
+ else:
178
+ assert False, 'pretrained image towers currently only supported for timm models'
179
+
180
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
181
+ cast_dtype = get_cast_dtype(precision)
182
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
183
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
184
+
185
+ if custom_text:
186
+ if is_hf_model:
187
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
188
+ if "coca" in model_name:
189
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
190
+ else:
191
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
192
+ else:
193
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
194
+
195
+ if precision in ("fp16", "bf16"):
196
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
197
+ # manual mixed precision that matches original OpenAI behaviour
198
+ if is_timm_model:
199
+ # FIXME this is a bit janky, create timm based model in low-precision and
200
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
201
+ # Why? The convert_weights_to_lp fn only works with native models.
202
+ model.to(device=device, dtype=dtype)
203
+ from .transformer import LayerNormFp32
204
+ def _convert_ln(m):
205
+ if isinstance(m, LayerNormFp32):
206
+ m.weight.data = m.weight.data.to(torch.float32)
207
+ m.bias.data = m.bias.data.to(torch.float32)
208
+ model.apply(_convert_ln)
209
+ else:
210
+ model.to(device=device)
211
+ convert_weights_to_lp(model, dtype=dtype)
212
+ elif precision in ("pure_fp16", "pure_bf16"):
213
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
214
+ model.to(device=device, dtype=dtype)
215
+ else:
216
+ model.to(device=device)
217
+
218
+ pretrained_loaded = False
219
+ if pretrained:
220
+ checkpoint_path = ''
221
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
222
+ if pretrained_cfg:
223
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
224
+ elif os.path.exists(pretrained):
225
+ checkpoint_path = pretrained
226
+
227
+ if checkpoint_path:
228
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
229
+ load_checkpoint(model, checkpoint_path)
230
+ else:
231
+ error_str = (
232
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
233
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
234
+ logger.warning(error_str)
235
+ raise RuntimeError(error_str)
236
+ pretrained_loaded = True
237
+ elif has_hf_hub_prefix:
238
+ logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
239
+ load_checkpoint(model, checkpoint_path)
240
+ pretrained_loaded = True
241
+
242
+ if require_pretrained and not pretrained_loaded:
243
+ # callers of create_model_from_pretrained always expect pretrained weights
244
+ raise RuntimeError(
245
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
246
+
247
+ # set image / mean metadata from pretrained_cfg if available, or use default
248
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
249
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
250
+
251
+ if output_dict and hasattr(model, "output_dict"):
252
+ model.output_dict = True
253
+
254
+ if jit:
255
+ model = torch.jit.script(model)
256
+
257
+ return model
258
+
259
+
260
+ def create_loss(args):
261
+ if args.distill:
262
+ return DistillClipLoss(
263
+ local_loss=args.local_loss,
264
+ gather_with_grad=args.gather_with_grad,
265
+ cache_labels=True,
266
+ rank=args.rank,
267
+ world_size=args.world_size,
268
+ use_horovod=args.horovod,
269
+ )
270
+ elif "coca" in args.model.lower():
271
+ return CoCaLoss(
272
+ caption_loss_weight=args.coca_caption_loss_weight,
273
+ clip_loss_weight=args.coca_contrastive_loss_weight,
274
+ local_loss=args.local_loss,
275
+ gather_with_grad=args.gather_with_grad,
276
+ cache_labels=True,
277
+ rank=args.rank,
278
+ world_size=args.world_size,
279
+ use_horovod=args.horovod,
280
+ )
281
+ return ClipLoss(
282
+ local_loss=args.local_loss,
283
+ gather_with_grad=args.gather_with_grad,
284
+ cache_labels=True,
285
+ rank=args.rank,
286
+ world_size=args.world_size,
287
+ use_horovod=args.horovod,
288
+ )
289
+
290
+
291
+ def create_model_and_transforms(
292
+ model_name: str,
293
+ pretrained: Optional[str] = None,
294
+ precision: str = 'fp32',
295
+ device: Union[str, torch.device] = 'cpu',
296
+ jit: bool = False,
297
+ force_quick_gelu: bool = False,
298
+ force_custom_text: bool = False,
299
+ force_patch_dropout: Optional[float] = None,
300
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
301
+ pretrained_image: bool = False,
302
+ pretrained_hf: bool = True,
303
+ image_mean: Optional[Tuple[float, ...]] = None,
304
+ image_std: Optional[Tuple[float, ...]] = None,
305
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
306
+ cache_dir: Optional[str] = None,
307
+ output_dict: Optional[bool] = None,
308
+ logger: logging.Logger = logging,
309
+ ):
310
+ model = create_model(
311
+ model_name,
312
+ pretrained,
313
+ precision=precision,
314
+ device=device,
315
+ jit=jit,
316
+ force_quick_gelu=force_quick_gelu,
317
+ force_custom_text=force_custom_text,
318
+ force_patch_dropout=force_patch_dropout,
319
+ force_image_size=force_image_size,
320
+ pretrained_image=pretrained_image,
321
+ pretrained_hf=pretrained_hf,
322
+ cache_dir=cache_dir,
323
+ output_dict=output_dict,
324
+ logger=logger,
325
+ )
326
+
327
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
328
+ image_std = image_std or getattr(model.visual, 'image_std', None)
329
+ preprocess_train = image_transform(
330
+ model.visual.image_size,
331
+ is_train=True,
332
+ mean=image_mean,
333
+ std=image_std,
334
+ aug_cfg=aug_cfg,
335
+ )
336
+ preprocess_val = image_transform(
337
+ model.visual.image_size,
338
+ is_train=False,
339
+ mean=image_mean,
340
+ std=image_std,
341
+ )
342
+
343
+ return model, preprocess_train, preprocess_val
344
+
345
+
346
+ def create_model_from_pretrained(
347
+ model_name: str,
348
+ pretrained: Optional[str] = None,
349
+ precision: str = 'fp32',
350
+ device: Union[str, torch.device] = 'cpu',
351
+ jit: bool = False,
352
+ force_quick_gelu: bool = False,
353
+ force_custom_text: bool = False,
354
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
355
+ return_transform: bool = True,
356
+ image_mean: Optional[Tuple[float, ...]] = None,
357
+ image_std: Optional[Tuple[float, ...]] = None,
358
+ cache_dir: Optional[str] = None,
359
+ logger: logging.Logger = logging,
360
+ ):
361
+ model = create_model(
362
+ model_name,
363
+ pretrained,
364
+ precision=precision,
365
+ device=device,
366
+ jit=jit,
367
+ force_quick_gelu=force_quick_gelu,
368
+ force_custom_text=force_custom_text,
369
+ force_image_size=force_image_size,
370
+ cache_dir=cache_dir,
371
+ require_pretrained=True,
372
+ logger=logger,
373
+ )
374
+
375
+ if not return_transform:
376
+ return model
377
+
378
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
379
+ image_std = image_std or getattr(model.visual, 'image_std', None)
380
+ preprocess = image_transform(
381
+ model.visual.image_size,
382
+ is_train=False,
383
+ mean=image_mean,
384
+ std=image_std,
385
+ )
386
+
387
+ return model, preprocess
ext/open_clip/generation_utils.py ADDED
File without changes
ext/open_clip/hf_configs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ # https://huggingface.co/docs/transformers/model_doc/bert
46
+ "bert": {
47
+ "config_names": {
48
+ "context_length": "max_position_embeddings",
49
+ "vocab_size": "vocab_size",
50
+ "width": "hidden_size",
51
+ "heads": "num_attention_heads",
52
+ "layers": "num_hidden_layers",
53
+ },
54
+ "pooler": "cls_pooler",
55
+ },
56
+ }
ext/open_clip/hf_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import TensorType
10
+
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+ if (self.use_pooler_output and
74
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
75
+ (x.pooler_output is not None)
76
+ ):
77
+ return x.pooler_output
78
+
79
+ return x.last_hidden_state[:, self.cls_token_position, :]
80
+
81
+
82
+ @register_pooler
83
+ class ClsLastHiddenStatePooler(nn.Module):
84
+ """CLS token pooling
85
+ NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.cls_token_position = 0
91
+
92
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
93
+ return x.last_hidden_state[:, self.cls_token_position, :]
94
+
95
+
96
+ class HFTextEncoder(nn.Module):
97
+ """HuggingFace model adapter"""
98
+ output_tokens: torch.jit.Final[bool]
99
+
100
+ def __init__(
101
+ self,
102
+ model_name_or_path: str,
103
+ output_dim: int,
104
+ config: PretrainedConfig = None,
105
+ pooler_type: str = None,
106
+ proj: str = None,
107
+ pretrained: bool = True,
108
+ output_tokens: bool = False,
109
+ ):
110
+ super().__init__()
111
+ self.output_tokens = output_tokens
112
+ self.output_dim = output_dim
113
+
114
+ # TODO: find better way to get this information
115
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
116
+
117
+ if transformers is None:
118
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
119
+ if config is None:
120
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
121
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
122
+ AutoModel.from_config, self.config)
123
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
124
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
125
+ self.transformer = create_func(model_args)
126
+ self.transformer = self.transformer.encoder
127
+ else:
128
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
129
+ else:
130
+ self.config = config
131
+ self.transformer = AutoModel.from_config(config)
132
+ if pooler_type is None: # get default arch pooler
133
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
134
+
135
+ # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
136
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
137
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
138
+
139
+ self.pooler = _POOLERS[pooler_type]()
140
+
141
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
142
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
143
+ self.proj = nn.Identity()
144
+ elif proj == 'linear':
145
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
146
+ elif proj == 'mlp':
147
+ hidden_size = (d_model + output_dim) // 2
148
+ self.proj = nn.Sequential(
149
+ nn.Linear(d_model, hidden_size, bias=False),
150
+ nn.GELU(),
151
+ nn.Linear(hidden_size, output_dim, bias=False),
152
+ )
153
+
154
+ def forward(self, x: TensorType):
155
+ attn_mask = (x != self.config.pad_token_id).long()
156
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
157
+ pooled_out = self.pooler(out, attn_mask)
158
+ projected = self.proj(pooled_out)
159
+
160
+ seq_len = out.last_hidden_state.shape[1]
161
+ tokens = (
162
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
163
+ if type(self.pooler) == ClsPooler
164
+ else out.last_hidden_state
165
+ )
166
+
167
+ if self.output_tokens:
168
+ return projected, tokens
169
+ return projected
170
+
171
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
172
+ if not unlocked_layers: # full freezing
173
+ for n, p in self.transformer.named_parameters():
174
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
175
+ return
176
+
177
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
178
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
179
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
180
+ embeddings = getattr(
181
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
182
+ modules = [embeddings, *layer_list][:-unlocked_layers]
183
+ # freeze layers
184
+ for module in modules:
185
+ for n, p in module.named_parameters():
186
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.transformer.gradient_checkpointing_enable()
191
+
192
+ def init_parameters(self):
193
+ pass
ext/open_clip/loss.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+
19
+ def gather_features(
20
+ image_features,
21
+ text_features,
22
+ local_loss=False,
23
+ gather_with_grad=False,
24
+ rank=0,
25
+ world_size=1,
26
+ use_horovod=False
27
+ ):
28
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
29
+ if use_horovod:
30
+ assert hvd is not None, 'Please install horovod'
31
+ if gather_with_grad:
32
+ all_image_features = hvd.allgather(image_features)
33
+ all_text_features = hvd.allgather(text_features)
34
+ else:
35
+ with torch.no_grad():
36
+ all_image_features = hvd.allgather(image_features)
37
+ all_text_features = hvd.allgather(text_features)
38
+ if not local_loss:
39
+ # ensure grads for local rank when all_* features don't have a gradient
40
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
41
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
42
+ gathered_image_features[rank] = image_features
43
+ gathered_text_features[rank] = text_features
44
+ all_image_features = torch.cat(gathered_image_features, dim=0)
45
+ all_text_features = torch.cat(gathered_text_features, dim=0)
46
+ else:
47
+ # We gather tensors from all gpus
48
+ if gather_with_grad:
49
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
50
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
51
+ else:
52
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
53
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
54
+ dist.all_gather(gathered_image_features, image_features)
55
+ dist.all_gather(gathered_text_features, text_features)
56
+ if not local_loss:
57
+ # ensure grads for local rank when all_* features don't have a gradient
58
+ gathered_image_features[rank] = image_features
59
+ gathered_text_features[rank] = text_features
60
+ all_image_features = torch.cat(gathered_image_features, dim=0)
61
+ all_text_features = torch.cat(gathered_text_features, dim=0)
62
+
63
+ return all_image_features, all_text_features
64
+
65
+
66
+ class ClipLoss(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ local_loss=False,
71
+ gather_with_grad=False,
72
+ cache_labels=False,
73
+ rank=0,
74
+ world_size=1,
75
+ use_horovod=False,
76
+ ):
77
+ super().__init__()
78
+ self.local_loss = local_loss
79
+ self.gather_with_grad = gather_with_grad
80
+ self.cache_labels = cache_labels
81
+ self.rank = rank
82
+ self.world_size = world_size
83
+ self.use_horovod = use_horovod
84
+
85
+ # cache state
86
+ self.prev_num_logits = 0
87
+ self.labels = {}
88
+
89
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
90
+ # calculated ground-truth and cache if enabled
91
+ if self.prev_num_logits != num_logits or device not in self.labels:
92
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
93
+ if self.world_size > 1 and self.local_loss:
94
+ labels = labels + num_logits * self.rank
95
+ if self.cache_labels:
96
+ self.labels[device] = labels
97
+ self.prev_num_logits = num_logits
98
+ else:
99
+ labels = self.labels[device]
100
+ return labels
101
+
102
+ def get_logits(self, image_features, text_features, logit_scale):
103
+ if self.world_size > 1:
104
+ all_image_features, all_text_features = gather_features(
105
+ image_features, text_features,
106
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
107
+
108
+ if self.local_loss:
109
+ logits_per_image = logit_scale * image_features @ all_text_features.T
110
+ logits_per_text = logit_scale * text_features @ all_image_features.T
111
+ else:
112
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
113
+ logits_per_text = logits_per_image.T
114
+ else:
115
+ logits_per_image = logit_scale * image_features @ text_features.T
116
+ logits_per_text = logit_scale * text_features @ image_features.T
117
+
118
+ return logits_per_image, logits_per_text
119
+
120
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
121
+ device = image_features.device
122
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
123
+
124
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
125
+
126
+ total_loss = (
127
+ F.cross_entropy(logits_per_image, labels) +
128
+ F.cross_entropy(logits_per_text, labels)
129
+ ) / 2
130
+
131
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
132
+
133
+
134
+ class CoCaLoss(ClipLoss):
135
+ def __init__(
136
+ self,
137
+ caption_loss_weight,
138
+ clip_loss_weight,
139
+ pad_id=0, # pad_token for open_clip custom tokenizer
140
+ local_loss=False,
141
+ gather_with_grad=False,
142
+ cache_labels=False,
143
+ rank=0,
144
+ world_size=1,
145
+ use_horovod=False,
146
+ ):
147
+ super().__init__(
148
+ local_loss=local_loss,
149
+ gather_with_grad=gather_with_grad,
150
+ cache_labels=cache_labels,
151
+ rank=rank,
152
+ world_size=world_size,
153
+ use_horovod=use_horovod
154
+ )
155
+
156
+ self.clip_loss_weight = clip_loss_weight
157
+ self.caption_loss_weight = caption_loss_weight
158
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
159
+
160
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
161
+
162
+ clip_loss = torch.tensor(0)
163
+
164
+ if self.clip_loss_weight:
165
+ clip_loss = super().forward(image_features, text_features, logit_scale)
166
+ clip_loss = self.clip_loss_weight * clip_loss
167
+
168
+ caption_loss = self.caption_loss(
169
+ logits.permute(0, 2, 1),
170
+ labels,
171
+ )
172
+ caption_loss = caption_loss * self.caption_loss_weight
173
+
174
+ if output_dict:
175
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
176
+
177
+ return clip_loss, caption_loss
178
+
179
+
180
+ class DistillClipLoss(ClipLoss):
181
+
182
+ def dist_loss(self, teacher_logits, student_logits):
183
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
184
+
185
+ def forward(
186
+ self,
187
+ image_features,
188
+ text_features,
189
+ logit_scale,
190
+ dist_image_features,
191
+ dist_text_features,
192
+ dist_logit_scale,
193
+ output_dict=False,
194
+ ):
195
+ logits_per_image, logits_per_text = \
196
+ self.get_logits(image_features, text_features, logit_scale)
197
+
198
+ dist_logits_per_image, dist_logits_per_text = \
199
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
200
+
201
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
202
+
203
+ contrastive_loss = (
204
+ F.cross_entropy(logits_per_image, labels) +
205
+ F.cross_entropy(logits_per_text, labels)
206
+ ) / 2
207
+
208
+ distill_loss = (
209
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
210
+ self.dist_loss(dist_logits_per_text, logits_per_text)
211
+ ) / 2
212
+
213
+ if output_dict:
214
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
215
+
216
+ return contrastive_loss, distill_loss
ext/open_clip/model.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+
32
+ ls_init_value: Optional[float] = None # layer scale initial value
33
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
34
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
35
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
36
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
37
+ n_queries: int = 256 # n_queries for attentional pooler
38
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
39
+ output_tokens: bool = False
40
+
41
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
42
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
43
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
44
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
45
+ timm_proj_bias: bool = False # enable bias final projection
46
+ timm_drop: float = 0. # head dropout
47
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
48
+
49
+
50
+ @dataclass
51
+ class CLIPTextCfg:
52
+ context_length: int = 77
53
+ vocab_size: int = 49408
54
+ width: int = 512
55
+ heads: int = 8
56
+ layers: int = 12
57
+ ls_init_value: Optional[float] = None # layer scale initial value
58
+ hf_model_name: str = None
59
+ hf_tokenizer_name: str = None
60
+ hf_model_pretrained: bool = True
61
+ proj: str = 'mlp'
62
+ pooler_type: str = 'mean_pooler'
63
+ embed_cls: bool = False
64
+ pad_id: int = 0
65
+ output_tokens: bool = False
66
+
67
+
68
+ def get_cast_dtype(precision: str):
69
+ cast_dtype = None
70
+ if precision == 'bf16':
71
+ cast_dtype = torch.bfloat16
72
+ elif precision == 'fp16':
73
+ cast_dtype = torch.float16
74
+ return cast_dtype
75
+
76
+
77
+ def get_input_dtype(precision: str):
78
+ input_dtype = None
79
+ if precision in ('bf16', 'pure_bf16'):
80
+ input_dtype = torch.bfloat16
81
+ elif precision in ('fp16', 'pure_fp16'):
82
+ input_dtype = torch.float16
83
+ return input_dtype
84
+
85
+
86
+ def _build_vision_tower(
87
+ embed_dim: int,
88
+ vision_cfg: CLIPVisionCfg,
89
+ quick_gelu: bool = False,
90
+ cast_dtype: Optional[torch.dtype] = None
91
+ ):
92
+ if isinstance(vision_cfg, dict):
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
94
+
95
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
96
+ # memory efficient in recent PyTorch releases (>= 1.10).
97
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
98
+ act_layer = QuickGELU if quick_gelu else nn.GELU
99
+
100
+ if vision_cfg.timm_model_name:
101
+ visual = TimmModel(
102
+ vision_cfg.timm_model_name,
103
+ pretrained=vision_cfg.timm_model_pretrained,
104
+ pool=vision_cfg.timm_pool,
105
+ proj=vision_cfg.timm_proj,
106
+ proj_bias=vision_cfg.timm_proj_bias,
107
+ drop=vision_cfg.timm_drop,
108
+ drop_path=vision_cfg.timm_drop_path,
109
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
110
+ embed_dim=embed_dim,
111
+ image_size=vision_cfg.image_size,
112
+ )
113
+ elif isinstance(vision_cfg.layers, (tuple, list)):
114
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
115
+ visual = ModifiedResNet(
116
+ layers=vision_cfg.layers,
117
+ output_dim=embed_dim,
118
+ heads=vision_heads,
119
+ image_size=vision_cfg.image_size,
120
+ width=vision_cfg.width,
121
+ )
122
+ else:
123
+ vision_heads = vision_cfg.width // vision_cfg.head_width
124
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
125
+ visual = VisionTransformer(
126
+ image_size=vision_cfg.image_size,
127
+ patch_size=vision_cfg.patch_size,
128
+ width=vision_cfg.width,
129
+ layers=vision_cfg.layers,
130
+ heads=vision_heads,
131
+ mlp_ratio=vision_cfg.mlp_ratio,
132
+ ls_init_value=vision_cfg.ls_init_value,
133
+ patch_dropout=vision_cfg.patch_dropout,
134
+ input_patchnorm=vision_cfg.input_patchnorm,
135
+ global_average_pool=vision_cfg.global_average_pool,
136
+ attentional_pool=vision_cfg.attentional_pool,
137
+ n_queries=vision_cfg.n_queries,
138
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
139
+ output_tokens=vision_cfg.output_tokens,
140
+ output_dim=embed_dim,
141
+ act_layer=act_layer,
142
+ norm_layer=norm_layer,
143
+ )
144
+
145
+ return visual
146
+
147
+
148
+ def _build_text_tower(
149
+ embed_dim: int,
150
+ text_cfg: CLIPTextCfg,
151
+ quick_gelu: bool = False,
152
+ cast_dtype: Optional[torch.dtype] = None,
153
+ ):
154
+ if isinstance(text_cfg, dict):
155
+ text_cfg = CLIPTextCfg(**text_cfg)
156
+
157
+ if text_cfg.hf_model_name:
158
+ text = HFTextEncoder(
159
+ text_cfg.hf_model_name,
160
+ output_dim=embed_dim,
161
+ proj=text_cfg.proj,
162
+ pooler_type=text_cfg.pooler_type,
163
+ pretrained=text_cfg.hf_model_pretrained,
164
+ output_tokens=text_cfg.output_tokens,
165
+ )
166
+ else:
167
+ act_layer = QuickGELU if quick_gelu else nn.GELU
168
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
169
+
170
+ text = TextTransformer(
171
+ context_length=text_cfg.context_length,
172
+ vocab_size=text_cfg.vocab_size,
173
+ width=text_cfg.width,
174
+ heads=text_cfg.heads,
175
+ layers=text_cfg.layers,
176
+ ls_init_value=text_cfg.ls_init_value,
177
+ output_dim=embed_dim,
178
+ embed_cls=text_cfg.embed_cls,
179
+ output_tokens=text_cfg.output_tokens,
180
+ pad_id=text_cfg.pad_id,
181
+ act_layer=act_layer,
182
+ norm_layer=norm_layer,
183
+ )
184
+ return text
185
+
186
+
187
+ class CLIP(nn.Module):
188
+ output_dict: torch.jit.Final[bool]
189
+
190
+ def __init__(
191
+ self,
192
+ embed_dim: int,
193
+ vision_cfg: CLIPVisionCfg,
194
+ text_cfg: CLIPTextCfg,
195
+ quick_gelu: bool = False,
196
+ cast_dtype: Optional[torch.dtype] = None,
197
+ output_dict: bool = False,
198
+ ):
199
+ super().__init__()
200
+ self.output_dict = output_dict
201
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
202
+
203
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
204
+ self.transformer = text.transformer
205
+ self.context_length = text.context_length
206
+ self.vocab_size = text.vocab_size
207
+ self.token_embedding = text.token_embedding
208
+ self.positional_embedding = text.positional_embedding
209
+ self.ln_final = text.ln_final
210
+ self.text_projection = text.text_projection
211
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
212
+
213
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
214
+
215
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
216
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
217
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
218
+
219
+ @torch.jit.ignore
220
+ def set_grad_checkpointing(self, enable=True):
221
+ self.visual.set_grad_checkpointing(enable)
222
+ self.transformer.grad_checkpointing = enable
223
+
224
+ def encode_image(self, image, normalize: bool = False):
225
+ features = self.visual(image)
226
+ return F.normalize(features, dim=-1) if normalize else features
227
+
228
+ def encode_text(self, text, normalize: bool = False):
229
+ cast_dtype = self.transformer.get_cast_dtype()
230
+
231
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
232
+
233
+ x = x + self.positional_embedding.to(cast_dtype)
234
+ x = x.permute(1, 0, 2) # NLD -> LND
235
+ x = self.transformer(x, attn_mask=self.attn_mask)
236
+ x = x.permute(1, 0, 2) # LND -> NLD
237
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
238
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
239
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
240
+ return F.normalize(x, dim=-1) if normalize else x
241
+
242
+ def forward(
243
+ self,
244
+ image: Optional[torch.Tensor] = None,
245
+ text: Optional[torch.Tensor] = None,
246
+ ):
247
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
248
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
249
+ if self.output_dict:
250
+ return {
251
+ "image_features": image_features,
252
+ "text_features": text_features,
253
+ "logit_scale": self.logit_scale.exp()
254
+ }
255
+ return image_features, text_features, self.logit_scale.exp()
256
+
257
+
258
+ class CustomTextCLIP(nn.Module):
259
+ output_dict: torch.jit.Final[bool]
260
+
261
+ def __init__(
262
+ self,
263
+ embed_dim: int,
264
+ vision_cfg: CLIPVisionCfg,
265
+ text_cfg: CLIPTextCfg,
266
+ quick_gelu: bool = False,
267
+ cast_dtype: Optional[torch.dtype] = None,
268
+ output_dict: bool = False,
269
+ ):
270
+ super().__init__()
271
+ self.output_dict = output_dict
272
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
+ self.context_length = self.text.context_length
275
+ self.vocab_size = self.text.vocab_size
276
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
277
+
278
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
279
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
280
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
281
+
282
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
283
+ self.text.lock(unlocked_layers, freeze_layer_norm)
284
+
285
+ @torch.jit.ignore
286
+ def set_grad_checkpointing(self, enable=True):
287
+ self.visual.set_grad_checkpointing(enable)
288
+ self.text.set_grad_checkpointing(enable)
289
+
290
+ def encode_image(self, image, normalize: bool = False):
291
+ features = self.visual(image)
292
+ return F.normalize(features, dim=-1) if normalize else features
293
+
294
+ def encode_text(self, text, normalize: bool = False):
295
+ features = self.text(text)
296
+ return F.normalize(features, dim=-1) if normalize else features
297
+
298
+ def forward(
299
+ self,
300
+ image: Optional[torch.Tensor] = None,
301
+ text: Optional[torch.Tensor] = None,
302
+ ):
303
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
304
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
305
+ if self.output_dict:
306
+ return {
307
+ "image_features": image_features,
308
+ "text_features": text_features,
309
+ "logit_scale": self.logit_scale.exp()
310
+ }
311
+ return image_features, text_features, self.logit_scale.exp()
312
+
313
+
314
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
315
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
316
+
317
+ def _convert_weights(l):
318
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
319
+ l.weight.data = l.weight.data.to(dtype)
320
+ if l.bias is not None:
321
+ l.bias.data = l.bias.data.to(dtype)
322
+
323
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
324
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
325
+ tensor = getattr(l, attr)
326
+ if tensor is not None:
327
+ tensor.data = tensor.data.to(dtype)
328
+
329
+ if isinstance(l, (CLIP, TextTransformer)):
330
+ # convert text nn.Parameter projections
331
+ attr = getattr(l, "text_projection", None)
332
+ if attr is not None:
333
+ attr.data = attr.data.to(dtype)
334
+
335
+ if isinstance(l, VisionTransformer):
336
+ # convert vision nn.Parameter projections
337
+ attr = getattr(l, "proj", None)
338
+ if attr is not None:
339
+ attr.data = attr.data.to(dtype)
340
+
341
+ model.apply(_convert_weights)
342
+
343
+
344
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
345
+
346
+
347
+ # used to maintain checkpoint compatibility
348
+ def convert_to_custom_text_state_dict(state_dict: dict):
349
+ if 'text_projection' in state_dict:
350
+ # old format state_dict, move text tower -> .text
351
+ new_state_dict = {}
352
+ for k, v in state_dict.items():
353
+ if any(k.startswith(p) for p in (
354
+ 'text_projection',
355
+ 'positional_embedding',
356
+ 'token_embedding',
357
+ 'transformer',
358
+ 'ln_final',
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers,
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
440
+
441
+
442
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
443
+ # Rescale the grid of position embeddings when loading from state_dict
444
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
445
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
446
+ return
447
+ grid_size = to_2tuple(model.visual.grid_size)
448
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
449
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
450
+ if new_seq_len == old_pos_embed.shape[0]:
451
+ return
452
+
453
+ if extra_tokens:
454
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
455
+ else:
456
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
457
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
458
+
459
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
460
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
461
+ pos_emb_img = F.interpolate(
462
+ pos_emb_img,
463
+ size=grid_size,
464
+ mode=interpolation,
465
+ antialias=antialias,
466
+ align_corners=False,
467
+ )
468
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
469
+ if pos_emb_tok is not None:
470
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
471
+ else:
472
+ new_pos_embed = pos_emb_img
473
+ state_dict['visual.positional_embedding'] = new_pos_embed
ext/open_clip/model_configs/EVA01-g-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA01-g-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-B-16.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_base_patch16_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1280,
14
+ "heads": 20,
15
+ "layers": 32
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-E-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14-336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "timm_model_name": "eva02_large_patch14_clip_336",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/EVA02-L-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_large_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
ext/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
ext/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
ext/open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-16-plus.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
ext/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
ext/open_clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 14
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
ext/open_clip/model_configs/ViT-H-16.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 16
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }