from transformers import VisionTextDualEncoderConfig class VTDEConfig(VisionTextDualEncoderConfig): def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, text_pooling_mode='mean', vision_pooling_mode='max', **kwargs): """ pooling_mode in ['mean', 'max', 'cls'] https://arxiv.org/pdf/2210.09996.pdf https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56 """ self.text_pooling_mode = text_pooling_mode self.vision_pooling_mode = vision_pooling_mode super().__init__(projection_dim, logit_scale_init_value, **kwargs)