HarborYuan
commited on
Commit
•
502989e
1
Parent(s):
a209a56
add rap_sam
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -4
- README.md +3 -5
- app/configs/rap_sam_r50_12e_adaptor.py +88 -0
- app/models/detectors/__init__.py +1 -0
- app/models/detectors/mask2former_vid.py +281 -0
- app/models/detectors/rapsam.py +66 -0
- app/models/heads/__init__.py +1 -0
- app/models/heads/mask2former_vid.py +616 -0
- app/models/heads/rapsam_head.py +227 -0
- app/models/heads/yoso_head.py +531 -0
- app/models/necks/__init__.py +1 -0
- app/models/necks/ramsam_neck.py +196 -0
- app/models/utils/__init__.py +3 -0
- app/models/utils/load_checkpoint.py +38 -0
- app/models/utils/mask_pool.py +27 -0
- app/models/utils/no_obj.py +1 -0
- app/models/utils/video_gt_preprocess.py +87 -0
- ext/meta/sam_meta.py +41 -0
- ext/open_clip/__init__.py +15 -0
- ext/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- ext/open_clip/coca_model.py +458 -0
- ext/open_clip/constants.py +2 -0
- ext/open_clip/factory.py +387 -0
- ext/open_clip/generation_utils.py +0 -0
- ext/open_clip/hf_configs.py +56 -0
- ext/open_clip/hf_model.py +193 -0
- ext/open_clip/loss.py +216 -0
- ext/open_clip/model.py +473 -0
- ext/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
- ext/open_clip/model_configs/EVA01-g-14.json +18 -0
- ext/open_clip/model_configs/EVA02-B-16.json +18 -0
- ext/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
- ext/open_clip/model_configs/EVA02-E-14.json +18 -0
- ext/open_clip/model_configs/EVA02-L-14-336.json +18 -0
- ext/open_clip/model_configs/EVA02-L-14.json +18 -0
- ext/open_clip/model_configs/RN101-quickgelu.json +22 -0
- ext/open_clip/model_configs/RN101.json +21 -0
- ext/open_clip/model_configs/RN50-quickgelu.json +22 -0
- ext/open_clip/model_configs/RN50.json +21 -0
- ext/open_clip/model_configs/RN50x16.json +21 -0
- ext/open_clip/model_configs/RN50x4.json +21 -0
- ext/open_clip/model_configs/RN50x64.json +21 -0
- ext/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
- ext/open_clip/model_configs/ViT-B-16-plus.json +16 -0
- ext/open_clip/model_configs/ViT-B-16.json +16 -0
- ext/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
- ext/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- ext/open_clip/model_configs/ViT-B-32.json +16 -0
- ext/open_clip/model_configs/ViT-H-14.json +17 -0
- ext/open_clip/model_configs/ViT-H-16.json +17 -0
.gitattributes
CHANGED
@@ -17,10 +17,6 @@
|
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
22 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -4,10 +4,8 @@ emoji: 🌍
|
|
4 |
colorFrom: purple
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
-
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.10
|
11 |
---
|
|
|
|
app/configs/rap_sam_r50_12e_adaptor.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmdet.models import ResNet, MaskFormerFusionHead, CrossEntropyLoss, DiceLoss
|
2 |
+
|
3 |
+
from app.models.detectors import YOSOVideoSam
|
4 |
+
from app.models.heads import RapSAMVideoHead
|
5 |
+
from app.models.necks import YOSONeck
|
6 |
+
|
7 |
+
num_things_classes = 80
|
8 |
+
num_stuff_classes = 53
|
9 |
+
ov_model_name = 'convnext_large_d_320'
|
10 |
+
ov_datasets_name = 'CocoPanopticOVDataset'
|
11 |
+
num_classes = num_things_classes + num_stuff_classes
|
12 |
+
model = dict(
|
13 |
+
type=YOSOVideoSam,
|
14 |
+
data_preprocessor=None,
|
15 |
+
backbone=dict(
|
16 |
+
type=ResNet,
|
17 |
+
depth=50,
|
18 |
+
num_stages=4,
|
19 |
+
out_indices=(0, 1, 2, 3),
|
20 |
+
frozen_stages=-1,
|
21 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
22 |
+
norm_eval=True,
|
23 |
+
init_cfg=None,
|
24 |
+
),
|
25 |
+
neck=dict(
|
26 |
+
type=YOSONeck,
|
27 |
+
agg_dim=128,
|
28 |
+
hidden_dim=256,
|
29 |
+
backbone_shape=[256, 512, 1024, 2048],
|
30 |
+
),
|
31 |
+
panoptic_head=dict(
|
32 |
+
type=RapSAMVideoHead,
|
33 |
+
prompt_with_kernel_updator=False,
|
34 |
+
panoptic_with_kernel_updator=True,
|
35 |
+
use_adaptor=True,
|
36 |
+
use_kernel_updator=True,
|
37 |
+
sphere_cls=True,
|
38 |
+
ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
|
39 |
+
num_stages=3,
|
40 |
+
feat_channels=256,
|
41 |
+
num_things_classes=num_things_classes,
|
42 |
+
num_stuff_classes=num_stuff_classes,
|
43 |
+
num_queries=100,
|
44 |
+
loss_cls=dict(
|
45 |
+
type=CrossEntropyLoss,
|
46 |
+
use_sigmoid=False,
|
47 |
+
loss_weight=2.0,
|
48 |
+
reduction='mean',
|
49 |
+
class_weight=[1.0] * num_classes + [0.1]),
|
50 |
+
loss_mask=dict(
|
51 |
+
type=CrossEntropyLoss,
|
52 |
+
use_sigmoid=True,
|
53 |
+
reduction='mean',
|
54 |
+
loss_weight=5.0),
|
55 |
+
loss_dice=dict(
|
56 |
+
type=DiceLoss,
|
57 |
+
use_sigmoid=True,
|
58 |
+
activate=True,
|
59 |
+
reduction='mean',
|
60 |
+
naive_dice=True,
|
61 |
+
eps=1.0,
|
62 |
+
loss_weight=5.0)
|
63 |
+
),
|
64 |
+
panoptic_fusion_head=dict(
|
65 |
+
type=MaskFormerFusionHead,
|
66 |
+
num_things_classes=num_things_classes,
|
67 |
+
num_stuff_classes=num_stuff_classes,
|
68 |
+
loss_panoptic=None,
|
69 |
+
init_cfg=None
|
70 |
+
),
|
71 |
+
train_cfg=None,
|
72 |
+
test_cfg=dict(
|
73 |
+
panoptic_on=True,
|
74 |
+
# For now, the dataset does not support
|
75 |
+
# evaluating semantic segmentation metric.
|
76 |
+
semantic_on=False,
|
77 |
+
instance_on=True,
|
78 |
+
# max_per_image is for instance segmentation.
|
79 |
+
max_per_image=100,
|
80 |
+
iou_thr=0.8,
|
81 |
+
# In Mask2Former's panoptic postprocessing,
|
82 |
+
# it will filter mask area where score is less than 0.5 .
|
83 |
+
filter_low_score=True),
|
84 |
+
init_cfg=dict(
|
85 |
+
type='Pretrained',
|
86 |
+
checkpoint='models/rapsam_r50_12e.pth'
|
87 |
+
)
|
88 |
+
)
|
app/models/detectors/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rapsam import YOSOVideoSam
|
app/models/detectors/mask2former_vid.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmengine.structures import InstanceData
|
6 |
+
from torch import Tensor
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from mmdet.registry import MODELS
|
10 |
+
from mmdet.structures import SampleList, OptSampleList, TrackDataSample
|
11 |
+
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
12 |
+
from mmdet.models.detectors.single_stage import SingleStageDetector
|
13 |
+
|
14 |
+
from app.models.utils import mask_pool
|
15 |
+
|
16 |
+
|
17 |
+
@MODELS.register_module()
|
18 |
+
class Mask2formerVideo(SingleStageDetector):
|
19 |
+
r"""Implementation of `Per-Pixel Classification is
|
20 |
+
NOT All You Need for Semantic Segmentation
|
21 |
+
<https://arxiv.org/pdf/2107.06278>`_."""
|
22 |
+
OVERLAPPING = None
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
backbone: ConfigType,
|
26 |
+
neck: OptConfigType = None,
|
27 |
+
panoptic_head: OptConfigType = None,
|
28 |
+
panoptic_fusion_head: OptConfigType = None,
|
29 |
+
train_cfg: OptConfigType = None,
|
30 |
+
test_cfg: OptConfigType = None,
|
31 |
+
data_preprocessor: OptConfigType = None,
|
32 |
+
inference_sam: bool = False,
|
33 |
+
init_cfg: OptMultiConfig = None
|
34 |
+
):
|
35 |
+
super(SingleStageDetector, self).__init__(
|
36 |
+
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
37 |
+
self.backbone = MODELS.build(backbone)
|
38 |
+
if neck is not None:
|
39 |
+
self.neck = MODELS.build(neck)
|
40 |
+
|
41 |
+
panoptic_head_ = panoptic_head.deepcopy()
|
42 |
+
panoptic_head_.update(train_cfg=train_cfg)
|
43 |
+
panoptic_head_.update(test_cfg=test_cfg)
|
44 |
+
self.panoptic_head = MODELS.build(panoptic_head_)
|
45 |
+
|
46 |
+
panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
|
47 |
+
panoptic_fusion_head_.update(test_cfg=test_cfg)
|
48 |
+
self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
|
49 |
+
|
50 |
+
self.num_things_classes = self.panoptic_head.num_things_classes
|
51 |
+
self.num_stuff_classes = self.panoptic_head.num_stuff_classes
|
52 |
+
self.num_classes = self.panoptic_head.num_classes
|
53 |
+
|
54 |
+
self.train_cfg = train_cfg
|
55 |
+
self.test_cfg = test_cfg
|
56 |
+
|
57 |
+
self.alpha = 0.4
|
58 |
+
self.beta = 0.8
|
59 |
+
|
60 |
+
self.inference_sam = inference_sam
|
61 |
+
|
62 |
+
def predict(self,
|
63 |
+
batch_inputs: Tensor,
|
64 |
+
batch_data_samples: SampleList,
|
65 |
+
rescale: bool = True) -> SampleList:
|
66 |
+
"""Predict results from a batch of inputs and data samples with post-
|
67 |
+
processing.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
71 |
+
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
72 |
+
Samples. It usually includes information such as
|
73 |
+
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
74 |
+
rescale (bool): Whether to rescale the results.
|
75 |
+
Defaults to True.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
list[:obj:`DetDataSample`]: Detection results of the
|
79 |
+
input images. Each DetDataSample usually contain
|
80 |
+
'pred_instances' and `pred_panoptic_seg`. And the
|
81 |
+
``pred_instances`` usually contains following keys.
|
82 |
+
|
83 |
+
- scores (Tensor): Classification scores, has a shape
|
84 |
+
(num_instance, )
|
85 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
86 |
+
(num_instances, ).
|
87 |
+
- bboxes (Tensor): Has a shape (num_instances, 4),
|
88 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
89 |
+
- masks (Tensor): Has a shape (num_instances, H, W).
|
90 |
+
|
91 |
+
And the ``pred_panoptic_seg`` contains the following key
|
92 |
+
|
93 |
+
- sem_seg (Tensor): panoptic segmentation mask, has a
|
94 |
+
shape (1, h, w).
|
95 |
+
"""
|
96 |
+
if isinstance(batch_data_samples[0], TrackDataSample):
|
97 |
+
bs, num_frames, three, h, w = batch_inputs.shape
|
98 |
+
assert three == 3, "Only supporting images with 3 channels."
|
99 |
+
x = batch_inputs.reshape((bs * num_frames, three, h, w))
|
100 |
+
feats = self.extract_feat(x)
|
101 |
+
else:
|
102 |
+
num_frames = 0
|
103 |
+
bs = batch_inputs.shape[0]
|
104 |
+
feats = self.extract_feat(batch_inputs)
|
105 |
+
|
106 |
+
mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples)
|
107 |
+
|
108 |
+
if self.inference_sam:
|
109 |
+
for i, data_sample in enumerate(batch_data_samples):
|
110 |
+
meta = data_sample.metainfo
|
111 |
+
img_height, img_width = meta['img_shape'][:2]
|
112 |
+
mask_pred_result = mask_pred_results[i][:, :img_height, :img_width]
|
113 |
+
mask_pred_result = mask_pred_result.view(-1, img_height, img_width) > 0
|
114 |
+
all_pred_instances = InstanceData(masks=mask_pred_result)
|
115 |
+
batch_data_samples[i].pred_instances = all_pred_instances
|
116 |
+
|
117 |
+
return batch_data_samples
|
118 |
+
|
119 |
+
if self.OVERLAPPING is not None:
|
120 |
+
assert len(self.OVERLAPPING) == self.num_classes
|
121 |
+
mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results)
|
122 |
+
|
123 |
+
if num_frames > 0:
|
124 |
+
for frame_id in range(num_frames):
|
125 |
+
results_list_img = self.panoptic_fusion_head.predict(
|
126 |
+
mask_cls_results,
|
127 |
+
mask_pred_results[:, :, frame_id],
|
128 |
+
[batch_data_samples[idx][frame_id] for idx in range(bs)],
|
129 |
+
rescale=rescale
|
130 |
+
)
|
131 |
+
_ = self.add_track_pred_to_datasample(
|
132 |
+
[batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img
|
133 |
+
)
|
134 |
+
results = batch_data_samples
|
135 |
+
else:
|
136 |
+
results_list = self.panoptic_fusion_head.predict(
|
137 |
+
mask_cls_results,
|
138 |
+
mask_pred_results,
|
139 |
+
batch_data_samples,
|
140 |
+
iou_results=iou_results,
|
141 |
+
rescale=rescale
|
142 |
+
)
|
143 |
+
results = self.add_pred_to_datasample(batch_data_samples, results_list)
|
144 |
+
|
145 |
+
return results
|
146 |
+
|
147 |
+
def add_pred_to_datasample(self, data_samples: SampleList,
|
148 |
+
results_list: List[dict]) -> SampleList:
|
149 |
+
"""Add predictions to `DetDataSample`.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
data_samples (list[:obj:`DetDataSample`], optional): A batch of
|
153 |
+
data samples that contain annotations and predictions.
|
154 |
+
results_list (List[dict]): Instance segmentation, segmantic
|
155 |
+
segmentation and panoptic segmentation results.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
list[:obj:`DetDataSample`]: Detection results of the
|
159 |
+
input images. Each DetDataSample usually contain
|
160 |
+
'pred_instances' and `pred_panoptic_seg`. And the
|
161 |
+
``pred_instances`` usually contains following keys.
|
162 |
+
|
163 |
+
- scores (Tensor): Classification scores, has a shape
|
164 |
+
(num_instance, )
|
165 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
166 |
+
(num_instances, ).
|
167 |
+
- bboxes (Tensor): Has a shape (num_instances, 4),
|
168 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
169 |
+
- masks (Tensor): Has a shape (num_instances, H, W).
|
170 |
+
|
171 |
+
And the ``pred_panoptic_seg`` contains the following key
|
172 |
+
|
173 |
+
- sem_seg (Tensor): panoptic segmentation mask, has a
|
174 |
+
shape (1, h, w).
|
175 |
+
"""
|
176 |
+
for data_sample, pred_results in zip(data_samples, results_list):
|
177 |
+
if 'pan_results' in pred_results:
|
178 |
+
data_sample.pred_panoptic_seg = pred_results['pan_results']
|
179 |
+
|
180 |
+
if 'ins_results' in pred_results:
|
181 |
+
data_sample.pred_instances = pred_results['ins_results']
|
182 |
+
|
183 |
+
assert 'sem_results' not in pred_results
|
184 |
+
|
185 |
+
return data_samples
|
186 |
+
|
187 |
+
def add_track_pred_to_datasample(self, data_samples: SampleList, results_list: List[dict]) -> SampleList:
|
188 |
+
for data_sample, pred_results in zip(data_samples, results_list):
|
189 |
+
if 'pan_results' in pred_results:
|
190 |
+
assert self.num_stuff_classes > 0
|
191 |
+
data_sample.pred_track_panoptic_seg = pred_results['pan_results']
|
192 |
+
|
193 |
+
if 'ins_results' in pred_results:
|
194 |
+
bboxes = pred_results['ins_results']['bboxes']
|
195 |
+
labels = pred_results['ins_results']['labels']
|
196 |
+
track_ids = torch.arange(len(bboxes), dtype=labels.dtype, device=bboxes.device) + 1
|
197 |
+
pred_results['ins_results']['instances_id'] = track_ids
|
198 |
+
data_sample.pred_track_instances = pred_results['ins_results']
|
199 |
+
|
200 |
+
if 'pro_results' in pred_results:
|
201 |
+
data_sample.pred_track_proposal = pred_results['pro_results']
|
202 |
+
|
203 |
+
assert 'sem_results' not in pred_results
|
204 |
+
|
205 |
+
return data_samples
|
206 |
+
|
207 |
+
def _forward(
|
208 |
+
self,
|
209 |
+
batch_inputs: Tensor,
|
210 |
+
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
211 |
+
"""Network forward process. Usually includes backbone, neck and head
|
212 |
+
forward without any post-processing.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
216 |
+
batch_data_samples (list[:obj:`DetDataSample`]): The batch
|
217 |
+
data samples. It usually includes information such
|
218 |
+
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
tuple[List[Tensor]]: A tuple of features from ``panoptic_head``
|
222 |
+
forward.
|
223 |
+
"""
|
224 |
+
if isinstance(batch_data_samples[0], TrackDataSample):
|
225 |
+
bs, num_frames, three, h, w = batch_inputs.shape
|
226 |
+
assert three == 3, "Only supporting images with 3 channels."
|
227 |
+
|
228 |
+
x = batch_inputs.reshape((bs * num_frames, three, h, w))
|
229 |
+
feats = self.extract_feat(x)
|
230 |
+
else:
|
231 |
+
feats = self.extract_feat(batch_inputs)
|
232 |
+
results = self.panoptic_head.forward(feats, batch_data_samples)
|
233 |
+
return results
|
234 |
+
|
235 |
+
def open_voc_inference(self, feats, mask_cls_results, mask_pred_results):
|
236 |
+
if len(mask_pred_results.shape) == 5:
|
237 |
+
batch_size = mask_cls_results.shape[0]
|
238 |
+
num_frames = mask_pred_results.shape[2]
|
239 |
+
mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
240 |
+
else:
|
241 |
+
batch_size = mask_cls_results.shape[0]
|
242 |
+
num_frames = 0
|
243 |
+
clip_feat = self.backbone.get_clip_feature(feats[-1])
|
244 |
+
clip_feat_mask = F.interpolate(
|
245 |
+
mask_pred_results,
|
246 |
+
size=clip_feat.shape[-2:],
|
247 |
+
mode='bilinear',
|
248 |
+
align_corners=False
|
249 |
+
)
|
250 |
+
if num_frames > 0:
|
251 |
+
clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
|
252 |
+
clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
|
253 |
+
instance_feat = mask_pool(clip_feat, clip_feat_mask)
|
254 |
+
instance_feat = self.backbone.forward_feat(instance_feat)
|
255 |
+
clip_logit = self.panoptic_head.forward_logit(instance_feat)
|
256 |
+
clip_logit = clip_logit[..., :-1]
|
257 |
+
query_logit = mask_cls_results[..., :-1]
|
258 |
+
|
259 |
+
clip_logit = clip_logit.softmax(-1)
|
260 |
+
query_logit = query_logit.softmax(-1)
|
261 |
+
overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device)
|
262 |
+
|
263 |
+
valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to(
|
264 |
+
torch.float32)[..., None]
|
265 |
+
alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking
|
266 |
+
beta = torch.ones_like(clip_logit) * self.beta * valid_masking
|
267 |
+
|
268 |
+
cls_logits_seen = (
|
269 |
+
(query_logit ** (1 - alpha) * clip_logit ** alpha).log()
|
270 |
+
* overlapping_mask
|
271 |
+
)
|
272 |
+
cls_logits_unseen = (
|
273 |
+
(query_logit ** (1 - beta) * clip_logit ** beta).log()
|
274 |
+
* (1 - overlapping_mask)
|
275 |
+
)
|
276 |
+
cls_results = cls_logits_seen + cls_logits_unseen
|
277 |
+
is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:]
|
278 |
+
mask_cls_results = torch.cat([
|
279 |
+
cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1)
|
280 |
+
mask_cls_results = torch.log(mask_cls_results + 1e-8)
|
281 |
+
return mask_cls_results
|
app/models/detectors/rapsam.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmdet.registry import MODELS
|
2 |
+
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
3 |
+
|
4 |
+
from mmdet.models.detectors import SingleStageDetector
|
5 |
+
|
6 |
+
from .mask2former_vid import Mask2formerVideo
|
7 |
+
|
8 |
+
|
9 |
+
@MODELS.register_module()
|
10 |
+
class YOSOVideoSam(Mask2formerVideo):
|
11 |
+
OVERLAPPING = None
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
backbone: ConfigType,
|
15 |
+
neck: OptConfigType = None,
|
16 |
+
panoptic_head: OptConfigType = None,
|
17 |
+
panoptic_fusion_head: OptConfigType = None,
|
18 |
+
train_cfg: OptConfigType = None,
|
19 |
+
test_cfg: OptConfigType = None,
|
20 |
+
data_preprocessor: OptConfigType = None,
|
21 |
+
inference_sam: bool = False,
|
22 |
+
init_cfg: OptMultiConfig = None
|
23 |
+
):
|
24 |
+
super(SingleStageDetector, self).__init__(
|
25 |
+
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
26 |
+
self.backbone = MODELS.build(backbone)
|
27 |
+
if neck is not None:
|
28 |
+
self.neck = MODELS.build(neck)
|
29 |
+
|
30 |
+
panoptic_head_ = panoptic_head.deepcopy()
|
31 |
+
panoptic_head_.update(train_cfg=train_cfg)
|
32 |
+
panoptic_head_.update(test_cfg=test_cfg)
|
33 |
+
self.panoptic_head = MODELS.build(panoptic_head_)
|
34 |
+
|
35 |
+
panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
|
36 |
+
panoptic_fusion_head_.update(test_cfg=test_cfg)
|
37 |
+
self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
|
38 |
+
|
39 |
+
self.num_things_classes = self.panoptic_head.num_things_classes
|
40 |
+
self.num_stuff_classes = self.panoptic_head.num_stuff_classes
|
41 |
+
self.num_classes = self.panoptic_head.num_classes
|
42 |
+
|
43 |
+
self.train_cfg = train_cfg
|
44 |
+
self.test_cfg = test_cfg
|
45 |
+
|
46 |
+
self.alpha = 0.4
|
47 |
+
self.beta = 0.8
|
48 |
+
|
49 |
+
self.inference_sam = inference_sam
|
50 |
+
|
51 |
+
def predict_with_point(self, x, batch_data_samples):
|
52 |
+
feats = self.extract_feat(x)
|
53 |
+
mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples)
|
54 |
+
|
55 |
+
if 'gt_instances_collected' not in batch_data_samples[0]:
|
56 |
+
results_list = self.panoptic_fusion_head.predict(
|
57 |
+
mask_cls_results,
|
58 |
+
mask_pred_results,
|
59 |
+
batch_data_samples,
|
60 |
+
iou_results=iou_results,
|
61 |
+
rescale=False
|
62 |
+
)
|
63 |
+
mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
|
64 |
+
mask_cls_results = mask_cls_results
|
65 |
+
|
66 |
+
return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
|
app/models/heads/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rapsam_head import RapSAMVideoHead
|
app/models/heads/mask2former_vid.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import os
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.distributed as dist
|
10 |
+
from mmcv.cnn import Conv2d
|
11 |
+
from mmdet.models import Mask2FormerTransformerDecoder
|
12 |
+
from mmengine.dist import get_dist_info
|
13 |
+
from mmengine.model import caffe2_xavier_init, ModuleList
|
14 |
+
from torch import Tensor
|
15 |
+
from mmdet.models.layers import MLP, inverse_sigmoid
|
16 |
+
from mmdet.models.layers import coordinate_to_encoding
|
17 |
+
from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
|
18 |
+
|
19 |
+
from mmdet.registry import MODELS, TASK_UTILS
|
20 |
+
from mmdet.structures import SampleList, TrackDataSample
|
21 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig)
|
22 |
+
from mmdet.models.layers import SinePositionalEncoding3D
|
23 |
+
from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
|
24 |
+
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
25 |
+
from app.models.utils import mask_pool
|
26 |
+
|
27 |
+
|
28 |
+
@MODELS.register_module()
|
29 |
+
class Mask2FormerVideoHead(AnchorFreeHead):
|
30 |
+
"""Implements the Mask2Former head.
|
31 |
+
|
32 |
+
See `Masked-attention Mask Transformer for Universal Image
|
33 |
+
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
in_channels (list[int]): Number of channels in the input feature map.
|
37 |
+
feat_channels (int): Number of channels for features.
|
38 |
+
out_channels (int): Number of channels for output.
|
39 |
+
num_things_classes (int): Number of things.
|
40 |
+
num_stuff_classes (int): Number of stuff.
|
41 |
+
num_queries (int): Number of query in Transformer decoder.
|
42 |
+
pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
|
43 |
+
decoder. Defaults to None.
|
44 |
+
enforce_decoder_input_project (bool, optional): Whether to add
|
45 |
+
a layer to change the embed_dim of tranformer encoder in
|
46 |
+
pixel decoder to the embed_dim of transformer decoder.
|
47 |
+
Defaults to False.
|
48 |
+
transformer_decoder (:obj:`ConfigDict` or dict): Config for
|
49 |
+
transformer decoder. Defaults to None.
|
50 |
+
positional_encoding (:obj:`ConfigDict` or dict): Config for
|
51 |
+
transformer decoder position encoding. Defaults to
|
52 |
+
dict(num_feats=128, normalize=True).
|
53 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of the classification
|
54 |
+
loss. Defaults to None.
|
55 |
+
loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
|
56 |
+
Defaults to None.
|
57 |
+
loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
|
58 |
+
Defaults to None.
|
59 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
60 |
+
Mask2Former head.
|
61 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
62 |
+
Mask2Former head.
|
63 |
+
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
|
64 |
+
dict], optional): Initialization config dict. Defaults to None.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
in_channels: List[int],
|
69 |
+
feat_channels: int,
|
70 |
+
out_channels: int,
|
71 |
+
num_mask_tokens: int = 1,
|
72 |
+
num_things_classes: int = 80,
|
73 |
+
num_stuff_classes: int = 53,
|
74 |
+
num_queries: int = 100,
|
75 |
+
num_transformer_feat_level: int = 3,
|
76 |
+
pixel_decoder: ConfigType = ...,
|
77 |
+
enforce_decoder_input_project: bool = False,
|
78 |
+
transformer_decoder: ConfigType = ...,
|
79 |
+
positional_encoding: ConfigType = None,
|
80 |
+
loss_cls: ConfigType = None,
|
81 |
+
loss_mask: ConfigType = None,
|
82 |
+
loss_dice: ConfigType = None,
|
83 |
+
train_cfg: OptConfigType = None,
|
84 |
+
test_cfg: OptConfigType = None,
|
85 |
+
init_cfg: OptMultiConfig = None,
|
86 |
+
# ov configs
|
87 |
+
sphere_cls: bool = False,
|
88 |
+
ov_classifier_name: Optional[str] = None,
|
89 |
+
logit: Optional[int] = None,
|
90 |
+
use_adaptor = False,
|
91 |
+
**kwargs) -> None:
|
92 |
+
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
|
93 |
+
self.use_adaptor = use_adaptor
|
94 |
+
|
95 |
+
self.num_mask_tokens = num_mask_tokens
|
96 |
+
self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels)
|
97 |
+
self.pb_embedding = nn.Embedding(2, feat_channels)
|
98 |
+
self.pos_linear = nn.Linear(2 * feat_channels, feat_channels)
|
99 |
+
|
100 |
+
self.num_things_classes = num_things_classes
|
101 |
+
self.num_stuff_classes = num_stuff_classes
|
102 |
+
self.num_classes = self.num_things_classes + self.num_stuff_classes
|
103 |
+
self.num_queries = num_queries
|
104 |
+
self.num_transformer_feat_level = num_transformer_feat_level
|
105 |
+
self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
|
106 |
+
self.num_transformer_decoder_layers = transformer_decoder.num_layers
|
107 |
+
# assert pixel_decoder.encoder.layer_cfg. \
|
108 |
+
# self_attn_cfg.num_levels == num_transformer_feat_level
|
109 |
+
pixel_decoder_ = copy.deepcopy(pixel_decoder)
|
110 |
+
pixel_decoder_.update(
|
111 |
+
in_channels=in_channels,
|
112 |
+
feat_channels=feat_channels,
|
113 |
+
out_channels=out_channels)
|
114 |
+
self.pixel_decoder = MODELS.build(pixel_decoder_)
|
115 |
+
self.transformer_decoder = Mask2FormerTransformerDecoder(
|
116 |
+
**transformer_decoder)
|
117 |
+
self.decoder_embed_dims = self.transformer_decoder.embed_dims
|
118 |
+
|
119 |
+
self.decoder_input_projs = ModuleList()
|
120 |
+
# from low resolution to high resolution
|
121 |
+
for _ in range(num_transformer_feat_level):
|
122 |
+
if (self.decoder_embed_dims != feat_channels
|
123 |
+
or enforce_decoder_input_project):
|
124 |
+
self.decoder_input_projs.append(
|
125 |
+
Conv2d(
|
126 |
+
feat_channels, self.decoder_embed_dims, kernel_size=1))
|
127 |
+
else:
|
128 |
+
self.decoder_input_projs.append(nn.Identity())
|
129 |
+
self.decoder_positional_encoding = SinePositionalEncoding3D(
|
130 |
+
**positional_encoding)
|
131 |
+
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
|
132 |
+
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
|
133 |
+
# from low resolution to high resolution
|
134 |
+
self.level_embed = nn.Embedding(self.num_transformer_feat_level,
|
135 |
+
feat_channels)
|
136 |
+
|
137 |
+
if not sphere_cls:
|
138 |
+
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
139 |
+
self.mask_embed = nn.Sequential(
|
140 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
141 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
142 |
+
nn.Linear(feat_channels, out_channels))
|
143 |
+
|
144 |
+
self.iou_embed = nn.Sequential(
|
145 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
146 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
147 |
+
nn.Linear(feat_channels, 1))
|
148 |
+
|
149 |
+
self.test_cfg = test_cfg
|
150 |
+
self.train_cfg = train_cfg
|
151 |
+
if train_cfg:
|
152 |
+
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
|
153 |
+
self.sampler = TASK_UTILS.build(
|
154 |
+
self.train_cfg['sampler'], default_args=dict(context=self))
|
155 |
+
self.num_points = self.train_cfg.get('num_points', 12544)
|
156 |
+
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
|
157 |
+
self.importance_sample_ratio = self.train_cfg.get(
|
158 |
+
'importance_sample_ratio', 0.75)
|
159 |
+
|
160 |
+
self.class_weight = loss_cls.class_weight
|
161 |
+
self.loss_cls = MODELS.build(loss_cls)
|
162 |
+
self.loss_mask = MODELS.build(loss_mask)
|
163 |
+
self.loss_dice = MODELS.build(loss_dice)
|
164 |
+
|
165 |
+
# prepare OV things
|
166 |
+
# OV cls embed
|
167 |
+
if sphere_cls:
|
168 |
+
rank, world_size = get_dist_info()
|
169 |
+
if ov_classifier_name is None:
|
170 |
+
_dim = 1024 # temporally hard code
|
171 |
+
cls_embed = torch.empty(self.num_classes, _dim)
|
172 |
+
torch.nn.init.orthogonal_(cls_embed)
|
173 |
+
cls_embed = cls_embed[:, None]
|
174 |
+
else:
|
175 |
+
# ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth")
|
176 |
+
ov_path = os.path.join('./models/', f"{ov_classifier_name}.pth")
|
177 |
+
cls_embed = torch.load(ov_path)
|
178 |
+
cls_embed_norm = cls_embed.norm(p=2, dim=-1)
|
179 |
+
assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm))
|
180 |
+
if self.loss_cls and self.loss_cls.use_sigmoid:
|
181 |
+
pass
|
182 |
+
else:
|
183 |
+
_dim = cls_embed.size(2)
|
184 |
+
_prototypes = cls_embed.size(1)
|
185 |
+
|
186 |
+
if rank == 0:
|
187 |
+
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
188 |
+
# back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
|
189 |
+
else:
|
190 |
+
back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
|
191 |
+
if world_size > 1:
|
192 |
+
dist.broadcast(back_token, src=0)
|
193 |
+
back_token = back_token.to(device='cpu')
|
194 |
+
cls_embed = torch.cat([
|
195 |
+
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
196 |
+
], dim=0)
|
197 |
+
self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False)
|
198 |
+
|
199 |
+
# cls embd proj
|
200 |
+
cls_embed_dim = self.cls_embed.size(0)
|
201 |
+
self.cls_proj = nn.Sequential(
|
202 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
203 |
+
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
|
204 |
+
nn.Linear(feat_channels, cls_embed_dim)
|
205 |
+
)
|
206 |
+
|
207 |
+
# Haobo Yuan:
|
208 |
+
# For the logit_scale, I refer to this issue.
|
209 |
+
# https://github.com/openai/CLIP/issues/46#issuecomment-945062212
|
210 |
+
# https://github.com/openai/CLIP/issues/46#issuecomment-782558799
|
211 |
+
# Based on my understanding, it is a mistake of CLIP.
|
212 |
+
# Because they mention that they refer to InstDisc (Wu, 2018) paper.
|
213 |
+
# InstDisc set a non-learnable temperature to np.log(1 / 0.07).
|
214 |
+
# 4.6052 is np.log(1 / 0.01)
|
215 |
+
# np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01)
|
216 |
+
if logit is None:
|
217 |
+
logit_scale = torch.tensor(4.6052, dtype=torch.float32)
|
218 |
+
else:
|
219 |
+
logit_scale = torch.tensor(logit, dtype=torch.float32)
|
220 |
+
self.register_buffer('logit_scale', logit_scale, persistent=False)
|
221 |
+
|
222 |
+
# Mask Pooling
|
223 |
+
self.mask_pooling = mask_pool
|
224 |
+
self.mask_pooling_proj = nn.Sequential(
|
225 |
+
nn.LayerNorm(feat_channels),
|
226 |
+
nn.Linear(feat_channels, feat_channels)
|
227 |
+
)
|
228 |
+
|
229 |
+
if use_adaptor:
|
230 |
+
cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8)
|
231 |
+
self.panoptic_attn = MultiheadAttention(**cross_attn_cfg)
|
232 |
+
self.panoptic_norm = nn.LayerNorm(256)
|
233 |
+
if sphere_cls:
|
234 |
+
cls_embed_dim = self.cls_embed.size(0)
|
235 |
+
self.panoptic_cls = nn.Sequential(
|
236 |
+
nn.Linear(feat_channels, cls_embed_dim)
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
raise NotImplementedError
|
240 |
+
self.prompt_attn = MultiheadAttention(**cross_attn_cfg)
|
241 |
+
self.prompt_norm = nn.LayerNorm(256)
|
242 |
+
self.prompt_iou = nn.Linear(256, 1)
|
243 |
+
|
244 |
+
def init_weights(self) -> None:
|
245 |
+
for m in self.decoder_input_projs:
|
246 |
+
if isinstance(m, Conv2d):
|
247 |
+
caffe2_xavier_init(m, bias=0)
|
248 |
+
|
249 |
+
self.pixel_decoder.init_weights()
|
250 |
+
|
251 |
+
for p in self.transformer_decoder.parameters():
|
252 |
+
if p.dim() > 1:
|
253 |
+
nn.init.xavier_normal_(p)
|
254 |
+
|
255 |
+
def forward_logit(self, cls_embd):
|
256 |
+
cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
|
257 |
+
cls_pred = cls_pred.max(-1).values
|
258 |
+
cls_pred = self.logit_scale.exp() * cls_pred
|
259 |
+
return cls_pred
|
260 |
+
|
261 |
+
def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
|
262 |
+
attn_mask_target_size: Tuple[int, int],
|
263 |
+
num_frames: int = 0) -> Tuple[Tensor]:
|
264 |
+
"""Forward for head part which is called after every decoder layer.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
decoder_out (Tensor): in shape (batch_size, num_queries, c).
|
268 |
+
mask_feature (Tensor): in shape (batch_size, c, h, w).
|
269 |
+
attn_mask_target_size (tuple[int, int]): target attention
|
270 |
+
mask size.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
tuple: A tuple contain three elements.
|
274 |
+
|
275 |
+
- cls_pred (Tensor): Classification scores in shape \
|
276 |
+
(batch_size, num_queries, cls_out_channels). \
|
277 |
+
Note `cls_out_channels` should includes background.
|
278 |
+
- mask_pred (Tensor): Mask scores in shape \
|
279 |
+
(batch_size, num_queries,h, w).
|
280 |
+
- attn_mask (Tensor): Attention mask in shape \
|
281 |
+
(batch_size * num_heads, num_queries, h, w).
|
282 |
+
- num_frames: How many frames are there in video.
|
283 |
+
"""
|
284 |
+
decoder_out = self.transformer_decoder.post_norm(decoder_out)
|
285 |
+
# shape (num_queries, batch_size, c)
|
286 |
+
if isinstance(self.cls_embed, nn.Module):
|
287 |
+
cls_pred = self.cls_embed(decoder_out)
|
288 |
+
# shape (num_queries, batch_size, c)
|
289 |
+
mask_embed = self.mask_embed(decoder_out)
|
290 |
+
# shape (num_queries, batch_size, h, w)
|
291 |
+
mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
|
292 |
+
|
293 |
+
if not isinstance(self.cls_embed, nn.Module):
|
294 |
+
maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach())
|
295 |
+
maskpool_embd = self.mask_pooling_proj(maskpool_embd)
|
296 |
+
cls_embd = self.cls_proj(maskpool_embd + decoder_out)
|
297 |
+
cls_pred = self.forward_logit(cls_embd)
|
298 |
+
|
299 |
+
iou_pred = self.iou_embed(decoder_out)
|
300 |
+
|
301 |
+
if num_frames > 0:
|
302 |
+
assert len(mask_pred.shape) == 4
|
303 |
+
assert mask_pred.shape[2] % num_frames == 0
|
304 |
+
frame_h = mask_pred.shape[2] // num_frames
|
305 |
+
num_q = mask_pred.shape[1]
|
306 |
+
_mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2)
|
307 |
+
attn_mask = F.interpolate(
|
308 |
+
_mask_pred,
|
309 |
+
attn_mask_target_size,
|
310 |
+
mode='bilinear',
|
311 |
+
align_corners=False)
|
312 |
+
attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3)
|
313 |
+
else:
|
314 |
+
attn_mask = F.interpolate(
|
315 |
+
mask_pred,
|
316 |
+
attn_mask_target_size,
|
317 |
+
mode='bilinear',
|
318 |
+
align_corners=False)
|
319 |
+
# shape (num_queries, batch_size, h, w) ->
|
320 |
+
# (batch_size * num_head, num_queries, h, w)
|
321 |
+
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
|
322 |
+
(1, self.num_heads, 1, 1)).flatten(0, 1)
|
323 |
+
attn_mask = attn_mask.sigmoid() < 0.5
|
324 |
+
attn_mask = attn_mask.detach()
|
325 |
+
|
326 |
+
return cls_pred, mask_pred, iou_pred, attn_mask
|
327 |
+
|
328 |
+
def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
|
329 |
+
"""Forward function.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
x (list[Tensor]): Multi scale Features from the
|
333 |
+
upstream network, each is a 4D-tensor.
|
334 |
+
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
335 |
+
Samples. It usually includes information such as
|
336 |
+
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
tuple[list[Tensor]]: A tuple contains two elements.
|
340 |
+
|
341 |
+
- cls_pred_list (list[Tensor)]: Classification logits \
|
342 |
+
for each decoder layer. Each is a 3D-tensor with shape \
|
343 |
+
(batch_size, num_queries, cls_out_channels). \
|
344 |
+
Note `cls_out_channels` should includes background.
|
345 |
+
- mask_pred_list (list[Tensor]): Mask logits for each \
|
346 |
+
decoder layer. Each with shape (batch_size, num_queries, \
|
347 |
+
h, w).
|
348 |
+
"""
|
349 |
+
batch_img_metas = []
|
350 |
+
if isinstance(batch_data_samples[0], TrackDataSample):
|
351 |
+
for track_sample in batch_data_samples:
|
352 |
+
cur_list = []
|
353 |
+
for det_sample in track_sample:
|
354 |
+
cur_list.append(det_sample.metainfo)
|
355 |
+
batch_img_metas.append(cur_list)
|
356 |
+
num_frames = len(batch_img_metas[0])
|
357 |
+
else:
|
358 |
+
for data_sample in batch_data_samples:
|
359 |
+
batch_img_metas.append(data_sample.metainfo)
|
360 |
+
num_frames = 0
|
361 |
+
batch_size = len(batch_img_metas)
|
362 |
+
#(bs_nf, c, h,w)
|
363 |
+
mask_features, multi_scale_memorys = self.pixel_decoder(x)
|
364 |
+
if num_frames > 0:
|
365 |
+
mask_features = mask_features.unflatten(0, (batch_size, num_frames))
|
366 |
+
mask_features = mask_features.transpose(1, 2).flatten(2, 3) #(bs, c, nf*h,w)
|
367 |
+
decoder_inputs = []
|
368 |
+
decoder_positional_encodings = []
|
369 |
+
for i in range(self.num_transformer_feat_level):
|
370 |
+
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) #(bs_nf, c, h,w)
|
371 |
+
decoder_input = decoder_input.flatten(2).permute(0, 2, 1) #(bs_nf,h*w, c)
|
372 |
+
if num_frames > 0:
|
373 |
+
decoder_input = decoder_input.unflatten(0, (batch_size, num_frames))
|
374 |
+
decoder_input = decoder_input.flatten(1, 2) #(bs, nf*h*w, c)
|
375 |
+
level_embed = self.level_embed.weight[i].view(1, 1, -1)
|
376 |
+
decoder_input = decoder_input + level_embed
|
377 |
+
|
378 |
+
# shape (batch_size, c, h, w) -> (batch_size, h*w, c)
|
379 |
+
num_frames_real = 1 if num_frames == 0 else num_frames
|
380 |
+
mask = decoder_input.new_zeros(
|
381 |
+
(batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:],
|
382 |
+
dtype=torch.bool)
|
383 |
+
decoder_positional_encoding = self.decoder_positional_encoding(
|
384 |
+
mask)
|
385 |
+
decoder_positional_encoding = decoder_positional_encoding.transpose(
|
386 |
+
1, 2).flatten(2).permute(0, 2, 1)
|
387 |
+
decoder_inputs.append(decoder_input) #(bs, nf*h*w, c)
|
388 |
+
decoder_positional_encodings.append(decoder_positional_encoding) #(bs, nf*h*w, c)
|
389 |
+
|
390 |
+
if self.prompt_training:
|
391 |
+
query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo(
|
392 |
+
batch_data_samples)
|
393 |
+
query_embed = coordinate_to_encoding(input_query_bbox.sigmoid())
|
394 |
+
query_embed = self.pos_linear(query_embed)
|
395 |
+
else:
|
396 |
+
query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1))
|
397 |
+
query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1))
|
398 |
+
self_attn_mask = None
|
399 |
+
|
400 |
+
cls_pred_list = []
|
401 |
+
mask_pred_list = []
|
402 |
+
iou_pred_list = []
|
403 |
+
cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head(
|
404 |
+
query_feat, mask_features, multi_scale_memorys[0].shape[-2:],
|
405 |
+
num_frames=num_frames
|
406 |
+
)
|
407 |
+
cls_pred_list.append(cls_pred)
|
408 |
+
iou_pred_list.append(iou_pred)
|
409 |
+
if num_frames > 0: #(bs, 100, nf*h, w)-->(bs, 100, nf, h, w)
|
410 |
+
mask_pred = mask_pred.unflatten(2, (num_frames, -1))
|
411 |
+
mask_pred_list.append(mask_pred)
|
412 |
+
|
413 |
+
for i in range(self.num_transformer_decoder_layers):
|
414 |
+
level_idx = i % self.num_transformer_feat_level
|
415 |
+
# if a mask is all True(all background), then set it all False.
|
416 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
417 |
+
|
418 |
+
# cross_attn + self_attn
|
419 |
+
layer = self.transformer_decoder.layers[i]
|
420 |
+
query_feat = layer(
|
421 |
+
query=query_feat, #(bs, 100, c)
|
422 |
+
key=decoder_inputs[level_idx], #(bs, nf*h*w, c)
|
423 |
+
value=decoder_inputs[level_idx],
|
424 |
+
query_pos=query_embed,
|
425 |
+
key_pos=decoder_positional_encodings[level_idx],
|
426 |
+
cross_attn_mask=attn_mask,
|
427 |
+
self_attn_mask=self_attn_mask,
|
428 |
+
query_key_padding_mask=None,
|
429 |
+
# here we do not apply masking on padded region
|
430 |
+
key_padding_mask=None)
|
431 |
+
cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head(
|
432 |
+
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:],
|
433 |
+
num_frames=num_frames
|
434 |
+
)
|
435 |
+
|
436 |
+
cls_pred_list.append(cls_pred)
|
437 |
+
iou_pred_list.append(iou_pred)
|
438 |
+
if num_frames > 0:
|
439 |
+
mask_pred = mask_pred.unflatten(2, (num_frames, -1))
|
440 |
+
mask_pred_list.append(mask_pred)
|
441 |
+
|
442 |
+
if self.use_adaptor:
|
443 |
+
keys = mask_features.flatten(2).transpose(1, 2).contiguous()
|
444 |
+
h, w = mask_features.shape[-2] // num_frames_real, mask_features.shape[-1]
|
445 |
+
mask = decoder_input.new_zeros((batch_size, num_frames_real, h, w), dtype=torch.bool)
|
446 |
+
key_pos = self.decoder_positional_encoding(mask)
|
447 |
+
key_pos = key_pos.transpose(1, 2).flatten(2).permute(0, 2, 1)
|
448 |
+
if not self.prompt_training:
|
449 |
+
object_kernels = self.panoptic_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed)
|
450 |
+
object_kernels = self.panoptic_norm(object_kernels)
|
451 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
452 |
+
|
453 |
+
cls_embd = self.panoptic_cls(object_kernels)
|
454 |
+
cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed)
|
455 |
+
cls_scores = cls_scores.max(-1).values
|
456 |
+
cls_scores = self.logit_scale.exp() * cls_scores
|
457 |
+
|
458 |
+
if num_frames > 0:
|
459 |
+
mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1)))
|
460 |
+
else:
|
461 |
+
mask_pred_list.append(mask_preds)
|
462 |
+
cls_pred_list.append(cls_scores)
|
463 |
+
iou_pred_list.append(iou_pred_list[-1])
|
464 |
+
else:
|
465 |
+
object_kernels = self.prompt_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed)
|
466 |
+
object_kernels = self.prompt_norm(object_kernels)
|
467 |
+
iou_preds = self.prompt_iou(object_kernels)
|
468 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
469 |
+
|
470 |
+
if num_frames > 0:
|
471 |
+
mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1)))
|
472 |
+
else:
|
473 |
+
mask_pred_list.append(mask_preds)
|
474 |
+
cls_pred_list.append(cls_pred_list[-1])
|
475 |
+
iou_pred_list.append(iou_preds)
|
476 |
+
|
477 |
+
return cls_pred_list, mask_pred_list, iou_pred_list, query_feat
|
478 |
+
|
479 |
+
def predict(self, x: Tuple[Tensor],
|
480 |
+
batch_data_samples: SampleList,
|
481 |
+
return_query=False,
|
482 |
+
) -> Tuple[Tensor, ...]:
|
483 |
+
"""Test without augmentaton.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
return_query:
|
487 |
+
x (tuple[Tensor]): Multi-level features from the
|
488 |
+
upstream network, each is a 4D-tensor.
|
489 |
+
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
490 |
+
Samples. It usually includes information such as
|
491 |
+
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
tuple[Tensor]: A tuple contains two tensors.
|
495 |
+
|
496 |
+
- mask_cls_results (Tensor): Mask classification logits,\
|
497 |
+
shape (batch_size, num_queries, cls_out_channels).
|
498 |
+
Note `cls_out_channels` should includes background.
|
499 |
+
- mask_pred_results (Tensor): Mask logits, shape \
|
500 |
+
(batch_size, num_queries, h, w).
|
501 |
+
"""
|
502 |
+
self.prompt_training = False
|
503 |
+
data_sample = batch_data_samples[0]
|
504 |
+
if isinstance(data_sample, TrackDataSample):
|
505 |
+
img_shape = data_sample[0].metainfo['batch_input_shape']
|
506 |
+
num_frames = len(data_sample)
|
507 |
+
else:
|
508 |
+
if 'gt_instances_collected' in data_sample:
|
509 |
+
self.prompt_training = True
|
510 |
+
img_shape = data_sample.metainfo['batch_input_shape']
|
511 |
+
num_frames = 0
|
512 |
+
all_cls_scores, all_mask_preds, all_iou_preds, query_feat = self(x, batch_data_samples)
|
513 |
+
mask_cls_results = all_cls_scores[-1]
|
514 |
+
mask_pred_results = all_mask_preds[-1]
|
515 |
+
iou_results = all_iou_preds[-1]
|
516 |
+
|
517 |
+
if num_frames > 0:
|
518 |
+
mask_pred_results = mask_pred_results.flatten(1, 2)
|
519 |
+
mask_pred_results = F.interpolate(
|
520 |
+
mask_pred_results,
|
521 |
+
size=(img_shape[0], img_shape[1]),
|
522 |
+
mode='bilinear',
|
523 |
+
align_corners=False)
|
524 |
+
if num_frames > 0:
|
525 |
+
num_queries = mask_cls_results.shape[1]
|
526 |
+
mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames))
|
527 |
+
|
528 |
+
if return_query:
|
529 |
+
return mask_cls_results, mask_pred_results, query_feat, iou_results
|
530 |
+
else:
|
531 |
+
return mask_cls_results, mask_pred_results, iou_results
|
532 |
+
|
533 |
+
def prepare_for_dn_mo(self, batch_data_samples):
|
534 |
+
scalar, noise_scale = 100, 0.4
|
535 |
+
gt_instances = [t.gt_instances_collected for t in batch_data_samples]
|
536 |
+
|
537 |
+
point_coords = torch.stack([inst.point_coords for inst in gt_instances])
|
538 |
+
pb_labels = torch.stack([inst['pb_labels'] for inst in gt_instances])
|
539 |
+
labels = torch.zeros_like(pb_labels).long()
|
540 |
+
|
541 |
+
boxes = point_coords # + boxes
|
542 |
+
|
543 |
+
factors = []
|
544 |
+
for i, data_sample in enumerate(batch_data_samples):
|
545 |
+
h, w, = data_sample.metainfo['img_shape']
|
546 |
+
factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1)
|
547 |
+
factors.append(factor)
|
548 |
+
factors = torch.stack(factors, 0)
|
549 |
+
|
550 |
+
boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ????
|
551 |
+
# box_start = [t['box_start'] for t in targets]
|
552 |
+
box_start = [len(point) for point in point_coords]
|
553 |
+
|
554 |
+
known_labels = labels
|
555 |
+
known_pb_labels = pb_labels
|
556 |
+
known_bboxs = boxes
|
557 |
+
|
558 |
+
known_labels_expaned = known_labels.clone()
|
559 |
+
known_pb_labels_expaned = known_pb_labels.clone()
|
560 |
+
known_bbox_expand = known_bboxs.clone()
|
561 |
+
|
562 |
+
if noise_scale > 0 and self.training:
|
563 |
+
diff = torch.zeros_like(known_bbox_expand)
|
564 |
+
diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2
|
565 |
+
diff[:, :, 2:] = known_bbox_expand[:, :, 2:]
|
566 |
+
# add very small noise to input points; no box
|
567 |
+
sc = 0.01
|
568 |
+
for i, st in enumerate(box_start):
|
569 |
+
diff[i, :st] = diff[i, :st] * sc
|
570 |
+
known_bbox_expand += torch.mul(
|
571 |
+
(torch.rand_like(known_bbox_expand) * 2 - 1.0),
|
572 |
+
diff) * noise_scale
|
573 |
+
|
574 |
+
known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)
|
575 |
+
|
576 |
+
input_label_embed = self.pb_embedding(known_pb_labels_expaned)
|
577 |
+
|
578 |
+
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
579 |
+
|
580 |
+
input_label_embed = input_label_embed.repeat_interleave(
|
581 |
+
self.num_mask_tokens,
|
582 |
+
1) + self.mask_tokens.weight.unsqueeze(0).repeat(
|
583 |
+
input_label_embed.shape[0], input_label_embed.shape[1], 1)
|
584 |
+
input_bbox_embed = input_bbox_embed.repeat_interleave(
|
585 |
+
self.num_mask_tokens, 1)
|
586 |
+
|
587 |
+
single_pad = self.num_mask_tokens
|
588 |
+
|
589 |
+
# NOTE scalar is modified to 100, each click cannot see each other
|
590 |
+
scalar = int(input_label_embed.shape[1] / self.num_mask_tokens)
|
591 |
+
|
592 |
+
pad_size = input_label_embed.shape[1]
|
593 |
+
|
594 |
+
if input_label_embed.shape[1] > 0:
|
595 |
+
input_query_label = input_label_embed
|
596 |
+
input_query_bbox = input_bbox_embed
|
597 |
+
|
598 |
+
tgt_size = pad_size
|
599 |
+
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
600 |
+
# match query cannot see the reconstruct
|
601 |
+
attn_mask[pad_size:, :pad_size] = True
|
602 |
+
# reconstruct cannot see each other
|
603 |
+
for i in range(scalar):
|
604 |
+
if i == 0:
|
605 |
+
attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True
|
606 |
+
if i == scalar - 1:
|
607 |
+
attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True
|
608 |
+
else:
|
609 |
+
attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True
|
610 |
+
attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True
|
611 |
+
mask_dict = {
|
612 |
+
'known_lbs_bboxes': (known_labels, known_bboxs),
|
613 |
+
'pad_size': pad_size,
|
614 |
+
'scalar': scalar,
|
615 |
+
}
|
616 |
+
return input_query_label, input_query_bbox, attn_mask, mask_dict
|
app/models/heads/rapsam_head.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
from mmdet.models.layers import coordinate_to_encoding
|
9 |
+
|
10 |
+
from mmdet.registry import MODELS, TASK_UTILS
|
11 |
+
from mmdet.structures import SampleList, TrackDataSample
|
12 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig)
|
13 |
+
from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
|
14 |
+
|
15 |
+
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
16 |
+
|
17 |
+
from .mask2former_vid import Mask2FormerVideoHead
|
18 |
+
from .yoso_head import CrossAttenHead, KernelUpdator
|
19 |
+
|
20 |
+
@MODELS.register_module()
|
21 |
+
class RapSAMVideoHead(Mask2FormerVideoHead):
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
frozen_head=False,
|
25 |
+
frozen_pred=False,
|
26 |
+
use_adaptor=False,
|
27 |
+
prompt_with_kernel_updator=False,
|
28 |
+
panoptic_with_kernel_updator=False,
|
29 |
+
num_mask_tokens = 1,
|
30 |
+
num_stages = 3,
|
31 |
+
use_kernel_updator=False,
|
32 |
+
sphere_cls = False,
|
33 |
+
ov_classifier_name = None,
|
34 |
+
temperature=0.1,
|
35 |
+
feat_channels=256,
|
36 |
+
num_things_classes: int = 80,
|
37 |
+
num_stuff_classes: int = 53,
|
38 |
+
num_queries: int = 100,
|
39 |
+
loss_cls: ConfigType = None,
|
40 |
+
loss_mask: ConfigType = None,
|
41 |
+
loss_dice: ConfigType = None,
|
42 |
+
train_cfg: OptConfigType = None,
|
43 |
+
test_cfg: OptConfigType = None,
|
44 |
+
init_cfg: OptMultiConfig = None,
|
45 |
+
matching_whole_map: bool = False,
|
46 |
+
enable_box_query: bool = False,
|
47 |
+
**kwargs) -> None:
|
48 |
+
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
|
49 |
+
self.prompt_with_kernel_updator = prompt_with_kernel_updator
|
50 |
+
self.panoptic_with_kernel_updator = panoptic_with_kernel_updator
|
51 |
+
self.use_adaptor = use_adaptor
|
52 |
+
|
53 |
+
self.num_mask_tokens = num_mask_tokens
|
54 |
+
self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels)
|
55 |
+
self.pb_embedding = nn.Embedding(2, feat_channels)
|
56 |
+
self.pos_linear = nn.Linear(2 * feat_channels, feat_channels)
|
57 |
+
|
58 |
+
self.matching_whole_map = matching_whole_map
|
59 |
+
self.enable_box_query = enable_box_query
|
60 |
+
|
61 |
+
self.num_things_classes = num_things_classes
|
62 |
+
self.num_stuff_classes = num_stuff_classes
|
63 |
+
self.num_classes = self.num_things_classes + self.num_stuff_classes
|
64 |
+
self.num_queries = num_queries
|
65 |
+
self.feat_channels = feat_channels
|
66 |
+
self.num_stages = num_stages
|
67 |
+
self.kernels = nn.Embedding(self.num_queries, feat_channels)
|
68 |
+
self.mask_heads = nn.ModuleList()
|
69 |
+
for _ in range(self.num_stages):
|
70 |
+
self.mask_heads.append(CrossAttenHead(
|
71 |
+
self.num_classes, self.feat_channels, self.num_queries,
|
72 |
+
use_kernel_updator=use_kernel_updator,
|
73 |
+
frozen_head=frozen_head, frozen_pred=frozen_pred,
|
74 |
+
sphere_cls=sphere_cls,
|
75 |
+
ov_classifier_name=ov_classifier_name, with_iou_pred=True))
|
76 |
+
self.temperature = temperature
|
77 |
+
|
78 |
+
if use_adaptor:
|
79 |
+
cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8)
|
80 |
+
if self.panoptic_with_kernel_updator:
|
81 |
+
self.panoptic_attn = KernelUpdator(feat_channels=256)
|
82 |
+
self.panoptic_norm = nn.Identity()
|
83 |
+
if sphere_cls:
|
84 |
+
cls_embed_dim = self.mask_heads[0].fc_cls.size(0)
|
85 |
+
self.panoptic_cls = nn.Sequential(
|
86 |
+
nn.Linear(feat_channels, cls_embed_dim)
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
self.panoptic_cls = nn.Linear(256, self.num_classes+1)
|
91 |
+
else:
|
92 |
+
self.panoptic_attn = MultiheadAttention(**cross_attn_cfg)
|
93 |
+
self.panoptic_norm = nn.LayerNorm(256)
|
94 |
+
if sphere_cls:
|
95 |
+
cls_embed_dim = self.mask_heads[0].fc_cls.size(0)
|
96 |
+
self.panoptic_cls = nn.Sequential(
|
97 |
+
nn.Linear(feat_channels, cls_embed_dim)
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError
|
101 |
+
self.panoptic_cls = nn.Linear(256, self.num_classes+1)
|
102 |
+
|
103 |
+
if self.prompt_with_kernel_updator:
|
104 |
+
self.prompt_attn = KernelUpdator(feat_channels=256)
|
105 |
+
self.prompt_norm = nn.Identity()
|
106 |
+
self.prompt_iou = nn.Linear(256, 1)
|
107 |
+
else:
|
108 |
+
self.prompt_attn = MultiheadAttention(**cross_attn_cfg)
|
109 |
+
self.prompt_norm = nn.LayerNorm(256)
|
110 |
+
self.prompt_iou = nn.Linear(256, 1)
|
111 |
+
|
112 |
+
self.test_cfg = test_cfg
|
113 |
+
self.train_cfg = train_cfg
|
114 |
+
if train_cfg:
|
115 |
+
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
|
116 |
+
self.sampler = TASK_UTILS.build(
|
117 |
+
self.train_cfg['sampler'], default_args=dict(context=self))
|
118 |
+
self.num_points = self.train_cfg.get('num_points', 12544)
|
119 |
+
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
|
120 |
+
self.importance_sample_ratio = self.train_cfg.get(
|
121 |
+
'importance_sample_ratio', 0.75)
|
122 |
+
|
123 |
+
self.class_weight = loss_cls.class_weight
|
124 |
+
self.loss_cls = MODELS.build(loss_cls)
|
125 |
+
self.loss_mask = MODELS.build(loss_mask)
|
126 |
+
self.loss_dice = MODELS.build(loss_dice)
|
127 |
+
|
128 |
+
|
129 |
+
def init_weights(self) -> None:
|
130 |
+
pass
|
131 |
+
|
132 |
+
def forward(self, x, batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
|
133 |
+
batch_img_metas = []
|
134 |
+
if isinstance(batch_data_samples[0], TrackDataSample):
|
135 |
+
for track_sample in batch_data_samples:
|
136 |
+
cur_list = []
|
137 |
+
for det_sample in track_sample:
|
138 |
+
cur_list.append(det_sample.metainfo)
|
139 |
+
batch_img_metas.append(cur_list)
|
140 |
+
num_frames = len(batch_img_metas[0])
|
141 |
+
else:
|
142 |
+
for data_sample in batch_data_samples:
|
143 |
+
batch_img_metas.append(data_sample.metainfo)
|
144 |
+
num_frames = 0
|
145 |
+
bs = len(batch_img_metas)
|
146 |
+
|
147 |
+
all_cls_scores = []
|
148 |
+
all_masks_preds = []
|
149 |
+
all_iou_preds = []
|
150 |
+
if self.prompt_training:
|
151 |
+
input_query_label, input_query_bbox, self_attn_mask, mask_dict = self.prepare_for_dn_mo(
|
152 |
+
batch_data_samples)
|
153 |
+
pos_embed = coordinate_to_encoding(input_query_bbox.sigmoid())
|
154 |
+
pos_embed = self.pos_linear(pos_embed)
|
155 |
+
object_kernels = input_query_label + pos_embed
|
156 |
+
else:
|
157 |
+
object_kernels = self.kernels.weight[None].repeat(bs, 1, 1)
|
158 |
+
self_attn_mask = None
|
159 |
+
mask_features = x
|
160 |
+
if num_frames > 0: # (bs*num_frames, c, h, w) -> (bs, c, num_frames*h, w)
|
161 |
+
mask_features = mask_features.unflatten(0, (bs, num_frames))
|
162 |
+
mask_features = mask_features.transpose(1, 2).flatten(2, 3)
|
163 |
+
|
164 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
165 |
+
for stage in range(self.num_stages):
|
166 |
+
mask_head = self.mask_heads[stage]
|
167 |
+
cls_scores, mask_preds, iou_preds, object_kernels = mask_head(
|
168 |
+
mask_features, object_kernels, mask_preds, self_attn_mask)
|
169 |
+
cls_scores = cls_scores / self.temperature
|
170 |
+
all_iou_preds.append(iou_preds)
|
171 |
+
all_cls_scores.append(cls_scores)
|
172 |
+
if num_frames > 0:
|
173 |
+
#(bs,num_query, num_frames*h, w) --> (bs,num_query,num_frames,h,w)
|
174 |
+
all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
|
175 |
+
else:
|
176 |
+
all_masks_preds.append(mask_preds)
|
177 |
+
|
178 |
+
if self.use_adaptor:
|
179 |
+
keys = mask_features.flatten(2).transpose(1, 2).contiguous()
|
180 |
+
if not self.prompt_training:
|
181 |
+
if self.panoptic_with_kernel_updator:
|
182 |
+
hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float()
|
183 |
+
f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features)
|
184 |
+
object_kernels = self.panoptic_attn(f, object_kernels)
|
185 |
+
object_kernels = self.panoptic_norm(object_kernels)
|
186 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
187 |
+
else:
|
188 |
+
object_kernels = self.panoptic_attn(object_kernels, keys)
|
189 |
+
object_kernels = self.panoptic_norm(object_kernels)
|
190 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
191 |
+
cls_embd = self.panoptic_cls(object_kernels)
|
192 |
+
cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.mask_heads[0].fc_cls)
|
193 |
+
cls_scores = cls_scores.max(-1).values
|
194 |
+
cls_scores = self.mask_heads[0].logit_scale.exp() * cls_scores
|
195 |
+
|
196 |
+
if num_frames > 0:
|
197 |
+
all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
|
198 |
+
else:
|
199 |
+
all_masks_preds.append(mask_preds)
|
200 |
+
all_cls_scores.append(cls_scores)
|
201 |
+
all_iou_preds.append(all_iou_preds[-1])
|
202 |
+
else:
|
203 |
+
if self.prompt_with_kernel_updator:
|
204 |
+
hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float()
|
205 |
+
f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features)
|
206 |
+
object_kernels = self.prompt_attn(f, object_kernels)
|
207 |
+
object_kernels = self.prompt_norm(object_kernels)
|
208 |
+
iou_preds = self.prompt_iou(object_kernels)
|
209 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
210 |
+
else:
|
211 |
+
object_kernels = self.prompt_attn(object_kernels, keys)
|
212 |
+
object_kernels = self.prompt_norm(object_kernels)
|
213 |
+
iou_preds = self.prompt_iou(object_kernels)
|
214 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features)
|
215 |
+
if num_frames > 0:
|
216 |
+
all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1)))
|
217 |
+
else:
|
218 |
+
all_masks_preds.append(mask_preds)
|
219 |
+
all_cls_scores.append(all_cls_scores[-1])
|
220 |
+
all_iou_preds.append(iou_preds)
|
221 |
+
return all_cls_scores, all_masks_preds, all_iou_preds, object_kernels
|
222 |
+
|
223 |
+
def get_targets(self, *args, **kwargs):
|
224 |
+
raise NotImplementedError
|
225 |
+
|
226 |
+
def loss_by_feat(self, *args, **kwargs):
|
227 |
+
raise NotImplementedError
|
app/models/heads/yoso_head.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
import os
|
3 |
+
import torch.distributed as dist
|
4 |
+
from torch import Tensor
|
5 |
+
from mmdet.registry import MODELS, TASK_UTILS
|
6 |
+
from mmdet.models.dense_heads import AnchorFreeHead
|
7 |
+
from mmdet.structures import SampleList
|
8 |
+
from mmdet.models.dense_heads import Mask2FormerHead
|
9 |
+
import math
|
10 |
+
from mmengine.model.weight_init import trunc_normal_
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from mmcv.cnn import build_activation_layer, build_norm_layer
|
15 |
+
|
16 |
+
from mmengine.dist import get_dist_info
|
17 |
+
|
18 |
+
|
19 |
+
@MODELS.register_module()
|
20 |
+
class YOSOHead(Mask2FormerHead):
|
21 |
+
def __init__(self,
|
22 |
+
num_cls_fcs=1,
|
23 |
+
num_mask_fcs=1,
|
24 |
+
sphere_cls=False,
|
25 |
+
ov_classifier_name=None,
|
26 |
+
use_kernel_updator=False,
|
27 |
+
num_stages=3,
|
28 |
+
feat_channels=256,
|
29 |
+
out_channels=256,
|
30 |
+
num_things_classes=80,
|
31 |
+
num_stuff_classes=53,
|
32 |
+
num_classes=133,
|
33 |
+
num_queries=100,
|
34 |
+
temperature=0.1,
|
35 |
+
loss_cls=dict(
|
36 |
+
type='CrossEntropyLoss',
|
37 |
+
use_sigmoid=False,
|
38 |
+
loss_weight=2.0,
|
39 |
+
reduction='mean',
|
40 |
+
class_weight=[1.0] * 133 + [0.1]),
|
41 |
+
loss_mask=dict(
|
42 |
+
type='CrossEntropyLoss',
|
43 |
+
use_sigmoid=True,
|
44 |
+
reduction='mean',
|
45 |
+
loss_weight=5.0),
|
46 |
+
loss_dice=dict(
|
47 |
+
type='DiceLoss',
|
48 |
+
use_sigmoid=True,
|
49 |
+
activate=True,
|
50 |
+
reduction='mean',
|
51 |
+
naive_dice=True,
|
52 |
+
eps=1.0,
|
53 |
+
loss_weight=5.0),
|
54 |
+
train_cfg=None,
|
55 |
+
test_cfg=None,
|
56 |
+
init_cfg=None):
|
57 |
+
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
|
58 |
+
self.num_stages = num_stages
|
59 |
+
self.feat_channels = feat_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
self.num_things_classes = num_things_classes
|
62 |
+
self.num_stuff_classes = num_stuff_classes
|
63 |
+
self.num_classes = num_classes
|
64 |
+
self.num_queries = num_queries
|
65 |
+
self.temperature = temperature
|
66 |
+
|
67 |
+
self.test_cfg = test_cfg
|
68 |
+
self.train_cfg = train_cfg
|
69 |
+
if train_cfg:
|
70 |
+
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
|
71 |
+
self.sampler = TASK_UTILS.build(
|
72 |
+
self.train_cfg['sampler'], default_args=dict(context=self))
|
73 |
+
self.num_points = self.train_cfg.get('num_points', 12544)
|
74 |
+
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
|
75 |
+
self.importance_sample_ratio = self.train_cfg.get(
|
76 |
+
'importance_sample_ratio', 0.75)
|
77 |
+
|
78 |
+
self.class_weight = loss_cls.class_weight
|
79 |
+
self.loss_cls = MODELS.build(loss_cls)
|
80 |
+
self.loss_mask = MODELS.build(loss_mask)
|
81 |
+
self.loss_dice = MODELS.build(loss_dice)
|
82 |
+
|
83 |
+
self.kernels = nn.Embedding(self.num_queries, self.feat_channels)
|
84 |
+
|
85 |
+
self.mask_heads = nn.ModuleList()
|
86 |
+
for _ in range(self.num_stages):
|
87 |
+
self.mask_heads.append(CrossAttenHead(
|
88 |
+
self.num_classes, self.feat_channels, self.num_queries,
|
89 |
+
use_kernel_updator=use_kernel_updator,
|
90 |
+
sphere_cls=sphere_cls, ov_classifier_name=ov_classifier_name,
|
91 |
+
num_cls_fcs=num_cls_fcs, num_mask_fcs=num_mask_fcs
|
92 |
+
))
|
93 |
+
|
94 |
+
def init_weights(self) -> None:
|
95 |
+
super(AnchorFreeHead, self).init_weights()
|
96 |
+
|
97 |
+
def forward(self, x: List[Tensor],
|
98 |
+
batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
|
99 |
+
all_cls_scores = []
|
100 |
+
all_masks_preds = []
|
101 |
+
proposal_kernels = self.kernels.weight
|
102 |
+
object_kernels = proposal_kernels[None].repeat(x.shape[0], 1, 1)
|
103 |
+
mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, x)
|
104 |
+
|
105 |
+
for stage in range(self.num_stages):
|
106 |
+
mask_head = self.mask_heads[stage]
|
107 |
+
cls_scores, mask_preds, iou_pred, object_kernels = mask_head(x, object_kernels, mask_preds)
|
108 |
+
cls_scores = cls_scores / self.temperature
|
109 |
+
|
110 |
+
all_cls_scores.append(cls_scores)
|
111 |
+
all_masks_preds.append(mask_preds)
|
112 |
+
|
113 |
+
return all_cls_scores, all_masks_preds
|
114 |
+
|
115 |
+
def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> Tuple[Tensor]:
|
116 |
+
batch_img_metas = [
|
117 |
+
data_sample.metainfo for data_sample in batch_data_samples
|
118 |
+
]
|
119 |
+
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
120 |
+
mask_cls_results = all_cls_scores[-1]
|
121 |
+
mask_pred_results = all_mask_preds[-1]
|
122 |
+
|
123 |
+
# upsample masks
|
124 |
+
img_shape = batch_img_metas[0]['batch_input_shape']
|
125 |
+
mask_pred_results = F.interpolate(
|
126 |
+
mask_pred_results,
|
127 |
+
size=(img_shape[0], img_shape[1]),
|
128 |
+
mode='bilinear',
|
129 |
+
align_corners=False)
|
130 |
+
|
131 |
+
return mask_cls_results, mask_pred_results
|
132 |
+
|
133 |
+
|
134 |
+
class FFN(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self,
|
137 |
+
embed_dims=256,
|
138 |
+
feedforward_channels=1024,
|
139 |
+
num_fcs=2,
|
140 |
+
add_identity=True):
|
141 |
+
super(FFN, self).__init__()
|
142 |
+
self.embed_dims = embed_dims
|
143 |
+
self.feedforward_channels = feedforward_channels
|
144 |
+
self.num_fcs = num_fcs
|
145 |
+
|
146 |
+
layers = []
|
147 |
+
in_channels = embed_dims
|
148 |
+
for _ in range(num_fcs - 1):
|
149 |
+
layers.append(nn.Sequential(
|
150 |
+
nn.Linear(in_channels, feedforward_channels),
|
151 |
+
nn.ReLU(True),
|
152 |
+
nn.Dropout(0.0)))
|
153 |
+
in_channels = feedforward_channels
|
154 |
+
layers.append(nn.Linear(feedforward_channels, embed_dims))
|
155 |
+
layers.append(nn.Dropout(0.0))
|
156 |
+
self.layers = nn.Sequential(*layers)
|
157 |
+
self.add_identity = add_identity
|
158 |
+
self.dropout_layer = nn.Dropout(0.0)
|
159 |
+
|
160 |
+
def forward(self, x, identity=None):
|
161 |
+
out = self.layers(x)
|
162 |
+
if not self.add_identity:
|
163 |
+
return self.dropout_layer(out)
|
164 |
+
if identity is None:
|
165 |
+
identity = x
|
166 |
+
return identity + self.dropout_layer(out)
|
167 |
+
|
168 |
+
|
169 |
+
class DySepConvAtten(nn.Module):
|
170 |
+
def __init__(self, hidden_dim, num_proposals, conv_kernel_size_1d):
|
171 |
+
super(DySepConvAtten, self).__init__()
|
172 |
+
self.hidden_dim = hidden_dim
|
173 |
+
self.num_proposals = num_proposals
|
174 |
+
self.kernel_size = conv_kernel_size_1d
|
175 |
+
|
176 |
+
self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size)
|
177 |
+
self.norm = nn.LayerNorm(self.hidden_dim)
|
178 |
+
|
179 |
+
def forward(self, query, value):
|
180 |
+
assert query.shape == value.shape
|
181 |
+
B, N, C = query.shape
|
182 |
+
|
183 |
+
dy_conv_weight = self.weight_linear(query)
|
184 |
+
dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B, self.num_proposals, 1, self.kernel_size)
|
185 |
+
dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B, self.num_proposals, self.num_proposals,
|
186 |
+
1)
|
187 |
+
|
188 |
+
res = []
|
189 |
+
value = value.unsqueeze(1)
|
190 |
+
for i in range(B):
|
191 |
+
out = F.relu(F.conv1d(input=value[i], weight=dy_depth_conv_weight[i], groups=N, padding='same'))
|
192 |
+
out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same')
|
193 |
+
res.append(out)
|
194 |
+
|
195 |
+
point_out = torch.cat(res, dim=0)
|
196 |
+
point_out = self.norm(point_out)
|
197 |
+
return point_out
|
198 |
+
|
199 |
+
|
200 |
+
class KernelUpdator(nn.Module):
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
in_channels=256,
|
204 |
+
feat_channels=64,
|
205 |
+
out_channels=None,
|
206 |
+
input_feat_shape=3,
|
207 |
+
gate_sigmoid=True,
|
208 |
+
gate_norm_act=False,
|
209 |
+
activate_out=False,
|
210 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
211 |
+
norm_cfg=dict(type='LN')):
|
212 |
+
super(KernelUpdator, self).__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
self.feat_channels = feat_channels
|
215 |
+
self.out_channels_raw = out_channels
|
216 |
+
self.gate_sigmoid = gate_sigmoid
|
217 |
+
self.gate_norm_act = gate_norm_act
|
218 |
+
self.activate_out = activate_out
|
219 |
+
if isinstance(input_feat_shape, int):
|
220 |
+
input_feat_shape = [input_feat_shape] * 2
|
221 |
+
self.input_feat_shape = input_feat_shape
|
222 |
+
self.act_cfg = act_cfg
|
223 |
+
self.norm_cfg = norm_cfg
|
224 |
+
self.out_channels = out_channels if out_channels else in_channels
|
225 |
+
|
226 |
+
self.num_params_in = self.feat_channels
|
227 |
+
self.num_params_out = self.feat_channels
|
228 |
+
self.dynamic_layer = nn.Linear(
|
229 |
+
self.in_channels, self.num_params_in + self.num_params_out)
|
230 |
+
self.input_layer = nn.Linear(self.in_channels,
|
231 |
+
self.num_params_in + self.num_params_out,
|
232 |
+
1)
|
233 |
+
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
234 |
+
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
235 |
+
if self.gate_norm_act:
|
236 |
+
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
237 |
+
|
238 |
+
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
239 |
+
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
240 |
+
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
241 |
+
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
242 |
+
|
243 |
+
self.activation = build_activation_layer(act_cfg)
|
244 |
+
|
245 |
+
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
246 |
+
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
247 |
+
|
248 |
+
def forward(self, update_feature, input_feature):
|
249 |
+
"""
|
250 |
+
Args:
|
251 |
+
update_feature (torch.Tensor): [bs, num_proposals, in_channels]
|
252 |
+
input_feature (torch.Tensor): [bs, num_proposals, in_channels]
|
253 |
+
"""
|
254 |
+
bs, num_proposals, _ = update_feature.shape
|
255 |
+
|
256 |
+
parameters = self.dynamic_layer(update_feature)
|
257 |
+
param_in = parameters[..., :self.num_params_in]
|
258 |
+
param_out = parameters[..., -self.num_params_out:]
|
259 |
+
|
260 |
+
input_feats = self.input_layer(input_feature)
|
261 |
+
input_in = input_feats[..., :self.num_params_in]
|
262 |
+
input_out = input_feats[..., -self.num_params_out:]
|
263 |
+
|
264 |
+
gate_feats = input_in * param_in
|
265 |
+
if self.gate_norm_act:
|
266 |
+
gate_feats = self.activation(self.gate_norm(gate_feats))
|
267 |
+
|
268 |
+
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
269 |
+
update_gate = self.norm_in(self.update_gate(gate_feats))
|
270 |
+
if self.gate_sigmoid:
|
271 |
+
input_gate = input_gate.sigmoid()
|
272 |
+
update_gate = update_gate.sigmoid()
|
273 |
+
param_out = self.norm_out(param_out)
|
274 |
+
input_out = self.input_norm_out(input_out)
|
275 |
+
|
276 |
+
if self.activate_out:
|
277 |
+
param_out = self.activation(param_out)
|
278 |
+
input_out = self.activation(input_out)
|
279 |
+
|
280 |
+
# param_out has shape (batch_size, feat_channels, out_channels)
|
281 |
+
features = update_gate * param_out + input_gate * input_out
|
282 |
+
|
283 |
+
features = self.fc_layer(features)
|
284 |
+
features = self.fc_norm(features)
|
285 |
+
features = self.activation(features)
|
286 |
+
|
287 |
+
return features
|
288 |
+
|
289 |
+
|
290 |
+
class CrossAttenHead(nn.Module):
|
291 |
+
|
292 |
+
def __init__(self,
|
293 |
+
num_classes,
|
294 |
+
in_channels,
|
295 |
+
num_proposals,
|
296 |
+
frozen_head=False,
|
297 |
+
frozen_pred=False,
|
298 |
+
with_iou_pred=False,
|
299 |
+
sphere_cls=False,
|
300 |
+
ov_classifier_name=None,
|
301 |
+
num_cls_fcs=1,
|
302 |
+
num_mask_fcs=1,
|
303 |
+
conv_kernel_size_1d=3,
|
304 |
+
conv_kernel_size_2d=1,
|
305 |
+
use_kernel_updator=False):
|
306 |
+
super(CrossAttenHead, self).__init__()
|
307 |
+
self.sphere_cls = sphere_cls
|
308 |
+
self.with_iou_pred = with_iou_pred
|
309 |
+
self.frozen_head = frozen_head
|
310 |
+
self.frozen_pred = frozen_pred
|
311 |
+
self.num_cls_fcs = num_cls_fcs
|
312 |
+
self.num_mask_fcs = num_mask_fcs
|
313 |
+
self.num_classes = num_classes
|
314 |
+
self.conv_kernel_size_2d = conv_kernel_size_2d
|
315 |
+
|
316 |
+
self.hidden_dim = in_channels
|
317 |
+
self.feat_channels = in_channels
|
318 |
+
self.num_proposals = num_proposals
|
319 |
+
self.hard_mask_thr = 0.5
|
320 |
+
self.use_kernel_updator = use_kernel_updator
|
321 |
+
# assert use_kernel_updator
|
322 |
+
if use_kernel_updator:
|
323 |
+
self.kernel_update = KernelUpdator(
|
324 |
+
in_channels=256,
|
325 |
+
feat_channels=256,
|
326 |
+
out_channels=256,
|
327 |
+
input_feat_shape=3,
|
328 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
329 |
+
norm_cfg=dict(type='LN')
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
self.f_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d)
|
333 |
+
self.f_dropout = nn.Dropout(0.0)
|
334 |
+
self.f_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
|
335 |
+
self.k_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d)
|
336 |
+
self.k_dropout = nn.Dropout(0.0)
|
337 |
+
self.k_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
|
338 |
+
|
339 |
+
self.s_atten = nn.MultiheadAttention(embed_dim=self.hidden_dim *
|
340 |
+
self.conv_kernel_size_2d ** 2,
|
341 |
+
num_heads=8,
|
342 |
+
dropout=0.0)
|
343 |
+
self.s_dropout = nn.Dropout(0.0)
|
344 |
+
self.s_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2)
|
345 |
+
|
346 |
+
self.ffn = FFN(self.hidden_dim, feedforward_channels=2048, num_fcs=2)
|
347 |
+
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
348 |
+
|
349 |
+
self.cls_fcs = nn.ModuleList()
|
350 |
+
for _ in range(self.num_cls_fcs):
|
351 |
+
self.cls_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False))
|
352 |
+
self.cls_fcs.append(nn.LayerNorm(self.hidden_dim))
|
353 |
+
self.cls_fcs.append(nn.ReLU(True))
|
354 |
+
|
355 |
+
if sphere_cls:
|
356 |
+
rank, world_size = get_dist_info()
|
357 |
+
if ov_classifier_name is None:
|
358 |
+
_dim = 1024 # temporally hard code
|
359 |
+
cls_embed = torch.empty(self.num_classes, _dim)
|
360 |
+
torch.nn.init.orthogonal_(cls_embed)
|
361 |
+
cls_embed = cls_embed[:, None]
|
362 |
+
else:
|
363 |
+
ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth")
|
364 |
+
cls_embed = torch.load(ov_path)
|
365 |
+
cls_embed_norm = cls_embed.norm(p=2, dim=-1)
|
366 |
+
assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm))
|
367 |
+
|
368 |
+
# background class
|
369 |
+
_dim = cls_embed.size(2)
|
370 |
+
_prototypes = cls_embed.size(1)
|
371 |
+
if rank == 0:
|
372 |
+
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
373 |
+
else:
|
374 |
+
back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
|
375 |
+
if world_size > 1:
|
376 |
+
dist.broadcast(back_token, src=0)
|
377 |
+
back_token = back_token.to(device='cpu')
|
378 |
+
cls_embed = torch.cat([
|
379 |
+
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
380 |
+
], dim=0)
|
381 |
+
self.register_buffer('fc_cls', cls_embed.permute(2, 0, 1).contiguous(), persistent=False)
|
382 |
+
|
383 |
+
# cls embd proj
|
384 |
+
cls_embed_dim = self.fc_cls.size(0)
|
385 |
+
self.cls_proj = nn.Sequential(
|
386 |
+
nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True),
|
387 |
+
nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True),
|
388 |
+
nn.Linear(self.hidden_dim, cls_embed_dim)
|
389 |
+
)
|
390 |
+
|
391 |
+
logit_scale = torch.tensor(4.6052, dtype=torch.float32)
|
392 |
+
self.register_buffer('logit_scale', logit_scale, persistent=False)
|
393 |
+
else:
|
394 |
+
self.fc_cls = nn.Linear(self.hidden_dim, self.num_classes + 1)
|
395 |
+
|
396 |
+
self.mask_fcs = nn.ModuleList()
|
397 |
+
for _ in range(self.num_mask_fcs):
|
398 |
+
self.mask_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False))
|
399 |
+
self.mask_fcs.append(nn.LayerNorm(self.hidden_dim))
|
400 |
+
self.mask_fcs.append(nn.ReLU(True))
|
401 |
+
self.fc_mask = nn.Linear(self.hidden_dim, self.hidden_dim)
|
402 |
+
|
403 |
+
if self.with_iou_pred:
|
404 |
+
self.iou_embed = nn.Sequential(
|
405 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
406 |
+
nn.ReLU(inplace=True),
|
407 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
408 |
+
nn.ReLU(inplace=True),
|
409 |
+
nn.Linear(self.hidden_dim, 1),
|
410 |
+
)
|
411 |
+
prior_prob = 0.01
|
412 |
+
self.bias_value = -math.log((1 - prior_prob) / prior_prob)
|
413 |
+
|
414 |
+
self.apply(self._init_weights)
|
415 |
+
if not sphere_cls:
|
416 |
+
nn.init.constant_(self.fc_cls.bias, self.bias_value)
|
417 |
+
|
418 |
+
if self.frozen_head:
|
419 |
+
self._frozen_head()
|
420 |
+
if self.frozen_pred:
|
421 |
+
self._frozen_pred()
|
422 |
+
|
423 |
+
def _init_weights(self, m):
|
424 |
+
# print("init weights")
|
425 |
+
if isinstance(m, nn.Linear):
|
426 |
+
trunc_normal_(m.weight, std=.02)
|
427 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
428 |
+
nn.init.constant_(m.bias, 0)
|
429 |
+
elif isinstance(m, nn.LayerNorm):
|
430 |
+
nn.init.constant_(m.bias, 0)
|
431 |
+
nn.init.constant_(m.weight, 1.0)
|
432 |
+
|
433 |
+
def _frozen_head(self):
|
434 |
+
for n, p in self.kernel_update.named_parameters():
|
435 |
+
p.requires_grad = False
|
436 |
+
for n, p in self.s_atten.named_parameters():
|
437 |
+
p.requires_grad = False
|
438 |
+
for n, p in self.s_dropout.named_parameters():
|
439 |
+
p.requires_grad = False
|
440 |
+
for n, p in self.s_atten_norm.named_parameters():
|
441 |
+
p.requires_grad = False
|
442 |
+
for n, p in self.ffn.named_parameters():
|
443 |
+
p.requires_grad = False
|
444 |
+
for n, p in self.ffn_norm.named_parameters():
|
445 |
+
p.requires_grad = False
|
446 |
+
|
447 |
+
def _frozen_pred(self):
|
448 |
+
# frozen cls_fcs, fc_cls, mask_fcs, fc_mask
|
449 |
+
for n, p in self.cls_fcs.named_parameters():
|
450 |
+
p.requires_grad = False
|
451 |
+
for n, p in self.fc_cls.named_parameters():
|
452 |
+
p.requires_grad = False
|
453 |
+
for n, p in self.mask_fcs.named_parameters():
|
454 |
+
p.requires_grad = False
|
455 |
+
for n, p in self.fc_mask.named_parameters():
|
456 |
+
p.requires_grad = False
|
457 |
+
|
458 |
+
def train(self, mode):
|
459 |
+
super().train(mode)
|
460 |
+
if self.frozen_head:
|
461 |
+
self.kernel_update.eval()
|
462 |
+
self.s_atten.eval()
|
463 |
+
self.s_dropout.eval()
|
464 |
+
self.s_atten_norm.eval()
|
465 |
+
self.ffn.eval()
|
466 |
+
self.ffn_norm.eval()
|
467 |
+
if self.frozen_pred:
|
468 |
+
self.cls_fcs.eval()
|
469 |
+
self.fc_cls.eval()
|
470 |
+
self.mask_fcs.eval()
|
471 |
+
self.fc_mask.eval()
|
472 |
+
|
473 |
+
def forward(self, features, proposal_kernels, mask_preds, self_attn_mask=None):
|
474 |
+
B, C, H, W = features.shape
|
475 |
+
|
476 |
+
soft_sigmoid_masks = mask_preds.sigmoid()
|
477 |
+
nonzero_inds = soft_sigmoid_masks > self.hard_mask_thr
|
478 |
+
hard_sigmoid_masks = nonzero_inds.float()
|
479 |
+
|
480 |
+
# [B, N, C]
|
481 |
+
f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, features)
|
482 |
+
# [B, N, C, K, K] -> [B, N, C * K * K]
|
483 |
+
num_proposals = proposal_kernels.shape[1]
|
484 |
+
k = proposal_kernels.view(B, num_proposals, -1)
|
485 |
+
|
486 |
+
# ----
|
487 |
+
if self.use_kernel_updator:
|
488 |
+
k = self.kernel_update(f, k)
|
489 |
+
else:
|
490 |
+
f_tmp = self.f_atten(k, f)
|
491 |
+
f = f + self.f_dropout(f_tmp)
|
492 |
+
f = self.f_atten_norm(f)
|
493 |
+
|
494 |
+
f_tmp = self.k_atten(k, f)
|
495 |
+
f = f + self.k_dropout(f_tmp)
|
496 |
+
k = self.k_atten_norm(f)
|
497 |
+
|
498 |
+
# [N, B, C]
|
499 |
+
k = k.permute(1, 0, 2)
|
500 |
+
|
501 |
+
k_tmp = self.s_atten(query=k, key=k, value=k, attn_mask=self_attn_mask)[0]
|
502 |
+
k = k + self.s_dropout(k_tmp)
|
503 |
+
k = self.s_atten_norm(k.permute(1, 0, 2))
|
504 |
+
|
505 |
+
obj_feat = self.ffn_norm(self.ffn(k))
|
506 |
+
|
507 |
+
cls_feat = obj_feat
|
508 |
+
mask_feat = obj_feat
|
509 |
+
|
510 |
+
for cls_layer in self.cls_fcs:
|
511 |
+
cls_feat = cls_layer(cls_feat)
|
512 |
+
|
513 |
+
if self.sphere_cls:
|
514 |
+
cls_embd = self.cls_proj(cls_feat) # FIXME Too much cls linear (cls_fcs + cls_proj)
|
515 |
+
cls_score = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.fc_cls)
|
516 |
+
cls_score = cls_score.max(-1).values
|
517 |
+
cls_score = self.logit_scale.exp() * cls_score
|
518 |
+
else:
|
519 |
+
cls_score = self.fc_cls(cls_feat)
|
520 |
+
for reg_layer in self.mask_fcs:
|
521 |
+
mask_feat = reg_layer(mask_feat)
|
522 |
+
# [B, N, K * K, C] -> [B, N, C]
|
523 |
+
mask_kernels = self.fc_mask(mask_feat)
|
524 |
+
|
525 |
+
new_mask_preds = torch.einsum("bqc,bchw->bqhw", mask_kernels, features)
|
526 |
+
if self.with_iou_pred:
|
527 |
+
iou_pred = self.iou_embed(mask_feat)
|
528 |
+
iou_pred = iou_pred
|
529 |
+
else:
|
530 |
+
iou_pred = None
|
531 |
+
return cls_score, new_mask_preds, iou_pred, obj_feat
|
app/models/necks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ramsam_neck import YOSONeck
|
app/models/necks/ramsam_neck.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from mmengine.model import kaiming_init
|
6 |
+
from mmdet.registry import MODELS
|
7 |
+
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
|
8 |
+
|
9 |
+
|
10 |
+
class DeformLayer(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
in_planes,
|
14 |
+
out_planes,
|
15 |
+
deconv_kernel=4,
|
16 |
+
deconv_stride=2,
|
17 |
+
deconv_pad=1,
|
18 |
+
deconv_out_pad=0,
|
19 |
+
modulate_deform=True,
|
20 |
+
num_groups=1,
|
21 |
+
deform_num_groups=1,
|
22 |
+
dilation=1):
|
23 |
+
super(DeformLayer, self).__init__()
|
24 |
+
self.deform_modulated = modulate_deform
|
25 |
+
if modulate_deform:
|
26 |
+
deform_conv_op = ModulatedDeformConv2d
|
27 |
+
offset_channels = 27
|
28 |
+
else:
|
29 |
+
deform_conv_op = DeformConv2d
|
30 |
+
offset_channels = 18
|
31 |
+
|
32 |
+
self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1 * dilation, dilation=dilation)
|
33 |
+
self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1 * dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups)
|
34 |
+
for layer in [self.dcn]:
|
35 |
+
kaiming_init(layer)
|
36 |
+
|
37 |
+
nn.init.constant_(self.dcn_offset.weight, 0)
|
38 |
+
nn.init.constant_(self.dcn_offset.bias, 0)
|
39 |
+
|
40 |
+
# nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
|
41 |
+
self.dcn_bn = nn.SyncBatchNorm(out_planes)
|
42 |
+
self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False)
|
43 |
+
self._deconv_init()
|
44 |
+
# nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
|
45 |
+
self.up_bn = nn.SyncBatchNorm(out_planes)
|
46 |
+
self.relu = nn.ReLU()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = x
|
50 |
+
if self.deform_modulated:
|
51 |
+
offset_mask = self.dcn_offset(out)
|
52 |
+
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
|
53 |
+
offset = torch.cat((offset_x, offset_y), dim=1)
|
54 |
+
mask = mask.sigmoid()
|
55 |
+
out = self.dcn(out, offset, mask)
|
56 |
+
else:
|
57 |
+
offset = self.dcn_offset(out)
|
58 |
+
out = self.dcn(out, offset)
|
59 |
+
x = out
|
60 |
+
|
61 |
+
x = self.dcn_bn(x)
|
62 |
+
x = self.relu(x)
|
63 |
+
x = self.up_sample(x)
|
64 |
+
x = self.up_bn(x)
|
65 |
+
x = self.relu(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def _deconv_init(self):
|
69 |
+
w = self.up_sample.weight.data
|
70 |
+
f = math.ceil(w.size(2) / 2)
|
71 |
+
c = (2 * f - 1 - f % 2) / (2. * f)
|
72 |
+
for i in range(w.size(2)):
|
73 |
+
for j in range(w.size(3)):
|
74 |
+
w[0, 0, i, j] = \
|
75 |
+
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
|
76 |
+
for c in range(1, w.size(0)):
|
77 |
+
w[c, 0, :, :] = w[0, 0, :, :]
|
78 |
+
|
79 |
+
class LiteDeformConv(nn.Module):
|
80 |
+
def __init__(self, agg_dim, backbone_shape):
|
81 |
+
super(LiteDeformConv, self).__init__()
|
82 |
+
in_channels = []
|
83 |
+
out_channels = [agg_dim]
|
84 |
+
for feat in backbone_shape:
|
85 |
+
in_channels.append(feat)
|
86 |
+
out_channels.append(feat//2)
|
87 |
+
|
88 |
+
self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0)
|
89 |
+
|
90 |
+
self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2])
|
91 |
+
|
92 |
+
self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0)
|
93 |
+
|
94 |
+
self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3])
|
95 |
+
|
96 |
+
self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0)
|
97 |
+
|
98 |
+
self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4])
|
99 |
+
|
100 |
+
self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
# self.fuse_conv = nn.Conv2d(in_channels=sum(out_channels[1:]), out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
|
103 |
+
self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
|
104 |
+
|
105 |
+
self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True)
|
106 |
+
self.bias.data.fill_(0.0)
|
107 |
+
|
108 |
+
self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
|
109 |
+
self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
|
110 |
+
self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
|
111 |
+
self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
|
112 |
+
|
113 |
+
def forward(self, features_list):
|
114 |
+
p5 = self.lateral_conv0(features_list[-1])
|
115 |
+
x5 = p5
|
116 |
+
x = self.deform_conv1(x5)
|
117 |
+
|
118 |
+
p4 = self.lateral_conv1(features_list[-2])
|
119 |
+
x4 = p4 + x
|
120 |
+
x = self.deform_conv2(x4)
|
121 |
+
|
122 |
+
p3 = self.lateral_conv2(features_list[-3])
|
123 |
+
x3 = p3 + x
|
124 |
+
x = self.deform_conv3(x3)
|
125 |
+
|
126 |
+
p2 = self.lateral_conv3(features_list[-4])
|
127 |
+
x2 = p2 + x
|
128 |
+
|
129 |
+
# CFA
|
130 |
+
x5 = self.conv_a5(x5)
|
131 |
+
x4 = self.conv_a4(x4)
|
132 |
+
x3 = self.conv_a3(x3)
|
133 |
+
|
134 |
+
_x5 = F.interpolate(x5, scale_factor=8, align_corners=False, mode='bilinear')
|
135 |
+
_x4 = F.interpolate(x4, scale_factor=4, align_corners=False, mode='bilinear')
|
136 |
+
_x3 = F.interpolate(x3, scale_factor=2, align_corners=False, mode='bilinear')
|
137 |
+
x2 = self.conv_a2(x2)
|
138 |
+
x = _x5 + _x4 + _x3 + x2 + self.bias
|
139 |
+
|
140 |
+
x = self.output_conv(x)
|
141 |
+
|
142 |
+
return x, (x5, x4, x3)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
@MODELS.register_module()
|
147 |
+
class YOSONeck(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self,
|
150 |
+
agg_dim,
|
151 |
+
hidden_dim,
|
152 |
+
backbone_shape,
|
153 |
+
return_multi_scale=False,
|
154 |
+
return_single_scale=False,
|
155 |
+
#Just for compatibility with Mask2Former, not actually used
|
156 |
+
in_channels=None,
|
157 |
+
feat_channels=None,
|
158 |
+
out_channels=None
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
# in_channels == backbone_shape
|
162 |
+
# hidden_dim == feat_channels == out_channels == 256
|
163 |
+
self.return_single_scale = return_single_scale
|
164 |
+
self.return_multi_scale = return_multi_scale
|
165 |
+
self.deconv = LiteDeformConv(agg_dim=agg_dim, backbone_shape=backbone_shape)
|
166 |
+
|
167 |
+
self.loc_conv = nn.Conv2d(in_channels=agg_dim + 2, out_channels=hidden_dim, kernel_size=1, stride=1)
|
168 |
+
self.init_weights()
|
169 |
+
|
170 |
+
def init_weights(self) -> None:
|
171 |
+
for p in self.parameters():
|
172 |
+
if p.dim() > 1:
|
173 |
+
nn.init.xavier_uniform_(p)
|
174 |
+
|
175 |
+
def generate_coord(self, input_feat):
|
176 |
+
x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device)
|
177 |
+
y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device)
|
178 |
+
y, x = torch.meshgrid(y_range, x_range)
|
179 |
+
y = y.expand([input_feat.shape[0], 1, -1, -1])
|
180 |
+
x = x.expand([input_feat.shape[0], 1, -1, -1])
|
181 |
+
coord_feat = torch.cat([x, y], 1)
|
182 |
+
return coord_feat
|
183 |
+
|
184 |
+
def forward(self,
|
185 |
+
features_list,
|
186 |
+
batch_img_metas = None,
|
187 |
+
num_frames = None):
|
188 |
+
features, multi_scale = self.deconv(features_list)
|
189 |
+
coord_feat = self.generate_coord(features)
|
190 |
+
features = torch.cat([features, coord_feat], 1)
|
191 |
+
features = self.loc_conv(features)
|
192 |
+
if self.return_single_scale: # maskformer
|
193 |
+
return features, multi_scale[0]
|
194 |
+
if self.return_multi_scale: # mask2former
|
195 |
+
return features, multi_scale
|
196 |
+
return features
|
app/models/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .video_gt_preprocess import preprocess_video_panoptic_gt
|
2 |
+
from .mask_pool import mask_pool
|
3 |
+
from .no_obj import NO_OBJ
|
app/models/utils/load_checkpoint.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmengine.runner.checkpoint import CheckpointLoader
|
2 |
+
|
3 |
+
|
4 |
+
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
|
5 |
+
"""Load partial pretrained model with specific prefix.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
prefix (str): The prefix of sub-module.
|
9 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
10 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
11 |
+
details.
|
12 |
+
map_location (str | None): Same as :func:`torch.load`.
|
13 |
+
Defaults to None.
|
14 |
+
logger: logger
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
dict or OrderedDict: The loaded checkpoint.
|
18 |
+
"""
|
19 |
+
|
20 |
+
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger)
|
21 |
+
|
22 |
+
if 'state_dict' in checkpoint:
|
23 |
+
state_dict = checkpoint['state_dict']
|
24 |
+
else:
|
25 |
+
state_dict = checkpoint
|
26 |
+
if not prefix:
|
27 |
+
return state_dict
|
28 |
+
if not prefix.endswith('.'):
|
29 |
+
prefix += '.'
|
30 |
+
prefix_len = len(prefix)
|
31 |
+
|
32 |
+
state_dict = {
|
33 |
+
k[prefix_len:]: v
|
34 |
+
for k, v in state_dict.items() if k.startswith(prefix)
|
35 |
+
}
|
36 |
+
|
37 |
+
assert state_dict, f'{prefix} is not in the pretrained model'
|
38 |
+
return state_dict
|
app/models/utils/mask_pool.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
# https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923
|
6 |
+
|
7 |
+
def mask_pool(x, mask):
|
8 |
+
"""
|
9 |
+
Args:
|
10 |
+
x: [B, C, H, W]
|
11 |
+
mask: [B, Q, H, W]
|
12 |
+
"""
|
13 |
+
if not x.shape[-2:] == mask.shape[-2:]:
|
14 |
+
# reshape mask to x
|
15 |
+
mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
16 |
+
with torch.no_grad():
|
17 |
+
mask = mask.detach()
|
18 |
+
mask = (mask > 0).to(mask.dtype)
|
19 |
+
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
|
20 |
+
|
21 |
+
mask_pooled_x = torch.einsum(
|
22 |
+
"bchw,bqhw->bqc",
|
23 |
+
x,
|
24 |
+
mask / denorm,
|
25 |
+
)
|
26 |
+
return mask_pooled_x
|
27 |
+
|
app/models/utils/no_obj.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
NO_OBJ = 65535
|
app/models/utils/video_gt_preprocess.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def preprocess_video_panoptic_gt(
|
5 |
+
gt_labels,
|
6 |
+
gt_masks,
|
7 |
+
gt_semantic_seg,
|
8 |
+
gt_instance_ids,
|
9 |
+
num_things,
|
10 |
+
num_stuff,
|
11 |
+
):
|
12 |
+
num_classes = num_things + num_stuff
|
13 |
+
num_frames = len(gt_masks)
|
14 |
+
mask_size = gt_masks[0].masks.shape[-2:]
|
15 |
+
|
16 |
+
thing_masks_list = []
|
17 |
+
for frame_id in range(num_frames):
|
18 |
+
thing_masks_list.append(gt_masks[frame_id].pad(
|
19 |
+
mask_size, pad_val=0).to_tensor(
|
20 |
+
dtype=torch.bool, device=gt_labels.device)
|
21 |
+
)
|
22 |
+
instances = torch.unique(gt_instance_ids[:, 1])
|
23 |
+
things_masks = []
|
24 |
+
labels = []
|
25 |
+
for instance in instances:
|
26 |
+
pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple
|
27 |
+
labels_instance = gt_labels[:, 1][pos_ins]
|
28 |
+
assert torch.allclose(labels_instance, labels_instance[0])
|
29 |
+
labels.append(labels_instance[0])
|
30 |
+
instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist()
|
31 |
+
instance_masks = []
|
32 |
+
for frame_id in range(num_frames):
|
33 |
+
frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1]
|
34 |
+
if frame_id not in instance_frame_ids:
|
35 |
+
empty_mask = torch.zeros(
|
36 |
+
mask_size,
|
37 |
+
dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device
|
38 |
+
)
|
39 |
+
instance_masks.append(empty_mask)
|
40 |
+
else:
|
41 |
+
pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item()
|
42 |
+
frame_mask = thing_masks_list[frame_id][pos_inner_frame]
|
43 |
+
instance_masks.append(frame_mask)
|
44 |
+
things_masks.append(torch.stack(instance_masks))
|
45 |
+
|
46 |
+
if len(instances) == 0:
|
47 |
+
things_masks = torch.stack(thing_masks_list, dim=1)
|
48 |
+
labels = torch.empty_like(instances)
|
49 |
+
else:
|
50 |
+
things_masks = torch.stack(things_masks)
|
51 |
+
labels = torch.stack(labels)
|
52 |
+
assert torch.all(torch.less(labels, num_things))
|
53 |
+
|
54 |
+
if gt_semantic_seg is not None:
|
55 |
+
things_labels = labels
|
56 |
+
gt_semantic_seg = gt_semantic_seg.squeeze(1)
|
57 |
+
|
58 |
+
semantic_labels = torch.unique(
|
59 |
+
gt_semantic_seg,
|
60 |
+
sorted=False,
|
61 |
+
return_inverse=False,
|
62 |
+
return_counts=False)
|
63 |
+
stuff_masks_list = []
|
64 |
+
stuff_labels_list = []
|
65 |
+
for label in semantic_labels:
|
66 |
+
if label < num_things or label >= num_classes:
|
67 |
+
continue
|
68 |
+
stuff_mask = gt_semantic_seg == label
|
69 |
+
stuff_masks_list.append(stuff_mask)
|
70 |
+
stuff_labels_list.append(label)
|
71 |
+
|
72 |
+
if len(stuff_masks_list) > 0:
|
73 |
+
stuff_masks = torch.stack(stuff_masks_list, dim=0)
|
74 |
+
stuff_labels = torch.stack(stuff_labels_list, dim=0)
|
75 |
+
assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes))
|
76 |
+
labels = torch.cat([things_labels, stuff_labels], dim=0)
|
77 |
+
masks = torch.cat([things_masks, stuff_masks], dim=0)
|
78 |
+
else:
|
79 |
+
labels = things_labels
|
80 |
+
masks = things_masks
|
81 |
+
assert len(labels) == len(masks)
|
82 |
+
else:
|
83 |
+
masks = things_masks
|
84 |
+
|
85 |
+
labels = labels.to(dtype=torch.long)
|
86 |
+
masks = masks.to(dtype=torch.long)
|
87 |
+
return labels, masks
|
ext/meta/sam_meta.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
meta_dict = {
|
2 |
+
'vit_h': dict(
|
3 |
+
encoder_embed_dim=1280,
|
4 |
+
encoder_depth=32,
|
5 |
+
encoder_num_heads=16,
|
6 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
7 |
+
# common
|
8 |
+
prompt_embed_dim=256,
|
9 |
+
image_size=1024,
|
10 |
+
vit_patch_size=16,
|
11 |
+
image_embedding_size=64
|
12 |
+
),
|
13 |
+
'vit_l': dict(
|
14 |
+
encoder_embed_dim=1024,
|
15 |
+
encoder_depth=24,
|
16 |
+
encoder_num_heads=16,
|
17 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
18 |
+
# common
|
19 |
+
prompt_embed_dim=256,
|
20 |
+
image_size=1024,
|
21 |
+
vit_patch_size=16,
|
22 |
+
image_embedding_size=64
|
23 |
+
),
|
24 |
+
'vit_b': dict(
|
25 |
+
encoder_embed_dim=768,
|
26 |
+
encoder_depth=12,
|
27 |
+
encoder_num_heads=12,
|
28 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
29 |
+
# common
|
30 |
+
prompt_embed_dim=256,
|
31 |
+
image_size=1024,
|
32 |
+
vit_patch_size=16,
|
33 |
+
image_embedding_size=64
|
34 |
+
)
|
35 |
+
}
|
36 |
+
|
37 |
+
checkpoint_dict = {
|
38 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
39 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
40 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
41 |
+
}
|
ext/open_clip/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .coca_model import CoCa
|
2 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
3 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
4 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
5 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
6 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
7 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
|
8 |
+
from .openai import load_openai_model, list_openai_models
|
9 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
10 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
11 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
12 |
+
from .tokenizer import SimpleTokenizer, tokenize, decode
|
13 |
+
from .transform import image_transform, AugmentationCfg
|
14 |
+
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
|
15 |
+
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
|
ext/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
ext/open_clip/coca_model.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from .transformer import (
|
10 |
+
LayerNormFp32,
|
11 |
+
LayerNorm,
|
12 |
+
QuickGELU,
|
13 |
+
MultimodalTransformer,
|
14 |
+
)
|
15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
16 |
+
|
17 |
+
try:
|
18 |
+
from transformers import (
|
19 |
+
BeamSearchScorer,
|
20 |
+
LogitsProcessorList,
|
21 |
+
TopPLogitsWarper,
|
22 |
+
TopKLogitsWarper,
|
23 |
+
RepetitionPenaltyLogitsProcessor,
|
24 |
+
MinLengthLogitsProcessor,
|
25 |
+
MaxLengthCriteria,
|
26 |
+
StoppingCriteriaList
|
27 |
+
)
|
28 |
+
|
29 |
+
GENERATION_TYPES = {
|
30 |
+
"top_k": TopKLogitsWarper,
|
31 |
+
"top_p": TopPLogitsWarper,
|
32 |
+
"beam_search": "beam_search"
|
33 |
+
}
|
34 |
+
_has_transformers = True
|
35 |
+
except ImportError as e:
|
36 |
+
GENERATION_TYPES = {
|
37 |
+
"top_k": None,
|
38 |
+
"top_p": None,
|
39 |
+
"beam_search": "beam_search"
|
40 |
+
}
|
41 |
+
_has_transformers = False
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class MultimodalCfg(CLIPTextCfg):
|
46 |
+
mlp_ratio: int = 4
|
47 |
+
dim_head: int = 64
|
48 |
+
heads: int = 8
|
49 |
+
n_queries: int = 256
|
50 |
+
attn_pooler_heads: int = 8
|
51 |
+
|
52 |
+
|
53 |
+
def _build_text_decoder_tower(
|
54 |
+
embed_dim,
|
55 |
+
multimodal_cfg,
|
56 |
+
quick_gelu: bool = False,
|
57 |
+
cast_dtype: Optional[torch.dtype] = None,
|
58 |
+
):
|
59 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
60 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
61 |
+
norm_layer = (
|
62 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
63 |
+
)
|
64 |
+
|
65 |
+
decoder = MultimodalTransformer(
|
66 |
+
context_length=multimodal_cfg.context_length,
|
67 |
+
width=multimodal_cfg.width,
|
68 |
+
heads=multimodal_cfg.heads,
|
69 |
+
layers=multimodal_cfg.layers,
|
70 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
71 |
+
output_dim=embed_dim,
|
72 |
+
act_layer=act_layer,
|
73 |
+
norm_layer=norm_layer,
|
74 |
+
)
|
75 |
+
|
76 |
+
return decoder
|
77 |
+
|
78 |
+
|
79 |
+
class CoCa(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
embed_dim,
|
83 |
+
multimodal_cfg: MultimodalCfg,
|
84 |
+
text_cfg: CLIPTextCfg,
|
85 |
+
vision_cfg: CLIPVisionCfg,
|
86 |
+
quick_gelu: bool = False,
|
87 |
+
cast_dtype: Optional[torch.dtype] = None,
|
88 |
+
pad_id: int = 0,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
92 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
94 |
+
|
95 |
+
self.text = _build_text_tower(
|
96 |
+
embed_dim=embed_dim,
|
97 |
+
text_cfg=text_cfg,
|
98 |
+
quick_gelu=quick_gelu,
|
99 |
+
cast_dtype=cast_dtype,
|
100 |
+
)
|
101 |
+
|
102 |
+
vocab_size = (
|
103 |
+
text_cfg.vocab_size # for hf models
|
104 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
105 |
+
else text_cfg.vocab_size
|
106 |
+
)
|
107 |
+
|
108 |
+
self.visual = _build_vision_tower(
|
109 |
+
embed_dim=embed_dim,
|
110 |
+
vision_cfg=vision_cfg,
|
111 |
+
quick_gelu=quick_gelu,
|
112 |
+
cast_dtype=cast_dtype,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.text_decoder = _build_text_decoder_tower(
|
116 |
+
vocab_size,
|
117 |
+
multimodal_cfg=multimodal_cfg,
|
118 |
+
quick_gelu=quick_gelu,
|
119 |
+
cast_dtype=cast_dtype,
|
120 |
+
)
|
121 |
+
|
122 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
123 |
+
self.pad_id = pad_id
|
124 |
+
|
125 |
+
@torch.jit.ignore
|
126 |
+
def set_grad_checkpointing(self, enable=True):
|
127 |
+
self.visual.set_grad_checkpointing(enable)
|
128 |
+
self.text.set_grad_checkpointing(enable)
|
129 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
130 |
+
|
131 |
+
def _encode_image(self, images, normalize=True):
|
132 |
+
image_latent, tokens_embs = self.visual(images)
|
133 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
134 |
+
return image_latent, tokens_embs
|
135 |
+
|
136 |
+
def _encode_text(self, text, normalize=True, embed_cls=True):
|
137 |
+
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
138 |
+
text_latent, token_emb = self.text(text)
|
139 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
140 |
+
return text_latent, token_emb
|
141 |
+
|
142 |
+
def encode_image(self, images, normalize=True):
|
143 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
144 |
+
return image_latent
|
145 |
+
|
146 |
+
def encode_text(self, text, normalize=True, embed_cls=True):
|
147 |
+
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
148 |
+
return text_latent
|
149 |
+
|
150 |
+
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
151 |
+
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
152 |
+
if image_latent is None or image_embs is None:
|
153 |
+
image_latent, image_embs = self._encode_image(image)
|
154 |
+
|
155 |
+
# TODO: add assertion to avoid bugs?
|
156 |
+
labels = text[:, -token_embs.shape[1]:]
|
157 |
+
|
158 |
+
logits = self.text_decoder(image_embs, token_embs)
|
159 |
+
return {
|
160 |
+
"image_features": image_latent,
|
161 |
+
"text_features": text_latent,
|
162 |
+
"logits": logits,
|
163 |
+
"labels": labels,
|
164 |
+
"logit_scale": self.logit_scale.exp()
|
165 |
+
}
|
166 |
+
|
167 |
+
def generate(
|
168 |
+
self,
|
169 |
+
image,
|
170 |
+
text=None,
|
171 |
+
seq_len=30,
|
172 |
+
max_seq_len=77,
|
173 |
+
temperature=1.,
|
174 |
+
generation_type="beam_search",
|
175 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
176 |
+
top_k=1, # keeps the top_k most probable tokens
|
177 |
+
pad_token_id=None,
|
178 |
+
eos_token_id=None,
|
179 |
+
sot_token_id=None,
|
180 |
+
num_beams=6,
|
181 |
+
num_beam_groups=3,
|
182 |
+
min_seq_len=5,
|
183 |
+
stopping_criteria=None,
|
184 |
+
repetition_penalty=1.0,
|
185 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
186 |
+
):
|
187 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
188 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
189 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
190 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
194 |
+
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
195 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
196 |
+
logit_processor = LogitsProcessorList(
|
197 |
+
[
|
198 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
199 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
if stopping_criteria is None:
|
204 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
205 |
+
|
206 |
+
stopping_criteria = StoppingCriteriaList(
|
207 |
+
stopping_criteria
|
208 |
+
)
|
209 |
+
|
210 |
+
device = image.device
|
211 |
+
|
212 |
+
if generation_type == "beam_search":
|
213 |
+
output = self._generate_beamsearch(
|
214 |
+
image_inputs = image,
|
215 |
+
pad_token_id=pad_token_id,
|
216 |
+
eos_token_id=eos_token_id,
|
217 |
+
sot_token_id=sot_token_id,
|
218 |
+
num_beams=num_beams,
|
219 |
+
num_beam_groups=num_beam_groups,
|
220 |
+
min_seq_len=min_seq_len,
|
221 |
+
stopping_criteria=stopping_criteria,
|
222 |
+
logit_processor=logit_processor,
|
223 |
+
)
|
224 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
225 |
+
return torch.cat(
|
226 |
+
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
227 |
+
dim=1
|
228 |
+
)
|
229 |
+
return output
|
230 |
+
|
231 |
+
elif generation_type == "top_p":
|
232 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
233 |
+
elif generation_type == "top_k":
|
234 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
235 |
+
else:
|
236 |
+
raise ValueError(
|
237 |
+
f"generation_type has to be one of "
|
238 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
239 |
+
)
|
240 |
+
|
241 |
+
image_latent, image_embs = self._encode_image(image)
|
242 |
+
|
243 |
+
if text is None:
|
244 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
245 |
+
|
246 |
+
was_training = self.training
|
247 |
+
num_dims = len(text.shape)
|
248 |
+
|
249 |
+
if num_dims == 1:
|
250 |
+
text = text[None, :]
|
251 |
+
|
252 |
+
cur_len = text.shape[1]
|
253 |
+
self.eval()
|
254 |
+
out = text
|
255 |
+
|
256 |
+
while True:
|
257 |
+
x = out[:, -max_seq_len:]
|
258 |
+
cur_len = x.shape[1]
|
259 |
+
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
260 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
261 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
262 |
+
|
263 |
+
if mask.all():
|
264 |
+
if not fixed_output_length:
|
265 |
+
break
|
266 |
+
else:
|
267 |
+
logits = logits[~mask, :]
|
268 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
269 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
270 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
271 |
+
|
272 |
+
if (cur_len + 1 == seq_len):
|
273 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
274 |
+
else:
|
275 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
276 |
+
|
277 |
+
out = torch.cat((out, sample), dim=-1)
|
278 |
+
|
279 |
+
cur_len += 1
|
280 |
+
|
281 |
+
if stopping_criteria(out, None):
|
282 |
+
break
|
283 |
+
|
284 |
+
if num_dims == 1:
|
285 |
+
out = out.squeeze(0)
|
286 |
+
|
287 |
+
self.train(was_training)
|
288 |
+
return out
|
289 |
+
|
290 |
+
def _generate_beamsearch(
|
291 |
+
self,
|
292 |
+
image_inputs,
|
293 |
+
pad_token_id=None,
|
294 |
+
eos_token_id=None,
|
295 |
+
sot_token_id=None,
|
296 |
+
num_beams=6,
|
297 |
+
num_beam_groups=3,
|
298 |
+
min_seq_len=5,
|
299 |
+
stopping_criteria=None,
|
300 |
+
logit_processor=None,
|
301 |
+
logit_warper=None,
|
302 |
+
):
|
303 |
+
device = image_inputs.device
|
304 |
+
batch_size = image_inputs.shape[0]
|
305 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
306 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
307 |
+
|
308 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
309 |
+
input_ids = input_ids * sot_token_id
|
310 |
+
beam_scorer = BeamSearchScorer(
|
311 |
+
batch_size=batch_size,
|
312 |
+
num_beams=num_beams,
|
313 |
+
device=device,
|
314 |
+
num_beam_groups=num_beam_groups,
|
315 |
+
)
|
316 |
+
# instantiate logits processors
|
317 |
+
logits_processor = (
|
318 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
319 |
+
if logit_processor is None
|
320 |
+
else logit_processor
|
321 |
+
)
|
322 |
+
|
323 |
+
batch_size = len(beam_scorer._beam_hyps)
|
324 |
+
num_beams = beam_scorer.num_beams
|
325 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
326 |
+
num_sub_beams = num_beams // num_beam_groups
|
327 |
+
batch_beam_size, cur_len = input_ids.shape
|
328 |
+
beam_indices = None
|
329 |
+
|
330 |
+
if num_beams * batch_size != batch_beam_size:
|
331 |
+
raise ValueError(
|
332 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
333 |
+
)
|
334 |
+
|
335 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
336 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
337 |
+
# the same group don't produce same tokens everytime.
|
338 |
+
beam_scores[:, ::num_sub_beams] = 0
|
339 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
340 |
+
|
341 |
+
while True:
|
342 |
+
|
343 |
+
# predicted tokens in cur_len step
|
344 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
345 |
+
|
346 |
+
# indices which will form the beams in the next time step
|
347 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
348 |
+
|
349 |
+
# do one decoder step on all beams of all sentences in batch
|
350 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
351 |
+
outputs = self(
|
352 |
+
model_inputs['images'],
|
353 |
+
model_inputs['text'],
|
354 |
+
embed_cls=False,
|
355 |
+
image_latent=image_latent,
|
356 |
+
image_embs=image_embs
|
357 |
+
)
|
358 |
+
|
359 |
+
for beam_group_idx in range(num_beam_groups):
|
360 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
361 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
362 |
+
group_size = group_end_idx - group_start_idx
|
363 |
+
|
364 |
+
# indices of beams of current group among all sentences in batch
|
365 |
+
batch_group_indices = []
|
366 |
+
|
367 |
+
for batch_idx in range(batch_size):
|
368 |
+
batch_group_indices.extend(
|
369 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
370 |
+
)
|
371 |
+
group_input_ids = input_ids[batch_group_indices]
|
372 |
+
|
373 |
+
# select outputs of beams of currentg group only
|
374 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
375 |
+
vocab_size = next_token_logits.shape[-1]
|
376 |
+
|
377 |
+
next_token_scores_processed = logits_processor(
|
378 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
379 |
+
)
|
380 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
381 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
382 |
+
|
383 |
+
# reshape for beam search
|
384 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
385 |
+
|
386 |
+
next_token_scores, next_tokens = torch.topk(
|
387 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
388 |
+
)
|
389 |
+
|
390 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
391 |
+
next_tokens = next_tokens % vocab_size
|
392 |
+
|
393 |
+
# stateless
|
394 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
395 |
+
beam_outputs = beam_scorer.process(
|
396 |
+
group_input_ids,
|
397 |
+
next_token_scores,
|
398 |
+
next_tokens,
|
399 |
+
next_indices,
|
400 |
+
pad_token_id=pad_token_id,
|
401 |
+
eos_token_id=eos_token_id,
|
402 |
+
beam_indices=process_beam_indices,
|
403 |
+
)
|
404 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
405 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
406 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
407 |
+
|
408 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
409 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
410 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
411 |
+
|
412 |
+
# (beam_idx // group_size) -> batch_idx
|
413 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
414 |
+
reordering_indices[batch_group_indices] = (
|
415 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
416 |
+
)
|
417 |
+
|
418 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
419 |
+
|
420 |
+
# increase cur_len
|
421 |
+
cur_len = cur_len + 1
|
422 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
423 |
+
break
|
424 |
+
|
425 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
426 |
+
sequence_outputs = beam_scorer.finalize(
|
427 |
+
input_ids,
|
428 |
+
beam_scores,
|
429 |
+
next_tokens,
|
430 |
+
next_indices,
|
431 |
+
pad_token_id=pad_token_id,
|
432 |
+
eos_token_id=eos_token_id,
|
433 |
+
max_length=stopping_criteria.max_length,
|
434 |
+
beam_indices=final_beam_indices,
|
435 |
+
)
|
436 |
+
return sequence_outputs['sequences']
|
437 |
+
|
438 |
+
|
439 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
440 |
+
if past:
|
441 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
442 |
+
|
443 |
+
attention_mask = kwargs.get("attention_mask", None)
|
444 |
+
position_ids = kwargs.get("position_ids", None)
|
445 |
+
|
446 |
+
if attention_mask is not None and position_ids is None:
|
447 |
+
# create position_ids on the fly for batch generation
|
448 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
449 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
450 |
+
else:
|
451 |
+
position_ids = None
|
452 |
+
return {
|
453 |
+
"text": input_ids,
|
454 |
+
"images": image_inputs,
|
455 |
+
"past_key_values": past,
|
456 |
+
"position_ids": position_ids,
|
457 |
+
"attention_mask": attention_mask,
|
458 |
+
}
|
ext/open_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
ext/open_clip/factory.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
13 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
14 |
+
resize_pos_embed, get_cast_dtype
|
15 |
+
from .coca_model import CoCa
|
16 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
17 |
+
from .openai import load_openai_model
|
18 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
19 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
20 |
+
from .transform import image_transform, AugmentationCfg
|
21 |
+
from .tokenizer import HFTokenizer, tokenize
|
22 |
+
|
23 |
+
|
24 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
25 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
26 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
27 |
+
|
28 |
+
|
29 |
+
def _natural_key(string_):
|
30 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
31 |
+
|
32 |
+
|
33 |
+
def _rescan_model_configs():
|
34 |
+
global _MODEL_CONFIGS
|
35 |
+
|
36 |
+
config_ext = ('.json',)
|
37 |
+
config_files = []
|
38 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
39 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
40 |
+
config_files.append(config_path)
|
41 |
+
elif config_path.is_dir():
|
42 |
+
for ext in config_ext:
|
43 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
44 |
+
|
45 |
+
for cf in config_files:
|
46 |
+
with open(cf, 'r') as f:
|
47 |
+
model_cfg = json.load(f)
|
48 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
49 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
50 |
+
|
51 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
52 |
+
|
53 |
+
|
54 |
+
_rescan_model_configs() # initial populate of model config registry
|
55 |
+
|
56 |
+
|
57 |
+
def list_models():
|
58 |
+
""" enumerate available model architectures based on config files """
|
59 |
+
return list(_MODEL_CONFIGS.keys())
|
60 |
+
|
61 |
+
|
62 |
+
def add_model_config(path):
|
63 |
+
""" add model config path or file and update registry """
|
64 |
+
if not isinstance(path, Path):
|
65 |
+
path = Path(path)
|
66 |
+
_MODEL_CONFIG_PATHS.append(path)
|
67 |
+
_rescan_model_configs()
|
68 |
+
|
69 |
+
|
70 |
+
def get_model_config(model_name):
|
71 |
+
if model_name in _MODEL_CONFIGS:
|
72 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
73 |
+
else:
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def get_tokenizer(model_name):
|
78 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
79 |
+
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
80 |
+
else:
|
81 |
+
config = get_model_config(model_name)
|
82 |
+
tokenizer = HFTokenizer(
|
83 |
+
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
84 |
+
return tokenizer
|
85 |
+
|
86 |
+
|
87 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
88 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
89 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
90 |
+
state_dict = checkpoint['state_dict']
|
91 |
+
else:
|
92 |
+
state_dict = checkpoint
|
93 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
94 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
95 |
+
return state_dict
|
96 |
+
|
97 |
+
|
98 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
99 |
+
state_dict = load_state_dict(checkpoint_path)
|
100 |
+
# detect old format and make compatible with new format
|
101 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
102 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
103 |
+
resize_pos_embed(state_dict, model)
|
104 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
105 |
+
return incompatible_keys
|
106 |
+
|
107 |
+
|
108 |
+
def create_model(
|
109 |
+
model_name: str,
|
110 |
+
pretrained: Optional[str] = None,
|
111 |
+
precision: str = 'fp32',
|
112 |
+
device: Union[str, torch.device] = 'cpu',
|
113 |
+
jit: bool = False,
|
114 |
+
force_quick_gelu: bool = False,
|
115 |
+
force_custom_text: bool = False,
|
116 |
+
force_patch_dropout: Optional[float] = None,
|
117 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
118 |
+
pretrained_image: bool = False,
|
119 |
+
pretrained_hf: bool = True,
|
120 |
+
cache_dir: Optional[str] = None,
|
121 |
+
output_dict: Optional[bool] = None,
|
122 |
+
require_pretrained: bool = False,
|
123 |
+
logger: logging.Logger = logging,
|
124 |
+
):
|
125 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
126 |
+
if has_hf_hub_prefix:
|
127 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
128 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
129 |
+
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
130 |
+
|
131 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
132 |
+
config = json.load(f)
|
133 |
+
pretrained_cfg = config['preprocess_cfg']
|
134 |
+
model_cfg = config['model_cfg']
|
135 |
+
else:
|
136 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
137 |
+
checkpoint_path = None
|
138 |
+
pretrained_cfg = {}
|
139 |
+
model_cfg = None
|
140 |
+
|
141 |
+
if isinstance(device, str):
|
142 |
+
device = torch.device(device)
|
143 |
+
|
144 |
+
if pretrained and pretrained.lower() == 'openai':
|
145 |
+
logger.info(f'Loading pretrained {model_name} from OpenAI.')
|
146 |
+
model = load_openai_model(
|
147 |
+
model_name,
|
148 |
+
precision=precision,
|
149 |
+
device=device,
|
150 |
+
cache_dir=cache_dir,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
154 |
+
if model_cfg is not None:
|
155 |
+
logger.info(f'Loaded {model_name} model config.')
|
156 |
+
else:
|
157 |
+
logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
158 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
159 |
+
|
160 |
+
if force_quick_gelu:
|
161 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
162 |
+
model_cfg["quick_gelu"] = True
|
163 |
+
|
164 |
+
if force_patch_dropout is not None:
|
165 |
+
# override the default patch dropout value
|
166 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
167 |
+
|
168 |
+
if force_image_size is not None:
|
169 |
+
# override model config's image size
|
170 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
171 |
+
|
172 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
173 |
+
if pretrained_image:
|
174 |
+
if is_timm_model:
|
175 |
+
# pretrained weight loading for timm models set via vision_cfg
|
176 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
177 |
+
else:
|
178 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
179 |
+
|
180 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
181 |
+
cast_dtype = get_cast_dtype(precision)
|
182 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
183 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
184 |
+
|
185 |
+
if custom_text:
|
186 |
+
if is_hf_model:
|
187 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
188 |
+
if "coca" in model_name:
|
189 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
190 |
+
else:
|
191 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
192 |
+
else:
|
193 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
194 |
+
|
195 |
+
if precision in ("fp16", "bf16"):
|
196 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
197 |
+
# manual mixed precision that matches original OpenAI behaviour
|
198 |
+
if is_timm_model:
|
199 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
200 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
201 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
202 |
+
model.to(device=device, dtype=dtype)
|
203 |
+
from .transformer import LayerNormFp32
|
204 |
+
def _convert_ln(m):
|
205 |
+
if isinstance(m, LayerNormFp32):
|
206 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
207 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
208 |
+
model.apply(_convert_ln)
|
209 |
+
else:
|
210 |
+
model.to(device=device)
|
211 |
+
convert_weights_to_lp(model, dtype=dtype)
|
212 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
213 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
214 |
+
model.to(device=device, dtype=dtype)
|
215 |
+
else:
|
216 |
+
model.to(device=device)
|
217 |
+
|
218 |
+
pretrained_loaded = False
|
219 |
+
if pretrained:
|
220 |
+
checkpoint_path = ''
|
221 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
222 |
+
if pretrained_cfg:
|
223 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
224 |
+
elif os.path.exists(pretrained):
|
225 |
+
checkpoint_path = pretrained
|
226 |
+
|
227 |
+
if checkpoint_path:
|
228 |
+
logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
229 |
+
load_checkpoint(model, checkpoint_path)
|
230 |
+
else:
|
231 |
+
error_str = (
|
232 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
233 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
234 |
+
logger.warning(error_str)
|
235 |
+
raise RuntimeError(error_str)
|
236 |
+
pretrained_loaded = True
|
237 |
+
elif has_hf_hub_prefix:
|
238 |
+
logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
239 |
+
load_checkpoint(model, checkpoint_path)
|
240 |
+
pretrained_loaded = True
|
241 |
+
|
242 |
+
if require_pretrained and not pretrained_loaded:
|
243 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
244 |
+
raise RuntimeError(
|
245 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
246 |
+
|
247 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
248 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
249 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
250 |
+
|
251 |
+
if output_dict and hasattr(model, "output_dict"):
|
252 |
+
model.output_dict = True
|
253 |
+
|
254 |
+
if jit:
|
255 |
+
model = torch.jit.script(model)
|
256 |
+
|
257 |
+
return model
|
258 |
+
|
259 |
+
|
260 |
+
def create_loss(args):
|
261 |
+
if args.distill:
|
262 |
+
return DistillClipLoss(
|
263 |
+
local_loss=args.local_loss,
|
264 |
+
gather_with_grad=args.gather_with_grad,
|
265 |
+
cache_labels=True,
|
266 |
+
rank=args.rank,
|
267 |
+
world_size=args.world_size,
|
268 |
+
use_horovod=args.horovod,
|
269 |
+
)
|
270 |
+
elif "coca" in args.model.lower():
|
271 |
+
return CoCaLoss(
|
272 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
273 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
274 |
+
local_loss=args.local_loss,
|
275 |
+
gather_with_grad=args.gather_with_grad,
|
276 |
+
cache_labels=True,
|
277 |
+
rank=args.rank,
|
278 |
+
world_size=args.world_size,
|
279 |
+
use_horovod=args.horovod,
|
280 |
+
)
|
281 |
+
return ClipLoss(
|
282 |
+
local_loss=args.local_loss,
|
283 |
+
gather_with_grad=args.gather_with_grad,
|
284 |
+
cache_labels=True,
|
285 |
+
rank=args.rank,
|
286 |
+
world_size=args.world_size,
|
287 |
+
use_horovod=args.horovod,
|
288 |
+
)
|
289 |
+
|
290 |
+
|
291 |
+
def create_model_and_transforms(
|
292 |
+
model_name: str,
|
293 |
+
pretrained: Optional[str] = None,
|
294 |
+
precision: str = 'fp32',
|
295 |
+
device: Union[str, torch.device] = 'cpu',
|
296 |
+
jit: bool = False,
|
297 |
+
force_quick_gelu: bool = False,
|
298 |
+
force_custom_text: bool = False,
|
299 |
+
force_patch_dropout: Optional[float] = None,
|
300 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
301 |
+
pretrained_image: bool = False,
|
302 |
+
pretrained_hf: bool = True,
|
303 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
304 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
305 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
306 |
+
cache_dir: Optional[str] = None,
|
307 |
+
output_dict: Optional[bool] = None,
|
308 |
+
logger: logging.Logger = logging,
|
309 |
+
):
|
310 |
+
model = create_model(
|
311 |
+
model_name,
|
312 |
+
pretrained,
|
313 |
+
precision=precision,
|
314 |
+
device=device,
|
315 |
+
jit=jit,
|
316 |
+
force_quick_gelu=force_quick_gelu,
|
317 |
+
force_custom_text=force_custom_text,
|
318 |
+
force_patch_dropout=force_patch_dropout,
|
319 |
+
force_image_size=force_image_size,
|
320 |
+
pretrained_image=pretrained_image,
|
321 |
+
pretrained_hf=pretrained_hf,
|
322 |
+
cache_dir=cache_dir,
|
323 |
+
output_dict=output_dict,
|
324 |
+
logger=logger,
|
325 |
+
)
|
326 |
+
|
327 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
328 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
329 |
+
preprocess_train = image_transform(
|
330 |
+
model.visual.image_size,
|
331 |
+
is_train=True,
|
332 |
+
mean=image_mean,
|
333 |
+
std=image_std,
|
334 |
+
aug_cfg=aug_cfg,
|
335 |
+
)
|
336 |
+
preprocess_val = image_transform(
|
337 |
+
model.visual.image_size,
|
338 |
+
is_train=False,
|
339 |
+
mean=image_mean,
|
340 |
+
std=image_std,
|
341 |
+
)
|
342 |
+
|
343 |
+
return model, preprocess_train, preprocess_val
|
344 |
+
|
345 |
+
|
346 |
+
def create_model_from_pretrained(
|
347 |
+
model_name: str,
|
348 |
+
pretrained: Optional[str] = None,
|
349 |
+
precision: str = 'fp32',
|
350 |
+
device: Union[str, torch.device] = 'cpu',
|
351 |
+
jit: bool = False,
|
352 |
+
force_quick_gelu: bool = False,
|
353 |
+
force_custom_text: bool = False,
|
354 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
355 |
+
return_transform: bool = True,
|
356 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
357 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
358 |
+
cache_dir: Optional[str] = None,
|
359 |
+
logger: logging.Logger = logging,
|
360 |
+
):
|
361 |
+
model = create_model(
|
362 |
+
model_name,
|
363 |
+
pretrained,
|
364 |
+
precision=precision,
|
365 |
+
device=device,
|
366 |
+
jit=jit,
|
367 |
+
force_quick_gelu=force_quick_gelu,
|
368 |
+
force_custom_text=force_custom_text,
|
369 |
+
force_image_size=force_image_size,
|
370 |
+
cache_dir=cache_dir,
|
371 |
+
require_pretrained=True,
|
372 |
+
logger=logger,
|
373 |
+
)
|
374 |
+
|
375 |
+
if not return_transform:
|
376 |
+
return model
|
377 |
+
|
378 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
379 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
380 |
+
preprocess = image_transform(
|
381 |
+
model.visual.image_size,
|
382 |
+
is_train=False,
|
383 |
+
mean=image_mean,
|
384 |
+
std=image_std,
|
385 |
+
)
|
386 |
+
|
387 |
+
return model, preprocess
|
ext/open_clip/generation_utils.py
ADDED
File without changes
|
ext/open_clip/hf_configs.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
# https://huggingface.co/docs/transformers/model_doc/bert
|
46 |
+
"bert": {
|
47 |
+
"config_names": {
|
48 |
+
"context_length": "max_position_embeddings",
|
49 |
+
"vocab_size": "vocab_size",
|
50 |
+
"width": "hidden_size",
|
51 |
+
"heads": "num_attention_heads",
|
52 |
+
"layers": "num_hidden_layers",
|
53 |
+
},
|
54 |
+
"pooler": "cls_pooler",
|
55 |
+
},
|
56 |
+
}
|
ext/open_clip/hf_model.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import TensorType
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
15 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
16 |
+
except ImportError as e:
|
17 |
+
transformers = None
|
18 |
+
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedConfig:
|
25 |
+
pass
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
|
35 |
+
# TODO: ?last - for gpt-like models
|
36 |
+
_POOLERS = {}
|
37 |
+
|
38 |
+
|
39 |
+
def register_pooler(cls):
|
40 |
+
"""Decorator registering pooler class"""
|
41 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
@register_pooler
|
46 |
+
class MeanPooler(nn.Module):
|
47 |
+
"""Mean pooling"""
|
48 |
+
|
49 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
50 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
51 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
52 |
+
|
53 |
+
|
54 |
+
@register_pooler
|
55 |
+
class MaxPooler(nn.Module):
|
56 |
+
"""Max pooling"""
|
57 |
+
|
58 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
59 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
60 |
+
return masked_output.max(1).values
|
61 |
+
|
62 |
+
|
63 |
+
@register_pooler
|
64 |
+
class ClsPooler(nn.Module):
|
65 |
+
"""CLS token pooling"""
|
66 |
+
|
67 |
+
def __init__(self, use_pooler_output=True):
|
68 |
+
super().__init__()
|
69 |
+
self.cls_token_position = 0
|
70 |
+
self.use_pooler_output = use_pooler_output
|
71 |
+
|
72 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
73 |
+
if (self.use_pooler_output and
|
74 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
75 |
+
(x.pooler_output is not None)
|
76 |
+
):
|
77 |
+
return x.pooler_output
|
78 |
+
|
79 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
80 |
+
|
81 |
+
|
82 |
+
@register_pooler
|
83 |
+
class ClsLastHiddenStatePooler(nn.Module):
|
84 |
+
"""CLS token pooling
|
85 |
+
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self):
|
89 |
+
super().__init__()
|
90 |
+
self.cls_token_position = 0
|
91 |
+
|
92 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
93 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
94 |
+
|
95 |
+
|
96 |
+
class HFTextEncoder(nn.Module):
|
97 |
+
"""HuggingFace model adapter"""
|
98 |
+
output_tokens: torch.jit.Final[bool]
|
99 |
+
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
model_name_or_path: str,
|
103 |
+
output_dim: int,
|
104 |
+
config: PretrainedConfig = None,
|
105 |
+
pooler_type: str = None,
|
106 |
+
proj: str = None,
|
107 |
+
pretrained: bool = True,
|
108 |
+
output_tokens: bool = False,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.output_tokens = output_tokens
|
112 |
+
self.output_dim = output_dim
|
113 |
+
|
114 |
+
# TODO: find better way to get this information
|
115 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
116 |
+
|
117 |
+
if transformers is None:
|
118 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
119 |
+
if config is None:
|
120 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
121 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
122 |
+
AutoModel.from_config, self.config)
|
123 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
124 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
125 |
+
self.transformer = create_func(model_args)
|
126 |
+
self.transformer = self.transformer.encoder
|
127 |
+
else:
|
128 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
129 |
+
else:
|
130 |
+
self.config = config
|
131 |
+
self.transformer = AutoModel.from_config(config)
|
132 |
+
if pooler_type is None: # get default arch pooler
|
133 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
134 |
+
|
135 |
+
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
|
136 |
+
self.vocab_size = getattr(self.config, 'vocab_size', 0)
|
137 |
+
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
|
138 |
+
|
139 |
+
self.pooler = _POOLERS[pooler_type]()
|
140 |
+
|
141 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
142 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
143 |
+
self.proj = nn.Identity()
|
144 |
+
elif proj == 'linear':
|
145 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
146 |
+
elif proj == 'mlp':
|
147 |
+
hidden_size = (d_model + output_dim) // 2
|
148 |
+
self.proj = nn.Sequential(
|
149 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
150 |
+
nn.GELU(),
|
151 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: TensorType):
|
155 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
156 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
157 |
+
pooled_out = self.pooler(out, attn_mask)
|
158 |
+
projected = self.proj(pooled_out)
|
159 |
+
|
160 |
+
seq_len = out.last_hidden_state.shape[1]
|
161 |
+
tokens = (
|
162 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
163 |
+
if type(self.pooler) == ClsPooler
|
164 |
+
else out.last_hidden_state
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.output_tokens:
|
168 |
+
return projected, tokens
|
169 |
+
return projected
|
170 |
+
|
171 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
172 |
+
if not unlocked_layers: # full freezing
|
173 |
+
for n, p in self.transformer.named_parameters():
|
174 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
175 |
+
return
|
176 |
+
|
177 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
178 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
179 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
180 |
+
embeddings = getattr(
|
181 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
182 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
183 |
+
# freeze layers
|
184 |
+
for module in modules:
|
185 |
+
for n, p in module.named_parameters():
|
186 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def set_grad_checkpointing(self, enable=True):
|
190 |
+
self.transformer.gradient_checkpointing_enable()
|
191 |
+
|
192 |
+
def init_parameters(self):
|
193 |
+
pass
|
ext/open_clip/loss.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
try:
|
6 |
+
import torch.distributed.nn
|
7 |
+
from torch import distributed as dist
|
8 |
+
|
9 |
+
has_distributed = True
|
10 |
+
except ImportError:
|
11 |
+
has_distributed = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import horovod.torch as hvd
|
15 |
+
except ImportError:
|
16 |
+
hvd = None
|
17 |
+
|
18 |
+
|
19 |
+
def gather_features(
|
20 |
+
image_features,
|
21 |
+
text_features,
|
22 |
+
local_loss=False,
|
23 |
+
gather_with_grad=False,
|
24 |
+
rank=0,
|
25 |
+
world_size=1,
|
26 |
+
use_horovod=False
|
27 |
+
):
|
28 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
29 |
+
if use_horovod:
|
30 |
+
assert hvd is not None, 'Please install horovod'
|
31 |
+
if gather_with_grad:
|
32 |
+
all_image_features = hvd.allgather(image_features)
|
33 |
+
all_text_features = hvd.allgather(text_features)
|
34 |
+
else:
|
35 |
+
with torch.no_grad():
|
36 |
+
all_image_features = hvd.allgather(image_features)
|
37 |
+
all_text_features = hvd.allgather(text_features)
|
38 |
+
if not local_loss:
|
39 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
40 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
41 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
42 |
+
gathered_image_features[rank] = image_features
|
43 |
+
gathered_text_features[rank] = text_features
|
44 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
45 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
46 |
+
else:
|
47 |
+
# We gather tensors from all gpus
|
48 |
+
if gather_with_grad:
|
49 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
50 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
51 |
+
else:
|
52 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
53 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
54 |
+
dist.all_gather(gathered_image_features, image_features)
|
55 |
+
dist.all_gather(gathered_text_features, text_features)
|
56 |
+
if not local_loss:
|
57 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
58 |
+
gathered_image_features[rank] = image_features
|
59 |
+
gathered_text_features[rank] = text_features
|
60 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
61 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
62 |
+
|
63 |
+
return all_image_features, all_text_features
|
64 |
+
|
65 |
+
|
66 |
+
class ClipLoss(nn.Module):
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
local_loss=False,
|
71 |
+
gather_with_grad=False,
|
72 |
+
cache_labels=False,
|
73 |
+
rank=0,
|
74 |
+
world_size=1,
|
75 |
+
use_horovod=False,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.local_loss = local_loss
|
79 |
+
self.gather_with_grad = gather_with_grad
|
80 |
+
self.cache_labels = cache_labels
|
81 |
+
self.rank = rank
|
82 |
+
self.world_size = world_size
|
83 |
+
self.use_horovod = use_horovod
|
84 |
+
|
85 |
+
# cache state
|
86 |
+
self.prev_num_logits = 0
|
87 |
+
self.labels = {}
|
88 |
+
|
89 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
90 |
+
# calculated ground-truth and cache if enabled
|
91 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
92 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
93 |
+
if self.world_size > 1 and self.local_loss:
|
94 |
+
labels = labels + num_logits * self.rank
|
95 |
+
if self.cache_labels:
|
96 |
+
self.labels[device] = labels
|
97 |
+
self.prev_num_logits = num_logits
|
98 |
+
else:
|
99 |
+
labels = self.labels[device]
|
100 |
+
return labels
|
101 |
+
|
102 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
103 |
+
if self.world_size > 1:
|
104 |
+
all_image_features, all_text_features = gather_features(
|
105 |
+
image_features, text_features,
|
106 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
107 |
+
|
108 |
+
if self.local_loss:
|
109 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
110 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
111 |
+
else:
|
112 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
113 |
+
logits_per_text = logits_per_image.T
|
114 |
+
else:
|
115 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
116 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
117 |
+
|
118 |
+
return logits_per_image, logits_per_text
|
119 |
+
|
120 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
121 |
+
device = image_features.device
|
122 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
123 |
+
|
124 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
125 |
+
|
126 |
+
total_loss = (
|
127 |
+
F.cross_entropy(logits_per_image, labels) +
|
128 |
+
F.cross_entropy(logits_per_text, labels)
|
129 |
+
) / 2
|
130 |
+
|
131 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
132 |
+
|
133 |
+
|
134 |
+
class CoCaLoss(ClipLoss):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
caption_loss_weight,
|
138 |
+
clip_loss_weight,
|
139 |
+
pad_id=0, # pad_token for open_clip custom tokenizer
|
140 |
+
local_loss=False,
|
141 |
+
gather_with_grad=False,
|
142 |
+
cache_labels=False,
|
143 |
+
rank=0,
|
144 |
+
world_size=1,
|
145 |
+
use_horovod=False,
|
146 |
+
):
|
147 |
+
super().__init__(
|
148 |
+
local_loss=local_loss,
|
149 |
+
gather_with_grad=gather_with_grad,
|
150 |
+
cache_labels=cache_labels,
|
151 |
+
rank=rank,
|
152 |
+
world_size=world_size,
|
153 |
+
use_horovod=use_horovod
|
154 |
+
)
|
155 |
+
|
156 |
+
self.clip_loss_weight = clip_loss_weight
|
157 |
+
self.caption_loss_weight = caption_loss_weight
|
158 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
159 |
+
|
160 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
161 |
+
|
162 |
+
clip_loss = torch.tensor(0)
|
163 |
+
|
164 |
+
if self.clip_loss_weight:
|
165 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
166 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
167 |
+
|
168 |
+
caption_loss = self.caption_loss(
|
169 |
+
logits.permute(0, 2, 1),
|
170 |
+
labels,
|
171 |
+
)
|
172 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
173 |
+
|
174 |
+
if output_dict:
|
175 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
176 |
+
|
177 |
+
return clip_loss, caption_loss
|
178 |
+
|
179 |
+
|
180 |
+
class DistillClipLoss(ClipLoss):
|
181 |
+
|
182 |
+
def dist_loss(self, teacher_logits, student_logits):
|
183 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
184 |
+
|
185 |
+
def forward(
|
186 |
+
self,
|
187 |
+
image_features,
|
188 |
+
text_features,
|
189 |
+
logit_scale,
|
190 |
+
dist_image_features,
|
191 |
+
dist_text_features,
|
192 |
+
dist_logit_scale,
|
193 |
+
output_dict=False,
|
194 |
+
):
|
195 |
+
logits_per_image, logits_per_text = \
|
196 |
+
self.get_logits(image_features, text_features, logit_scale)
|
197 |
+
|
198 |
+
dist_logits_per_image, dist_logits_per_text = \
|
199 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
200 |
+
|
201 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
202 |
+
|
203 |
+
contrastive_loss = (
|
204 |
+
F.cross_entropy(logits_per_image, labels) +
|
205 |
+
F.cross_entropy(logits_per_text, labels)
|
206 |
+
) / 2
|
207 |
+
|
208 |
+
distill_loss = (
|
209 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
210 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
211 |
+
) / 2
|
212 |
+
|
213 |
+
if output_dict:
|
214 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
215 |
+
|
216 |
+
return contrastive_loss, distill_loss
|
ext/open_clip/model.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
|
16 |
+
from .hf_model import HFTextEncoder
|
17 |
+
from .modified_resnet import ModifiedResNet
|
18 |
+
from .timm_model import TimmModel
|
19 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
20 |
+
from .utils import to_2tuple
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class CLIPVisionCfg:
|
25 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
26 |
+
width: int = 768
|
27 |
+
head_width: int = 64
|
28 |
+
mlp_ratio: float = 4.0
|
29 |
+
patch_size: int = 16
|
30 |
+
image_size: Union[Tuple[int, int], int] = 224
|
31 |
+
|
32 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
33 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
34 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
35 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
36 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
37 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
38 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
39 |
+
output_tokens: bool = False
|
40 |
+
|
41 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
42 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
43 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
44 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
45 |
+
timm_proj_bias: bool = False # enable bias final projection
|
46 |
+
timm_drop: float = 0. # head dropout
|
47 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class CLIPTextCfg:
|
52 |
+
context_length: int = 77
|
53 |
+
vocab_size: int = 49408
|
54 |
+
width: int = 512
|
55 |
+
heads: int = 8
|
56 |
+
layers: int = 12
|
57 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
58 |
+
hf_model_name: str = None
|
59 |
+
hf_tokenizer_name: str = None
|
60 |
+
hf_model_pretrained: bool = True
|
61 |
+
proj: str = 'mlp'
|
62 |
+
pooler_type: str = 'mean_pooler'
|
63 |
+
embed_cls: bool = False
|
64 |
+
pad_id: int = 0
|
65 |
+
output_tokens: bool = False
|
66 |
+
|
67 |
+
|
68 |
+
def get_cast_dtype(precision: str):
|
69 |
+
cast_dtype = None
|
70 |
+
if precision == 'bf16':
|
71 |
+
cast_dtype = torch.bfloat16
|
72 |
+
elif precision == 'fp16':
|
73 |
+
cast_dtype = torch.float16
|
74 |
+
return cast_dtype
|
75 |
+
|
76 |
+
|
77 |
+
def get_input_dtype(precision: str):
|
78 |
+
input_dtype = None
|
79 |
+
if precision in ('bf16', 'pure_bf16'):
|
80 |
+
input_dtype = torch.bfloat16
|
81 |
+
elif precision in ('fp16', 'pure_fp16'):
|
82 |
+
input_dtype = torch.float16
|
83 |
+
return input_dtype
|
84 |
+
|
85 |
+
|
86 |
+
def _build_vision_tower(
|
87 |
+
embed_dim: int,
|
88 |
+
vision_cfg: CLIPVisionCfg,
|
89 |
+
quick_gelu: bool = False,
|
90 |
+
cast_dtype: Optional[torch.dtype] = None
|
91 |
+
):
|
92 |
+
if isinstance(vision_cfg, dict):
|
93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
94 |
+
|
95 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
96 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
97 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
98 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
99 |
+
|
100 |
+
if vision_cfg.timm_model_name:
|
101 |
+
visual = TimmModel(
|
102 |
+
vision_cfg.timm_model_name,
|
103 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
104 |
+
pool=vision_cfg.timm_pool,
|
105 |
+
proj=vision_cfg.timm_proj,
|
106 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
107 |
+
drop=vision_cfg.timm_drop,
|
108 |
+
drop_path=vision_cfg.timm_drop_path,
|
109 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
110 |
+
embed_dim=embed_dim,
|
111 |
+
image_size=vision_cfg.image_size,
|
112 |
+
)
|
113 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
114 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
115 |
+
visual = ModifiedResNet(
|
116 |
+
layers=vision_cfg.layers,
|
117 |
+
output_dim=embed_dim,
|
118 |
+
heads=vision_heads,
|
119 |
+
image_size=vision_cfg.image_size,
|
120 |
+
width=vision_cfg.width,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
124 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
125 |
+
visual = VisionTransformer(
|
126 |
+
image_size=vision_cfg.image_size,
|
127 |
+
patch_size=vision_cfg.patch_size,
|
128 |
+
width=vision_cfg.width,
|
129 |
+
layers=vision_cfg.layers,
|
130 |
+
heads=vision_heads,
|
131 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
132 |
+
ls_init_value=vision_cfg.ls_init_value,
|
133 |
+
patch_dropout=vision_cfg.patch_dropout,
|
134 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
135 |
+
global_average_pool=vision_cfg.global_average_pool,
|
136 |
+
attentional_pool=vision_cfg.attentional_pool,
|
137 |
+
n_queries=vision_cfg.n_queries,
|
138 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
139 |
+
output_tokens=vision_cfg.output_tokens,
|
140 |
+
output_dim=embed_dim,
|
141 |
+
act_layer=act_layer,
|
142 |
+
norm_layer=norm_layer,
|
143 |
+
)
|
144 |
+
|
145 |
+
return visual
|
146 |
+
|
147 |
+
|
148 |
+
def _build_text_tower(
|
149 |
+
embed_dim: int,
|
150 |
+
text_cfg: CLIPTextCfg,
|
151 |
+
quick_gelu: bool = False,
|
152 |
+
cast_dtype: Optional[torch.dtype] = None,
|
153 |
+
):
|
154 |
+
if isinstance(text_cfg, dict):
|
155 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
156 |
+
|
157 |
+
if text_cfg.hf_model_name:
|
158 |
+
text = HFTextEncoder(
|
159 |
+
text_cfg.hf_model_name,
|
160 |
+
output_dim=embed_dim,
|
161 |
+
proj=text_cfg.proj,
|
162 |
+
pooler_type=text_cfg.pooler_type,
|
163 |
+
pretrained=text_cfg.hf_model_pretrained,
|
164 |
+
output_tokens=text_cfg.output_tokens,
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
168 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
169 |
+
|
170 |
+
text = TextTransformer(
|
171 |
+
context_length=text_cfg.context_length,
|
172 |
+
vocab_size=text_cfg.vocab_size,
|
173 |
+
width=text_cfg.width,
|
174 |
+
heads=text_cfg.heads,
|
175 |
+
layers=text_cfg.layers,
|
176 |
+
ls_init_value=text_cfg.ls_init_value,
|
177 |
+
output_dim=embed_dim,
|
178 |
+
embed_cls=text_cfg.embed_cls,
|
179 |
+
output_tokens=text_cfg.output_tokens,
|
180 |
+
pad_id=text_cfg.pad_id,
|
181 |
+
act_layer=act_layer,
|
182 |
+
norm_layer=norm_layer,
|
183 |
+
)
|
184 |
+
return text
|
185 |
+
|
186 |
+
|
187 |
+
class CLIP(nn.Module):
|
188 |
+
output_dict: torch.jit.Final[bool]
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
embed_dim: int,
|
193 |
+
vision_cfg: CLIPVisionCfg,
|
194 |
+
text_cfg: CLIPTextCfg,
|
195 |
+
quick_gelu: bool = False,
|
196 |
+
cast_dtype: Optional[torch.dtype] = None,
|
197 |
+
output_dict: bool = False,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
self.output_dict = output_dict
|
201 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
202 |
+
|
203 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
204 |
+
self.transformer = text.transformer
|
205 |
+
self.context_length = text.context_length
|
206 |
+
self.vocab_size = text.vocab_size
|
207 |
+
self.token_embedding = text.token_embedding
|
208 |
+
self.positional_embedding = text.positional_embedding
|
209 |
+
self.ln_final = text.ln_final
|
210 |
+
self.text_projection = text.text_projection
|
211 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
212 |
+
|
213 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
214 |
+
|
215 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
216 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
217 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
218 |
+
|
219 |
+
@torch.jit.ignore
|
220 |
+
def set_grad_checkpointing(self, enable=True):
|
221 |
+
self.visual.set_grad_checkpointing(enable)
|
222 |
+
self.transformer.grad_checkpointing = enable
|
223 |
+
|
224 |
+
def encode_image(self, image, normalize: bool = False):
|
225 |
+
features = self.visual(image)
|
226 |
+
return F.normalize(features, dim=-1) if normalize else features
|
227 |
+
|
228 |
+
def encode_text(self, text, normalize: bool = False):
|
229 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
230 |
+
|
231 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
232 |
+
|
233 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
234 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
235 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
236 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
237 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
238 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
239 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
240 |
+
return F.normalize(x, dim=-1) if normalize else x
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
image: Optional[torch.Tensor] = None,
|
245 |
+
text: Optional[torch.Tensor] = None,
|
246 |
+
):
|
247 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
248 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
249 |
+
if self.output_dict:
|
250 |
+
return {
|
251 |
+
"image_features": image_features,
|
252 |
+
"text_features": text_features,
|
253 |
+
"logit_scale": self.logit_scale.exp()
|
254 |
+
}
|
255 |
+
return image_features, text_features, self.logit_scale.exp()
|
256 |
+
|
257 |
+
|
258 |
+
class CustomTextCLIP(nn.Module):
|
259 |
+
output_dict: torch.jit.Final[bool]
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
embed_dim: int,
|
264 |
+
vision_cfg: CLIPVisionCfg,
|
265 |
+
text_cfg: CLIPTextCfg,
|
266 |
+
quick_gelu: bool = False,
|
267 |
+
cast_dtype: Optional[torch.dtype] = None,
|
268 |
+
output_dict: bool = False,
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
self.output_dict = output_dict
|
272 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
273 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
274 |
+
self.context_length = self.text.context_length
|
275 |
+
self.vocab_size = self.text.vocab_size
|
276 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
277 |
+
|
278 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
279 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
280 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
281 |
+
|
282 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
283 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
284 |
+
|
285 |
+
@torch.jit.ignore
|
286 |
+
def set_grad_checkpointing(self, enable=True):
|
287 |
+
self.visual.set_grad_checkpointing(enable)
|
288 |
+
self.text.set_grad_checkpointing(enable)
|
289 |
+
|
290 |
+
def encode_image(self, image, normalize: bool = False):
|
291 |
+
features = self.visual(image)
|
292 |
+
return F.normalize(features, dim=-1) if normalize else features
|
293 |
+
|
294 |
+
def encode_text(self, text, normalize: bool = False):
|
295 |
+
features = self.text(text)
|
296 |
+
return F.normalize(features, dim=-1) if normalize else features
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
image: Optional[torch.Tensor] = None,
|
301 |
+
text: Optional[torch.Tensor] = None,
|
302 |
+
):
|
303 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
304 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
305 |
+
if self.output_dict:
|
306 |
+
return {
|
307 |
+
"image_features": image_features,
|
308 |
+
"text_features": text_features,
|
309 |
+
"logit_scale": self.logit_scale.exp()
|
310 |
+
}
|
311 |
+
return image_features, text_features, self.logit_scale.exp()
|
312 |
+
|
313 |
+
|
314 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
315 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
316 |
+
|
317 |
+
def _convert_weights(l):
|
318 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
319 |
+
l.weight.data = l.weight.data.to(dtype)
|
320 |
+
if l.bias is not None:
|
321 |
+
l.bias.data = l.bias.data.to(dtype)
|
322 |
+
|
323 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
324 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
325 |
+
tensor = getattr(l, attr)
|
326 |
+
if tensor is not None:
|
327 |
+
tensor.data = tensor.data.to(dtype)
|
328 |
+
|
329 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
330 |
+
# convert text nn.Parameter projections
|
331 |
+
attr = getattr(l, "text_projection", None)
|
332 |
+
if attr is not None:
|
333 |
+
attr.data = attr.data.to(dtype)
|
334 |
+
|
335 |
+
if isinstance(l, VisionTransformer):
|
336 |
+
# convert vision nn.Parameter projections
|
337 |
+
attr = getattr(l, "proj", None)
|
338 |
+
if attr is not None:
|
339 |
+
attr.data = attr.data.to(dtype)
|
340 |
+
|
341 |
+
model.apply(_convert_weights)
|
342 |
+
|
343 |
+
|
344 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
345 |
+
|
346 |
+
|
347 |
+
# used to maintain checkpoint compatibility
|
348 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
349 |
+
if 'text_projection' in state_dict:
|
350 |
+
# old format state_dict, move text tower -> .text
|
351 |
+
new_state_dict = {}
|
352 |
+
for k, v in state_dict.items():
|
353 |
+
if any(k.startswith(p) for p in (
|
354 |
+
'text_projection',
|
355 |
+
'positional_embedding',
|
356 |
+
'token_embedding',
|
357 |
+
'transformer',
|
358 |
+
'ln_final',
|
359 |
+
)):
|
360 |
+
k = 'text.' + k
|
361 |
+
new_state_dict[k] = v
|
362 |
+
return new_state_dict
|
363 |
+
return state_dict
|
364 |
+
|
365 |
+
|
366 |
+
def build_model_from_openai_state_dict(
|
367 |
+
state_dict: dict,
|
368 |
+
quick_gelu=True,
|
369 |
+
cast_dtype=torch.float16,
|
370 |
+
):
|
371 |
+
vit = "visual.proj" in state_dict
|
372 |
+
|
373 |
+
if vit:
|
374 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
375 |
+
vision_layers = len(
|
376 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
377 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
378 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
379 |
+
image_size = vision_patch_size * grid_size
|
380 |
+
else:
|
381 |
+
counts: list = [
|
382 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
383 |
+
vision_layers = tuple(counts)
|
384 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
385 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
386 |
+
vision_patch_size = None
|
387 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
388 |
+
image_size = output_width * 32
|
389 |
+
|
390 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
391 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
392 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
393 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
394 |
+
transformer_heads = transformer_width // 64
|
395 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
396 |
+
|
397 |
+
vision_cfg = CLIPVisionCfg(
|
398 |
+
layers=vision_layers,
|
399 |
+
width=vision_width,
|
400 |
+
patch_size=vision_patch_size,
|
401 |
+
image_size=image_size,
|
402 |
+
)
|
403 |
+
text_cfg = CLIPTextCfg(
|
404 |
+
context_length=context_length,
|
405 |
+
vocab_size=vocab_size,
|
406 |
+
width=transformer_width,
|
407 |
+
heads=transformer_heads,
|
408 |
+
layers=transformer_layers,
|
409 |
+
)
|
410 |
+
model = CLIP(
|
411 |
+
embed_dim,
|
412 |
+
vision_cfg=vision_cfg,
|
413 |
+
text_cfg=text_cfg,
|
414 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
415 |
+
cast_dtype=cast_dtype,
|
416 |
+
)
|
417 |
+
|
418 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
419 |
+
state_dict.pop(key, None)
|
420 |
+
|
421 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
422 |
+
model.load_state_dict(state_dict)
|
423 |
+
return model.eval()
|
424 |
+
|
425 |
+
|
426 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
427 |
+
model.eval()
|
428 |
+
image_size = model.visual.image_size
|
429 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
430 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
431 |
+
model = torch.jit.trace_module(
|
432 |
+
model,
|
433 |
+
inputs=dict(
|
434 |
+
forward=(example_images, example_text),
|
435 |
+
encode_text=(example_text,),
|
436 |
+
encode_image=(example_images,)
|
437 |
+
))
|
438 |
+
model.visual.image_size = image_size
|
439 |
+
return model
|
440 |
+
|
441 |
+
|
442 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
443 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
444 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
445 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
446 |
+
return
|
447 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
448 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
449 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
450 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
451 |
+
return
|
452 |
+
|
453 |
+
if extra_tokens:
|
454 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
455 |
+
else:
|
456 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
457 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
458 |
+
|
459 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
460 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
461 |
+
pos_emb_img = F.interpolate(
|
462 |
+
pos_emb_img,
|
463 |
+
size=grid_size,
|
464 |
+
mode=interpolation,
|
465 |
+
antialias=antialias,
|
466 |
+
align_corners=False,
|
467 |
+
)
|
468 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
469 |
+
if pos_emb_tok is not None:
|
470 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
471 |
+
else:
|
472 |
+
new_pos_embed = pos_emb_img
|
473 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
ext/open_clip/model_configs/EVA01-g-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA01-g-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-B-16.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_base_patch16_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-E-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1280,
|
14 |
+
"heads": 20,
|
15 |
+
"layers": 32
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-E-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-L-14-336.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_336",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/EVA02-L-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
ext/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
ext/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
ext/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
ext/open_clip/model_configs/ViT-B-16-plus-240.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 240,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-16-plus.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 256,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": 12,
|
7 |
+
"width": 768,
|
8 |
+
"patch_size": 32
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
}
|
17 |
+
}
|
ext/open_clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
ext/open_clip/model_configs/ViT-H-14.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 14
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|
ext/open_clip/model_configs/ViT-H-16.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 16
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|