Spaces:
Runtime error
Runtime error
import torch | |
import transformers | |
from transformers import T5Tokenizer, T5EncoderModel, T5Config | |
transformers.logging.set_verbosity_error() | |
def exists(val): | |
return val is not None | |
# config | |
MAX_LENGTH = 256 | |
DEFAULT_T5_NAME = 'google/t5-v1_1-base' | |
T5_CONFIGS = {} | |
# singleton globals | |
def get_tokenizer(name): | |
tokenizer = T5Tokenizer.from_pretrained(name) | |
return tokenizer | |
def get_model(name): | |
model = T5EncoderModel.from_pretrained(name) | |
return model | |
def get_model_and_tokenizer(name): | |
global T5_CONFIGS | |
if name not in T5_CONFIGS: | |
T5_CONFIGS[name] = dict() | |
if "model" not in T5_CONFIGS[name]: | |
T5_CONFIGS[name]["model"] = get_model(name) | |
if "tokenizer" not in T5_CONFIGS[name]: | |
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) | |
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] | |
def get_encoded_dim(name): | |
if name not in T5_CONFIGS: | |
# avoids loading the model if we only want to get the dim | |
config = T5Config.from_pretrained(name) | |
T5_CONFIGS[name] = dict(config=config) | |
elif "config" in T5_CONFIGS[name]: | |
config = T5_CONFIGS[name]["config"] | |
elif "model" in T5_CONFIGS[name]: | |
config = T5_CONFIGS[name]["model"].config | |
else: | |
assert False | |
return config.d_model | |
class T5Encoder(torch.nn.Module): | |
def __init__(self, name=DEFAULT_T5_NAME, max_length=MAX_LENGTH, padding='longest', masked_mean=False): | |
super().__init__() | |
self.name = name | |
self.t5, self.tokenizer = get_model_and_tokenizer(name) | |
self.max_length = max_length | |
self.output_size = get_encoded_dim(name) | |
self.padding = padding | |
self.masked_mean = masked_mean | |
def forward(self, x, return_only_pooled=True): | |
encoded = self.tokenizer.batch_encode_plus( | |
x, | |
return_tensors = "pt", | |
padding = self.padding, | |
max_length = self.max_length, | |
truncation = True | |
) | |
device = next(self.t5.parameters()).device | |
input_ids = encoded.input_ids.to(device) | |
attn_mask = encoded.attention_mask.to(device).bool() | |
output = self.t5(input_ids = input_ids, attention_mask = attn_mask) | |
encoded_text = output.last_hidden_state.detach() | |
# return encoded_text[:, 0] | |
# print(input_ids) | |
# print(attn_mask) | |
#if self.masked_mean: | |
pooled = masked_mean(encoded_text, dim=1, mask=attn_mask) | |
if return_only_pooled: | |
return pooled | |
else: | |
return pooled, encoded_text, attn_mask | |
#else: | |
# return encoded_text.mean(dim=1) | |
from einops import rearrange | |
def masked_mean(t, *, dim, mask = None): | |
if not exists(mask): | |
return t.mean(dim = dim) | |
denom = mask.sum(dim = dim, keepdim = True) | |
mask = rearrange(mask, 'b n -> b n 1') | |
masked_t = t.masked_fill(~mask, 0.) | |
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) | |