''' LinCIR Copyright (c) 2023-present NAVER Corp. CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) ''' import torch from clip.model import CLIP from transformers import CLIPTextModelWithProjection def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693 """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor, num_tokens=1, return_last_states=False) -> torch.Tensor: x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model] x = torch.where(text.unsqueeze(-1) == 259, pseudo_tokens.unsqueeze(1).type(clip_model.dtype), x) x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids) _causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device) x = clip_model.text_model.encoder(inputs_embeds=x, attention_mask=None, causal_attention_mask=_causal_attention_mask, output_attentions=False, output_hidden_states=False, return_dict=False) x = x[0] x_last = clip_model.text_model.final_layer_norm(x) x = x_last[torch.arange(x_last.shape[0], device=x_last.device), text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1), ] if hasattr(clip_model, 'text_projection'): x = clip_model.text_projection(x) if return_last_states: return x, x_last else: return x