Manli's picture
initial commit
8b55b6a
raw
history blame
6.41 kB
from transformers import PretrainedConfig
from transformers import logging
from transformers import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class XGenMMVisionEncoderConfig(PretrainedConfig):
model_type = "xgenmm_vision_encoder"
def __init__(self,
model_name: str = 'google/siglip-so400m-patch14-384',
anyres_grids: list[int] = [[384, 768],[768, 384],[768, 768],[1152, 384],[384,1152]],
**kwargs):
self.model_name = model_name
self.anyres_grids = anyres_grids
super().__init__(**kwargs)
class XGenMMVisionTokenizerConfig(PretrainedConfig):
model_type = "xgenmm_vision_tokenizer"
def __init__(self,
vis_feature_dim: int = 1152,
lang_embedding_dim: int = 3072,
num_vis_tokens: int = 128,
image_aspect_ratio: str = 'anyres',
**kwargs):
self.vis_feature_dim = vis_feature_dim
self.lang_embedding_dim = lang_embedding_dim
self.num_vis_tokens = num_vis_tokens
self.image_aspect_ratio = image_aspect_ratio
super().__init__(**kwargs)
class XGenMMConfig(PretrainedConfig):
model_type = "xgenmm"
def __init__(self,
vision_encoder_config: dict = None,
vision_tokenizer_config: dict = None,
text_config: dict = None,
**kwargs):
if vision_encoder_config is None:
vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
if vision_tokenizer_config is None:
vision_tokenizer_config = {}
logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
if text_config is None:
text_config = {
'initial_tokenizer_len':32012,
'pad_token_id':32011,
'bos_token_id':1,
'eos_token_id':32000,
'vocab_size': 32064,
'hidden_size': 3072,
'intermediate_size': 8192,
'num_hidden_layers': 32,
'num_attention_heads': 32,
'num_key_value_heads': 32,
'resid_pdrop': 0.0,
'embd_pdrop': 0.0,
'attention_dropout': 0.0,
'hidden_act': 'silu',
'max_position_embeddings': 4096,
'original_max_position_embeddings': 4096,
'initializer_range': 0.02,
'rms_norm_eps': 1e-05,
'use_cache': True,
'rope_theta': 10000.0,
'rope_scaling': None,
'sliding_window': 2047,
'return_dict': True,
'output_hidden_states': False,
'output_attentions': False,
'torchscript': False,
'torch_dtype': 'bfloat16',
'use_bfloat16': False,
'tf_legacy_loss': False,
'pruned_heads': {},
'tie_word_embeddings': False,
'chunk_size_feed_forward': 0,
'is_encoder_decoder': False,
'is_decoder': False,
'cross_attention_hidden_size': None,
'add_cross_attention': False,
'tie_encoder_decoder': False,
'max_length': 20,
'min_length': 0,
'do_sample': False,
'early_stopping': False,
'num_beams': 1,
'num_beam_groups': 1,
'diversity_penalty': 0.0,
'temperature': 1.0,
'top_k': 50,
'top_p': 1.0,
'typical_p': 1.0,
'repetition_penalty': 1.0,
'length_penalty': 1.0,
'no_repeat_ngram_size': 0,
'encoder_no_repeat_ngram_size': 0,
'bad_words_ids': None,
'num_return_sequences': 1,
'output_scores': False,
'return_dict_in_generate': False,
'forced_bos_token_id': None,
'forced_eos_token_id': None,
'remove_invalid_values': False,
'exponential_decay_length_penalty': None,
'suppress_tokens': None,
'begin_suppress_tokens': None,
'finetuning_task': None,
'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
'tokenizer_class': None,
'prefix': None,
'bos_token_id': 1,
'pad_token_id': 32000,
'eos_token_id': 32000,
'sep_token_id': None,
'decoder_start_token_id': None,
'task_specific_params': None,
'problem_type': None,
'model_type': 'phi3'
}
logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
for key in ['initial_tokenizer_len', 'pad_token_id']:
if key not in self.text_config.to_dict():
raise ValueError(f"The key `{key}` is missing in the text_config.")
super().__init__(**kwargs)
@classmethod
def from_vision_encoder_vision_tokenizer_text_configs(
cls,
vision_encoder_config: XGenMMVisionEncoderConfig,
vision_tokenizer_config: XGenMMVisionTokenizerConfig,
text_config: PretrainedConfig,
**kwargs):
return cls(
vision_encoder_config=vision_encoder_config.to_dict(),
vision_tokenizer_config=vision_tokenizer_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
)