File size: 4,492 Bytes
4974490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
import torch
import open_clip
from typing import List, Optional, Tuple, Union
from utils import check_embedding_fns
from vlm import PerceiverResampler, XGenMMPerceiver
from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig

class XGenMMVisionEncoder(PreTrainedModel):
    main_input_name = "pixel_values"
    config_class = XGenMMVisionEncoderConfig
    
    def __init__(self, config: XGenMMVisionEncoderConfig):
        super().__init__(config)
        if config.model_name != 'google/siglip-so400m-patch14-384':
            raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
        self.model = AutoModel.from_pretrained(config.model_name)
        
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
        return self.model.encode_image(pixel_values)
        

# vision tokenizer        
class XGenMMVisionTokenizer(PreTrainedModel):
    config_class = XGenMMVisionTokenizerConfig
    def __init__(self, config: XGenMMVisionTokenizerConfig):
        super().__init__(config)
        self.model = PerceiverResampler(
            dim=config.vis_feature_dim,
            dim_inner=config.lang_embedding_dim,
            num_latents=config.num_vis_tokens,
        )
        
    def forward(self, 
                vision_features: torch.Tensor, 
                vision_attn_masks: torch.Tensor):
        return self.model(vision_features, vision_attn_masks)
    
# XGenMM model
class XGenMMModelForConditionalGeneration(PreTrainedModel):
    config_class = XGenMMConfig
    
    def __init__(self, config: XGenMMConfig):
        super().__init__(config)
        
        # vision encoder initialization
        vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
        
        # language model initialization    
        language_model = AutoModelForCausalLM.from_config(config.text_config)
        check_embedding_fns(language_model)
        # Update _tied_weights_keys using the base model used.
        if language_model._tied_weights_keys is not None:
            self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
        
        # vision tokenizer initialization
        if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
            overwrite = language_model.get_input_embeddings().weight.shape[1]
            config.vision_tokenizer_config.lang_embedding_dim = overwrite
            print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
            
        vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model

        self.vlm = XGenMMPerceiver(
            vision_encoder=vision_encoder,
            vision_tokenizer=vision_tokenizer,
            lang_model=language_model,
            initial_tokenizer_len = config.text_config.initial_tokenizer_len,
            pad_token_id = config.text_config.pad_token_id,
            image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
        )
        # Initialize weights and apply final processing
        self.post_init()
        
    @torch.no_grad()
    def generate(
        self,
        pixel_values: torch.FloatTensor,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        **generate_kwargs,
        ) -> torch.LongTensor:
        self.vlm = self.vlm.eval()
        return self.vlm.generate(
            vision_x = pixel_values, 
            lang_x = input_ids, 
            attention_mask = attention_mask, 
            **generate_kwargs)
        
    def update_special_tokens(self, tokenizer):
        tokenizer.add_special_tokens(
            {"additional_special_tokens": list(self.vlm.special_tokens.values())}
        )
        self.vlm.lang_model.config.vocab_size = len(tokenizer)
        self.vlm.set_special_token_ids(
            {
                v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
            }
        )
        return tokenizer