Mehdi Cherti
text to image support
c81908d
raw
history blame
2.99 kB
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config
transformers.logging.set_verbosity_error()
def exists(val):
return val is not None
# config
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}
# singleton globals
def get_tokenizer(name):
tokenizer = T5Tokenizer.from_pretrained(name)
return tokenizer
def get_model(name):
model = T5EncoderModel.from_pretrained(name)
return model
def get_model_and_tokenizer(name):
global T5_CONFIGS
if name not in T5_CONFIGS:
T5_CONFIGS[name] = dict()
if "model" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["model"] = get_model(name)
if "tokenizer" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
def get_encoded_dim(name):
if name not in T5_CONFIGS:
# avoids loading the model if we only want to get the dim
config = T5Config.from_pretrained(name)
T5_CONFIGS[name] = dict(config=config)
elif "config" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["config"]
elif "model" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["model"].config
else:
assert False
return config.d_model
class T5Encoder(torch.nn.Module):
def __init__(self, name=DEFAULT_T5_NAME, max_length=MAX_LENGTH, padding='longest', masked_mean=False):
super().__init__()
self.name = name
self.t5, self.tokenizer = get_model_and_tokenizer(name)
self.max_length = max_length
self.output_size = get_encoded_dim(name)
self.padding = padding
self.masked_mean = masked_mean
def forward(self, x, return_only_pooled=True):
encoded = self.tokenizer.batch_encode_plus(
x,
return_tensors = "pt",
padding = self.padding,
max_length = self.max_length,
truncation = True
)
device = next(self.t5.parameters()).device
input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device).bool()
output = self.t5(input_ids = input_ids, attention_mask = attn_mask)
encoded_text = output.last_hidden_state.detach()
# return encoded_text[:, 0]
# print(input_ids)
# print(attn_mask)
#if self.masked_mean:
pooled = masked_mean(encoded_text, dim=1, mask=attn_mask)
if return_only_pooled:
return pooled
else:
return pooled, encoded_text, attn_mask
#else:
# return encoded_text.mean(dim=1)
from einops import rearrange
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)