Vikram Voleti
SD3.5 Large, SD3.5 Large Turbo
19bf11c
### This file contains impls for underlying related models (CLIP, T5, etc)
import logging
import math
import os
import torch
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
### Core/Utility
#################################################################################################
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
dtype=None,
device=None,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(
in_features, hidden_features, bias=bias, dtype=dtype, device=device
)
self.act = act_layer
self.fc2 = nn.Linear(
hidden_features, out_features, bias=bias, dtype=dtype, device=device
)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
#################################################################################################
### CLIP
#################################################################################################
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.k_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.v_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
def forward(self, x, mask=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPLayer(torch.nn.Module):
def __init__(
self,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(
embed_dim,
intermediate_size,
embed_dim,
act_layer=ACTIVATIONS[intermediate_activation],
dtype=dtype,
device=device,
)
def forward(self, x, mask=None):
x += self.self_attn(self.layer_norm1(x), mask)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(
self,
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
CLIPLayer(
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
for i in range(num_layers)
]
)
def forward(self, x, mask=None, intermediate_output=None):
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
):
super().__init__()
self.token_embedding = torch.nn.Embedding(
vocab_size, embed_dim, dtype=dtype, device=device
)
self.position_embedding = torch.nn.Embedding(
num_positions, embed_dim, dtype=dtype, device=device
)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
):
x = self.embeddings(input_tokens)
causal_mask = (
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
.fill_(float("-inf"))
.triu_(1)
)
x, i = self.encoder(
x, mask=causal_mask, intermediate_output=intermediate_output
)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[
torch.arange(x.shape[0], device=x.device),
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def parse_parentheses(string):
result = []
current_item = ""
nesting_level = 0
for char in string:
if char == "(":
if nesting_level == 0:
if current_item:
result.append(current_item)
current_item = "("
else:
current_item = "("
else:
current_item += char
nesting_level += 1
elif char == ")":
nesting_level -= 1
if nesting_level == 0:
result.append(current_item + ")")
current_item = ""
else:
current_item += char
else:
current_item += char
if current_item:
result.append(current_item)
return result
def token_weights(string, current_weight):
a = parse_parentheses(string)
out = []
for x in a:
weight = current_weight
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x = x[1:-1]
xx = x.rfind(":")
weight *= 1.1
if xx > 0:
try:
weight = float(x[xx + 1 :])
x = x[:xx]
except:
pass
out += token_weights(x, weight)
else:
out += [(x, current_weight)]
return out
def escape_important(text):
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
def unescape_important(text):
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
class SDTokenizer:
def __init__(
self,
max_length=77,
pad_with_end=True,
tokenizer=None,
has_start_token=True,
pad_to_max_length=True,
min_length=None,
extra_padding_token=None,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.extra_padding_token = extra_padding_token
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize_with_weights(self, text: str, return_word_ids=False):
"""
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
"""
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = (
unescape_important(weighted_segment).replace("\n", " ").split(" ")
)
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
# parse word
tokens.append(
[
(t, weight)
for t in self.tokenizer(word)["input_ids"][
self.tokens_start : -1
]
]
)
# reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
# determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
# break word in two and add end token
if is_large:
batch.extend(
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
)
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
# add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
# start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
# pad extra padding token first befor getting to the end token
if self.extra_padding_token is not None:
batch.extend(
[(self.extra_padding_token, 1.0, 0)]
* (self.min_length - len(batch) - 1)
)
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SDXLClipGTokenizer(SDTokenizer):
def __init__(self, tokenizer):
super().__init__(pad_with_end=False, tokenizer=tokenizer)
class SD3Tokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text: str):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text)
out["g"] = self.clip_g.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
return out
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
device="cpu",
max_length=77,
layer="last",
layer_idx=None,
textmodel_json_config=None,
dtype=None,
model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
layer_norm_hidden_state=True,
return_projected_pooled=True,
):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})
self.options_default = (
self.layer,
self.layer_idx,
self.return_projected_pooled,
)
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get(
"projected_pooled", self.return_projected_pooled
)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = torch.LongTensor(tokens).to(device)
outputs = self.transformer(
tokens,
intermediate_output=self.layer_idx,
final_layer_norm_intermediate=self.layer_norm_hidden_state,
)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if (
not self.return_projected_pooled
and len(outputs) >= 4
and outputs[3] is not None
):
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
class SDXLClipG(SDClipModel):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
def __init__(
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0},
layer_norm_hidden_state=False,
)
class T5XXLModel(SDClipModel):
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"end": 1, "pad": 0},
model_class=T5,
)
#################################################################################################
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def __init__(self):
super().__init__(
pad_with_end=False,
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
has_start_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=77,
)
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
super().__init__()
self.weight = torch.nn.Parameter(
torch.ones(hidden_size, dtype=dtype, device=device)
)
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, self.num_heads, device=device
)
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, past_bias=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
mask = past_bias
out = attention(
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.SelfAttention = T5Attention(
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x, past_bias=None):
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
)
)
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
def forward(self, x, past_bias=None):
x, past_bias = self.layer[0](x, past_bias)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(
self,
num_layers,
model_dim,
inner_dim,
ff_dim,
num_heads,
vocab_size,
dtype,
device,
):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
self.block = torch.nn.ModuleList(
[
T5Block(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias=(i == 0),
dtype=dtype,
device=device,
)
for i in range(num_layers)
]
)
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
):
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(
self.num_layers,
config_dict["d_model"],
config_dict["d_model"],
config_dict["d_ff"],
config_dict["num_heads"],
config_dict["vocab_size"],
dtype,
device,
)
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)