|
|
|
|
|
import logging |
|
import math |
|
import os |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import CLIPTokenizer, T5TokenizerFast |
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention(q, k, v, heads, mask=None): |
|
"""Convenience wrapper around a basic attention operation""" |
|
b, _, dim_head = q.shape |
|
dim_head //= heads |
|
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) |
|
out = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False |
|
) |
|
return out.transpose(1, 2).reshape(b, -1, heads * dim_head) |
|
|
|
|
|
class Mlp(nn.Module): |
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
bias=True, |
|
dtype=None, |
|
device=None, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
|
|
self.fc1 = nn.Linear( |
|
in_features, hidden_features, bias=bias, dtype=dtype, device=device |
|
) |
|
self.act = act_layer |
|
self.fc2 = nn.Linear( |
|
hidden_features, out_features, bias=bias, dtype=dtype, device=device |
|
) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPAttention(torch.nn.Module): |
|
def __init__(self, embed_dim, heads, dtype, device): |
|
super().__init__() |
|
self.heads = heads |
|
self.q_proj = nn.Linear( |
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device |
|
) |
|
self.k_proj = nn.Linear( |
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device |
|
) |
|
self.v_proj = nn.Linear( |
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device |
|
) |
|
self.out_proj = nn.Linear( |
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
q = self.q_proj(x) |
|
k = self.k_proj(x) |
|
v = self.v_proj(x) |
|
out = attention(q, k, v, self.heads, mask) |
|
return self.out_proj(out) |
|
|
|
|
|
ACTIVATIONS = { |
|
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), |
|
"gelu": torch.nn.functional.gelu, |
|
} |
|
|
|
|
|
class CLIPLayer(torch.nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dtype, |
|
device, |
|
): |
|
super().__init__() |
|
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) |
|
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) |
|
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) |
|
|
|
self.mlp = Mlp( |
|
embed_dim, |
|
intermediate_size, |
|
embed_dim, |
|
act_layer=ACTIVATIONS[intermediate_activation], |
|
dtype=dtype, |
|
device=device, |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
x += self.self_attn(self.layer_norm1(x), mask) |
|
x += self.mlp(self.layer_norm2(x)) |
|
return x |
|
|
|
|
|
class CLIPEncoder(torch.nn.Module): |
|
def __init__( |
|
self, |
|
num_layers, |
|
embed_dim, |
|
heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dtype, |
|
device, |
|
): |
|
super().__init__() |
|
self.layers = torch.nn.ModuleList( |
|
[ |
|
CLIPLayer( |
|
embed_dim, |
|
heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dtype, |
|
device, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, x, mask=None, intermediate_output=None): |
|
if intermediate_output is not None: |
|
if intermediate_output < 0: |
|
intermediate_output = len(self.layers) + intermediate_output |
|
intermediate = None |
|
for i, l in enumerate(self.layers): |
|
x = l(x, mask) |
|
if i == intermediate_output: |
|
intermediate = x.clone() |
|
return x, intermediate |
|
|
|
|
|
class CLIPEmbeddings(torch.nn.Module): |
|
def __init__( |
|
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None |
|
): |
|
super().__init__() |
|
self.token_embedding = torch.nn.Embedding( |
|
vocab_size, embed_dim, dtype=dtype, device=device |
|
) |
|
self.position_embedding = torch.nn.Embedding( |
|
num_positions, embed_dim, dtype=dtype, device=device |
|
) |
|
|
|
def forward(self, input_tokens): |
|
return self.token_embedding(input_tokens) + self.position_embedding.weight |
|
|
|
|
|
class CLIPTextModel_(torch.nn.Module): |
|
def __init__(self, config_dict, dtype, device): |
|
num_layers = config_dict["num_hidden_layers"] |
|
embed_dim = config_dict["hidden_size"] |
|
heads = config_dict["num_attention_heads"] |
|
intermediate_size = config_dict["intermediate_size"] |
|
intermediate_activation = config_dict["hidden_act"] |
|
super().__init__() |
|
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) |
|
self.encoder = CLIPEncoder( |
|
num_layers, |
|
embed_dim, |
|
heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dtype, |
|
device, |
|
) |
|
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) |
|
|
|
def forward( |
|
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True |
|
): |
|
x = self.embeddings(input_tokens) |
|
causal_mask = ( |
|
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) |
|
.fill_(float("-inf")) |
|
.triu_(1) |
|
) |
|
x, i = self.encoder( |
|
x, mask=causal_mask, intermediate_output=intermediate_output |
|
) |
|
x = self.final_layer_norm(x) |
|
if i is not None and final_layer_norm_intermediate: |
|
i = self.final_layer_norm(i) |
|
pooled_output = x[ |
|
torch.arange(x.shape[0], device=x.device), |
|
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), |
|
] |
|
return x, i, pooled_output |
|
|
|
|
|
class CLIPTextModel(torch.nn.Module): |
|
def __init__(self, config_dict, dtype, device): |
|
super().__init__() |
|
self.num_layers = config_dict["num_hidden_layers"] |
|
self.text_model = CLIPTextModel_(config_dict, dtype, device) |
|
embed_dim = config_dict["hidden_size"] |
|
self.text_projection = nn.Linear( |
|
embed_dim, embed_dim, bias=False, dtype=dtype, device=device |
|
) |
|
self.text_projection.weight.copy_(torch.eye(embed_dim)) |
|
self.dtype = dtype |
|
|
|
def get_input_embeddings(self): |
|
return self.text_model.embeddings.token_embedding |
|
|
|
def set_input_embeddings(self, embeddings): |
|
self.text_model.embeddings.token_embedding = embeddings |
|
|
|
def forward(self, *args, **kwargs): |
|
x = self.text_model(*args, **kwargs) |
|
out = self.text_projection(x[2]) |
|
return (x[0], x[1], out, x[2]) |
|
|
|
|
|
def parse_parentheses(string): |
|
result = [] |
|
current_item = "" |
|
nesting_level = 0 |
|
for char in string: |
|
if char == "(": |
|
if nesting_level == 0: |
|
if current_item: |
|
result.append(current_item) |
|
current_item = "(" |
|
else: |
|
current_item = "(" |
|
else: |
|
current_item += char |
|
nesting_level += 1 |
|
elif char == ")": |
|
nesting_level -= 1 |
|
if nesting_level == 0: |
|
result.append(current_item + ")") |
|
current_item = "" |
|
else: |
|
current_item += char |
|
else: |
|
current_item += char |
|
if current_item: |
|
result.append(current_item) |
|
return result |
|
|
|
|
|
def token_weights(string, current_weight): |
|
a = parse_parentheses(string) |
|
out = [] |
|
for x in a: |
|
weight = current_weight |
|
if len(x) >= 2 and x[-1] == ")" and x[0] == "(": |
|
x = x[1:-1] |
|
xx = x.rfind(":") |
|
weight *= 1.1 |
|
if xx > 0: |
|
try: |
|
weight = float(x[xx + 1 :]) |
|
x = x[:xx] |
|
except: |
|
pass |
|
out += token_weights(x, weight) |
|
else: |
|
out += [(x, current_weight)] |
|
return out |
|
|
|
|
|
def escape_important(text): |
|
text = text.replace("\\)", "\0\1") |
|
text = text.replace("\\(", "\0\2") |
|
return text |
|
|
|
|
|
def unescape_important(text): |
|
text = text.replace("\0\1", ")") |
|
text = text.replace("\0\2", "(") |
|
return text |
|
|
|
|
|
class SDTokenizer: |
|
def __init__( |
|
self, |
|
max_length=77, |
|
pad_with_end=True, |
|
tokenizer=None, |
|
has_start_token=True, |
|
pad_to_max_length=True, |
|
min_length=None, |
|
extra_padding_token=None, |
|
): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.min_length = min_length |
|
|
|
empty = self.tokenizer("")["input_ids"] |
|
if has_start_token: |
|
self.tokens_start = 1 |
|
self.start_token = empty[0] |
|
self.end_token = empty[1] |
|
else: |
|
self.tokens_start = 0 |
|
self.start_token = None |
|
self.end_token = empty[0] |
|
self.pad_with_end = pad_with_end |
|
self.pad_to_max_length = pad_to_max_length |
|
self.extra_padding_token = extra_padding_token |
|
|
|
vocab = self.tokenizer.get_vocab() |
|
self.inv_vocab = {v: k for k, v in vocab.items()} |
|
self.max_word_length = 8 |
|
|
|
def tokenize_with_weights(self, text: str, return_word_ids=False): |
|
""" |
|
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. |
|
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3. |
|
""" |
|
if self.pad_with_end: |
|
pad_token = self.end_token |
|
else: |
|
pad_token = 0 |
|
|
|
text = escape_important(text) |
|
parsed_weights = token_weights(text, 1.0) |
|
|
|
|
|
tokens = [] |
|
for weighted_segment, weight in parsed_weights: |
|
to_tokenize = ( |
|
unescape_important(weighted_segment).replace("\n", " ").split(" ") |
|
) |
|
to_tokenize = [x for x in to_tokenize if x != ""] |
|
for word in to_tokenize: |
|
|
|
tokens.append( |
|
[ |
|
(t, weight) |
|
for t in self.tokenizer(word)["input_ids"][ |
|
self.tokens_start : -1 |
|
] |
|
] |
|
) |
|
|
|
|
|
batched_tokens = [] |
|
batch = [] |
|
if self.start_token is not None: |
|
batch.append((self.start_token, 1.0, 0)) |
|
batched_tokens.append(batch) |
|
for i, t_group in enumerate(tokens): |
|
|
|
is_large = len(t_group) >= self.max_word_length |
|
|
|
while len(t_group) > 0: |
|
if len(t_group) + len(batch) > self.max_length - 1: |
|
remaining_length = self.max_length - len(batch) - 1 |
|
|
|
if is_large: |
|
batch.extend( |
|
[(t, w, i + 1) for t, w in t_group[:remaining_length]] |
|
) |
|
batch.append((self.end_token, 1.0, 0)) |
|
t_group = t_group[remaining_length:] |
|
|
|
else: |
|
batch.append((self.end_token, 1.0, 0)) |
|
if self.pad_to_max_length: |
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) |
|
|
|
batch = [] |
|
if self.start_token is not None: |
|
batch.append((self.start_token, 1.0, 0)) |
|
batched_tokens.append(batch) |
|
else: |
|
batch.extend([(t, w, i + 1) for t, w in t_group]) |
|
t_group = [] |
|
|
|
|
|
if self.extra_padding_token is not None: |
|
batch.extend( |
|
[(self.extra_padding_token, 1.0, 0)] |
|
* (self.min_length - len(batch) - 1) |
|
) |
|
|
|
batch.append((self.end_token, 1.0, 0)) |
|
if self.pad_to_max_length: |
|
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) |
|
if self.min_length is not None and len(batch) < self.min_length: |
|
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) |
|
|
|
if not return_word_ids: |
|
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens] |
|
|
|
return batched_tokens |
|
|
|
def untokenize(self, token_weight_pair): |
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) |
|
|
|
|
|
class SDXLClipGTokenizer(SDTokenizer): |
|
def __init__(self, tokenizer): |
|
super().__init__(pad_with_end=False, tokenizer=tokenizer) |
|
|
|
|
|
class SD3Tokenizer: |
|
def __init__(self): |
|
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) |
|
self.clip_g = SDXLClipGTokenizer(clip_tokenizer) |
|
self.t5xxl = T5XXLTokenizer() |
|
|
|
def tokenize_with_weights(self, text: str): |
|
out = {} |
|
out["l"] = self.clip_l.tokenize_with_weights(text) |
|
out["g"] = self.clip_g.tokenize_with_weights(text) |
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226]) |
|
return out |
|
|
|
|
|
class ClipTokenWeightEncoder: |
|
def encode_token_weights(self, token_weight_pairs): |
|
tokens = list(map(lambda a: a[0], token_weight_pairs[0])) |
|
out, pooled = self([tokens]) |
|
if pooled is not None: |
|
first_pooled = pooled[0:1].cpu() |
|
else: |
|
first_pooled = pooled |
|
output = [out[0:1]] |
|
return torch.cat(output, dim=-2).cpu(), first_pooled |
|
|
|
|
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): |
|
"""Uses the CLIP transformer encoder for text (from huggingface)""" |
|
|
|
LAYERS = ["last", "pooled", "hidden"] |
|
|
|
def __init__( |
|
self, |
|
device="cpu", |
|
max_length=77, |
|
layer="last", |
|
layer_idx=None, |
|
textmodel_json_config=None, |
|
dtype=None, |
|
model_class=CLIPTextModel, |
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, |
|
layer_norm_hidden_state=True, |
|
return_projected_pooled=True, |
|
): |
|
super().__init__() |
|
assert layer in self.LAYERS |
|
self.transformer = model_class(textmodel_json_config, dtype, device) |
|
self.num_layers = self.transformer.num_layers |
|
self.max_length = max_length |
|
self.transformer = self.transformer.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
self.layer = layer |
|
self.layer_idx = None |
|
self.special_tokens = special_tokens |
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) |
|
self.layer_norm_hidden_state = layer_norm_hidden_state |
|
self.return_projected_pooled = return_projected_pooled |
|
if layer == "hidden": |
|
assert layer_idx is not None |
|
assert abs(layer_idx) < self.num_layers |
|
self.set_clip_options({"layer": layer_idx}) |
|
self.options_default = ( |
|
self.layer, |
|
self.layer_idx, |
|
self.return_projected_pooled, |
|
) |
|
|
|
def set_clip_options(self, options): |
|
layer_idx = options.get("layer", self.layer_idx) |
|
self.return_projected_pooled = options.get( |
|
"projected_pooled", self.return_projected_pooled |
|
) |
|
if layer_idx is None or abs(layer_idx) > self.num_layers: |
|
self.layer = "last" |
|
else: |
|
self.layer = "hidden" |
|
self.layer_idx = layer_idx |
|
|
|
def forward(self, tokens): |
|
backup_embeds = self.transformer.get_input_embeddings() |
|
device = backup_embeds.weight.device |
|
tokens = torch.LongTensor(tokens).to(device) |
|
outputs = self.transformer( |
|
tokens, |
|
intermediate_output=self.layer_idx, |
|
final_layer_norm_intermediate=self.layer_norm_hidden_state, |
|
) |
|
self.transformer.set_input_embeddings(backup_embeds) |
|
if self.layer == "last": |
|
z = outputs[0] |
|
else: |
|
z = outputs[1] |
|
pooled_output = None |
|
if len(outputs) >= 3: |
|
if ( |
|
not self.return_projected_pooled |
|
and len(outputs) >= 4 |
|
and outputs[3] is not None |
|
): |
|
pooled_output = outputs[3].float() |
|
elif outputs[2] is not None: |
|
pooled_output = outputs[2].float() |
|
return z.float(), pooled_output |
|
|
|
|
|
class SDXLClipG(SDClipModel): |
|
"""Wraps the CLIP-G model into the SD-CLIP-Model interface""" |
|
|
|
def __init__( |
|
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None |
|
): |
|
if layer == "penultimate": |
|
layer = "hidden" |
|
layer_idx = -2 |
|
super().__init__( |
|
device=device, |
|
layer=layer, |
|
layer_idx=layer_idx, |
|
textmodel_json_config=config, |
|
dtype=dtype, |
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, |
|
layer_norm_hidden_state=False, |
|
) |
|
|
|
|
|
class T5XXLModel(SDClipModel): |
|
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" |
|
|
|
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): |
|
super().__init__( |
|
device=device, |
|
layer=layer, |
|
layer_idx=layer_idx, |
|
textmodel_json_config=config, |
|
dtype=dtype, |
|
special_tokens={"end": 1, "pad": 0}, |
|
model_class=T5, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5XXLTokenizer(SDTokenizer): |
|
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" |
|
|
|
def __init__(self): |
|
super().__init__( |
|
pad_with_end=False, |
|
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), |
|
has_start_token=False, |
|
pad_to_max_length=False, |
|
max_length=99999999, |
|
min_length=77, |
|
) |
|
|
|
|
|
class T5LayerNorm(torch.nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): |
|
super().__init__() |
|
self.weight = torch.nn.Parameter( |
|
torch.ones(hidden_size, dtype=dtype, device=device) |
|
) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, x): |
|
variance = x.pow(2).mean(-1, keepdim=True) |
|
x = x * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight.to(device=x.device, dtype=x.dtype) * x |
|
|
|
|
|
class T5DenseGatedActDense(torch.nn.Module): |
|
def __init__(self, model_dim, ff_dim, dtype, device): |
|
super().__init__() |
|
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) |
|
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) |
|
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) |
|
|
|
def forward(self, x): |
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") |
|
hidden_linear = self.wi_1(x) |
|
x = hidden_gelu * hidden_linear |
|
x = self.wo(x) |
|
return x |
|
|
|
|
|
class T5LayerFF(torch.nn.Module): |
|
def __init__(self, model_dim, ff_dim, dtype, device): |
|
super().__init__() |
|
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) |
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) |
|
|
|
def forward(self, x): |
|
forwarded_states = self.layer_norm(x) |
|
forwarded_states = self.DenseReluDense(forwarded_states) |
|
x += forwarded_states |
|
return x |
|
|
|
|
|
class T5Attention(torch.nn.Module): |
|
def __init__( |
|
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device |
|
): |
|
super().__init__() |
|
|
|
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) |
|
self.num_heads = num_heads |
|
self.relative_attention_bias = None |
|
if relative_attention_bias: |
|
self.relative_attention_num_buckets = 32 |
|
self.relative_attention_max_distance = 128 |
|
self.relative_attention_bias = torch.nn.Embedding( |
|
self.relative_attention_num_buckets, self.num_heads, device=device |
|
) |
|
|
|
@staticmethod |
|
def _relative_position_bucket( |
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128 |
|
): |
|
""" |
|
Adapted from Mesh Tensorflow: |
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 |
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as |
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to |
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for |
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative |
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. |
|
This should allow for more graceful generalization to longer sequences than the model has been trained on |
|
|
|
Args: |
|
relative_position: an int32 Tensor |
|
bidirectional: a boolean - whether the attention is bidirectional |
|
num_buckets: an integer |
|
max_distance: an integer |
|
|
|
Returns: |
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) |
|
""" |
|
relative_buckets = 0 |
|
if bidirectional: |
|
num_buckets //= 2 |
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets |
|
relative_position = torch.abs(relative_position) |
|
else: |
|
relative_position = -torch.min( |
|
relative_position, torch.zeros_like(relative_position) |
|
) |
|
|
|
|
|
max_exact = num_buckets // 2 |
|
is_small = relative_position < max_exact |
|
|
|
relative_position_if_large = max_exact + ( |
|
torch.log(relative_position.float() / max_exact) |
|
/ math.log(max_distance / max_exact) |
|
* (num_buckets - max_exact) |
|
).to(torch.long) |
|
relative_position_if_large = torch.min( |
|
relative_position_if_large, |
|
torch.full_like(relative_position_if_large, num_buckets - 1), |
|
) |
|
relative_buckets += torch.where( |
|
is_small, relative_position, relative_position_if_large |
|
) |
|
return relative_buckets |
|
|
|
def compute_bias(self, query_length, key_length, device): |
|
"""Compute binned relative position bias""" |
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[ |
|
:, None |
|
] |
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ |
|
None, : |
|
] |
|
relative_position = ( |
|
memory_position - context_position |
|
) |
|
relative_position_bucket = self._relative_position_bucket( |
|
relative_position, |
|
bidirectional=True, |
|
num_buckets=self.relative_attention_num_buckets, |
|
max_distance=self.relative_attention_max_distance, |
|
) |
|
values = self.relative_attention_bias( |
|
relative_position_bucket |
|
) |
|
values = values.permute([2, 0, 1]).unsqueeze( |
|
0 |
|
) |
|
return values |
|
|
|
def forward(self, x, past_bias=None): |
|
q = self.q(x) |
|
k = self.k(x) |
|
v = self.v(x) |
|
if self.relative_attention_bias is not None: |
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) |
|
if past_bias is not None: |
|
mask = past_bias |
|
out = attention( |
|
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask |
|
) |
|
return self.o(out), past_bias |
|
|
|
|
|
class T5LayerSelfAttention(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model_dim, |
|
inner_dim, |
|
ff_dim, |
|
num_heads, |
|
relative_attention_bias, |
|
dtype, |
|
device, |
|
): |
|
super().__init__() |
|
self.SelfAttention = T5Attention( |
|
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device |
|
) |
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) |
|
|
|
def forward(self, x, past_bias=None): |
|
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) |
|
x += output |
|
return x, past_bias |
|
|
|
|
|
class T5Block(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model_dim, |
|
inner_dim, |
|
ff_dim, |
|
num_heads, |
|
relative_attention_bias, |
|
dtype, |
|
device, |
|
): |
|
super().__init__() |
|
self.layer = torch.nn.ModuleList() |
|
self.layer.append( |
|
T5LayerSelfAttention( |
|
model_dim, |
|
inner_dim, |
|
ff_dim, |
|
num_heads, |
|
relative_attention_bias, |
|
dtype, |
|
device, |
|
) |
|
) |
|
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) |
|
|
|
def forward(self, x, past_bias=None): |
|
x, past_bias = self.layer[0](x, past_bias) |
|
x = self.layer[-1](x) |
|
return x, past_bias |
|
|
|
|
|
class T5Stack(torch.nn.Module): |
|
def __init__( |
|
self, |
|
num_layers, |
|
model_dim, |
|
inner_dim, |
|
ff_dim, |
|
num_heads, |
|
vocab_size, |
|
dtype, |
|
device, |
|
): |
|
super().__init__() |
|
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) |
|
self.block = torch.nn.ModuleList( |
|
[ |
|
T5Block( |
|
model_dim, |
|
inner_dim, |
|
ff_dim, |
|
num_heads, |
|
relative_attention_bias=(i == 0), |
|
dtype=dtype, |
|
device=device, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) |
|
|
|
def forward( |
|
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True |
|
): |
|
intermediate = None |
|
x = self.embed_tokens(input_ids) |
|
past_bias = None |
|
for i, l in enumerate(self.block): |
|
x, past_bias = l(x, past_bias) |
|
if i == intermediate_output: |
|
intermediate = x.clone() |
|
x = self.final_layer_norm(x) |
|
if intermediate is not None and final_layer_norm_intermediate: |
|
intermediate = self.final_layer_norm(intermediate) |
|
return x, intermediate |
|
|
|
|
|
class T5(torch.nn.Module): |
|
def __init__(self, config_dict, dtype, device): |
|
super().__init__() |
|
self.num_layers = config_dict["num_layers"] |
|
self.encoder = T5Stack( |
|
self.num_layers, |
|
config_dict["d_model"], |
|
config_dict["d_model"], |
|
config_dict["d_ff"], |
|
config_dict["num_heads"], |
|
config_dict["vocab_size"], |
|
dtype, |
|
device, |
|
) |
|
self.dtype = dtype |
|
|
|
def get_input_embeddings(self): |
|
return self.encoder.embed_tokens |
|
|
|
def set_input_embeddings(self, embeddings): |
|
self.encoder.embed_tokens = embeddings |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.encoder(*args, **kwargs) |
|
|