OMG_Seg / seg /models /backbones /openclip_backbone.py
Haobo Yuan
add omg code
b34d1d6
from typing import Optional, List
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS
from mmengine.model import BaseModule
from mmengine.dist import get_dist_info
from mmengine.logging import MMLogger
from timm.layers import resample_abs_pos_embed
import ext.open_clip as open_clip
from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix
def flatten_permute(x):
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
return x
@MODELS.register_module()
class OpenCLIPBackbone(BaseModule):
"""OpenCLIPBackbone,
Please refer to:
https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface
for the supported models and checkpoints.
"""
STAGES = 4
def __init__(
self,
img_size: int = 1024,
model_name: str = '',
fix: bool = True,
fix_layers: Optional[List] = None,
init_cfg=None,
):
assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \
f"{init_cfg['type']} is not supported."
pretrained = init_cfg['checkpoint']
super().__init__(init_cfg=None)
self.init_cfg = init_cfg
self.logger = MMLogger.get_current_instance()
rank, world_size = get_dist_info()
if world_size > 1:
if rank == 0:
if init_cfg['type'] == 'clip_pretrain':
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
return_transform=False, logger=self.logger)
elif init_cfg['type'] == 'image_pretrain':
_ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
else:
pass
dist.barrier()
# Get the clip model
if init_cfg['type'] == 'clip_pretrain':
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
return_transform=False, logger=self.logger)
elif init_cfg['type'] == 'image_pretrain':
clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
elif init_cfg['type'] == 'Pretrained':
clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger)
else:
raise NotImplementedError
self.out_indices = (0, 1, 2, 3)
model_name_lower = model_name.lower()
if 'convnext_' in model_name_lower:
model_type = 'convnext'
if '_base' in model_name_lower:
output_channels = [128, 256, 512, 1024]
feat_size = 0
elif '_large' in model_name_lower:
output_channels = [192, 384, 768, 1536]
feat_size = 0
elif '_xxlarge' in model_name_lower:
output_channels = [384, 768, 1536, 3072]
feat_size = 0
else:
raise NotImplementedError(f"{model_name} not supported yet.")
elif 'rn' in model_name_lower:
model_type = 'resnet'
if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']:
output_channels = [256, 512, 1024, 2048]
feat_size = 7
elif model_name_lower == 'rn50x4':
output_channels = [320, 640, 1280, 2560]
feat_size = 9
elif model_name_lower == 'rn50x16':
output_channels = [384, 768, 1536, 3072]
feat_size = 12
elif model_name_lower == 'rn50x64':
output_channels = [512, 1024, 2048, 4096]
feat_size = 14
else:
raise NotImplementedError(f"{model_name} not supported yet.")
elif "vit" in model_name_lower:
model_type = 'vit'
if model_name_lower == 'vit-l-14':
output_channels = [1024, 1024, 1024, 1024]
feat_size = 0
assert not clip_model.visual.input_patchnorm
assert clip_model.visual.attn_pool is None
else:
raise NotImplementedError(f"{model_name} not supported yet.")
else:
raise NotImplementedError(f"{model_name} not supported yet.")
self.model_name = model_name
self.fix = fix
self.model_type = model_type
self.output_channels = output_channels
self.feat_size = feat_size
# Get the visual model
if self.model_type == 'resnet':
self.stem = nn.Sequential(*[
clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1,
clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2,
clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3,
])
elif self.model_type == 'convnext':
self.stem = clip_model.visual.trunk.stem
elif self.model_type == 'vit':
self.stem = clip_model.visual.conv1
else:
raise ValueError
if self.model_type == 'resnet':
self.avgpool = clip_model.visual.avgpool
elif self.model_type == 'convnext':
self.avgpool = nn.Identity()
elif self.model_type == 'vit':
self.avgpool = flatten_permute
else:
raise ValueError
self.res_layers = []
if self.model_type in ['vit']:
self.t_class_embedding = clip_model.visual.class_embedding
self.t_positional_embedding = clip_model.visual.positional_embedding
self.t_ln_pre_trans = clip_model.visual.ln_pre
self.t_transformer = clip_model.visual.transformer
else:
for i in range(self.STAGES):
if self.model_type == 'resnet':
layer_name = f'layer{i + 1}'
layer = getattr(clip_model.visual, layer_name)
elif self.model_type == 'convnext':
layer_name = f'layer{i + 1}'
layer = clip_model.visual.trunk.stages[i]
else:
raise ValueError
self.add_module(layer_name, layer)
self.res_layers.append(layer_name)
if self.model_type == 'resnet':
self.norm_pre = nn.Identity()
elif self.model_type == 'convnext':
self.norm_pre = clip_model.visual.trunk.norm_pre
elif self.model_type == 'vit':
self.norm_pre = nn.Identity()
if self.model_type == 'resnet':
self.head = clip_model.visual.attnpool
elif self.model_type == 'convnext':
self.head = nn.Sequential(*[
clip_model.visual.trunk.head,
clip_model.visual.head,
])
elif self.model_type == 'vit':
self.head = clip_model.visual.ln_post
if self.init_cfg['type'] == 'Pretrained':
checkpoint_path = pretrained
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
self.load_state_dict(state_dict, strict=True)
self.fix_layers = fix_layers
if not self.fix:
self.train()
for name, param in self.norm_pre.named_parameters():
param.requires_grad = False
for name, param in self.head.named_parameters():
param.requires_grad = False
if self.fix_layers is not None:
for i, layer_name in enumerate(self.res_layers):
if i in self.fix_layers:
res_layer = getattr(self, layer_name)
for name, param in res_layer.named_parameters():
param.requires_grad = False
if i == 0:
for name, param in self.stem.named_parameters():
param.requires_grad = False
if self.fix:
self.train(mode=False)
for name, param in self.named_parameters():
param.requires_grad = False
def init_weights(self):
self.logger.info(f"Init Config for {self.model_name}")
self.logger.info(self.init_cfg)
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
if self.fix:
super().train(mode=False)
else:
super().train(mode=mode)
if self.fix_layers is not None:
for i, layer_name in enumerate(self.res_layers):
if i in self.fix_layers:
res_layer = getattr(self, layer_name)
res_layer.train(mode=False)
if i == 0:
self.stem.train(mode=False)
return self
def forward_func(self, x):
x = self.stem(x)
h, w = x.shape[-2:]
x = self.avgpool(x)
outs = []
if self.model_type == 'vit':
x = torch.cat(
[self.t_class_embedding.to(x.dtype) +
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1
) # shape = [*, grid ** 2 + 1, width]
new_pos_embed = resample_abs_pos_embed(
self.t_positional_embedding[None],
[h, w],
num_prefix_tokens=1
)
x = x + new_pos_embed.to(x.dtype)
x = self.t_ln_pre_trans(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.t_transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = x[:, 1:]
x = x.permute(0, 2, 1).unflatten(2, (h, w)) # BCHW
for i in range(self.STAGES):
outs.append(
F.interpolate(
x,
scale_factor=2 ** (2 - i),
mode='bilinear',
align_corners=False
)
)
else:
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x).contiguous()
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def get_clip_feature(self, backbone_feat):
if self.model_type == 'resnet':
return backbone_feat
elif self.model_type == 'convnext':
return self.norm_pre(backbone_feat)
raise NotImplementedError
def forward_feat(self, features):
if self.model_type == 'convnext':
batch, num_query, channel = features.shape
features = features.reshape(batch * num_query, channel, 1, 1)
features = self.head(features)
return features.view(batch, num_query, features.shape[-1])
elif self.model_type == 'resnet':
num_query, channel, seven, seven = features.shape
features = self.head(features)
return features
def forward(self, x):
if self.fix:
with torch.no_grad():
outs = self.forward_func(x)
else:
outs = self.forward_func(x)
return outs
def get_text_model(self):
return OpenCLIPBackboneText(
self.model_name,
init_cfg=self.init_cfg
)
@MODELS.register_module()
class OpenCLIPBackboneText(BaseModule):
def __init__(
self,
model_name: str = '',
init_cfg=None,
):
assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported."
pretrained = init_cfg['checkpoint']
super().__init__(init_cfg=None)
self.init_cfg = init_cfg
self.logger = MMLogger.get_current_instance()
rank, world_size = get_dist_info()
if world_size > 1:
if rank == 0:
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
logger=self.logger)
else:
pass
dist.barrier()
# Get the clip model
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
logger=self.logger)
# Get the textual model
self.text_tokenizer = open_clip.get_tokenizer(model_name)
self.text_transformer = clip_model.transformer
self.text_token_embedding = clip_model.token_embedding
self.text_pe = clip_model.positional_embedding
self.text_ln_final = clip_model.ln_final
self.text_proj = clip_model.text_projection
self.register_buffer('text_attn_mask', clip_model.attn_mask)
self.param_dtype = torch.float32
self.model_name = model_name
def init_weights(self):
self.logger.info(f"Init Config for {self.model_name}")
self.logger.info(self.init_cfg)
# Copied from
# https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343
@torch.no_grad()
def forward(self, text):
text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device)
x = self.text_token_embedding(text_tokens).to(self.param_dtype)
x = x + self.text_pe.to(self.param_dtype)
x = x.permute(1, 0, 2)
x = self.text_transformer(x, attn_mask=self.text_attn_mask)
x = x.permute(1, 0, 2)
x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj
return x