Ovis-Clip-Qwen1_5-7B / clip_visual_tokenizer.py
runninglsy's picture
initial commit
a2db297
from datetime import datetime
from typing import Dict
import deepspeed
import torch
from torch import Tensor
from transformers import AutoConfig, AutoModel
from transformers import CLIPVisionModel, CLIPImageProcessor
from transformers.integrations import is_deepspeed_zero3_enabled
from .utils import BEGIN_LINE, END_LINE, rank0_print
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
MODEL_TYPE = "clip_visual_tokenizer"
class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = MODEL_TYPE
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class ClipVisualTokenizer(BaseVisualTokenizer):
config_class = ClipVisualTokenizerConfig
supports_gradient_checkpointing = True
_no_split_modules = ["CLIPEncoderLayer"]
_image_processor_class = CLIPImageProcessor
_image_processor_kwargs = dict(do_center_crop=False)
_backbone_class = CLIPVisionModel
_backbone_name_or_path = "openai/clip-vit-large-patch14-336"
def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
head_dim = self.config.vocab_size
if self.config.use_indicators:
head_dim -= 2 # reserved for two image indicator tokens
self.head = torch.nn.Sequential(
torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False),
torch.nn.LayerNorm(head_dim)
)
def re_init_layers(self, re_init_layer_begin):
layer_dict = self.get_re_init_layer_dict(re_init_layer_begin)
for name, layer in layer_dict.items():
rank0_print(BEGIN_LINE)
rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ')
for k, v in layer.named_parameters():
with deepspeed.zero.GatheredParameters([v]):
rank0_print(f'{k}: {v}')
with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0):
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
layer.apply(self.backbone._init_weights)
rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:')
for k, v in layer.named_parameters():
with deepspeed.zero.GatheredParameters([v]):
rank0_print(f'{k}: {v}')
rank0_print(END_LINE)
def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]:
assert re_init_layer_begin >= 0, "negative index is prohibited"
layer_dict = dict()
for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers):
layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i]
return layer_dict
def get_monitor_tensors(self):
return dict(
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
head=self.head[0].weight
)
def get_image_size(self):
height = self.image_processor.crop_size["height"]
width = self.image_processor.crop_size["width"]
return height, width
def forward(self, pixel_values) -> Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
output = self.backbone(
pixel_values, output_hidden_states=True, return_dict=True)
features = output.last_hidden_state
if self.config.drop_cls_token:
features = features[:, 1:, :]
logits = self.head(features)
tokens = self.tokenize(logits)
if self.config.use_indicators:
# tokens' shape is [BatchSize, #Token, VocabSize-2], so padding with [BatchSize, #Token, 2], after
# which, tokens' shape should become [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len, 2),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
# adding indicator tokens, after which tokens' shape should become [BatchSize, 1+#Token+1, VocabSize]
begin_indicator = torch.zeros(size=(batch_size, 1),
dtype=torch.long,
device=tokens.device,
requires_grad=False) + self.config.vocab_size - 2
begin_indicator_token = torch.nn.functional.one_hot(begin_indicator,
num_classes=self.config.vocab_size).to(
dtype=tokens.dtype)
end_indicator = torch.zeros(size=(batch_size, 1),
dtype=torch.long,
device=tokens.device,
requires_grad=False) + self.config.vocab_size - 1
end_indicator_token = torch.nn.functional.one_hot(end_indicator,
num_classes=self.config.vocab_size).to(dtype=tokens.dtype)
tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1)
return tokens
AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig)
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)