|
|
|
from math import floor, ceil |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from axial_positional_embedding import AxialPositionalEmbedding |
|
from einops import rearrange |
|
from celle.utils import ( |
|
exists, |
|
always, |
|
eval_decorator, |
|
gumbel_sample, |
|
top_k, |
|
gamma_func, |
|
DivideMax, |
|
) |
|
from tqdm import tqdm |
|
|
|
|
|
from celle.transformer import Transformer |
|
|
|
|
|
def generate_mask(gamma_func, batch_size, length, device): |
|
|
|
num_true_values = floor(gamma_func(torch.rand(1)) * length) |
|
|
|
|
|
|
|
indices = ( |
|
torch.rand((batch_size, length), device=device) |
|
.topk(num_true_values, dim=1) |
|
.indices |
|
) |
|
|
|
|
|
mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device) |
|
mask.scatter_(dim=1, index=indices, value=True) |
|
|
|
return mask |
|
|
|
|
|
def match_batch_size(text, condition, image, batch_size): |
|
""" |
|
This function ensures all inputs to the sample function have the same batch size. |
|
""" |
|
if text.shape[0] != batch_size: |
|
text = text.repeat(batch_size, 1) |
|
|
|
if condition.shape[0] != batch_size: |
|
condition = condition.repeat(batch_size, 1) |
|
|
|
if image.shape[0] != batch_size: |
|
image = image.repeat(batch_size, 1) |
|
|
|
return text, condition, image |
|
|
|
|
|
def calc_unmask_probs(timestep, timesteps, gamma_func): |
|
if timestep == 1 or timesteps == 1: |
|
unmask_prob = 1 |
|
else: |
|
unmask_prob = 1 - gamma_func(timestep) |
|
return unmask_prob |
|
|
|
|
|
def calculate_logits( |
|
input_tokens, input_mask, logits_function, filter_thres, temperature |
|
): |
|
logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False) |
|
filtered_logits = top_k(logits, thres=filter_thres) |
|
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
|
|
|
return logits, sample |
|
|
|
|
|
def unmask_tokens( |
|
input_tokens, |
|
input_mask, |
|
num_masked_tokens, |
|
logits, |
|
sample, |
|
timestep, |
|
timesteps, |
|
gamma, |
|
filter_func=None, |
|
pad_token=None, |
|
mask_token=None, |
|
force_aas=True, |
|
): |
|
sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf) |
|
if filter_func: |
|
sample = filter_func( |
|
input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token |
|
) |
|
selected_token_probs, selected_tokens = torch.max(sample, dim=-1) |
|
|
|
unmask_prob = calc_unmask_probs(timestep, timesteps, gamma) |
|
num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens)) |
|
|
|
_, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1) |
|
|
|
sample_mask = torch.zeros( |
|
input_tokens.shape, dtype=torch.bool, device=input_tokens.device |
|
) |
|
sample_mask.scatter_(dim=1, index=top_k_indices, value=True) |
|
|
|
unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens) |
|
full_logits = torch.where( |
|
sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits) |
|
) |
|
return unmasked_tokens, full_logits |
|
|
|
|
|
def suppress_invalid_text_tokens( |
|
text, |
|
logits, |
|
start_token=None, |
|
end_token=None, |
|
pad_token=None, |
|
mask_token=None, |
|
force_aas=False, |
|
): |
|
|
|
idx_start = (text == start_token).nonzero(as_tuple=True)[1] |
|
idx_end = (text == end_token).nonzero(as_tuple=True)[1] |
|
|
|
|
|
if idx_start.nelement() != start_token: |
|
try: |
|
mask = idx_start.unsqueeze(1) != torch.arange( |
|
logits.size(1), device=text.device |
|
) |
|
indices = torch.where(mask) |
|
logits[indices[0], indices[1], start_token] = -torch.inf |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if idx_end.nelement() != 0: |
|
try: |
|
mask = idx_end.unsqueeze(1) != torch.arange( |
|
logits.size(1), device=text.device |
|
) |
|
indices = torch.where(mask) |
|
logits[indices[0], indices[1], end_token] = -torch.inf |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
if pad_token: |
|
if idx_start.nelement() != 0 and idx_end.nelement() != 0: |
|
try: |
|
|
|
mask = ( |
|
torch.arange(logits.size(1), device=text.device) |
|
>= idx_start.unsqueeze(1) |
|
) & ( |
|
torch.arange(logits.size(1), device=text.device) |
|
<= idx_end.unsqueeze(1) |
|
) |
|
|
|
indices = torch.where(mask) |
|
logits[indices[0], indices[1], pad_token] = -torch.inf |
|
|
|
indices = torch.where(~mask) |
|
logits[indices[0], indices[1], pad_token] = torch.inf |
|
|
|
except: |
|
pass |
|
|
|
elif idx_start.nelement() != 0: |
|
try: |
|
mask = torch.arange( |
|
logits.size(1), device=text.device |
|
) < idx_start.unsqueeze(1) |
|
logits[indices[0], indices[1], pad_token] = torch.inf |
|
except: |
|
pass |
|
|
|
elif idx_end.nelement() != 0: |
|
try: |
|
mask = torch.arange( |
|
logits.size(1), device=text.device |
|
) > idx_end.unsqueeze(1) |
|
logits[indices[0], indices[1], pad_token] = torch.inf |
|
except: |
|
pass |
|
|
|
if force_aas: |
|
if pad_token: |
|
logits[:, :, pad_token] = -torch.inf |
|
logits[:, :, 3] = -torch.inf |
|
logits[:, :, 29:] = -torch.inf |
|
|
|
if mask_token: |
|
logits[:, :, mask_token] = -torch.inf |
|
|
|
return logits |
|
|
|
|
|
def detokenize_text(text_embedding, sequence): |
|
if text_embedding == "esm1b" or text_embedding == "esm2": |
|
from esm import Alphabet |
|
|
|
alphabet = ( |
|
Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks |
|
) |
|
else: |
|
assert NameError("Detokenization only available for ESM mdodels") |
|
|
|
output_seqs = [] |
|
|
|
for batch in sequence: |
|
converted_seq = [alphabet[idx] for idx in batch] |
|
converted_seq = "".join(converted_seq) |
|
output_seqs.append(converted_seq) |
|
|
|
return output_seqs |
|
|
|
class ImageEmbedding(nn.Module): |
|
def __init__(self, num_tokens, dim): |
|
super(ImageEmbedding, self).__init__() |
|
self.image_embedding = nn.Embedding(num_tokens, dim) |
|
|
|
def forward(self, image): |
|
return self.image_embedding(image) |
|
|
|
|
|
class ModelExtender(nn.Module): |
|
def __init__(self, vocab, out_features, fixed_embedding=False): |
|
super(ModelExtender, self).__init__() |
|
|
|
|
|
self.vocab = vocab |
|
|
|
if vocab == "esm1b": |
|
from esm import pretrained |
|
|
|
self.model, _ = pretrained.esm1b_t33_650M_UR50S() |
|
self.in_features = 1280 |
|
elif vocab == "esm2": |
|
from esm import pretrained |
|
|
|
if out_features == 320: |
|
self.model, _ = pretrained.esm2_t6_8M_UR50D() |
|
elif out_features == 480: |
|
self.model, _ = pretrained.esm2_t12_35M_UR50D() |
|
elif out_features == 640: |
|
self.model, _ = pretrained.esm2_t30_150M_UR50D() |
|
elif out_features == 1280: |
|
self.model, _ = pretrained.esm2_t33_650M_UR50D() |
|
elif out_features == 2560: |
|
self.model, _ = pretrained.esm2_t36_3B_UR50D() |
|
else: |
|
self.model, _ = pretrained.esm2_t33_650M_UR50D() |
|
self.in_features = self.model.embed_dim |
|
|
|
|
|
self.out_features = out_features |
|
self.scale_layer = nn.Linear(self.in_features, self.out_features) |
|
|
|
|
|
self.fixed_embedding = fixed_embedding |
|
if self.fixed_embedding: |
|
self.model = self.model.eval() |
|
|
|
def forward(self, x, **kwargs): |
|
|
|
if self.fixed_embedding: |
|
with torch.no_grad(): |
|
if self.vocab == "esm1b" or self.vocab == "esm2": |
|
|
|
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[ |
|
"representations" |
|
][self.model.num_layers] |
|
|
|
else: |
|
|
|
x = self.model(x, **kwargs)[0] |
|
|
|
else: |
|
if self.vocab == "esm1b" or self.vocab == "esm2": |
|
|
|
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[ |
|
"representations" |
|
][self.model.num_layers] |
|
|
|
else: |
|
|
|
x = self.model(x, **kwargs)[0] |
|
|
|
|
|
|
|
if self.out_features != self.in_features: |
|
x = self.scale_layer(x) |
|
|
|
|
|
return x |
|
|
|
class CELLE(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
vae, |
|
condition_vae=None, |
|
num_images=2, |
|
num_text_tokens=30, |
|
text_seq_len=1000, |
|
depth=16, |
|
heads=16, |
|
dim_head=64, |
|
attn_dropout=0.1, |
|
ff_dropout=0.1, |
|
attn_types=None, |
|
causal=False, |
|
loss_cond_weight=1, |
|
loss_img_weight=1, |
|
stable=False, |
|
rotary_emb=True, |
|
text_embedding="esm2", |
|
fixed_embedding=True, |
|
sampling_mode="cosine", |
|
linear_project=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.stable = stable |
|
|
|
|
|
if stable: |
|
self.norm_by_max = DivideMax(dim=-1) |
|
|
|
|
|
|
|
|
|
self.text_embedding = text_embedding |
|
self.fixed_embedding = fixed_embedding |
|
|
|
|
|
self.num_text_tokens = num_text_tokens |
|
self.linear_project = linear_project |
|
|
|
|
|
if text_embedding.lower() in ("esm1b", "esm2"): |
|
self.text_seq_len = text_seq_len + 2 |
|
else: |
|
self.text_seq_len = text_seq_len |
|
|
|
|
|
self.sep_emb = nn.Embedding(1, dim) |
|
|
|
|
|
self.text_pos_emb = ( |
|
nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0) |
|
) |
|
|
|
|
|
|
|
self.num_images = num_images |
|
|
|
|
|
|
|
|
|
if exists(condition_vae): |
|
condition_size = condition_vae.image_size |
|
num_condition_tokens = condition_vae.num_tokens |
|
self.num_condition_tokens = num_condition_tokens |
|
condition_fmap_size = condition_vae.image_size // ( |
|
2**condition_vae.num_layers |
|
) |
|
condition_seq_len = condition_fmap_size**2 |
|
|
|
|
|
self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim) |
|
|
|
|
|
self.condition_pos_emb = ( |
|
AxialPositionalEmbedding( |
|
dim, axial_shape=(condition_fmap_size, condition_fmap_size) |
|
) |
|
if not rotary_emb |
|
else always(0) |
|
) |
|
|
|
else: |
|
condition_fmap_size = 0 |
|
condition_seq_len = 0 |
|
num_condition_tokens = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.image_size = vae.image_size |
|
num_image_tokens = vae.num_tokens |
|
image_fmap_size = vae.image_size // (2**vae.num_layers) |
|
image_seq_len = image_fmap_size**2 |
|
self.image_seq_len = image_seq_len |
|
self.num_image_tokens = num_image_tokens |
|
|
|
|
|
self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) |
|
|
|
self.image_pos_emb = ( |
|
AxialPositionalEmbedding( |
|
dim, axial_shape=(image_fmap_size, image_fmap_size) |
|
) |
|
if not rotary_emb |
|
else always(0) |
|
) |
|
|
|
|
|
self.num_condition_tokens = num_condition_tokens |
|
self.condition_seq_len = condition_seq_len |
|
|
|
seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len |
|
total_tokens = ( |
|
num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1 |
|
) |
|
self.total_tokens = total_tokens |
|
self.total_seq_len = seq_len |
|
|
|
|
|
self.vae = vae.eval() |
|
self.condition_vae = condition_vae.eval() |
|
|
|
|
|
|
|
|
|
|
|
if text_embedding == "esm1b" or text_embedding == "esm2": |
|
self.text_mask_token = 32 |
|
self.pad_token = 1 |
|
self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding) |
|
else: |
|
raise ValueError("Only ESM models are supported.") |
|
|
|
|
|
self.sep_token = num_text_tokens |
|
self.cond_mask_token = num_condition_tokens |
|
self.image_mask_token = num_image_tokens |
|
|
|
|
|
self.seq_range = torch.arange(seq_len) |
|
self.logits_range = torch.arange(total_tokens) |
|
|
|
|
|
self.seq_range = rearrange(self.seq_range, "n -> () n ()") |
|
self.logits_range = rearrange(self.logits_range, "d -> () () d") |
|
|
|
|
|
|
|
logits_mask = ( |
|
|
|
( |
|
(self.seq_range < self.text_seq_len) |
|
& (self.logits_range < num_text_tokens) |
|
& (self.logits_range != self.text_mask_token) |
|
) |
|
| |
|
|
|
( |
|
(self.seq_range == self.text_seq_len) |
|
& (self.logits_range == num_text_tokens) |
|
) |
|
| |
|
|
|
( |
|
(self.seq_range >= self.text_seq_len + 1) |
|
& (self.seq_range < self.text_seq_len + 1 + condition_seq_len) |
|
& (self.logits_range >= num_text_tokens + 1) |
|
& (self.logits_range < num_text_tokens + 1 + num_condition_tokens) |
|
) |
|
| |
|
|
|
( |
|
(self.seq_range >= self.text_seq_len + 1 + condition_seq_len) |
|
& (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1) |
|
& ( |
|
self.logits_range |
|
< num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens |
|
) |
|
) |
|
) |
|
|
|
|
|
logits_mask = ~logits_mask |
|
|
|
|
|
self.register_buffer("logits_mask", logits_mask, persistent=False) |
|
|
|
|
|
|
|
|
|
self.transformer = Transformer( |
|
dim=dim, |
|
causal=causal, |
|
seq_len=seq_len, |
|
depth=depth, |
|
heads=heads, |
|
dim_head=dim_head, |
|
attn_dropout=attn_dropout, |
|
ff_dropout=ff_dropout, |
|
image_fmap_size=image_fmap_size + condition_fmap_size, |
|
num_images=num_images, |
|
stable=stable, |
|
rotary_emb=rotary_emb, |
|
) |
|
|
|
|
|
self.to_logits = nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, self.total_tokens), |
|
) |
|
|
|
|
|
self.loss_img_weight = loss_img_weight |
|
self.loss_cond_weight = loss_cond_weight |
|
self.gamma = gamma_func(sampling_mode) |
|
|
|
def embed_and_transform(self, inputs, masks, return_encoding=False): |
|
text, condition, image = inputs |
|
device = text.device |
|
text_mask, _, image_mask = masks |
|
|
|
text_labels = text.clone() |
|
text = torch.where( |
|
text_mask, self.text_mask_token * torch.ones_like(text, device=device), text |
|
) |
|
|
|
tokens = self.text_emb(text) |
|
|
|
|
|
|
|
sep_token_emb = self.sep_emb( |
|
torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device) |
|
) |
|
tokens = torch.cat((tokens, sep_token_emb), dim=1) |
|
tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device)) |
|
|
|
with torch.no_grad(): |
|
if self.linear_project: |
|
b = condition.shape[0] |
|
condition, _, [_, _, condition_labels] = self.condition_vae.encode( |
|
condition |
|
) |
|
condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b) |
|
|
|
else: |
|
condition_labels = condition |
|
if condition.dtype == torch.float: |
|
condition_labels = self.condition_vae.get_codebook_indices( |
|
condition |
|
) |
|
condition = condition_labels.clone() |
|
|
|
condition_emb = self.condition_emb(condition) |
|
condition_emb += self.condition_pos_emb(condition_emb) |
|
tokens = torch.cat((tokens, condition_emb), dim=1) |
|
|
|
with torch.no_grad(): |
|
if self.linear_project: |
|
b = image.shape[0] |
|
image, _, [_, _, image_labels] = self.vae.encode(image) |
|
image_labels = rearrange(image_labels, "(b n) -> b n", b=b) |
|
|
|
else: |
|
image_labels = image |
|
if image.dtype == torch.float: |
|
image_labels = self.vae.get_codebook_indices(image) |
|
image = torch.where( |
|
image_mask, |
|
self.image_mask_token |
|
* torch.ones_like(image_labels, device=device), |
|
image_labels, |
|
) |
|
|
|
image_emb = self.image_emb(image) |
|
|
|
image_emb += self.image_pos_emb(image_emb) |
|
tokens = torch.cat((tokens, image_emb), dim=1) |
|
|
|
if self.stable: |
|
alpha = 0.1 |
|
tokens = tokens * alpha + tokens.detach() * (1 - alpha) |
|
|
|
out = self.transformer(tokens) |
|
|
|
if self.stable: |
|
out = self.norm_by_max(out) |
|
|
|
logits = self.to_logits(out) |
|
|
|
max_neg_value = -torch.finfo(logits.dtype).max |
|
logits.masked_fill_(self.logits_mask, max_neg_value) |
|
|
|
if return_encoding: |
|
return logits, out, [text_labels, condition_labels, image_labels] |
|
else: |
|
return logits, None, [text_labels, condition_labels, image_labels] |
|
|
|
def forward( |
|
self, |
|
text, |
|
condition=None, |
|
image=None, |
|
return_loss=False, |
|
return_encoding=False, |
|
): |
|
batch_size, device = text.shape[0], text.device |
|
|
|
|
|
assert exists(image), "when training, image must be supplied" |
|
|
|
|
|
assert tuple(image.shape[1:]) == ( |
|
self.vae.channels, |
|
self.image_size, |
|
self.image_size, |
|
), f"invalid image of dimensions {image.shape} passed in during training" |
|
|
|
|
|
|
|
|
|
|
|
text_mask = generate_mask( |
|
gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device |
|
) |
|
|
|
image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device) |
|
|
|
|
|
logits, _, labels = self.embed_and_transform( |
|
[text, condition, image], |
|
[text_mask, None, image_mask], |
|
return_encoding, |
|
device, |
|
) |
|
|
|
|
|
if not return_loss: |
|
return logits |
|
|
|
|
|
text, condition, image = labels |
|
|
|
|
|
sep_token = torch.tensor(self.sep_token, device=device).repeat( |
|
labels.shape[0], 1 |
|
) |
|
labels = torch.cat([labels, sep_token], dim=1) |
|
|
|
|
|
if exists(condition) and exists(self.condition_vae): |
|
offsetted_condition = condition + self.num_text_tokens + 1 |
|
labels = torch.cat((labels, offsetted_condition), dim=1) |
|
|
|
|
|
offsetted_image = ( |
|
image + self.num_text_tokens + 1 + self.num_condition_tokens + 1 |
|
) |
|
labels = torch.cat((labels, offsetted_image), dim=1) |
|
|
|
|
|
|
|
|
|
logits = rearrange(logits, "b n c -> b c n") |
|
|
|
|
|
loss_text = F.cross_entropy( |
|
logits[:, :, : self.text_seq_len], |
|
labels[:, : self.text_seq_len], |
|
reduction="none", |
|
)[text_mask].mean() |
|
|
|
loss_img = F.cross_entropy( |
|
logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :], |
|
labels[:, self.text_seq_len + 1 + self.condition_seq_len :], |
|
reduction="none", |
|
)[image_mask].mean() |
|
|
|
|
|
loss = (loss_text + self.loss_img_weight * loss_img) / ( |
|
self.loss_img_weight + 1 |
|
) |
|
|
|
loss_dict = { |
|
"loss_text": loss_text, |
|
|
|
"loss_img": loss_img, |
|
"loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0), |
|
} |
|
|
|
return loss, loss_dict, None |
|
|
|
def create_tensors(self, text, condition, image): |
|
""" |
|
This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function. |
|
""" |
|
device = next( |
|
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]), |
|
None, |
|
).device |
|
|
|
if not isinstance(text, torch.Tensor): |
|
text = ( |
|
torch.ones(1, self.text_seq_len, device=device, dtype=torch.long) |
|
* self.text_mask_token |
|
) |
|
|
|
if not isinstance(condition, torch.Tensor): |
|
condition = ( |
|
torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long) |
|
* self.cond_mask_token |
|
) |
|
else: |
|
with torch.no_grad(): |
|
condition = self.condition_vae.get_codebook_indices(condition) |
|
|
|
if not isinstance(image, torch.Tensor): |
|
image = ( |
|
torch.ones(1, self.image_seq_len, device=device, dtype=torch.long) |
|
* self.image_mask_token |
|
) |
|
else: |
|
with torch.no_grad(): |
|
image = self.vae.get_codebook_indices(image) |
|
|
|
return text, condition, image |
|
|
|
@torch.no_grad() |
|
@eval_decorator |
|
def sample( |
|
self, |
|
text=None, |
|
condition=None, |
|
image=None, |
|
temperature=1.0, |
|
filter_thres=0.9, |
|
progress=False, |
|
timesteps=1, |
|
force_aas=True, |
|
): |
|
|
|
assert int(timesteps) > 0 |
|
|
|
self.eval() |
|
vae = self.vae.eval() |
|
if progress == True: |
|
progress = tqdm |
|
else: |
|
progress = lambda x: x |
|
|
|
|
|
|
|
assert ( |
|
isinstance(text, torch.Tensor) |
|
or isinstance(condition, torch.Tensor) |
|
or isinstance(image, torch.Tensor) |
|
), "some data must be supplied" |
|
|
|
|
|
text, condition, image = self.create_tensors(text, condition, image) |
|
|
|
|
|
batch_size = max(text.shape[0], condition.shape[0], image.shape[0]) |
|
|
|
|
|
text, condition, image = match_batch_size(text, condition, image, batch_size) |
|
|
|
|
|
device = next( |
|
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]), |
|
None, |
|
).device |
|
|
|
assert text.shape[0] == condition.shape[0] == image.shape[0] |
|
|
|
|
|
|
|
|
|
full_text_logits = torch.zeros( |
|
batch_size, self.text_seq_len, self.num_text_tokens |
|
).to(device) |
|
|
|
|
|
full_text_logits = full_text_logits.scatter_( |
|
dim=-1, index=text.unsqueeze(-1), value=1 |
|
) |
|
|
|
full_image_logits = torch.zeros( |
|
batch_size, self.image_seq_len, self.num_image_tokens + 1 |
|
).to(device) |
|
|
|
|
|
full_image_logits = full_image_logits.scatter_( |
|
dim=-1, index=image.unsqueeze(-1), value=1 |
|
) |
|
|
|
|
|
full_image_logits = full_image_logits[:, :, : self.num_image_tokens] |
|
|
|
count = 0 |
|
|
|
for timestep in progress(torch.linspace(0, 1, timesteps)): |
|
|
|
text_mask = text == self.text_mask_token |
|
cond_mask = condition == self.cond_mask_token |
|
image_mask = image == self.image_mask_token |
|
|
|
|
|
logits, sample = calculate_logits( |
|
[text, condition, image], |
|
[text_mask, cond_mask, image_mask], |
|
self.embed_and_transform, |
|
filter_thres, |
|
temperature, |
|
) |
|
|
|
|
|
num_masked_text_tokens = torch.sum(text_mask, dim=1)[0] |
|
num_masked_image_tokens = torch.sum(image_mask, dim=1)[0] |
|
|
|
|
|
if num_masked_text_tokens.any() > 0: |
|
text, full_text_logits = unmask_tokens( |
|
text, |
|
text_mask, |
|
num_masked_text_tokens, |
|
logits[:, : self.text_seq_len, : self.num_text_tokens], |
|
sample[:, : self.text_seq_len, : self.num_text_tokens], |
|
timestep, |
|
timesteps, |
|
self.gamma, |
|
suppress_invalid_text_tokens, |
|
self.pad_token, |
|
self.text_mask_token, |
|
force_aas=force_aas, |
|
) |
|
full_text_logits = full_text_logits.masked_fill( |
|
~text_mask.unsqueeze(-1), -torch.inf |
|
) |
|
|
|
|
|
if num_masked_image_tokens > 0: |
|
image, full_image_logits = unmask_tokens( |
|
image, |
|
image_mask, |
|
num_masked_image_tokens, |
|
logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1], |
|
sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1], |
|
timestep, |
|
timesteps, |
|
self.gamma, |
|
) |
|
full_text_logits = full_text_logits.masked_fill( |
|
~text_mask.unsqueeze(-1), -torch.inf |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
full_image_logits /= torch.max( |
|
torch.abs(full_image_logits), dim=-1, keepdim=True |
|
).values |
|
|
|
|
|
full_image_logits = torch.matmul( |
|
full_image_logits, self.vae.model.quantize.embedding.weight |
|
) |
|
|
|
|
|
h = int(self.image_seq_len**0.5) |
|
full_image_logits = rearrange( |
|
full_image_logits, "b (h w) c -> b c h w", h=h |
|
) |
|
|
|
|
|
full_image_logits = self.vae.model.decode(full_image_logits) |
|
|
|
|
|
max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0] |
|
min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0] |
|
full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view( |
|
batch_size, 1, 1, 1 |
|
) |
|
full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view( |
|
batch_size, 1, 1, 1 |
|
) |
|
|
|
|
|
full_image_logits = torch.clip(full_image_logits, 0, 1) |
|
|
|
|
|
|
|
return ( |
|
text, |
|
detokenize_text(self.text_embedding, text), |
|
full_text_logits, |
|
1.0 * (vae.decode(image) > 0.5), |
|
full_image_logits, |
|
) |
|
|
|
@torch.no_grad() |
|
@eval_decorator |
|
def sample_text( |
|
self, |
|
text=False, |
|
condition=False, |
|
image=False, |
|
temperature=1.0, |
|
filter_thres=0.9, |
|
progress=False, |
|
n_unmask=1, |
|
place_amino=True, |
|
force_aas=False, |
|
): |
|
|
|
self.eval() |
|
|
|
|
|
assert ( |
|
isinstance(text, torch.Tensor) |
|
or isinstance(condition, torch.Tensor) |
|
or isinstance(image, torch.Tensor) |
|
), "some data must be supplied" |
|
|
|
|
|
text, condition, image = self.create_tensors(text, condition, image) |
|
|
|
|
|
batch_size = max(text.shape[0], condition.shape[0], image.shape[0]) |
|
|
|
|
|
text, condition, image = match_batch_size(text, condition, image, batch_size) |
|
|
|
|
|
device = next( |
|
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]), |
|
None, |
|
).device |
|
|
|
assert text.shape[0] == condition.shape[0] == image.shape[0] |
|
|
|
|
|
|
|
|
|
full_text_logits = torch.zeros( |
|
batch_size, self.text_seq_len, self.num_text_tokens |
|
).to(device) |
|
|
|
|
|
full_text_logits = full_text_logits.scatter_( |
|
dim=-1, index=text.unsqueeze(-1), value=1 |
|
) |
|
|
|
text_mask = text == self.text_mask_token |
|
cond_mask = condition == self.cond_mask_token |
|
image_mask = image == self.image_mask_token |
|
|
|
mask_indices = text_mask.nonzero() |
|
non_mask_indices = (~text_mask).nonzero() |
|
|
|
|
|
central_protein_index = torch.tensor( |
|
[ |
|
torch.median( |
|
non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1] |
|
) |
|
for idx in range(batch_size) |
|
] |
|
) |
|
|
|
count = 1 |
|
|
|
run_mask = text_mask |
|
if progress: |
|
pbar = progress(total=torch.sum(run_mask).item()) |
|
while torch.sum(run_mask) > 0: |
|
logits, sample = calculate_logits( |
|
[text, condition, image], |
|
[text_mask, cond_mask, image_mask], |
|
self.embed_and_transform, |
|
filter_thres, |
|
temperature, |
|
) |
|
|
|
|
|
sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens] |
|
sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf) |
|
sub_sample = suppress_invalid_text_tokens( |
|
text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas |
|
) |
|
|
|
|
|
|
|
for idx in range(batch_size): |
|
selected_mask_indices = mask_indices[ |
|
torch.where(mask_indices[:, 0] == idx) |
|
][:, -1] |
|
|
|
|
|
if selected_mask_indices[-count] < central_protein_index[idx]: |
|
unmask_index = selected_mask_indices[-count] |
|
left_sample = max(0, (unmask_index + 1) - n_unmask) |
|
right_sample = min(unmask_index + 1, self.text_seq_len - 1) |
|
central_protein_index[idx] = max( |
|
0, central_protein_index[idx] - 0.5 * n_unmask |
|
) |
|
|
|
|
|
elif selected_mask_indices[count - 1] > central_protein_index[idx]: |
|
unmask_index = selected_mask_indices[count - 1] |
|
left_sample = max(0, unmask_index) |
|
right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1) |
|
central_protein_index[idx] = min( |
|
central_protein_index[idx] + 0.5 * n_unmask, |
|
self.text_seq_len - 1, |
|
) |
|
|
|
|
|
full_text_logits[ |
|
idx, left_sample:right_sample, : self.text_seq_len - 1 |
|
] = logits[idx, left_sample:right_sample, : self.num_text_tokens] |
|
|
|
run_mask[idx, left_sample:right_sample] = False |
|
|
|
|
|
|
|
if place_amino: |
|
text[idx, left_sample:right_sample] = torch.where( |
|
text[idx, left_sample:right_sample] == self.text_mask_token, |
|
sub_sample[ |
|
idx, left_sample:right_sample, : self.num_text_tokens |
|
].argmax(dim=-1), |
|
text[idx, left_sample:right_sample], |
|
) |
|
|
|
text_mask = run_mask |
|
|
|
count += n_unmask |
|
|
|
if progress: |
|
pbar.update(n_unmask) |
|
if progress: |
|
pbar.close() |
|
|
|
return ( |
|
text, |
|
detokenize_text(self.text_embedding, text), |
|
full_text_logits, |
|
) |
|
|