LinCIR / encode_with_pseudo_tokens.py
Geonmo's picture
initial commit
cacafc1
raw history blame
No virus
2.52 kB
'''
LinCIR
Copyright (c) 2023-present NAVER Corp.
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
'''
import torch
from clip.model import CLIP
from transformers import CLIPTextModelWithProjection
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor,
num_tokens=1, return_last_states=False) -> torch.Tensor:
x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model]
x = torch.where(text.unsqueeze(-1) == 259,
pseudo_tokens.unsqueeze(1).type(clip_model.dtype),
x)
x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids)
_causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device)
x = clip_model.text_model.encoder(inputs_embeds=x,
attention_mask=None,
causal_attention_mask=_causal_attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=False)
x = x[0]
x_last = clip_model.text_model.final_layer_norm(x)
x = x_last[torch.arange(x_last.shape[0], device=x_last.device),
text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1),
]
if hasattr(clip_model, 'text_projection'):
x = clip_model.text_projection(x)
if return_last_states:
return x, x_last
else:
return x