Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from .xattn import CrossAttentionBlock | |
from .utils import getattr_recursive, setattr_recursive | |
class WrapperLayer(nn.Module): | |
""" | |
WrapperLayer is a wrapper around the CrossAttentionBlock and DecoderLayer. | |
""" | |
def __init__( | |
self, cross_attn_layer, decoder_layer, gradient_checkpointing=False | |
): | |
super().__init__() | |
self.cross_attn_layer = cross_attn_layer | |
self.decoder_layer = decoder_layer | |
self.vis_x = None | |
if self.cross_attn_layer is not None: | |
self.cross_attn_layer._use_gradient_checkpointing = ( | |
gradient_checkpointing | |
) | |
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing | |
def is_conditioned(self) -> bool: | |
"""Check whether the layer is conditioned.""" | |
return self.vis_x is not None | |
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) | |
def condition_vis_x(self, vis_x): | |
self.vis_x = vis_x | |
def forward( | |
self, | |
lang_x, | |
attention_mask=None, | |
**decoder_layer_kwargs, | |
): | |
# Cross attention | |
if self.cross_attn_layer is not None: | |
if self.vis_x is None: | |
raise ValueError("vis_x must be conditioned before forward pass") | |
lang_x = self.cross_attn_layer( | |
lang_x, | |
self.vis_x | |
) | |
# Normal decoder layer | |
lang_x = self.decoder_layer( | |
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs | |
) | |
return lang_x | |
class phEYELMMixin(nn.Module): | |
""" | |
Mixin to add cross-attention layers to a language model. | |
""" | |
def set_decoder_layers_attr_name(self, decoder_layers_attr_name): | |
self.decoder_layers_attr_name = decoder_layers_attr_name | |
def _get_decoder_layers(self): | |
return getattr_recursive(self, self.decoder_layers_attr_name) | |
def _set_decoder_layers(self, value): | |
setattr_recursive(self, self.decoder_layers_attr_name, value) | |
def init_pheye( | |
self, | |
lang_hidden_size, | |
vis_hidden_size, | |
dtype, | |
cross_attn_every_n_layers, | |
gradient_checkpointing, | |
reduce_factor=1, | |
from_layer=0 | |
): | |
""" | |
Initialize phEYE by adding a new cross attn to the decoder. | |
""" | |
self.old_decoder_blocks = self._get_decoder_layers() | |
self.cross_attn_layers = nn.ModuleList( | |
[ | |
CrossAttentionBlock( | |
dim_text=lang_hidden_size, dim_visual=vis_hidden_size, reduce_factor=reduce_factor, layer_idx=layer_idx, n_decoder_layers=len(self.old_decoder_blocks), dtype=dtype | |
) | |
if (layer_idx + 1) % cross_attn_every_n_layers == 0 and layer_idx >= from_layer | |
else None | |
for layer_idx, _ in enumerate(self._get_decoder_layers()) | |
] | |
) | |
self.init_pheye_layers(gradient_checkpointing) | |
self.initialized_pheye = True | |
self._use_cached_vision_x = False | |
def init_pheye_layers(self, gradient_checkpointing): | |
""" | |
Re initializes the WrapperLayers. | |
Propagates any changes made to self.cross_attn_layers or self.old_decoder_blocks | |
""" | |
self._set_decoder_layers( | |
nn.ModuleList( | |
[ | |
WrapperLayer( | |
cross_attn_layer, decoder_layer, gradient_checkpointing | |
) | |
for cross_attn_layer, decoder_layer in zip( | |
self.cross_attn_layers, self.old_decoder_blocks | |
) | |
] | |
) | |
) | |
def forward(self, input_ids, attention_mask, **kwargs): | |
if not self.initialized_pheye: | |
raise ValueError( | |
"phEYE layers are not initialized. Please call `init_pheye` first." | |
) | |
kwargs["input_ids"] = input_ids | |
kwargs["attention_mask"] = attention_mask | |
return super().forward(**kwargs) # Call the other parent's forward method | |
def is_conditioned(self) -> bool: | |
"""Check whether all decoder layers are already conditioned.""" | |
return all(l.is_conditioned() for l in self._get_decoder_layers()) | |
def clear_conditioned_layers(self): | |
for layer in self._get_decoder_layers(): | |
layer.condition_vis_x(None) |