from Multilingual_CLIP.multilingual_clip import Config_MCLIP import transformers import torch class MultilingualCLIP(transformers.PreTrainedModel): config_class = Config_MCLIP.MCLIPConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.transformer = transformers.AutoModel.from_pretrained(config.modelBase) self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions, out_features=config.numDims) def forward(self, txt, tokenizer, device): txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device) embs = self.transformer(**txt_tok) print(embs.keys()) embs = embs[0] att = txt_tok['attention_mask'] embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2) return self.LinearTransformation(embs) @classmethod def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): model.load_state_dict(state_dict) return model, [], [], []