Spaces:
Running
Running
from typing import List | |
import torch.nn as nn | |
import os | |
import torch | |
import numpy as np | |
from torch import Tensor | |
from transformers import AutoTokenizer, AutoModel | |
from transformers import logging | |
from torch.nn.functional import normalize | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=5000): | |
super().__init__() | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe, persistent=False) | |
def forward(self, x): | |
return x + self.pe[:x.shape[0], :] | |
class TMR_textencoder(nn.Module): | |
def __init__(self, modelpath: str, latent_dim: int, ff_size: int, | |
num_layers: int, num_heads: int, activation: str, **kwargs) -> None: | |
super().__init__() | |
logging.set_verbosity_error() | |
# Tokenizer | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
self.tokenizer = AutoTokenizer.from_pretrained(modelpath) | |
# Text model | |
self.text_model = AutoModel.from_pretrained(modelpath) | |
# Then configure the model | |
self.text_encoded_dim = self.text_model.config.hidden_size | |
# Projection of the text-outputs into the latent space | |
self.projection = nn.Sequential( | |
nn.ReLU(), | |
nn.Linear(self.text_encoded_dim, latent_dim) | |
) | |
self.mu_token = nn.Parameter(torch.randn(latent_dim)) | |
self.logvar_token = nn.Parameter(torch.randn(latent_dim)) | |
self.sequence_pos_encoding = PositionalEncoding(latent_dim) | |
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=0.0, | |
activation=activation) | |
self.seqTransEncoder = nn.TransformerEncoder( | |
seq_trans_encoder_layer, | |
num_layers=num_layers | |
) | |
def get_last_hidden_state(self, texts: List[str], | |
return_mask: bool = False): | |
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) | |
output = self.text_model(**encoded_inputs.to(self.text_model.device)) | |
if not return_mask: | |
return output.last_hidden_state | |
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool) | |
def forward(self, texts: List[str]) -> Tensor: | |
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True) | |
x = self.projection(text_encoded) | |
bs, nframes, _ = x.shape | |
# bs, nframes, totjoints, nfeats = x.shape | |
# Switch sequence and batch_size because the input of | |
# Pytorch Transformer is [Sequence, Batch size, ...] | |
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] | |
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1) | |
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1) | |
# adding the distribution tokens for all sequences | |
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0) | |
# create a bigger mask, to allow attend to mu and logvar | |
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) | |
aug_mask = torch.cat((token_mask, mask), 1) | |
# add positional encoding | |
xseq = self.sequence_pos_encoding(xseq) | |
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) | |
# only mu for inference | |
mu = final[0] | |
return mu | |
# compute score for retrieval | |
def compute_scores(self, texts, unit_embs=None, embs=None): | |
# not both empty | |
assert not (unit_embs is None and embs is None) | |
# not both filled | |
assert not (unit_embs is not None and embs is not None) | |
output_str = False | |
# if one input, squeeze the output | |
if isinstance(texts, str): | |
texts = [texts] | |
output_str = True | |
# compute unit_embs from embs if not given | |
if embs is not None: | |
unit_embs = normalize(embs) | |
with torch.no_grad(): | |
latent_unit_texts = normalize(self(texts)) | |
# compute cosine similarity between 0 and 1 | |
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5 | |
scores = scores.cpu().numpy() | |
if output_str: | |
scores = scores[0] | |
return scores | |