|
from datetime import datetime |
|
from typing import Dict |
|
|
|
import deepspeed |
|
import torch |
|
from torch import Tensor |
|
from transformers import AutoConfig, AutoModel |
|
from transformers import CLIPVisionModel, CLIPImageProcessor |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
from .utils import BEGIN_LINE, END_LINE, rank0_print |
|
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer |
|
|
|
MODEL_TYPE = "clip_visual_tokenizer" |
|
|
|
|
|
class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig): |
|
model_type = MODEL_TYPE |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
if self.depths: |
|
assert len(self.depths) == 1 |
|
self.backbone_kwargs['num_hidden_layers'] = self.depths[0] |
|
|
|
|
|
class ClipVisualTokenizer(BaseVisualTokenizer): |
|
config_class = ClipVisualTokenizerConfig |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["CLIPEncoderLayer"] |
|
_image_processor_class = CLIPImageProcessor |
|
_image_processor_kwargs = dict(do_center_crop=False) |
|
_backbone_class = CLIPVisionModel |
|
_backbone_name_or_path = "openai/clip-vit-large-patch14-336" |
|
|
|
def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs): |
|
super().__init__(config, *inputs, **kwargs) |
|
head_dim = self.config.vocab_size |
|
if self.config.use_indicators: |
|
head_dim -= 2 |
|
self.head = torch.nn.Sequential( |
|
torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False), |
|
torch.nn.LayerNorm(head_dim) |
|
) |
|
|
|
def re_init_layers(self, re_init_layer_begin): |
|
layer_dict = self.get_re_init_layer_dict(re_init_layer_begin) |
|
for name, layer in layer_dict.items(): |
|
rank0_print(BEGIN_LINE) |
|
rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ') |
|
for k, v in layer.named_parameters(): |
|
with deepspeed.zero.GatheredParameters([v]): |
|
rank0_print(f'{k}: {v}') |
|
with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0): |
|
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0: |
|
layer.apply(self.backbone._init_weights) |
|
rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:') |
|
for k, v in layer.named_parameters(): |
|
with deepspeed.zero.GatheredParameters([v]): |
|
rank0_print(f'{k}: {v}') |
|
rank0_print(END_LINE) |
|
|
|
def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]: |
|
assert re_init_layer_begin >= 0, "negative index is prohibited" |
|
layer_dict = dict() |
|
for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers): |
|
layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i] |
|
return layer_dict |
|
|
|
def get_monitor_tensors(self): |
|
return dict( |
|
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight, |
|
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight, |
|
head=self.head[0].weight |
|
) |
|
|
|
def get_image_size(self): |
|
height = self.image_processor.crop_size["height"] |
|
width = self.image_processor.crop_size["width"] |
|
return height, width |
|
|
|
def forward(self, pixel_values) -> Tensor: |
|
output = self.backbone( |
|
pixel_values, output_hidden_states=True, return_dict=True) |
|
features = output.last_hidden_state |
|
if self.config.drop_cls_token: |
|
features = features[:, 1:, :] |
|
logits = self.head(features) |
|
tokens = self.tokenize(logits) |
|
if self.config.use_indicators: |
|
|
|
|
|
batch_size, token_len, _ = tokens.shape |
|
padding_tensor = torch.zeros(size=(batch_size, token_len, 2), |
|
dtype=tokens.dtype, |
|
device=tokens.device, |
|
layout=tokens.layout, |
|
requires_grad=False) |
|
tokens = torch.cat((tokens, padding_tensor), dim=2) |
|
|
|
|
|
begin_indicator = torch.zeros(size=(batch_size, 1), |
|
dtype=torch.long, |
|
device=tokens.device, |
|
requires_grad=False) + self.config.vocab_size - 2 |
|
begin_indicator_token = torch.nn.functional.one_hot(begin_indicator, |
|
num_classes=self.config.vocab_size).to( |
|
dtype=tokens.dtype) |
|
end_indicator = torch.zeros(size=(batch_size, 1), |
|
dtype=torch.long, |
|
device=tokens.device, |
|
requires_grad=False) + self.config.vocab_size - 1 |
|
end_indicator_token = torch.nn.functional.one_hot(end_indicator, |
|
num_classes=self.config.vocab_size).to(dtype=tokens.dtype) |
|
tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1) |
|
return tokens |
|
|
|
|
|
AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig) |
|
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer) |
|
|