Ovis-Clip-Qwen1_5-7B / visual_tokenizer.py
runninglsy's picture
initial commit
a2db297
raw history blame
No virus
11.6 kB
from datetime import datetime
from typing import Dict, Union, Optional
import deepspeed
import torch
import PIL.Image
from torch.nn.functional import softmax, gumbel_softmax
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoConfig, AutoModel
from transformers import CLIPVisionModel, CLIPImageProcessor
from transformers.integrations import is_deepspeed_zero3_enabled
from .utils import BEGIN_LINE, END_LINE, rank0_print
MODEL_TYPE = "clip_visual_tokenizer"
class BaseVisualTokenizerConfig(PretrainedConfig):
def __init__(self,
vocab_size=16384,
tokenize_function="softmax",
tau=1.0,
depths=None,
use_indicators=False,
drop_cls_token=False,
backbone_config: Optional[Union[PretrainedConfig, dict]] = None,
hidden_stride: int = 1,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.tokenize_function = tokenize_function
self.tau = tau
if isinstance(depths, str):
depths = [int(x) for x in depths.split('|')]
self.depths = depths
self.backbone_kwargs = {}
self.use_indicators = use_indicators
self.drop_cls_token = drop_cls_token
if backbone_config is not None:
assert isinstance(backbone_config, (PretrainedConfig, dict)), \
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
if not isinstance(backbone_config, PretrainedConfig):
model_type = backbone_config['model_type']
backbone_config.pop('model_type')
backbone_config = AutoConfig.for_model(model_type, **backbone_config)
self.backbone_config = backbone_config
self.hidden_stride = hidden_stride
class BaseVisualTokenizer(PreTrainedModel):
base_model_prefix = "backbone"
main_input_name = None
_image_processor_class = None
_image_processor_kwargs = {}
_backbone_class = None
_backbone_name_or_path = None
def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if kwargs.get('train_from_scratch'):
self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path,
**self._image_processor_kwargs)
self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path,
**self.config.backbone_kwargs)
self.config.backbone_config = self.backbone.config
else:
self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
self.backbone = AutoModel.from_config(self.config.backbone_config)
self.head = None
assert all((self.image_processor.do_resize,
not getattr(self.image_processor, 'do_center_crop', False),
self.image_processor.do_rescale,
self.image_processor.do_normalize
)), f"image_processor `{self.image_processor}` is not supported currently"
def get_backbone(self):
return self.backbone
def get_monitor_tensors(self):
raise NotImplementedError
def get_image_processor(self):
return self.image_processor
def get_head(self):
return self.head
def get_image_size(self):
raise NotImplementedError
def preprocess_image(self, image: PIL.Image.Image, convert_to_rgb=True):
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
# first resize and preprocess
sides = self.get_image_size()
if sides[0] != sides[1]:
raise ValueError('get_image_size() returns non-square size')
side = sides[0]
width, height = image.size
if width == height:
new_width = new_height = side
elif width > height:
new_width = side
new_height = int(height / width * new_width)
else:
new_height = side
new_width = int(width / height * new_height)
new_size = dict(height=new_height, width=new_width)
pixel_values = self.image_processor.preprocess(image, size=new_size, return_tensors='pt')['pixel_values']
# then pad to square
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
new_height, new_width = pixel_values.shape[2:]
if new_height == new_width:
square_values[:, :, :, :] = pixel_values
elif new_height > new_width:
from_index = (side - new_width) // 2
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
else:
from_index = (side - new_height) // 2
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
return square_values
def get_layer_norm(self):
return self.layer_norm
def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
return tokens
class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = MODEL_TYPE
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class ClipVisualTokenizer(BaseVisualTokenizer):
config_class = ClipVisualTokenizerConfig
supports_gradient_checkpointing = True
_no_split_modules = ["CLIPEncoderLayer"]
_image_processor_class = CLIPImageProcessor
_image_processor_kwargs = dict(do_center_crop=False)
_backbone_class = CLIPVisionModel
_backbone_name_or_path = "openai/clip-vit-large-patch14-336"
def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
head_dim = self.config.vocab_size
if self.config.use_indicators:
head_dim -= 2 # reserved for two image indicator tokens
self.head = torch.nn.Sequential(
torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False),
torch.nn.LayerNorm(head_dim)
)
def re_init_layers(self, re_init_layer_begin):
layer_dict = self.get_re_init_layer_dict(re_init_layer_begin)
for name, layer in layer_dict.items():
rank0_print(BEGIN_LINE)
rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ')
for k, v in layer.named_parameters():
with deepspeed.zero.GatheredParameters([v]):
rank0_print(f'{k}: {v}')
with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0):
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
layer.apply(self.backbone._init_weights)
rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:')
for k, v in layer.named_parameters():
with deepspeed.zero.GatheredParameters([v]):
rank0_print(f'{k}: {v}')
rank0_print(END_LINE)
def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]:
assert re_init_layer_begin >= 0, "negative index is prohibited"
layer_dict = dict()
for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers):
layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i]
return layer_dict
def get_monitor_tensors(self):
return dict(
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
head=self.head[0].weight
)
def get_image_size(self):
height = self.image_processor.crop_size["height"]
width = self.image_processor.crop_size["width"]
return height, width
def forward(self, pixel_values) -> Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
output = self.backbone(
pixel_values, output_hidden_states=True, return_dict=True)
features = output.last_hidden_state
if self.config.drop_cls_token:
features = features[:, 1:, :]
logits = self.head(features)
tokens = self.tokenize(logits)
if self.config.use_indicators:
# tokens' shape is [BatchSize, #Token, VocabSize-2], so padding with [BatchSize, #Token, 2], after
# which, tokens' shape should become [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len, 2),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
# adding indicator tokens, after which tokens' shape should become [BatchSize, 1+#Token+1, VocabSize]
begin_indicator = torch.zeros(size=(batch_size, 1),
dtype=torch.long,
device=tokens.device,
requires_grad=False) + self.config.vocab_size - 2
begin_indicator_token = torch.nn.functional.one_hot(begin_indicator,
num_classes=self.config.vocab_size).to(
dtype=tokens.dtype)
end_indicator = torch.zeros(size=(batch_size, 1),
dtype=torch.long,
device=tokens.device,
requires_grad=False) + self.config.vocab_size - 1
end_indicator_token = torch.nn.functional.one_hot(end_indicator,
num_classes=self.config.vocab_size).to(dtype=tokens.dtype)
tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1)
return tokens
AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig)
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)