Spaces:
Sleeping
Sleeping
from typing import Optional, List | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmdet.registry import MODELS | |
from mmengine.model import BaseModule | |
from mmengine.dist import get_dist_info | |
from mmengine.logging import MMLogger | |
from timm.layers import resample_abs_pos_embed | |
import ext.open_clip as open_clip | |
from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix | |
def flatten_permute(x): | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
return x | |
class OpenCLIPBackbone(BaseModule): | |
"""OpenCLIPBackbone, | |
Please refer to: | |
https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface | |
for the supported models and checkpoints. | |
""" | |
STAGES = 4 | |
def __init__( | |
self, | |
img_size: int = 1024, | |
model_name: str = '', | |
fix: bool = True, | |
fix_layers: Optional[List] = None, | |
init_cfg=None, | |
): | |
assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ | |
f"{init_cfg['type']} is not supported." | |
pretrained = init_cfg['checkpoint'] | |
super().__init__(init_cfg=None) | |
self.init_cfg = init_cfg | |
self.logger = MMLogger.get_current_instance() | |
rank, world_size = get_dist_info() | |
if world_size > 1: | |
if rank == 0: | |
if init_cfg['type'] == 'clip_pretrain': | |
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, | |
return_transform=False, logger=self.logger) | |
elif init_cfg['type'] == 'image_pretrain': | |
_ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) | |
else: | |
pass | |
dist.barrier() | |
# Get the clip model | |
if init_cfg['type'] == 'clip_pretrain': | |
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, | |
return_transform=False, logger=self.logger) | |
elif init_cfg['type'] == 'image_pretrain': | |
clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) | |
elif init_cfg['type'] == 'Pretrained': | |
clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) | |
else: | |
raise NotImplementedError | |
self.out_indices = (0, 1, 2, 3) | |
model_name_lower = model_name.lower() | |
if 'convnext_' in model_name_lower: | |
model_type = 'convnext' | |
if '_base' in model_name_lower: | |
output_channels = [128, 256, 512, 1024] | |
feat_size = 0 | |
elif '_large' in model_name_lower: | |
output_channels = [192, 384, 768, 1536] | |
feat_size = 0 | |
elif '_xxlarge' in model_name_lower: | |
output_channels = [384, 768, 1536, 3072] | |
feat_size = 0 | |
else: | |
raise NotImplementedError(f"{model_name} not supported yet.") | |
elif 'rn' in model_name_lower: | |
model_type = 'resnet' | |
if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: | |
output_channels = [256, 512, 1024, 2048] | |
feat_size = 7 | |
elif model_name_lower == 'rn50x4': | |
output_channels = [320, 640, 1280, 2560] | |
feat_size = 9 | |
elif model_name_lower == 'rn50x16': | |
output_channels = [384, 768, 1536, 3072] | |
feat_size = 12 | |
elif model_name_lower == 'rn50x64': | |
output_channels = [512, 1024, 2048, 4096] | |
feat_size = 14 | |
else: | |
raise NotImplementedError(f"{model_name} not supported yet.") | |
elif "vit" in model_name_lower: | |
model_type = 'vit' | |
if model_name_lower == 'vit-l-14': | |
output_channels = [1024, 1024, 1024, 1024] | |
feat_size = 0 | |
assert not clip_model.visual.input_patchnorm | |
assert clip_model.visual.attn_pool is None | |
else: | |
raise NotImplementedError(f"{model_name} not supported yet.") | |
else: | |
raise NotImplementedError(f"{model_name} not supported yet.") | |
self.model_name = model_name | |
self.fix = fix | |
self.model_type = model_type | |
self.output_channels = output_channels | |
self.feat_size = feat_size | |
# Get the visual model | |
if self.model_type == 'resnet': | |
self.stem = nn.Sequential(*[ | |
clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, | |
clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, | |
clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, | |
]) | |
elif self.model_type == 'convnext': | |
self.stem = clip_model.visual.trunk.stem | |
elif self.model_type == 'vit': | |
self.stem = clip_model.visual.conv1 | |
else: | |
raise ValueError | |
if self.model_type == 'resnet': | |
self.avgpool = clip_model.visual.avgpool | |
elif self.model_type == 'convnext': | |
self.avgpool = nn.Identity() | |
elif self.model_type == 'vit': | |
self.avgpool = flatten_permute | |
else: | |
raise ValueError | |
self.res_layers = [] | |
if self.model_type in ['vit']: | |
self.t_class_embedding = clip_model.visual.class_embedding | |
self.t_positional_embedding = clip_model.visual.positional_embedding | |
self.t_ln_pre_trans = clip_model.visual.ln_pre | |
self.t_transformer = clip_model.visual.transformer | |
else: | |
for i in range(self.STAGES): | |
if self.model_type == 'resnet': | |
layer_name = f'layer{i + 1}' | |
layer = getattr(clip_model.visual, layer_name) | |
elif self.model_type == 'convnext': | |
layer_name = f'layer{i + 1}' | |
layer = clip_model.visual.trunk.stages[i] | |
else: | |
raise ValueError | |
self.add_module(layer_name, layer) | |
self.res_layers.append(layer_name) | |
if self.model_type == 'resnet': | |
self.norm_pre = nn.Identity() | |
elif self.model_type == 'convnext': | |
self.norm_pre = clip_model.visual.trunk.norm_pre | |
elif self.model_type == 'vit': | |
self.norm_pre = nn.Identity() | |
if self.model_type == 'resnet': | |
self.head = clip_model.visual.attnpool | |
elif self.model_type == 'convnext': | |
self.head = nn.Sequential(*[ | |
clip_model.visual.trunk.head, | |
clip_model.visual.head, | |
]) | |
elif self.model_type == 'vit': | |
self.head = clip_model.visual.ln_post | |
if self.init_cfg['type'] == 'Pretrained': | |
checkpoint_path = pretrained | |
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) | |
self.load_state_dict(state_dict, strict=True) | |
self.fix_layers = fix_layers | |
if not self.fix: | |
self.train() | |
for name, param in self.norm_pre.named_parameters(): | |
param.requires_grad = False | |
for name, param in self.head.named_parameters(): | |
param.requires_grad = False | |
if self.fix_layers is not None: | |
for i, layer_name in enumerate(self.res_layers): | |
if i in self.fix_layers: | |
res_layer = getattr(self, layer_name) | |
for name, param in res_layer.named_parameters(): | |
param.requires_grad = False | |
if i == 0: | |
for name, param in self.stem.named_parameters(): | |
param.requires_grad = False | |
if self.fix: | |
self.train(mode=False) | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
def init_weights(self): | |
self.logger.info(f"Init Config for {self.model_name}") | |
self.logger.info(self.init_cfg) | |
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: | |
if not isinstance(mode, bool): | |
raise ValueError("training mode is expected to be boolean") | |
if self.fix: | |
super().train(mode=False) | |
else: | |
super().train(mode=mode) | |
if self.fix_layers is not None: | |
for i, layer_name in enumerate(self.res_layers): | |
if i in self.fix_layers: | |
res_layer = getattr(self, layer_name) | |
res_layer.train(mode=False) | |
if i == 0: | |
self.stem.train(mode=False) | |
return self | |
def forward_func(self, x): | |
x = self.stem(x) | |
h, w = x.shape[-2:] | |
x = self.avgpool(x) | |
outs = [] | |
if self.model_type == 'vit': | |
x = torch.cat( | |
[self.t_class_embedding.to(x.dtype) + | |
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
x], dim=1 | |
) # shape = [*, grid ** 2 + 1, width] | |
new_pos_embed = resample_abs_pos_embed( | |
self.t_positional_embedding[None], | |
[h, w], | |
num_prefix_tokens=1 | |
) | |
x = x + new_pos_embed.to(x.dtype) | |
x = self.t_ln_pre_trans(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.t_transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = x[:, 1:] | |
x = x.permute(0, 2, 1).unflatten(2, (h, w)) # BCHW | |
for i in range(self.STAGES): | |
outs.append( | |
F.interpolate( | |
x, | |
scale_factor=2 ** (2 - i), | |
mode='bilinear', | |
align_corners=False | |
) | |
) | |
else: | |
for i, layer_name in enumerate(self.res_layers): | |
res_layer = getattr(self, layer_name) | |
x = res_layer(x).contiguous() | |
if i in self.out_indices: | |
outs.append(x) | |
return tuple(outs) | |
def get_clip_feature(self, backbone_feat): | |
if self.model_type == 'resnet': | |
return backbone_feat | |
elif self.model_type == 'convnext': | |
return self.norm_pre(backbone_feat) | |
raise NotImplementedError | |
def forward_feat(self, features): | |
if self.model_type == 'convnext': | |
batch, num_query, channel = features.shape | |
features = features.reshape(batch * num_query, channel, 1, 1) | |
features = self.head(features) | |
return features.view(batch, num_query, features.shape[-1]) | |
elif self.model_type == 'resnet': | |
num_query, channel, seven, seven = features.shape | |
features = self.head(features) | |
return features | |
def forward(self, x): | |
if self.fix: | |
with torch.no_grad(): | |
outs = self.forward_func(x) | |
else: | |
outs = self.forward_func(x) | |
return outs | |
def get_text_model(self): | |
return OpenCLIPBackboneText( | |
self.model_name, | |
init_cfg=self.init_cfg | |
) | |
class OpenCLIPBackboneText(BaseModule): | |
def __init__( | |
self, | |
model_name: str = '', | |
init_cfg=None, | |
): | |
assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported." | |
pretrained = init_cfg['checkpoint'] | |
super().__init__(init_cfg=None) | |
self.init_cfg = init_cfg | |
self.logger = MMLogger.get_current_instance() | |
rank, world_size = get_dist_info() | |
if world_size > 1: | |
if rank == 0: | |
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, | |
logger=self.logger) | |
else: | |
pass | |
dist.barrier() | |
# Get the clip model | |
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, | |
logger=self.logger) | |
# Get the textual model | |
self.text_tokenizer = open_clip.get_tokenizer(model_name) | |
self.text_transformer = clip_model.transformer | |
self.text_token_embedding = clip_model.token_embedding | |
self.text_pe = clip_model.positional_embedding | |
self.text_ln_final = clip_model.ln_final | |
self.text_proj = clip_model.text_projection | |
self.register_buffer('text_attn_mask', clip_model.attn_mask) | |
self.param_dtype = torch.float32 | |
self.model_name = model_name | |
def init_weights(self): | |
self.logger.info(f"Init Config for {self.model_name}") | |
self.logger.info(self.init_cfg) | |
# Copied from | |
# https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343 | |
def forward(self, text): | |
text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device) | |
x = self.text_token_embedding(text_tokens).to(self.param_dtype) | |
x = x + self.text_pe.to(self.param_dtype) | |
x = x.permute(1, 0, 2) | |
x = self.text_transformer(x, attn_mask=self.text_attn_mask) | |
x = x.permute(1, 0, 2) | |
x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj | |
return x | |