diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e7147d9df2efd3675b8c80f5675ed35089916467 100644 --- a/.gitattributes +++ b/.gitattributes @@ -17,10 +17,6 @@ *.ot filter=lfs diff=lfs merge=lfs -text *.parquet filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text *.rar filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text saved_model/**/* filter=lfs diff=lfs merge=lfs -text @@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 0c63a6656b3e4975c6012ef62b91c1b295bad1fe..85f6a4cb2726537f192f16f78bda0889d572d648 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,8 @@ emoji: 🌍 colorFrom: purple colorTo: purple sdk: gradio -sdk_version: 4.13.0 -app_file: app.py +sdk_version: 4.7.1 +app_file: main.py pinned: false -license: mit +python_version: 3.10 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app/configs/rap_sam_r50_12e_adaptor.py b/app/configs/rap_sam_r50_12e_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..d31128b7bf91b0d49bd7377eb72a39dba33ac39f --- /dev/null +++ b/app/configs/rap_sam_r50_12e_adaptor.py @@ -0,0 +1,88 @@ +from mmdet.models import ResNet, MaskFormerFusionHead, CrossEntropyLoss, DiceLoss + +from app.models.detectors import YOSOVideoSam +from app.models.heads import RapSAMVideoHead +from app.models.necks import YOSONeck + +num_things_classes = 80 +num_stuff_classes = 53 +ov_model_name = 'convnext_large_d_320' +ov_datasets_name = 'CocoPanopticOVDataset' +num_classes = num_things_classes + num_stuff_classes +model = dict( + type=YOSOVideoSam, + data_preprocessor=None, + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + init_cfg=None, + ), + neck=dict( + type=YOSONeck, + agg_dim=128, + hidden_dim=256, + backbone_shape=[256, 512, 1024, 2048], + ), + panoptic_head=dict( + type=RapSAMVideoHead, + prompt_with_kernel_updator=False, + panoptic_with_kernel_updator=True, + use_adaptor=True, + use_kernel_updator=True, + sphere_cls=True, + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_stages=3, + feat_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0) + ), + panoptic_fusion_head=dict( + type=MaskFormerFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None + ), + train_cfg=None, + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=dict( + type='Pretrained', + checkpoint='models/rapsam_r50_12e.pth' + ) +) diff --git a/app/models/detectors/__init__.py b/app/models/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..447f33d6db3e446b8e1ede80de1957b6753702b6 --- /dev/null +++ b/app/models/detectors/__init__.py @@ -0,0 +1 @@ +from .rapsam import YOSOVideoSam diff --git a/app/models/detectors/mask2former_vid.py b/app/models/detectors/mask2former_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..d71bd56128a615cffdc1ae63ed35671fba96a81f --- /dev/null +++ b/app/models/detectors/mask2former_vid.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor +import torch.nn.functional as F + +from mmdet.registry import MODELS +from mmdet.structures import SampleList, OptSampleList, TrackDataSample +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmdet.models.detectors.single_stage import SingleStageDetector + +from app.models.utils import mask_pool + + +@MODELS.register_module() +class Mask2formerVideo(SingleStageDetector): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + `_.""" + OVERLAPPING = None + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + inference_sam: bool = False, + init_cfg: OptMultiConfig = None + ): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.alpha = 0.4 + self.beta = 0.8 + + self.inference_sam = inference_sam + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + num_frames = 0 + bs = batch_inputs.shape[0] + feats = self.extract_feat(batch_inputs) + + mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples) + + if self.inference_sam: + for i, data_sample in enumerate(batch_data_samples): + meta = data_sample.metainfo + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_results[i][:, :img_height, :img_width] + mask_pred_result = mask_pred_result.view(-1, img_height, img_width) > 0 + all_pred_instances = InstanceData(masks=mask_pred_result) + batch_data_samples[i].pred_instances = all_pred_instances + + return batch_data_samples + + if self.OVERLAPPING is not None: + assert len(self.OVERLAPPING) == self.num_classes + mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) + + if num_frames > 0: + for frame_id in range(num_frames): + results_list_img = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results[:, :, frame_id], + [batch_data_samples[idx][frame_id] for idx in range(bs)], + rescale=rescale + ) + _ = self.add_track_pred_to_datasample( + [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img + ) + results = batch_data_samples + else: + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + iou_results=iou_results, + rescale=rescale + ) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + + return results + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[dict]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (List[dict]): Instance segmentation, segmantic + segmentation and panoptic segmentation results. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + data_sample.pred_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + data_sample.pred_instances = pred_results['ins_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def add_track_pred_to_datasample(self, data_samples: SampleList, results_list: List[dict]) -> SampleList: + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + assert self.num_stuff_classes > 0 + data_sample.pred_track_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + bboxes = pred_results['ins_results']['bboxes'] + labels = pred_results['ins_results']['labels'] + track_ids = torch.arange(len(bboxes), dtype=labels.dtype, device=bboxes.device) + 1 + pred_results['ins_results']['instances_id'] = track_ids + data_sample.pred_track_instances = pred_results['ins_results'] + + if 'pro_results' in pred_results: + data_sample.pred_track_proposal = pred_results['pro_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + tuple[List[Tensor]]: A tuple of features from ``panoptic_head`` + forward. + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + feats = self.extract_feat(batch_inputs) + results = self.panoptic_head.forward(feats, batch_data_samples) + return results + + def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): + if len(mask_pred_results.shape) == 5: + batch_size = mask_cls_results.shape[0] + num_frames = mask_pred_results.shape[2] + mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) + else: + batch_size = mask_cls_results.shape[0] + num_frames = 0 + clip_feat = self.backbone.get_clip_feature(feats[-1]) + clip_feat_mask = F.interpolate( + mask_pred_results, + size=clip_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + if num_frames > 0: + clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + instance_feat = mask_pool(clip_feat, clip_feat_mask) + instance_feat = self.backbone.forward_feat(instance_feat) + clip_logit = self.panoptic_head.forward_logit(instance_feat) + clip_logit = clip_logit[..., :-1] + query_logit = mask_cls_results[..., :-1] + + clip_logit = clip_logit.softmax(-1) + query_logit = query_logit.softmax(-1) + overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) + + valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( + torch.float32)[..., None] + alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking + beta = torch.ones_like(clip_logit) * self.beta * valid_masking + + cls_logits_seen = ( + (query_logit ** (1 - alpha) * clip_logit ** alpha).log() + * overlapping_mask + ) + cls_logits_unseen = ( + (query_logit ** (1 - beta) * clip_logit ** beta).log() + * (1 - overlapping_mask) + ) + cls_results = cls_logits_seen + cls_logits_unseen + is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] + mask_cls_results = torch.cat([ + cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) + mask_cls_results = torch.log(mask_cls_results + 1e-8) + return mask_cls_results diff --git a/app/models/detectors/rapsam.py b/app/models/detectors/rapsam.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe72ed18ffe0aa82305c69b7f5c8eadc8174449 --- /dev/null +++ b/app/models/detectors/rapsam.py @@ -0,0 +1,66 @@ +from mmdet.registry import MODELS +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + +from mmdet.models.detectors import SingleStageDetector + +from .mask2former_vid import Mask2formerVideo + + +@MODELS.register_module() +class YOSOVideoSam(Mask2formerVideo): + OVERLAPPING = None + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + inference_sam: bool = False, + init_cfg: OptMultiConfig = None + ): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.alpha = 0.4 + self.beta = 0.8 + + self.inference_sam = inference_sam + + def predict_with_point(self, x, batch_data_samples): + feats = self.extract_feat(x) + mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples) + + if 'gt_instances_collected' not in batch_data_samples[0]: + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + iou_results=iou_results, + rescale=False + ) + mask_pred_results = results_list[0]['pan_results'].sem_seg[None] + mask_cls_results = mask_cls_results + + return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy() diff --git a/app/models/heads/__init__.py b/app/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66b90fcf8403bfbdac0515602aad5c639b512f92 --- /dev/null +++ b/app/models/heads/__init__.py @@ -0,0 +1 @@ +from .rapsam_head import RapSAMVideoHead diff --git a/app/models/heads/mask2former_vid.py b/app/models/heads/mask2former_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..e48c4111e7022a588e4ff815cd3ec7702ffddb00 --- /dev/null +++ b/app/models/heads/mask2former_vid.py @@ -0,0 +1,616 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from mmcv.cnn import Conv2d +from mmdet.models import Mask2FormerTransformerDecoder +from mmengine.dist import get_dist_info +from mmengine.model import caffe2_xavier_init, ModuleList +from torch import Tensor +from mmdet.models.layers import MLP, inverse_sigmoid +from mmdet.models.layers import coordinate_to_encoding +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList, TrackDataSample +from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig) +from mmdet.models.layers import SinePositionalEncoding3D +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead +from mmcv.cnn.bricks.transformer import MultiheadAttention +from app.models.utils import mask_pool + + +@MODELS.register_module() +class Mask2FormerVideoHead(AnchorFreeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + Mask2Former head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + Mask2Former head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_mask_tokens: int = 1, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + num_transformer_feat_level: int = 3, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = None, + loss_cls: ConfigType = None, + loss_mask: ConfigType = None, + loss_dice: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + # ov configs + sphere_cls: bool = False, + ov_classifier_name: Optional[str] = None, + logit: Optional[int] = None, + use_adaptor = False, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.use_adaptor = use_adaptor + + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels) + self.pb_embedding = nn.Embedding(2, feat_channels) + self.pos_linear = nn.Linear(2 * feat_channels, feat_channels) + + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + # assert pixel_decoder.encoder.layer_cfg. \ + # self_attn_cfg.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = MODELS.build(pixel_decoder_) + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = SinePositionalEncoding3D( + **positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + if not sphere_cls: + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.iou_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, 1)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + # prepare OV things + # OV cls embed + if sphere_cls: + rank, world_size = get_dist_info() + if ov_classifier_name is None: + _dim = 1024 # temporally hard code + cls_embed = torch.empty(self.num_classes, _dim) + torch.nn.init.orthogonal_(cls_embed) + cls_embed = cls_embed[:, None] + else: + # ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") + ov_path = os.path.join('./models/', f"{ov_classifier_name}.pth") + cls_embed = torch.load(ov_path) + cls_embed_norm = cls_embed.norm(p=2, dim=-1) + assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) + if self.loss_cls and self.loss_cls.use_sigmoid: + pass + else: + _dim = cls_embed.size(2) + _prototypes = cls_embed.size(1) + + if rank == 0: + back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') + # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True) + else: + back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') + if world_size > 1: + dist.broadcast(back_token, src=0) + back_token = back_token.to(device='cpu') + cls_embed = torch.cat([ + cls_embed, back_token.repeat(_prototypes, 1)[None] + ], dim=0) + self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) + + # cls embd proj + cls_embed_dim = self.cls_embed.size(0) + self.cls_proj = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, cls_embed_dim) + ) + + # Haobo Yuan: + # For the logit_scale, I refer to this issue. + # https://github.com/openai/CLIP/issues/46#issuecomment-945062212 + # https://github.com/openai/CLIP/issues/46#issuecomment-782558799 + # Based on my understanding, it is a mistake of CLIP. + # Because they mention that they refer to InstDisc (Wu, 2018) paper. + # InstDisc set a non-learnable temperature to np.log(1 / 0.07). + # 4.6052 is np.log(1 / 0.01) + # np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01) + if logit is None: + logit_scale = torch.tensor(4.6052, dtype=torch.float32) + else: + logit_scale = torch.tensor(logit, dtype=torch.float32) + self.register_buffer('logit_scale', logit_scale, persistent=False) + + # Mask Pooling + self.mask_pooling = mask_pool + self.mask_pooling_proj = nn.Sequential( + nn.LayerNorm(feat_channels), + nn.Linear(feat_channels, feat_channels) + ) + + if use_adaptor: + cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8) + self.panoptic_attn = MultiheadAttention(**cross_attn_cfg) + self.panoptic_norm = nn.LayerNorm(256) + if sphere_cls: + cls_embed_dim = self.cls_embed.size(0) + self.panoptic_cls = nn.Sequential( + nn.Linear(feat_channels, cls_embed_dim) + ) + else: + raise NotImplementedError + self.prompt_attn = MultiheadAttention(**cross_attn_cfg) + self.prompt_norm = nn.LayerNorm(256) + self.prompt_iou = nn.Linear(256, 1) + + def init_weights(self) -> None: + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def forward_logit(self, cls_embd): + cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) + cls_pred = cls_pred.max(-1).values + cls_pred = self.logit_scale.exp() * cls_pred + return cls_pred + + def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int], + num_frames: int = 0) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + - num_frames: How many frames are there in video. + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + # shape (num_queries, batch_size, c) + if isinstance(self.cls_embed, nn.Module): + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + + if not isinstance(self.cls_embed, nn.Module): + maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach()) + maskpool_embd = self.mask_pooling_proj(maskpool_embd) + cls_embd = self.cls_proj(maskpool_embd + decoder_out) + cls_pred = self.forward_logit(cls_embd) + + iou_pred = self.iou_embed(decoder_out) + + if num_frames > 0: + assert len(mask_pred.shape) == 4 + assert mask_pred.shape[2] % num_frames == 0 + frame_h = mask_pred.shape[2] // num_frames + num_q = mask_pred.shape[1] + _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) + attn_mask = F.interpolate( + _mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) + else: + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, iou_pred, attn_mask + + def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_img_metas = [] + if isinstance(batch_data_samples[0], TrackDataSample): + for track_sample in batch_data_samples: + cur_list = [] + for det_sample in track_sample: + cur_list.append(det_sample.metainfo) + batch_img_metas.append(cur_list) + num_frames = len(batch_img_metas[0]) + else: + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + num_frames = 0 + batch_size = len(batch_img_metas) + #(bs_nf, c, h,w) + mask_features, multi_scale_memorys = self.pixel_decoder(x) + if num_frames > 0: + mask_features = mask_features.unflatten(0, (batch_size, num_frames)) + mask_features = mask_features.transpose(1, 2).flatten(2, 3) #(bs, c, nf*h,w) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) #(bs_nf, c, h,w) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) #(bs_nf,h*w, c) + if num_frames > 0: + decoder_input = decoder_input.unflatten(0, (batch_size, num_frames)) + decoder_input = decoder_input.flatten(1, 2) #(bs, nf*h*w, c) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 if num_frames == 0 else num_frames + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) #(bs, nf*h*w, c) + decoder_positional_encodings.append(decoder_positional_encoding) #(bs, nf*h*w, c) + + if self.prompt_training: + query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo( + batch_data_samples) + query_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) + query_embed = self.pos_linear(query_embed) + else: + query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + self_attn_mask = None + + cls_pred_list = [] + mask_pred_list = [] + iou_pred_list = [] + cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames + ) + cls_pred_list.append(cls_pred) + iou_pred_list.append(iou_pred) + if num_frames > 0: #(bs, 100, nf*h, w)-->(bs, 100, nf, h, w) + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + query_feat = layer( + query=query_feat, #(bs, 100, c) + key=decoder_inputs[level_idx], #(bs, nf*h*w, c) + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, iou_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames + ) + + cls_pred_list.append(cls_pred) + iou_pred_list.append(iou_pred) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + if self.use_adaptor: + keys = mask_features.flatten(2).transpose(1, 2).contiguous() + h, w = mask_features.shape[-2] // num_frames_real, mask_features.shape[-1] + mask = decoder_input.new_zeros((batch_size, num_frames_real, h, w), dtype=torch.bool) + key_pos = self.decoder_positional_encoding(mask) + key_pos = key_pos.transpose(1, 2).flatten(2).permute(0, 2, 1) + if not self.prompt_training: + object_kernels = self.panoptic_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed) + object_kernels = self.panoptic_norm(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + + cls_embd = self.panoptic_cls(object_kernels) + cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) + cls_scores = cls_scores.max(-1).values + cls_scores = self.logit_scale.exp() * cls_scores + + if num_frames > 0: + mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1))) + else: + mask_pred_list.append(mask_preds) + cls_pred_list.append(cls_scores) + iou_pred_list.append(iou_pred_list[-1]) + else: + object_kernels = self.prompt_attn(query_feat, keys, key_pos=key_pos, query_pos=query_embed) + object_kernels = self.prompt_norm(object_kernels) + iou_preds = self.prompt_iou(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + + if num_frames > 0: + mask_pred_list.append(mask_preds.unflatten(2, (num_frames, -1))) + else: + mask_pred_list.append(mask_preds) + cls_pred_list.append(cls_pred_list[-1]) + iou_pred_list.append(iou_preds) + + return cls_pred_list, mask_pred_list, iou_pred_list, query_feat + + def predict(self, x: Tuple[Tensor], + batch_data_samples: SampleList, + return_query=False, + ) -> Tuple[Tensor, ...]: + """Test without augmentaton. + + Args: + return_query: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + self.prompt_training = False + data_sample = batch_data_samples[0] + if isinstance(data_sample, TrackDataSample): + img_shape = data_sample[0].metainfo['batch_input_shape'] + num_frames = len(data_sample) + else: + if 'gt_instances_collected' in data_sample: + self.prompt_training = True + img_shape = data_sample.metainfo['batch_input_shape'] + num_frames = 0 + all_cls_scores, all_mask_preds, all_iou_preds, query_feat = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + iou_results = all_iou_preds[-1] + + if num_frames > 0: + mask_pred_results = mask_pred_results.flatten(1, 2) + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + if num_frames > 0: + num_queries = mask_cls_results.shape[1] + mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames)) + + if return_query: + return mask_cls_results, mask_pred_results, query_feat, iou_results + else: + return mask_cls_results, mask_pred_results, iou_results + + def prepare_for_dn_mo(self, batch_data_samples): + scalar, noise_scale = 100, 0.4 + gt_instances = [t.gt_instances_collected for t in batch_data_samples] + + point_coords = torch.stack([inst.point_coords for inst in gt_instances]) + pb_labels = torch.stack([inst['pb_labels'] for inst in gt_instances]) + labels = torch.zeros_like(pb_labels).long() + + boxes = point_coords # + boxes + + factors = [] + for i, data_sample in enumerate(batch_data_samples): + h, w, = data_sample.metainfo['img_shape'] + factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1) + factors.append(factor) + factors = torch.stack(factors, 0) + + boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ???? + # box_start = [t['box_start'] for t in targets] + box_start = [len(point) for point in point_coords] + + known_labels = labels + known_pb_labels = pb_labels + known_bboxs = boxes + + known_labels_expaned = known_labels.clone() + known_pb_labels_expaned = known_pb_labels.clone() + known_bbox_expand = known_bboxs.clone() + + if noise_scale > 0 and self.training: + diff = torch.zeros_like(known_bbox_expand) + diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 + diff[:, :, 2:] = known_bbox_expand[:, :, 2:] + # add very small noise to input points; no box + sc = 0.01 + for i, st in enumerate(box_start): + diff[i, :st] = diff[i, :st] * sc + known_bbox_expand += torch.mul( + (torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff) * noise_scale + + known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) + + input_label_embed = self.pb_embedding(known_pb_labels_expaned) + + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + + input_label_embed = input_label_embed.repeat_interleave( + self.num_mask_tokens, + 1) + self.mask_tokens.weight.unsqueeze(0).repeat( + input_label_embed.shape[0], input_label_embed.shape[1], 1) + input_bbox_embed = input_bbox_embed.repeat_interleave( + self.num_mask_tokens, 1) + + single_pad = self.num_mask_tokens + + # NOTE scalar is modified to 100, each click cannot see each other + scalar = int(input_label_embed.shape[1] / self.num_mask_tokens) + + pad_size = input_label_embed.shape[1] + + if input_label_embed.shape[1] > 0: + input_query_label = input_label_embed + input_query_bbox = input_bbox_embed + + tgt_size = pad_size + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + mask_dict = { + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'pad_size': pad_size, + 'scalar': scalar, + } + return input_query_label, input_query_bbox, attn_mask, mask_dict \ No newline at end of file diff --git a/app/models/heads/rapsam_head.py b/app/models/heads/rapsam_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed45041e7687455c7fd0f4e0430ac9be2613e37 --- /dev/null +++ b/app/models/heads/rapsam_head.py @@ -0,0 +1,227 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from mmdet.models.layers import coordinate_to_encoding + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList, TrackDataSample +from mmdet.utils import (ConfigType, OptConfigType, OptMultiConfig) +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead + +from mmcv.cnn.bricks.transformer import MultiheadAttention + +from .mask2former_vid import Mask2FormerVideoHead +from .yoso_head import CrossAttenHead, KernelUpdator + +@MODELS.register_module() +class RapSAMVideoHead(Mask2FormerVideoHead): + + def __init__(self, + frozen_head=False, + frozen_pred=False, + use_adaptor=False, + prompt_with_kernel_updator=False, + panoptic_with_kernel_updator=False, + num_mask_tokens = 1, + num_stages = 3, + use_kernel_updator=False, + sphere_cls = False, + ov_classifier_name = None, + temperature=0.1, + feat_channels=256, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + loss_cls: ConfigType = None, + loss_mask: ConfigType = None, + loss_dice: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + matching_whole_map: bool = False, + enable_box_query: bool = False, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.prompt_with_kernel_updator = prompt_with_kernel_updator + self.panoptic_with_kernel_updator = panoptic_with_kernel_updator + self.use_adaptor = use_adaptor + + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.Embedding(num_mask_tokens, feat_channels) + self.pb_embedding = nn.Embedding(2, feat_channels) + self.pos_linear = nn.Linear(2 * feat_channels, feat_channels) + + self.matching_whole_map = matching_whole_map + self.enable_box_query = enable_box_query + + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.feat_channels = feat_channels + self.num_stages = num_stages + self.kernels = nn.Embedding(self.num_queries, feat_channels) + self.mask_heads = nn.ModuleList() + for _ in range(self.num_stages): + self.mask_heads.append(CrossAttenHead( + self.num_classes, self.feat_channels, self.num_queries, + use_kernel_updator=use_kernel_updator, + frozen_head=frozen_head, frozen_pred=frozen_pred, + sphere_cls=sphere_cls, + ov_classifier_name=ov_classifier_name, with_iou_pred=True)) + self.temperature = temperature + + if use_adaptor: + cross_attn_cfg = dict(embed_dims=256, batch_first=True, num_heads=8) + if self.panoptic_with_kernel_updator: + self.panoptic_attn = KernelUpdator(feat_channels=256) + self.panoptic_norm = nn.Identity() + if sphere_cls: + cls_embed_dim = self.mask_heads[0].fc_cls.size(0) + self.panoptic_cls = nn.Sequential( + nn.Linear(feat_channels, cls_embed_dim) + ) + else: + raise NotImplementedError + self.panoptic_cls = nn.Linear(256, self.num_classes+1) + else: + self.panoptic_attn = MultiheadAttention(**cross_attn_cfg) + self.panoptic_norm = nn.LayerNorm(256) + if sphere_cls: + cls_embed_dim = self.mask_heads[0].fc_cls.size(0) + self.panoptic_cls = nn.Sequential( + nn.Linear(feat_channels, cls_embed_dim) + ) + else: + raise NotImplementedError + self.panoptic_cls = nn.Linear(256, self.num_classes+1) + + if self.prompt_with_kernel_updator: + self.prompt_attn = KernelUpdator(feat_channels=256) + self.prompt_norm = nn.Identity() + self.prompt_iou = nn.Linear(256, 1) + else: + self.prompt_attn = MultiheadAttention(**cross_attn_cfg) + self.prompt_norm = nn.LayerNorm(256) + self.prompt_iou = nn.Linear(256, 1) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + + def init_weights(self) -> None: + pass + + def forward(self, x, batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + batch_img_metas = [] + if isinstance(batch_data_samples[0], TrackDataSample): + for track_sample in batch_data_samples: + cur_list = [] + for det_sample in track_sample: + cur_list.append(det_sample.metainfo) + batch_img_metas.append(cur_list) + num_frames = len(batch_img_metas[0]) + else: + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + num_frames = 0 + bs = len(batch_img_metas) + + all_cls_scores = [] + all_masks_preds = [] + all_iou_preds = [] + if self.prompt_training: + input_query_label, input_query_bbox, self_attn_mask, mask_dict = self.prepare_for_dn_mo( + batch_data_samples) + pos_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) + pos_embed = self.pos_linear(pos_embed) + object_kernels = input_query_label + pos_embed + else: + object_kernels = self.kernels.weight[None].repeat(bs, 1, 1) + self_attn_mask = None + mask_features = x + if num_frames > 0: # (bs*num_frames, c, h, w) -> (bs, c, num_frames*h, w) + mask_features = mask_features.unflatten(0, (bs, num_frames)) + mask_features = mask_features.transpose(1, 2).flatten(2, 3) + + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + for stage in range(self.num_stages): + mask_head = self.mask_heads[stage] + cls_scores, mask_preds, iou_preds, object_kernels = mask_head( + mask_features, object_kernels, mask_preds, self_attn_mask) + cls_scores = cls_scores / self.temperature + all_iou_preds.append(iou_preds) + all_cls_scores.append(cls_scores) + if num_frames > 0: + #(bs,num_query, num_frames*h, w) --> (bs,num_query,num_frames,h,w) + all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1))) + else: + all_masks_preds.append(mask_preds) + + if self.use_adaptor: + keys = mask_features.flatten(2).transpose(1, 2).contiguous() + if not self.prompt_training: + if self.panoptic_with_kernel_updator: + hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float() + f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features) + object_kernels = self.panoptic_attn(f, object_kernels) + object_kernels = self.panoptic_norm(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + else: + object_kernels = self.panoptic_attn(object_kernels, keys) + object_kernels = self.panoptic_norm(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + cls_embd = self.panoptic_cls(object_kernels) + cls_scores = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.mask_heads[0].fc_cls) + cls_scores = cls_scores.max(-1).values + cls_scores = self.mask_heads[0].logit_scale.exp() * cls_scores + + if num_frames > 0: + all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1))) + else: + all_masks_preds.append(mask_preds) + all_cls_scores.append(cls_scores) + all_iou_preds.append(all_iou_preds[-1]) + else: + if self.prompt_with_kernel_updator: + hard_sigmoid_masks = (mask_preds.sigmoid() > 0.5).float() + f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, mask_features) + object_kernels = self.prompt_attn(f, object_kernels) + object_kernels = self.prompt_norm(object_kernels) + iou_preds = self.prompt_iou(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + else: + object_kernels = self.prompt_attn(object_kernels, keys) + object_kernels = self.prompt_norm(object_kernels) + iou_preds = self.prompt_iou(object_kernels) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, mask_features) + if num_frames > 0: + all_masks_preds.append(mask_preds.unflatten(2, (num_frames, -1))) + else: + all_masks_preds.append(mask_preds) + all_cls_scores.append(all_cls_scores[-1]) + all_iou_preds.append(iou_preds) + return all_cls_scores, all_masks_preds, all_iou_preds, object_kernels + + def get_targets(self, *args, **kwargs): + raise NotImplementedError + + def loss_by_feat(self, *args, **kwargs): + raise NotImplementedError diff --git a/app/models/heads/yoso_head.py b/app/models/heads/yoso_head.py new file mode 100644 index 0000000000000000000000000000000000000000..faee5120520b1b892a04dfb52af1f402fc0775f0 --- /dev/null +++ b/app/models/heads/yoso_head.py @@ -0,0 +1,531 @@ +from typing import List, Tuple +import os +import torch.distributed as dist +from torch import Tensor +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.models.dense_heads import AnchorFreeHead +from mmdet.structures import SampleList +from mmdet.models.dense_heads import Mask2FormerHead +import math +from mmengine.model.weight_init import trunc_normal_ +import torch +from torch import nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer + +from mmengine.dist import get_dist_info + + +@MODELS.register_module() +class YOSOHead(Mask2FormerHead): + def __init__(self, + num_cls_fcs=1, + num_mask_fcs=1, + sphere_cls=False, + ov_classifier_name=None, + use_kernel_updator=False, + num_stages=3, + feat_channels=256, + out_channels=256, + num_things_classes=80, + num_stuff_classes=53, + num_classes=133, + num_queries=100, + temperature=0.1, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 133 + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + train_cfg=None, + test_cfg=None, + init_cfg=None): + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.num_stages = num_stages + self.feat_channels = feat_channels + self.out_channels = out_channels + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = num_classes + self.num_queries = num_queries + self.temperature = temperature + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + self.kernels = nn.Embedding(self.num_queries, self.feat_channels) + + self.mask_heads = nn.ModuleList() + for _ in range(self.num_stages): + self.mask_heads.append(CrossAttenHead( + self.num_classes, self.feat_channels, self.num_queries, + use_kernel_updator=use_kernel_updator, + sphere_cls=sphere_cls, ov_classifier_name=ov_classifier_name, + num_cls_fcs=num_cls_fcs, num_mask_fcs=num_mask_fcs + )) + + def init_weights(self) -> None: + super(AnchorFreeHead, self).init_weights() + + def forward(self, x: List[Tensor], + batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + all_cls_scores = [] + all_masks_preds = [] + proposal_kernels = self.kernels.weight + object_kernels = proposal_kernels[None].repeat(x.shape[0], 1, 1) + mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, x) + + for stage in range(self.num_stages): + mask_head = self.mask_heads[stage] + cls_scores, mask_preds, iou_pred, object_kernels = mask_head(x, object_kernels, mask_preds) + cls_scores = cls_scores / self.temperature + + all_cls_scores.append(cls_scores) + all_masks_preds.append(mask_preds) + + return all_cls_scores, all_masks_preds + + def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> Tuple[Tensor]: + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + + return mask_cls_results, mask_pred_results + + +class FFN(nn.Module): + + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + add_identity=True): + super(FFN, self).__init__() + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append(nn.Sequential( + nn.Linear(in_channels, feedforward_channels), + nn.ReLU(True), + nn.Dropout(0.0))) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(0.0)) + self.layers = nn.Sequential(*layers) + self.add_identity = add_identity + self.dropout_layer = nn.Dropout(0.0) + + def forward(self, x, identity=None): + out = self.layers(x) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class DySepConvAtten(nn.Module): + def __init__(self, hidden_dim, num_proposals, conv_kernel_size_1d): + super(DySepConvAtten, self).__init__() + self.hidden_dim = hidden_dim + self.num_proposals = num_proposals + self.kernel_size = conv_kernel_size_1d + + self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size) + self.norm = nn.LayerNorm(self.hidden_dim) + + def forward(self, query, value): + assert query.shape == value.shape + B, N, C = query.shape + + dy_conv_weight = self.weight_linear(query) + dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B, self.num_proposals, 1, self.kernel_size) + dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B, self.num_proposals, self.num_proposals, + 1) + + res = [] + value = value.unsqueeze(1) + for i in range(B): + out = F.relu(F.conv1d(input=value[i], weight=dy_depth_conv_weight[i], groups=N, padding='same')) + out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same') + res.append(out) + + point_out = torch.cat(res, dim=0) + point_out = self.norm(point_out) + return point_out + + +class KernelUpdator(nn.Module): + + def __init__(self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=3, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')): + super(KernelUpdator, self).__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + if isinstance(input_feat_shape, int): + input_feat_shape = [input_feat_shape] * 2 + self.input_feat_shape = input_feat_shape + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """ + Args: + update_feature (torch.Tensor): [bs, num_proposals, in_channels] + input_feature (torch.Tensor): [bs, num_proposals, in_channels] + """ + bs, num_proposals, _ = update_feature.shape + + parameters = self.dynamic_layer(update_feature) + param_in = parameters[..., :self.num_params_in] + param_out = parameters[..., -self.num_params_out:] + + input_feats = self.input_layer(input_feature) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + gate_feats = input_in * param_in + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +class CrossAttenHead(nn.Module): + + def __init__(self, + num_classes, + in_channels, + num_proposals, + frozen_head=False, + frozen_pred=False, + with_iou_pred=False, + sphere_cls=False, + ov_classifier_name=None, + num_cls_fcs=1, + num_mask_fcs=1, + conv_kernel_size_1d=3, + conv_kernel_size_2d=1, + use_kernel_updator=False): + super(CrossAttenHead, self).__init__() + self.sphere_cls = sphere_cls + self.with_iou_pred = with_iou_pred + self.frozen_head = frozen_head + self.frozen_pred = frozen_pred + self.num_cls_fcs = num_cls_fcs + self.num_mask_fcs = num_mask_fcs + self.num_classes = num_classes + self.conv_kernel_size_2d = conv_kernel_size_2d + + self.hidden_dim = in_channels + self.feat_channels = in_channels + self.num_proposals = num_proposals + self.hard_mask_thr = 0.5 + self.use_kernel_updator = use_kernel_updator + # assert use_kernel_updator + if use_kernel_updator: + self.kernel_update = KernelUpdator( + in_channels=256, + feat_channels=256, + out_channels=256, + input_feat_shape=3, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN') + ) + else: + self.f_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) + self.f_dropout = nn.Dropout(0.0) + self.f_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) + self.k_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) + self.k_dropout = nn.Dropout(0.0) + self.k_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) + + self.s_atten = nn.MultiheadAttention(embed_dim=self.hidden_dim * + self.conv_kernel_size_2d ** 2, + num_heads=8, + dropout=0.0) + self.s_dropout = nn.Dropout(0.0) + self.s_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) + + self.ffn = FFN(self.hidden_dim, feedforward_channels=2048, num_fcs=2) + self.ffn_norm = nn.LayerNorm(self.hidden_dim) + + self.cls_fcs = nn.ModuleList() + for _ in range(self.num_cls_fcs): + self.cls_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) + self.cls_fcs.append(nn.LayerNorm(self.hidden_dim)) + self.cls_fcs.append(nn.ReLU(True)) + + if sphere_cls: + rank, world_size = get_dist_info() + if ov_classifier_name is None: + _dim = 1024 # temporally hard code + cls_embed = torch.empty(self.num_classes, _dim) + torch.nn.init.orthogonal_(cls_embed) + cls_embed = cls_embed[:, None] + else: + ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") + cls_embed = torch.load(ov_path) + cls_embed_norm = cls_embed.norm(p=2, dim=-1) + assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) + + # background class + _dim = cls_embed.size(2) + _prototypes = cls_embed.size(1) + if rank == 0: + back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') + else: + back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') + if world_size > 1: + dist.broadcast(back_token, src=0) + back_token = back_token.to(device='cpu') + cls_embed = torch.cat([ + cls_embed, back_token.repeat(_prototypes, 1)[None] + ], dim=0) + self.register_buffer('fc_cls', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) + + # cls embd proj + cls_embed_dim = self.fc_cls.size(0) + self.cls_proj = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, cls_embed_dim) + ) + + logit_scale = torch.tensor(4.6052, dtype=torch.float32) + self.register_buffer('logit_scale', logit_scale, persistent=False) + else: + self.fc_cls = nn.Linear(self.hidden_dim, self.num_classes + 1) + + self.mask_fcs = nn.ModuleList() + for _ in range(self.num_mask_fcs): + self.mask_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) + self.mask_fcs.append(nn.LayerNorm(self.hidden_dim)) + self.mask_fcs.append(nn.ReLU(True)) + self.fc_mask = nn.Linear(self.hidden_dim, self.hidden_dim) + + if self.with_iou_pred: + self.iou_embed = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(self.hidden_dim, 1), + ) + prior_prob = 0.01 + self.bias_value = -math.log((1 - prior_prob) / prior_prob) + + self.apply(self._init_weights) + if not sphere_cls: + nn.init.constant_(self.fc_cls.bias, self.bias_value) + + if self.frozen_head: + self._frozen_head() + if self.frozen_pred: + self._frozen_pred() + + def _init_weights(self, m): + # print("init weights") + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _frozen_head(self): + for n, p in self.kernel_update.named_parameters(): + p.requires_grad = False + for n, p in self.s_atten.named_parameters(): + p.requires_grad = False + for n, p in self.s_dropout.named_parameters(): + p.requires_grad = False + for n, p in self.s_atten_norm.named_parameters(): + p.requires_grad = False + for n, p in self.ffn.named_parameters(): + p.requires_grad = False + for n, p in self.ffn_norm.named_parameters(): + p.requires_grad = False + + def _frozen_pred(self): + # frozen cls_fcs, fc_cls, mask_fcs, fc_mask + for n, p in self.cls_fcs.named_parameters(): + p.requires_grad = False + for n, p in self.fc_cls.named_parameters(): + p.requires_grad = False + for n, p in self.mask_fcs.named_parameters(): + p.requires_grad = False + for n, p in self.fc_mask.named_parameters(): + p.requires_grad = False + + def train(self, mode): + super().train(mode) + if self.frozen_head: + self.kernel_update.eval() + self.s_atten.eval() + self.s_dropout.eval() + self.s_atten_norm.eval() + self.ffn.eval() + self.ffn_norm.eval() + if self.frozen_pred: + self.cls_fcs.eval() + self.fc_cls.eval() + self.mask_fcs.eval() + self.fc_mask.eval() + + def forward(self, features, proposal_kernels, mask_preds, self_attn_mask=None): + B, C, H, W = features.shape + + soft_sigmoid_masks = mask_preds.sigmoid() + nonzero_inds = soft_sigmoid_masks > self.hard_mask_thr + hard_sigmoid_masks = nonzero_inds.float() + + # [B, N, C] + f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, features) + # [B, N, C, K, K] -> [B, N, C * K * K] + num_proposals = proposal_kernels.shape[1] + k = proposal_kernels.view(B, num_proposals, -1) + + # ---- + if self.use_kernel_updator: + k = self.kernel_update(f, k) + else: + f_tmp = self.f_atten(k, f) + f = f + self.f_dropout(f_tmp) + f = self.f_atten_norm(f) + + f_tmp = self.k_atten(k, f) + f = f + self.k_dropout(f_tmp) + k = self.k_atten_norm(f) + + # [N, B, C] + k = k.permute(1, 0, 2) + + k_tmp = self.s_atten(query=k, key=k, value=k, attn_mask=self_attn_mask)[0] + k = k + self.s_dropout(k_tmp) + k = self.s_atten_norm(k.permute(1, 0, 2)) + + obj_feat = self.ffn_norm(self.ffn(k)) + + cls_feat = obj_feat + mask_feat = obj_feat + + for cls_layer in self.cls_fcs: + cls_feat = cls_layer(cls_feat) + + if self.sphere_cls: + cls_embd = self.cls_proj(cls_feat) # FIXME Too much cls linear (cls_fcs + cls_proj) + cls_score = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.fc_cls) + cls_score = cls_score.max(-1).values + cls_score = self.logit_scale.exp() * cls_score + else: + cls_score = self.fc_cls(cls_feat) + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + # [B, N, K * K, C] -> [B, N, C] + mask_kernels = self.fc_mask(mask_feat) + + new_mask_preds = torch.einsum("bqc,bchw->bqhw", mask_kernels, features) + if self.with_iou_pred: + iou_pred = self.iou_embed(mask_feat) + iou_pred = iou_pred + else: + iou_pred = None + return cls_score, new_mask_preds, iou_pred, obj_feat diff --git a/app/models/necks/__init__.py b/app/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1b5ccddb7c9c67e676471774b4ee85dccf84e3 --- /dev/null +++ b/app/models/necks/__init__.py @@ -0,0 +1 @@ +from .ramsam_neck import YOSONeck diff --git a/app/models/necks/ramsam_neck.py b/app/models/necks/ramsam_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..827e1c52156adf9f8aa97221f2a1724a4d742c37 --- /dev/null +++ b/app/models/necks/ramsam_neck.py @@ -0,0 +1,196 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from mmengine.model import kaiming_init +from mmdet.registry import MODELS +from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + + +class DeformLayer(nn.Module): + + def __init__(self, + in_planes, + out_planes, + deconv_kernel=4, + deconv_stride=2, + deconv_pad=1, + deconv_out_pad=0, + modulate_deform=True, + num_groups=1, + deform_num_groups=1, + dilation=1): + super(DeformLayer, self).__init__() + self.deform_modulated = modulate_deform + if modulate_deform: + deform_conv_op = ModulatedDeformConv2d + offset_channels = 27 + else: + deform_conv_op = DeformConv2d + offset_channels = 18 + + self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1 * dilation, dilation=dilation) + self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1 * dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups) + for layer in [self.dcn]: + kaiming_init(layer) + + nn.init.constant_(self.dcn_offset.weight, 0) + nn.init.constant_(self.dcn_offset.bias, 0) + + # nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) # + self.dcn_bn = nn.SyncBatchNorm(out_planes) + self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False) + self._deconv_init() + # nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) # + self.up_bn = nn.SyncBatchNorm(out_planes) + self.relu = nn.ReLU() + + def forward(self, x): + out = x + if self.deform_modulated: + offset_mask = self.dcn_offset(out) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + out = self.dcn(out, offset, mask) + else: + offset = self.dcn_offset(out) + out = self.dcn(out, offset) + x = out + + x = self.dcn_bn(x) + x = self.relu(x) + x = self.up_sample(x) + x = self.up_bn(x) + x = self.relu(x) + return x + + def _deconv_init(self): + w = self.up_sample.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + +class LiteDeformConv(nn.Module): + def __init__(self, agg_dim, backbone_shape): + super(LiteDeformConv, self).__init__() + in_channels = [] + out_channels = [agg_dim] + for feat in backbone_shape: + in_channels.append(feat) + out_channels.append(feat//2) + + self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0) + + self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2]) + + self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0) + + self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3]) + + self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0) + + self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4]) + + self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0) + + # self.fuse_conv = nn.Conv2d(in_channels=sum(out_channels[1:]), out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1) + self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1) + + self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True) + self.bias.data.fill_(0.0) + + self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) + self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) + self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) + self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) + + def forward(self, features_list): + p5 = self.lateral_conv0(features_list[-1]) + x5 = p5 + x = self.deform_conv1(x5) + + p4 = self.lateral_conv1(features_list[-2]) + x4 = p4 + x + x = self.deform_conv2(x4) + + p3 = self.lateral_conv2(features_list[-3]) + x3 = p3 + x + x = self.deform_conv3(x3) + + p2 = self.lateral_conv3(features_list[-4]) + x2 = p2 + x + + # CFA + x5 = self.conv_a5(x5) + x4 = self.conv_a4(x4) + x3 = self.conv_a3(x3) + + _x5 = F.interpolate(x5, scale_factor=8, align_corners=False, mode='bilinear') + _x4 = F.interpolate(x4, scale_factor=4, align_corners=False, mode='bilinear') + _x3 = F.interpolate(x3, scale_factor=2, align_corners=False, mode='bilinear') + x2 = self.conv_a2(x2) + x = _x5 + _x4 + _x3 + x2 + self.bias + + x = self.output_conv(x) + + return x, (x5, x4, x3) + + + +@MODELS.register_module() +class YOSONeck(nn.Module): + + def __init__(self, + agg_dim, + hidden_dim, + backbone_shape, + return_multi_scale=False, + return_single_scale=False, + #Just for compatibility with Mask2Former, not actually used + in_channels=None, + feat_channels=None, + out_channels=None + ): + super().__init__() + # in_channels == backbone_shape + # hidden_dim == feat_channels == out_channels == 256 + self.return_single_scale = return_single_scale + self.return_multi_scale = return_multi_scale + self.deconv = LiteDeformConv(agg_dim=agg_dim, backbone_shape=backbone_shape) + + self.loc_conv = nn.Conv2d(in_channels=agg_dim + 2, out_channels=hidden_dim, kernel_size=1, stride=1) + self.init_weights() + + def init_weights(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def generate_coord(self, input_feat): + x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device) + y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([input_feat.shape[0], 1, -1, -1]) + x = x.expand([input_feat.shape[0], 1, -1, -1]) + coord_feat = torch.cat([x, y], 1) + return coord_feat + + def forward(self, + features_list, + batch_img_metas = None, + num_frames = None): + features, multi_scale = self.deconv(features_list) + coord_feat = self.generate_coord(features) + features = torch.cat([features, coord_feat], 1) + features = self.loc_conv(features) + if self.return_single_scale: # maskformer + return features, multi_scale[0] + if self.return_multi_scale: # mask2former + return features, multi_scale + return features \ No newline at end of file diff --git a/app/models/utils/__init__.py b/app/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d113e32d37cb2ffe129336966fd09f369f8b0c2 --- /dev/null +++ b/app/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .video_gt_preprocess import preprocess_video_panoptic_gt +from .mask_pool import mask_pool +from .no_obj import NO_OBJ diff --git a/app/models/utils/load_checkpoint.py b/app/models/utils/load_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..12e165a3f58814d7f9b933a68857c68a09a0971b --- /dev/null +++ b/app/models/utils/load_checkpoint.py @@ -0,0 +1,38 @@ +from mmengine.runner.checkpoint import CheckpointLoader + + +def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. + Defaults to None. + logger: logger + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix: + return state_dict + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict diff --git a/app/models/utils/mask_pool.py b/app/models/utils/mask_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..06a315b50921703b60ba402a87f047a4e4033119 --- /dev/null +++ b/app/models/utils/mask_pool.py @@ -0,0 +1,27 @@ +import torch +import torch.nn.functional as F + + +# https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 + +def mask_pool(x, mask): + """ + Args: + x: [B, C, H, W] + mask: [B, Q, H, W] + """ + if not x.shape[-2:] == mask.shape[-2:]: + # reshape mask to x + mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) + with torch.no_grad(): + mask = mask.detach() + mask = (mask > 0).to(mask.dtype) + denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 + + mask_pooled_x = torch.einsum( + "bchw,bqhw->bqc", + x, + mask / denorm, + ) + return mask_pooled_x + diff --git a/app/models/utils/no_obj.py b/app/models/utils/no_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4788486b1589905399634f1f484063fb2eee15 --- /dev/null +++ b/app/models/utils/no_obj.py @@ -0,0 +1 @@ +NO_OBJ = 65535 diff --git a/app/models/utils/video_gt_preprocess.py b/app/models/utils/video_gt_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..83dfa2a433f04592e662e9191a6c06d3c64ad0dc --- /dev/null +++ b/app/models/utils/video_gt_preprocess.py @@ -0,0 +1,87 @@ +import torch + + +def preprocess_video_panoptic_gt( + gt_labels, + gt_masks, + gt_semantic_seg, + gt_instance_ids, + num_things, + num_stuff, +): + num_classes = num_things + num_stuff + num_frames = len(gt_masks) + mask_size = gt_masks[0].masks.shape[-2:] + + thing_masks_list = [] + for frame_id in range(num_frames): + thing_masks_list.append(gt_masks[frame_id].pad( + mask_size, pad_val=0).to_tensor( + dtype=torch.bool, device=gt_labels.device) + ) + instances = torch.unique(gt_instance_ids[:, 1]) + things_masks = [] + labels = [] + for instance in instances: + pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple + labels_instance = gt_labels[:, 1][pos_ins] + assert torch.allclose(labels_instance, labels_instance[0]) + labels.append(labels_instance[0]) + instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist() + instance_masks = [] + for frame_id in range(num_frames): + frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1] + if frame_id not in instance_frame_ids: + empty_mask = torch.zeros( + mask_size, + dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device + ) + instance_masks.append(empty_mask) + else: + pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item() + frame_mask = thing_masks_list[frame_id][pos_inner_frame] + instance_masks.append(frame_mask) + things_masks.append(torch.stack(instance_masks)) + + if len(instances) == 0: + things_masks = torch.stack(thing_masks_list, dim=1) + labels = torch.empty_like(instances) + else: + things_masks = torch.stack(things_masks) + labels = torch.stack(labels) + assert torch.all(torch.less(labels, num_things)) + + if gt_semantic_seg is not None: + things_labels = labels + gt_semantic_seg = gt_semantic_seg.squeeze(1) + + semantic_labels = torch.unique( + gt_semantic_seg, + sorted=False, + return_inverse=False, + return_counts=False) + stuff_masks_list = [] + stuff_labels_list = [] + for label in semantic_labels: + if label < num_things or label >= num_classes: + continue + stuff_mask = gt_semantic_seg == label + stuff_masks_list.append(stuff_mask) + stuff_labels_list.append(label) + + if len(stuff_masks_list) > 0: + stuff_masks = torch.stack(stuff_masks_list, dim=0) + stuff_labels = torch.stack(stuff_labels_list, dim=0) + assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes)) + labels = torch.cat([things_labels, stuff_labels], dim=0) + masks = torch.cat([things_masks, stuff_masks], dim=0) + else: + labels = things_labels + masks = things_masks + assert len(labels) == len(masks) + else: + masks = things_masks + + labels = labels.to(dtype=torch.long) + masks = masks.to(dtype=torch.long) + return labels, masks diff --git a/ext/meta/sam_meta.py b/ext/meta/sam_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..aba1fa2daf3d24ba83cc54996ac7b5bf305e02e1 --- /dev/null +++ b/ext/meta/sam_meta.py @@ -0,0 +1,41 @@ +meta_dict = { + 'vit_h': dict( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ), + 'vit_l': dict( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ), + 'vit_b': dict( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ) +} + +checkpoint_dict = { + 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', + 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', + 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', +} diff --git a/ext/open_clip/__init__.py b/ext/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb1199b8aa87a919abff1bd0020c6624757ac62 --- /dev/null +++ b/ext/open_clip/__init__.py @@ -0,0 +1,15 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/ext/open_clip/bpe_simple_vocab_16e6.txt.gz b/ext/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/ext/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/ext/open_clip/coca_model.py b/ext/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/ext/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/ext/open_clip/constants.py b/ext/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/ext/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/ext/open_clip/factory.py b/ext/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..dce3e9fbb089804edb9f63775ccb2d832cab2500 --- /dev/null +++ b/ext/open_clip/factory.py @@ -0,0 +1,387 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + logger: logging.Logger = logging, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logger.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logger.info(f'Loaded {model_name} model config.') + else: + logger.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logger.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + logger=logger, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + logger=logger, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/ext/open_clip/generation_utils.py b/ext/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ext/open_clip/hf_configs.py b/ext/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..13c9bfd8c660eac59f1fbc1912b9fccc9c0c625a --- /dev/null +++ b/ext/open_clip/hf_configs.py @@ -0,0 +1,56 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, +} diff --git a/ext/open_clip/hf_model.py b/ext/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..08dbdbcde02b550ca765ca9bcb0b667be2c0443d --- /dev/null +++ b/ext/open_clip/hf_model.py @@ -0,0 +1,193 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = torch.tensor(0) + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/ext/open_clip/model.py b/ext/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f85b68ba23117cb65d082cf5cd4cf7528bab4619 --- /dev/null +++ b/ext/open_clip/model.py @@ -0,0 +1,473 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/ext/open_clip/model_configs/EVA01-g-14-plus.json b/ext/open_clip/model_configs/EVA01-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..73f46a71e664fce987218b8eb48903e7bd895f41 --- /dev/null +++ b/ext/open_clip/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA01-g-14.json b/ext/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/ext/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-B-16.json b/ext/open_clip/model_configs/EVA02-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..3f92357287e1f6600da1e7f391cb6370d7f66de4 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-E-14-plus.json b/ext/open_clip/model_configs/EVA02-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..e250c2a404c86ff168c54cfcf71bc2492be1b74c --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-E-14.json b/ext/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-L-14-336.json b/ext/open_clip/model_configs/EVA02-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb07f3c082fd88c4e86131b272163aaacfaef9e --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-L-14.json b/ext/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN101-quickgelu.json b/ext/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/ext/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN101.json b/ext/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/ext/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50-quickgelu.json b/ext/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/ext/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/ext/open_clip/model_configs/RN50.json b/ext/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/ext/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x16.json b/ext/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/ext/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x4.json b/ext/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/ext/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x64.json b/ext/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/ext/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16-plus-240.json b/ext/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16-plus.json b/ext/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16.json b/ext/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32-plus-256.json b/ext/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32-quickgelu.json b/ext/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32.json b/ext/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-H-14.json b/ext/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-H-16.json b/ext/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/ext/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14-280.json b/ext/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14-336.json b/ext/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14.json b/ext/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-16-320.json b/ext/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-16.json b/ext/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-16-alt.json b/ext/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-16.json b/ext/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-32-alt.json b/ext/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-32.json b/ext/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-16-alt.json b/ext/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-16.json b/ext/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-32-alt.json b/ext/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-32.json b/ext/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-bigG-14.json b/ext/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-e-14.json b/ext/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/ext/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-g-14.json b/ext/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_ViT-B-32.json b/ext/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/ext/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_ViT-L-14.json b/ext/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/ext/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/ext/open_clip/model_configs/coca_base.json b/ext/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/ext/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json b/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/ext/open_clip/model_configs/convnext_base.json b/ext/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_base_w.json b/ext/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_base_w_320.json b/ext/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large.json b/ext/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large_d.json b/ext/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large_d_320.json b/ext/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_small.json b/ext/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_tiny.json b/ext/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/ext/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xlarge.json b/ext/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xxlarge.json b/ext/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xxlarge_320.json b/ext/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/mt5-base-ViT-B-32.json b/ext/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/ext/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json b/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/roberta-ViT-B-32.json b/ext/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/ext/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/swin_base_patch4_window7_224.json b/ext/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/ext/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json b/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/modified_resnet.py b/ext/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9 --- /dev/null +++ b/ext/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/ext/open_clip/openai.py b/ext/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c0235245c2e4f1217b3b2bfaf2acf78e74981 --- /dev/null +++ b/ext/open_clip/openai.py @@ -0,0 +1,90 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + state_dict = torch.load(model_path, map_location="cpu") + + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': + model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) + + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD + return model diff --git a/ext/open_clip/pretrained.py b/ext/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1465a2325652be7e7a1d7563698e38b9ec408cc6 --- /dev/null +++ b/ext/open_clip/pretrained.py @@ -0,0 +1,427 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ) +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/ext/open_clip/push_to_hf_hub.py b/ext/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6271da1d35e36ea22e92d339dc9465d0793249 --- /dev/null +++ b/ext/open_clip/push_to_hf_hub.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + list_repo_files, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + safe_serialization: Union[bool, str] = False, + skip_weights : bool = False, +): + config_filename = HF_CONFIG_NAME + + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, str] = False, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + safe_serialization=safe_serialization, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + precision: str = 'fp32', + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + precision=precision, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + safe_serialization='both', + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- clip\n" + readme_text += "library_name: open_clip\n" + readme_text += "pipeline_tag: zero-shot-image-classification\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + precision=args.precision, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/ext/open_clip/timm_model.py b/ext/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3f595d67cdedd142b6312d26924e8e58c67086 --- /dev/null +++ b/ext/open_clip/timm_model.py @@ -0,0 +1,149 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if not proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + self.trunk = timm.create_model( + model_name, + num_classes=embed_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + else: + assert not proj, f'Unknown projection type {proj}.' + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/ext/open_clip/tokenizer.py b/ext/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/ext/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/ext/open_clip/transform.py b/ext/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/ext/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/ext/open_clip/transformer.py b/ext/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a30e94664a2dd890a373eb0a0f640818836baaa --- /dev/null +++ b/ext/open_clip/transformer.py @@ -0,0 +1,726 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/ext/open_clip/utils.py b/ext/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/ext/open_clip/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/ext/open_clip/version.py b/ext/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a910817da22d06aa0244c6d488b40d30da2bfb7e --- /dev/null +++ b/ext/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.20.0' diff --git a/ext/open_clip/zero_shot_classifier.py b/ext/open_clip/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..12b58f65bb0875b164946a9ee73e938aef255382 --- /dev/null +++ b/ext/open_clip/zero_shot_classifier.py @@ -0,0 +1,110 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + texts = tokenizer(texts).to(device) + class_embeddings = F.normalize(model.encode_text(texts), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/ext/open_clip/zero_shot_metadata.py b/ext/open_clip/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb452bbb6e27b71cff1dd27e2bb263259b9363f --- /dev/null +++ b/ext/open_clip/zero_shot_metadata.py @@ -0,0 +1,266 @@ + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +) + diff --git a/ext/sam/__init__.py b/ext/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9636764f4a9e6500fac352c388063d9b629aab --- /dev/null +++ b/ext/sam/__init__.py @@ -0,0 +1,3 @@ +from .image_encoder import ImageEncoderViT +from .prompt_encoder import PromptEncoder +from .mask_decoder import MaskDecoder diff --git a/ext/sam/common.py b/ext/sam/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/ext/sam/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/ext/sam/image_encoder.py b/ext/sam/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a --- /dev/null +++ b/ext/sam/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/ext/sam/mask_decoder.py b/ext/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..967fbdfd3024aed01aa604a0420cb3240720fcbc --- /dev/null +++ b/ext/sam/mask_decoder.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + +from .transformer import TwoWayTransformer + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + with_iou: bool = True + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = TwoWayTransformer( + depth=2, + embedding_dim=transformer_dim, + mlp_dim=2048, + num_heads=8, + ) + + self.num_multimask_outputs = num_multimask_outputs + + if with_iou: + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + if with_iou: + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/ext/sam/prompt_encoder.py b/ext/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/ext/sam/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/ext/sam/transformer.py b/ext/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/ext/sam/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a076d24bc7faee1bb544738c7329976e923ab807 --- /dev/null +++ b/main.py @@ -0,0 +1,205 @@ +import gradio as gr + +import numpy as np + +import torch +import torch.nn.functional as F +from PIL import Image + +# mm libs +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmengine import Config, print_log +from mmengine.structures import InstanceData + +from PIL import ImageDraw + +IMG_SIZE = 1024 + +TITLE = "
🚀RAP-SAM: Towards Real-Time All-Purpose Segment Anything
" +CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" + +model_cfg = Config.fromfile('app/configs/rap_sam_r50_12e_adaptor.py') + +model = MODELS.build(model_cfg.model) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device=device) +model = model.eval() +model.init_weights() + +mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None] +std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None] + + +class IMGState: + def __init__(self): + self.img = None + self.selected_points = [] + self.available_to_set = True + + def set_img(self, img): + self.img = img + self.available_to_set = False + + def clear(self): + self.img = None + self.selected_points = [] + self.available_to_set = True + + def clean(self): + self.selected_points = [] + + @property + def available(self): + return self.available_to_set + + @classmethod + def cls_clean(cls, state): + state.clean() + return Image.fromarray(state.img), None + + @classmethod + def cls_clear(cls, state): + state.clear() + return None, None + + +def store_img(img, img_state): + w, h = img.size + scale = IMG_SIZE / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR) + img_numpy = np.array(img) + img_state.set_img(img_numpy) + print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current') + + return img, None + + +def get_points_with_draw(image, img_state, evt: gr.SelectData): + x, y = evt.index[0], evt.index[1] + print_log(f"Point: {x}_{y}", logger='current') + point_radius, point_color = 10, (97, 217, 54) + + img_state.selected_points.append([x, y]) + if len(img_state.selected_points) > 0: + img_state.selected_points = img_state.selected_points[-1:] + image = Image.fromarray(img_state.img) + + draw = ImageDraw.Draw(image) + draw.ellipse( + [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], + fill=point_color, + ) + return image + + +def segment_point(image, img_state): + output_img = img_state.img + h, w = output_img.shape[:2] + + img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None] + img_tensor = (img_tensor - mean) / std + + im_w = w if w % 32 == 0 else w // 32 * 32 + 32 + im_h = h if h % 32 == 0 else h // 32 * 32 + 32 + img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0) + + if len(img_state.selected_points) > 0: + input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device) + batch_data_samples = [DetDataSample()] + selected_point = torch.cat([input_points - 3, input_points + 3], 1) + gt_instances = InstanceData( + point_coords=selected_point, + ) + pb_labels = torch.ones(len(gt_instances), dtype=torch.long, device=device) + gt_instances.pb_labels = pb_labels + batch_data_samples[0].gt_instances_collected = gt_instances + batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) + batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) + else: + batch_data_samples = [DetDataSample()] + batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) + batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) + with torch.no_grad(): + masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples) + + masks = masks[0, 0, :h, :w] + masks = masks > 0. + rgb_shape = tuple(list(masks.shape) + [3]) + color = np.zeros(rgb_shape, dtype=np.uint8) + color[masks] = np.array([97, 217, 54]) + # color[masks] = np.array([217, 90, 54]) + output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8) + + output_img = Image.fromarray(output_img) + return image, output_img + + +def register_title(): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown(TITLE) + + +def register_point_mode(): + with gr.Tab("Point mode"): + img_state = gr.State(IMGState()) + with gr.Row(variant="panel"): + with gr.Column(scale=1): + img_p = gr.Image(label="Input Image", type="pil") + + with gr.Column(scale=1): + segm_p = gr.Image(label="Segment", interactive=False, type="pil") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + segment_btn = gr.Button("Segment", variant="primary") + clean_btn = gr.Button("Clean Prompts", variant="secondary") + + img_p.upload( + store_img, + [img_p, img_state], + [img_p, segm_p] + ) + + img_p.select( + get_points_with_draw, + [img_p, img_state], + img_p + ) + + segment_btn.click( + segment_point, + [img_p, img_state], + [img_p, segm_p] + ) + + clean_btn.click( + IMGState.cls_clean, + img_state, + [img_p, segm_p] + ) + + img_p.clear( + IMGState.cls_clear, + img_state, + [img_p, segm_p] + ) + + +def build_demo(): + with gr.Blocks(css=CSS, title="RAP-SAM") as _demo: + register_title() + register_point_mode() + return _demo + + +if __name__ == '__main__': + demo = build_demo() + + demo.queue(api_open=False) + demo.launch(server_name='0.0.0.0') diff --git a/models/rapsam_r50_12e.pth b/models/rapsam_r50_12e.pth new file mode 100644 index 0000000000000000000000000000000000000000..9ac0ff0301cd023d9c8d8306487375696eca2baf --- /dev/null +++ b/models/rapsam_r50_12e.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a73c67e71767059ac2a5222c506721c79dd4c70086af17b00e38f513679f3de +size 189436775 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d83463f2cccded00a9c26cd34c423b1654b9e68d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +-f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/index.html + +torch==2.1.2 +torchvision +mmengine==0.10.2 +mmcv==2.1.0 +mmdet==3.3.0 +ftfy +timm +regex