Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,752 Bytes
ceb87f6 97214bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from typing import ClassVar
import torch
from torch import nn
from modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel
from configuration_florence2 import Florence2Config
class ColFlor2Old(Florence2ForConditionalGeneration):
"""
ColFlor2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
"""
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config: Florence2Config, use_cache=False):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
# Now initialize weights properly
self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
self.custom_text_proj.bias.data.zero_()
self.padding_side = "right"
self.post_init()
def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)
# TO BE DELETED
kwargs['decoder_input_ids'] = kwargs['input_ids']
# Create Full Attention Mask that includes the image
if 'full_attention_mask' in kwargs:
full_attention_mask = kwargs['full_attention_mask']
del kwargs['full_attention_mask']
else:
full_attention_mask = kwargs['attention_mask']
outputs = super().forward(*args,
**kwargs) # (batch_size, sequence_length, hidden_size)
last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return proj
class ColFlor(Florence2VisionLanguageModel):
"""
ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
"""
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config: Florence2Config, use_cache=False):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
# Now initialize weights properly
self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
self.custom_text_proj.bias.data.zero_()
self.padding_side = "right"
self.post_init()
def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)
# Create Full Attention Mask that includes both the image and text
if 'full_attention_mask' in kwargs:
full_attention_mask = kwargs['full_attention_mask']
del kwargs['full_attention_mask']
else:
full_attention_mask = kwargs['attention_mask']
outputs = super().forward(*args,
**kwargs) # (batch_size, sequence_length, hidden_size)
last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return proj |