EmaadKhwaja
commited on
Commit
•
64212e0
1
Parent(s):
86d2765
update app.py
Browse files- .gitignore +2 -0
- app.py +114 -4
- celle/celle.py +1061 -0
- celle/utils.py +230 -0
- dataloader.py +308 -0
- requirements.txt +13 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
env
|
app.py
CHANGED
@@ -1,7 +1,117 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
from prediction import run_image_prediction
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as T
|
6 |
+
from celle.utils import process_image
|
7 |
+
from PIL import Image
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
|
|
|
|
|
10 |
|
11 |
+
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
|
12 |
+
model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
|
13 |
+
config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
|
14 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
15 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
if 'Finetuned' in model_name:
|
19 |
+
dataset = 'OpenCell'
|
20 |
+
|
21 |
+
else:
|
22 |
+
dataset = 'HPA'
|
23 |
+
|
24 |
+
nucleus_image = process_image(nucleus_image,dataset,'nucleus')
|
25 |
+
if protein_image:
|
26 |
+
protein_image = process_image(protein_image,dataset,'protein')
|
27 |
+
protein_image = protein_image > torch.median(protein_image)
|
28 |
+
protein_image = protein_image[0,0]
|
29 |
+
protein_image = protein_image*1.0
|
30 |
+
else:
|
31 |
+
protein_image = torch.ones((256,256))
|
32 |
+
|
33 |
+
|
34 |
+
threshold, heatmap = run_image_prediction(sequence_input = sequence_input,
|
35 |
+
nucleus_image = nucleus_image,
|
36 |
+
model_ckpt_path=model,
|
37 |
+
model_config_path=config,
|
38 |
+
device=device)
|
39 |
+
|
40 |
+
# Plot the heatmap
|
41 |
+
plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic')
|
42 |
+
plt.axis('off')
|
43 |
+
|
44 |
+
# Save the plot to a temporary file
|
45 |
+
plt.savefig('temp.png', bbox_inches='tight', dpi = 256)
|
46 |
+
|
47 |
+
# Open the temporary file as a PIL image
|
48 |
+
heatmap = Image.open('temp.png')
|
49 |
+
|
50 |
+
return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap
|
51 |
+
|
52 |
+
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
gr.Markdown("Select the prediction model.")
|
55 |
+
gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.")
|
56 |
+
gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.")
|
57 |
+
with gr.Row():
|
58 |
+
model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'],
|
59 |
+
value='CELL-E_2_HPA_480', label = 'Model Name')
|
60 |
+
with gr.Row():
|
61 |
+
gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.")
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
|
65 |
+
label = 'Sequence')
|
66 |
+
with gr.Row():
|
67 |
+
gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.")
|
68 |
+
gr.Markdown("The protein image is optional and is just used for display.")
|
69 |
+
|
70 |
+
with gr.Row().style(equal_height=True):
|
71 |
+
nucleus_image = gr.Image(value = 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
|
72 |
+
type='pil',
|
73 |
+
label = 'Nucleus Image')
|
74 |
+
|
75 |
+
protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)')
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
gr.Markdown("Image predictions are show below.")
|
79 |
+
|
80 |
+
with gr.Row().style(equal_height=True):
|
81 |
+
nucleus_image_crop = gr.Image(type='pil',
|
82 |
+
label = 'Nucleus Image')
|
83 |
+
|
84 |
+
protein_threshold_image = gr.Image(type='pil',
|
85 |
+
label = 'Protein Threshold Image')
|
86 |
+
|
87 |
+
predicted_threshold_image = gr.Image(type='pil',
|
88 |
+
label = 'Predicted Threshold image')
|
89 |
+
|
90 |
+
predicted_heatmap = gr.Image(type='pil',
|
91 |
+
label = 'Predicted Heatmap')
|
92 |
+
with gr.Row():
|
93 |
+
button = gr.Button("Run Model")
|
94 |
+
|
95 |
+
inputs = [model_name,
|
96 |
+
sequence_input,
|
97 |
+
nucleus_image,
|
98 |
+
protein_image]
|
99 |
+
|
100 |
+
outputs = [nucleus_image_crop,
|
101 |
+
protein_threshold_image,
|
102 |
+
predicted_threshold_image,
|
103 |
+
predicted_heatmap]
|
104 |
+
|
105 |
+
button.click(gradio_demo, inputs, outputs)
|
106 |
+
|
107 |
+
examples = [['CELL-E_2_HPA_Finetuned_480',
|
108 |
+
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
|
109 |
+
'images/Proteasome activator complex subunit 3 nucleus.png',
|
110 |
+
'images/Proteasome activator complex subunit 3 protein.png'],
|
111 |
+
['CELL-E_2_HPA_480',
|
112 |
+
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
|
113 |
+
'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
|
114 |
+
'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']]
|
115 |
+
|
116 |
+
# demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
|
117 |
+
demo.launch(share=True)
|
celle/celle.py
ADDED
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary packages and modules
|
2 |
+
from math import floor, ceil
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from axial_positional_embedding import AxialPositionalEmbedding
|
7 |
+
from einops import rearrange
|
8 |
+
from celle.utils import (
|
9 |
+
exists,
|
10 |
+
always,
|
11 |
+
eval_decorator,
|
12 |
+
gumbel_sample,
|
13 |
+
top_k,
|
14 |
+
gamma_func,
|
15 |
+
DivideMax,
|
16 |
+
)
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
# Import additional modules from within the codebase
|
20 |
+
from celle.transformer import Transformer
|
21 |
+
|
22 |
+
|
23 |
+
def generate_mask(gamma_func, batch_size, length, device):
|
24 |
+
# Get the number of `True` values in the mask for each batch element
|
25 |
+
num_true_values = floor(gamma_func(torch.rand(1)) * length)
|
26 |
+
|
27 |
+
# Generate a random sample of indices to set to `True` in the mask
|
28 |
+
# The number of indices in the sample is determined by `num_true_values`
|
29 |
+
indices = (
|
30 |
+
torch.rand((batch_size, length), device=device)
|
31 |
+
.topk(num_true_values, dim=1)
|
32 |
+
.indices
|
33 |
+
)
|
34 |
+
|
35 |
+
# Create a binary mask tensor with `True` values at the sampled indices
|
36 |
+
mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device)
|
37 |
+
mask.scatter_(dim=1, index=indices, value=True)
|
38 |
+
|
39 |
+
return mask
|
40 |
+
|
41 |
+
|
42 |
+
def match_batch_size(text, condition, image, batch_size):
|
43 |
+
"""
|
44 |
+
This function ensures all inputs to the sample function have the same batch size.
|
45 |
+
"""
|
46 |
+
if text.shape[0] != batch_size:
|
47 |
+
text = text.repeat(batch_size, 1)
|
48 |
+
|
49 |
+
if condition.shape[0] != batch_size:
|
50 |
+
condition = condition.repeat(batch_size, 1)
|
51 |
+
|
52 |
+
if image.shape[0] != batch_size:
|
53 |
+
image = image.repeat(batch_size, 1)
|
54 |
+
|
55 |
+
return text, condition, image
|
56 |
+
|
57 |
+
|
58 |
+
def calc_unmask_probs(timestep, timesteps, gamma_func):
|
59 |
+
if timestep == 1 or timesteps == 1:
|
60 |
+
unmask_prob = 1
|
61 |
+
else:
|
62 |
+
unmask_prob = 1 - gamma_func(timestep)
|
63 |
+
return unmask_prob
|
64 |
+
|
65 |
+
|
66 |
+
def calculate_logits(
|
67 |
+
input_tokens, input_mask, logits_function, filter_thres, temperature
|
68 |
+
):
|
69 |
+
logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False)
|
70 |
+
filtered_logits = top_k(logits, thres=filter_thres)
|
71 |
+
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
|
72 |
+
|
73 |
+
return logits, sample
|
74 |
+
|
75 |
+
|
76 |
+
def unmask_tokens(
|
77 |
+
input_tokens,
|
78 |
+
input_mask,
|
79 |
+
num_masked_tokens,
|
80 |
+
logits,
|
81 |
+
sample,
|
82 |
+
timestep,
|
83 |
+
timesteps,
|
84 |
+
gamma,
|
85 |
+
filter_func=None,
|
86 |
+
pad_token=None,
|
87 |
+
mask_token=None,
|
88 |
+
force_aas=True,
|
89 |
+
):
|
90 |
+
sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf)
|
91 |
+
if filter_func:
|
92 |
+
sample = filter_func(
|
93 |
+
input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token
|
94 |
+
)
|
95 |
+
selected_token_probs, selected_tokens = torch.max(sample, dim=-1)
|
96 |
+
|
97 |
+
unmask_prob = calc_unmask_probs(timestep, timesteps, gamma)
|
98 |
+
num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens))
|
99 |
+
|
100 |
+
_, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1)
|
101 |
+
|
102 |
+
sample_mask = torch.zeros(
|
103 |
+
input_tokens.shape, dtype=torch.bool, device=input_tokens.device
|
104 |
+
)
|
105 |
+
sample_mask.scatter_(dim=1, index=top_k_indices, value=True)
|
106 |
+
|
107 |
+
unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens)
|
108 |
+
full_logits = torch.where(
|
109 |
+
sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits)
|
110 |
+
)
|
111 |
+
return unmasked_tokens, full_logits
|
112 |
+
|
113 |
+
|
114 |
+
def suppress_invalid_text_tokens(
|
115 |
+
text,
|
116 |
+
logits,
|
117 |
+
start_token=None,
|
118 |
+
end_token=None,
|
119 |
+
pad_token=None,
|
120 |
+
mask_token=None,
|
121 |
+
force_aas=False,
|
122 |
+
):
|
123 |
+
# Find the indices of start_token and end_token in tensor text along axis=1
|
124 |
+
idx_start = (text == start_token).nonzero(as_tuple=True)[1]
|
125 |
+
idx_end = (text == end_token).nonzero(as_tuple=True)[1]
|
126 |
+
|
127 |
+
# For every position other than the index corresponding to the start index, set the values on the start index of dimension=2 to -torch.inf
|
128 |
+
if idx_start.nelement() != start_token:
|
129 |
+
try:
|
130 |
+
mask = idx_start.unsqueeze(1) != torch.arange(
|
131 |
+
logits.size(1), device=text.device
|
132 |
+
)
|
133 |
+
indices = torch.where(mask)
|
134 |
+
logits[indices[0], indices[1], start_token] = -torch.inf
|
135 |
+
except:
|
136 |
+
pass
|
137 |
+
|
138 |
+
# else:
|
139 |
+
# idx_start = torch.zeros(text.size(0), dtype=torch.long)
|
140 |
+
|
141 |
+
# Similarly, for every position other than the index corresponding to the end index, set the values on the end index of dimension=2 to -torch.inf
|
142 |
+
if idx_end.nelement() != 0:
|
143 |
+
try:
|
144 |
+
mask = idx_end.unsqueeze(1) != torch.arange(
|
145 |
+
logits.size(1), device=text.device
|
146 |
+
)
|
147 |
+
indices = torch.where(mask)
|
148 |
+
logits[indices[0], indices[1], end_token] = -torch.inf
|
149 |
+
except:
|
150 |
+
pass
|
151 |
+
|
152 |
+
# else:
|
153 |
+
# idx_end = torch.full((text.size(0),), text.size(1) - 1, dtype=torch.long)
|
154 |
+
|
155 |
+
if pad_token:
|
156 |
+
if idx_start.nelement() != 0 and idx_end.nelement() != 0:
|
157 |
+
try:
|
158 |
+
# For every position between the indices of start_token and end_token, set the values for 1st index of dimension=2 equal to -torch.inf. Any value outside of that range should be set to torch.inf.
|
159 |
+
mask = (
|
160 |
+
torch.arange(logits.size(1), device=text.device)
|
161 |
+
>= idx_start.unsqueeze(1)
|
162 |
+
) & (
|
163 |
+
torch.arange(logits.size(1), device=text.device)
|
164 |
+
<= idx_end.unsqueeze(1)
|
165 |
+
)
|
166 |
+
|
167 |
+
indices = torch.where(mask)
|
168 |
+
logits[indices[0], indices[1], pad_token] = -torch.inf
|
169 |
+
|
170 |
+
indices = torch.where(~mask)
|
171 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
172 |
+
|
173 |
+
except:
|
174 |
+
pass
|
175 |
+
|
176 |
+
elif idx_start.nelement() != 0:
|
177 |
+
try:
|
178 |
+
mask = torch.arange(
|
179 |
+
logits.size(1), device=text.device
|
180 |
+
) < idx_start.unsqueeze(1)
|
181 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
182 |
+
except:
|
183 |
+
pass
|
184 |
+
|
185 |
+
elif idx_end.nelement() != 0:
|
186 |
+
try:
|
187 |
+
mask = torch.arange(
|
188 |
+
logits.size(1), device=text.device
|
189 |
+
) > idx_end.unsqueeze(1)
|
190 |
+
logits[indices[0], indices[1], pad_token] = torch.inf
|
191 |
+
except:
|
192 |
+
pass
|
193 |
+
|
194 |
+
if force_aas:
|
195 |
+
if pad_token:
|
196 |
+
logits[:, :, pad_token] = -torch.inf
|
197 |
+
logits[:, :, 3] = -torch.inf
|
198 |
+
logits[:, :, 29:] = -torch.inf
|
199 |
+
|
200 |
+
if mask_token:
|
201 |
+
logits[:, :, mask_token] = -torch.inf
|
202 |
+
|
203 |
+
return logits
|
204 |
+
|
205 |
+
|
206 |
+
def detokenize_text(text_embedding, sequence):
|
207 |
+
if text_embedding == "esm1b" or text_embedding == "esm2":
|
208 |
+
from esm import Alphabet
|
209 |
+
|
210 |
+
alphabet = (
|
211 |
+
Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
assert NameError("Detokenization only available for ESM mdodels")
|
215 |
+
|
216 |
+
output_seqs = []
|
217 |
+
|
218 |
+
for batch in sequence:
|
219 |
+
converted_seq = [alphabet[idx] for idx in batch]
|
220 |
+
converted_seq = "".join(converted_seq)
|
221 |
+
output_seqs.append(converted_seq)
|
222 |
+
|
223 |
+
return output_seqs
|
224 |
+
|
225 |
+
class ImageEmbedding(nn.Module):
|
226 |
+
def __init__(self, num_tokens, dim):
|
227 |
+
super(ImageEmbedding, self).__init__()
|
228 |
+
self.image_embedding = nn.Embedding(num_tokens, dim)
|
229 |
+
|
230 |
+
def forward(self, image):
|
231 |
+
return self.image_embedding(image)
|
232 |
+
|
233 |
+
|
234 |
+
class ModelExtender(nn.Module):
|
235 |
+
def __init__(self, vocab, out_features, fixed_embedding=False):
|
236 |
+
super(ModelExtender, self).__init__()
|
237 |
+
|
238 |
+
# Initialize the model according to the given vocabulary
|
239 |
+
self.vocab = vocab
|
240 |
+
|
241 |
+
if vocab == "esm1b":
|
242 |
+
from esm import pretrained
|
243 |
+
|
244 |
+
self.model, _ = pretrained.esm1b_t33_650M_UR50S()
|
245 |
+
self.in_features = 1280
|
246 |
+
elif vocab == "esm2":
|
247 |
+
from esm import pretrained
|
248 |
+
|
249 |
+
if out_features == 320:
|
250 |
+
self.model, _ = pretrained.esm2_t6_8M_UR50D()
|
251 |
+
elif out_features == 480:
|
252 |
+
self.model, _ = pretrained.esm2_t12_35M_UR50D()
|
253 |
+
elif out_features == 640:
|
254 |
+
self.model, _ = pretrained.esm2_t30_150M_UR50D()
|
255 |
+
elif out_features == 1280:
|
256 |
+
self.model, _ = pretrained.esm2_t33_650M_UR50D()
|
257 |
+
elif out_features == 2560:
|
258 |
+
self.model, _ = pretrained.esm2_t36_3B_UR50D()
|
259 |
+
else:
|
260 |
+
self.model, _ = pretrained.esm2_t33_650M_UR50D()
|
261 |
+
self.in_features = self.model.embed_dim
|
262 |
+
|
263 |
+
# Set the number of output features and initialize the scaling layer
|
264 |
+
self.out_features = out_features
|
265 |
+
self.scale_layer = nn.Linear(self.in_features, self.out_features)
|
266 |
+
|
267 |
+
# Determine whether to freeze the model's parameters
|
268 |
+
self.fixed_embedding = fixed_embedding
|
269 |
+
if self.fixed_embedding:
|
270 |
+
self.model = self.model.eval()
|
271 |
+
|
272 |
+
def forward(self, x, **kwargs):
|
273 |
+
# If the model's parameters are fixed, use torch.no_grad()
|
274 |
+
if self.fixed_embedding:
|
275 |
+
with torch.no_grad():
|
276 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
277 |
+
# Reduce sequence length dimension, get top layer representation tensor
|
278 |
+
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
|
279 |
+
"representations"
|
280 |
+
][self.model.num_layers]
|
281 |
+
# Tensor shape: (batch_size, hidden_size)
|
282 |
+
else:
|
283 |
+
# Get top layer representation tensor
|
284 |
+
x = self.model(x, **kwargs)[0]
|
285 |
+
# Tensor shape: (batch_size, sequence_length, hidden_size)
|
286 |
+
else:
|
287 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
288 |
+
# Reduce sequence length dimension, get top layer representation tensor
|
289 |
+
x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
|
290 |
+
"representations"
|
291 |
+
][self.model.num_layers]
|
292 |
+
# Tensor shape: (batch_size, hidden_size)
|
293 |
+
else:
|
294 |
+
# Get top layer representation tensor
|
295 |
+
x = self.model(x, **kwargs)[0]
|
296 |
+
# Tensor shape: (batch_size, sequence_length, hidden_size)
|
297 |
+
|
298 |
+
# Scale the representation tensor if necessary
|
299 |
+
if self.out_features != self.in_features:
|
300 |
+
x = self.scale_layer(x)
|
301 |
+
# Tensor shape: (batch_size, out_features)
|
302 |
+
|
303 |
+
return x
|
304 |
+
|
305 |
+
class CELLE(nn.Module):
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
*,
|
309 |
+
dim,
|
310 |
+
vae, # The VAE model used to encode/decode images
|
311 |
+
condition_vae=None, # An optional VAE model used to condition the image generation
|
312 |
+
num_images=2, # Number of images to generate
|
313 |
+
num_text_tokens=30, # Number of tokens in the text vocabulary
|
314 |
+
text_seq_len=1000, # Maximum length of input text sequence
|
315 |
+
depth=16, # Number of layers in the transformer model
|
316 |
+
heads=16, # Number of attention heads
|
317 |
+
dim_head=64, # Dimensionality of each attention head
|
318 |
+
attn_dropout=0.1, # Dropout rate for attention weights
|
319 |
+
ff_dropout=0.1, # Dropout rate for feedforward layers
|
320 |
+
attn_types=None, # Types of attention to use in the transformer
|
321 |
+
causal=False, # Whether to use causal attention
|
322 |
+
loss_cond_weight=1, # Weight of conditioning loss
|
323 |
+
loss_img_weight=1, # Weight of image generation loss
|
324 |
+
stable=False, # Whether to use divide-by-max normalization in the transformer
|
325 |
+
rotary_emb=True, # Whether to use rotary positional embeddings
|
326 |
+
text_embedding="esm2", # Text embedding to use (esm1b, esm2)
|
327 |
+
fixed_embedding=True, # Whether to fix the text embedding or learn it
|
328 |
+
sampling_mode="cosine", # Sampling mode for the VAE
|
329 |
+
linear_project=False, # Whether to project embeddings linearly
|
330 |
+
**kwargs,
|
331 |
+
):
|
332 |
+
super().__init__()
|
333 |
+
|
334 |
+
# Set the stable flag
|
335 |
+
self.stable = stable
|
336 |
+
|
337 |
+
# If the stable flag is set, initialize the DivideMax layer for normalization
|
338 |
+
if stable:
|
339 |
+
self.norm_by_max = DivideMax(dim=-1)
|
340 |
+
|
341 |
+
### Initializing text parameters ###
|
342 |
+
|
343 |
+
# Initialize the text and fixed embeddings
|
344 |
+
self.text_embedding = text_embedding
|
345 |
+
self.fixed_embedding = fixed_embedding
|
346 |
+
|
347 |
+
# Offset logits index and calculate cross entropy loss
|
348 |
+
self.num_text_tokens = num_text_tokens
|
349 |
+
self.linear_project = linear_project
|
350 |
+
|
351 |
+
# Add <BOS> and <EOS> tokens to the beginning and end of text sequences
|
352 |
+
if text_embedding.lower() in ("esm1b", "esm2"):
|
353 |
+
self.text_seq_len = text_seq_len + 2
|
354 |
+
else:
|
355 |
+
self.text_seq_len = text_seq_len
|
356 |
+
|
357 |
+
# Initialize embeddings for <SEP> token
|
358 |
+
self.sep_emb = nn.Embedding(1, dim)
|
359 |
+
|
360 |
+
# Initialize positional embeddings for text sequences and <SEP> token
|
361 |
+
self.text_pos_emb = (
|
362 |
+
nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0)
|
363 |
+
) # +1 for <SEP>
|
364 |
+
|
365 |
+
### ###
|
366 |
+
|
367 |
+
self.num_images = num_images
|
368 |
+
|
369 |
+
### Initializing condition parameters ###
|
370 |
+
|
371 |
+
# Initialize the number of condition tokens, condition sequence length, and condition embedding
|
372 |
+
if exists(condition_vae):
|
373 |
+
condition_size = condition_vae.image_size
|
374 |
+
num_condition_tokens = condition_vae.num_tokens
|
375 |
+
self.num_condition_tokens = num_condition_tokens
|
376 |
+
condition_fmap_size = condition_vae.image_size // (
|
377 |
+
2**condition_vae.num_layers
|
378 |
+
)
|
379 |
+
condition_seq_len = condition_fmap_size**2
|
380 |
+
|
381 |
+
# Initialize ImageEmbedding for condition embedding
|
382 |
+
self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim)
|
383 |
+
|
384 |
+
# Initialize positional embeddings for condition embedding
|
385 |
+
self.condition_pos_emb = (
|
386 |
+
AxialPositionalEmbedding(
|
387 |
+
dim, axial_shape=(condition_fmap_size, condition_fmap_size)
|
388 |
+
)
|
389 |
+
if not rotary_emb
|
390 |
+
else always(0)
|
391 |
+
)
|
392 |
+
|
393 |
+
else:
|
394 |
+
condition_fmap_size = 0
|
395 |
+
condition_seq_len = 0
|
396 |
+
num_condition_tokens = 0
|
397 |
+
|
398 |
+
### ####
|
399 |
+
|
400 |
+
### Initializing image parameters ###
|
401 |
+
|
402 |
+
# Initialize the image size, image token size, and sequence length
|
403 |
+
self.image_size = vae.image_size
|
404 |
+
num_image_tokens = vae.num_tokens
|
405 |
+
image_fmap_size = vae.image_size // (2**vae.num_layers)
|
406 |
+
image_seq_len = image_fmap_size**2
|
407 |
+
self.image_seq_len = image_seq_len
|
408 |
+
self.num_image_tokens = num_image_tokens
|
409 |
+
|
410 |
+
# Initialize ImageEmbedding and positional embeddings for image embedding
|
411 |
+
self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) # +1 for <IM_MASK>
|
412 |
+
|
413 |
+
self.image_pos_emb = (
|
414 |
+
AxialPositionalEmbedding(
|
415 |
+
dim, axial_shape=(image_fmap_size, image_fmap_size)
|
416 |
+
)
|
417 |
+
if not rotary_emb
|
418 |
+
else always(0)
|
419 |
+
)
|
420 |
+
|
421 |
+
# Set total sequence length and total tokens
|
422 |
+
self.num_condition_tokens = num_condition_tokens
|
423 |
+
self.condition_seq_len = condition_seq_len
|
424 |
+
# Text Length + <SEP> + Condition Tokens + Image Tokens
|
425 |
+
seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len
|
426 |
+
total_tokens = (
|
427 |
+
num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1
|
428 |
+
)
|
429 |
+
self.total_tokens = total_tokens
|
430 |
+
self.total_seq_len = seq_len
|
431 |
+
|
432 |
+
# Set the VAE and condition VAE for the model
|
433 |
+
self.vae = vae.eval()
|
434 |
+
self.condition_vae = condition_vae.eval()
|
435 |
+
|
436 |
+
### ###
|
437 |
+
|
438 |
+
### Setting discrete ids ###
|
439 |
+
# Initialize text embedding based on the given text_embedding parameter
|
440 |
+
if text_embedding == "esm1b" or text_embedding == "esm2":
|
441 |
+
self.text_mask_token = 32
|
442 |
+
self.pad_token = 1
|
443 |
+
self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding)
|
444 |
+
else:
|
445 |
+
raise ValueError("Only ESM models are supported.")
|
446 |
+
|
447 |
+
# Set token indices for text, condition, and image sequences
|
448 |
+
self.sep_token = num_text_tokens
|
449 |
+
self.cond_mask_token = num_condition_tokens
|
450 |
+
self.image_mask_token = num_image_tokens
|
451 |
+
|
452 |
+
# Create indices for sequence and logits dimensions
|
453 |
+
self.seq_range = torch.arange(seq_len)
|
454 |
+
self.logits_range = torch.arange(total_tokens)
|
455 |
+
|
456 |
+
# Reshape sequence and logits indices
|
457 |
+
self.seq_range = rearrange(self.seq_range, "n -> () n ()")
|
458 |
+
self.logits_range = rearrange(self.logits_range, "d -> () () d")
|
459 |
+
|
460 |
+
# Create a mask to exclude invalid token positions from the model output
|
461 |
+
# e.g. no image tokens where sequence tokens should be
|
462 |
+
logits_mask = (
|
463 |
+
# Mask text tokens beyond text_seq_len and invalid logits_range
|
464 |
+
(
|
465 |
+
(self.seq_range < self.text_seq_len)
|
466 |
+
& (self.logits_range < num_text_tokens)
|
467 |
+
& (self.logits_range != self.text_mask_token)
|
468 |
+
)
|
469 |
+
|
|
470 |
+
# Mask [SEP] token after text
|
471 |
+
(
|
472 |
+
(self.seq_range == self.text_seq_len)
|
473 |
+
& (self.logits_range == num_text_tokens)
|
474 |
+
)
|
475 |
+
|
|
476 |
+
# Mask condition tokens beyond text_seq_len+1 ([SEP]) and invalid logits_range
|
477 |
+
(
|
478 |
+
(self.seq_range >= self.text_seq_len + 1)
|
479 |
+
& (self.seq_range < self.text_seq_len + 1 + condition_seq_len)
|
480 |
+
& (self.logits_range >= num_text_tokens + 1)
|
481 |
+
& (self.logits_range < num_text_tokens + 1 + num_condition_tokens)
|
482 |
+
)
|
483 |
+
|
|
484 |
+
# Mask image tokens beyond num_text_tokens+num_condition_tokens+1
|
485 |
+
(
|
486 |
+
(self.seq_range >= self.text_seq_len + 1 + condition_seq_len)
|
487 |
+
& (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1)
|
488 |
+
& (
|
489 |
+
self.logits_range
|
490 |
+
< num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens
|
491 |
+
)
|
492 |
+
)
|
493 |
+
)
|
494 |
+
|
495 |
+
# Invert the mask
|
496 |
+
logits_mask = ~logits_mask
|
497 |
+
|
498 |
+
# Register the buffer with the logits_mask
|
499 |
+
self.register_buffer("logits_mask", logits_mask, persistent=False)
|
500 |
+
|
501 |
+
### ###
|
502 |
+
|
503 |
+
# Initialize the Transformer model with given parameters
|
504 |
+
self.transformer = Transformer(
|
505 |
+
dim=dim,
|
506 |
+
causal=causal,
|
507 |
+
seq_len=seq_len,
|
508 |
+
depth=depth,
|
509 |
+
heads=heads,
|
510 |
+
dim_head=dim_head,
|
511 |
+
attn_dropout=attn_dropout,
|
512 |
+
ff_dropout=ff_dropout,
|
513 |
+
image_fmap_size=image_fmap_size + condition_fmap_size,
|
514 |
+
num_images=num_images,
|
515 |
+
stable=stable,
|
516 |
+
rotary_emb=rotary_emb,
|
517 |
+
)
|
518 |
+
|
519 |
+
# Initialize the linear layers for converting transformer output to logits
|
520 |
+
self.to_logits = nn.Sequential(
|
521 |
+
nn.LayerNorm(dim),
|
522 |
+
nn.Linear(dim, self.total_tokens),
|
523 |
+
)
|
524 |
+
|
525 |
+
# Set instance variables for weights and critic
|
526 |
+
self.loss_img_weight = loss_img_weight
|
527 |
+
self.loss_cond_weight = loss_cond_weight
|
528 |
+
self.gamma = gamma_func(sampling_mode)
|
529 |
+
|
530 |
+
def embed_and_transform(self, inputs, masks, return_encoding=False):
|
531 |
+
text, condition, image = inputs
|
532 |
+
device = text.device
|
533 |
+
text_mask, _, image_mask = masks
|
534 |
+
|
535 |
+
text_labels = text.clone()
|
536 |
+
text = torch.where(
|
537 |
+
text_mask, self.text_mask_token * torch.ones_like(text, device=device), text
|
538 |
+
)
|
539 |
+
|
540 |
+
tokens = self.text_emb(text)
|
541 |
+
|
542 |
+
# Add SEP token
|
543 |
+
|
544 |
+
sep_token_emb = self.sep_emb(
|
545 |
+
torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device)
|
546 |
+
)
|
547 |
+
tokens = torch.cat((tokens, sep_token_emb), dim=1)
|
548 |
+
tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device))
|
549 |
+
|
550 |
+
with torch.no_grad():
|
551 |
+
if self.linear_project:
|
552 |
+
b = condition.shape[0]
|
553 |
+
condition, _, [_, _, condition_labels] = self.condition_vae.encode(
|
554 |
+
condition
|
555 |
+
)
|
556 |
+
condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b)
|
557 |
+
|
558 |
+
else:
|
559 |
+
condition_labels = condition
|
560 |
+
if condition.dtype == torch.float:
|
561 |
+
condition_labels = self.condition_vae.get_codebook_indices(
|
562 |
+
condition
|
563 |
+
)
|
564 |
+
condition = condition_labels.clone()
|
565 |
+
|
566 |
+
condition_emb = self.condition_emb(condition)
|
567 |
+
condition_emb += self.condition_pos_emb(condition_emb)
|
568 |
+
tokens = torch.cat((tokens, condition_emb), dim=1)
|
569 |
+
|
570 |
+
with torch.no_grad():
|
571 |
+
if self.linear_project:
|
572 |
+
b = image.shape[0]
|
573 |
+
image, _, [_, _, image_labels] = self.vae.encode(image)
|
574 |
+
image_labels = rearrange(image_labels, "(b n) -> b n", b=b)
|
575 |
+
|
576 |
+
else:
|
577 |
+
image_labels = image
|
578 |
+
if image.dtype == torch.float:
|
579 |
+
image_labels = self.vae.get_codebook_indices(image)
|
580 |
+
image = torch.where(
|
581 |
+
image_mask,
|
582 |
+
self.image_mask_token
|
583 |
+
* torch.ones_like(image_labels, device=device),
|
584 |
+
image_labels,
|
585 |
+
)
|
586 |
+
|
587 |
+
image_emb = self.image_emb(image)
|
588 |
+
|
589 |
+
image_emb += self.image_pos_emb(image_emb)
|
590 |
+
tokens = torch.cat((tokens, image_emb), dim=1)
|
591 |
+
|
592 |
+
if self.stable:
|
593 |
+
alpha = 0.1
|
594 |
+
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
|
595 |
+
|
596 |
+
out = self.transformer(tokens)
|
597 |
+
|
598 |
+
if self.stable:
|
599 |
+
out = self.norm_by_max(out)
|
600 |
+
|
601 |
+
logits = self.to_logits(out)
|
602 |
+
|
603 |
+
max_neg_value = -torch.finfo(logits.dtype).max
|
604 |
+
logits.masked_fill_(self.logits_mask, max_neg_value)
|
605 |
+
|
606 |
+
if return_encoding:
|
607 |
+
return logits, out, [text_labels, condition_labels, image_labels]
|
608 |
+
else:
|
609 |
+
return logits, None, [text_labels, condition_labels, image_labels]
|
610 |
+
|
611 |
+
def forward(
|
612 |
+
self,
|
613 |
+
text,
|
614 |
+
condition=None,
|
615 |
+
image=None,
|
616 |
+
return_loss=False,
|
617 |
+
return_encoding=False,
|
618 |
+
):
|
619 |
+
batch_size, device = text.shape[0], text.device
|
620 |
+
|
621 |
+
# Check that image is supplied when training
|
622 |
+
assert exists(image), "when training, image must be supplied"
|
623 |
+
|
624 |
+
# Check that image dimensions match the expected dimensions
|
625 |
+
assert tuple(image.shape[1:]) == (
|
626 |
+
self.vae.channels,
|
627 |
+
self.image_size,
|
628 |
+
self.image_size,
|
629 |
+
), f"invalid image of dimensions {image.shape} passed in during training"
|
630 |
+
|
631 |
+
# Generate masks for text, condition, and image
|
632 |
+
|
633 |
+
# text_mask = generate_mask(self.gamma, batch_size, self.text_seq_len, device)
|
634 |
+
|
635 |
+
text_mask = generate_mask(
|
636 |
+
gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device
|
637 |
+
)
|
638 |
+
|
639 |
+
image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device)
|
640 |
+
|
641 |
+
# Embed and transform inputs
|
642 |
+
logits, _, labels = self.embed_and_transform(
|
643 |
+
[text, condition, image],
|
644 |
+
[text_mask, None, image_mask],
|
645 |
+
return_encoding,
|
646 |
+
device,
|
647 |
+
)
|
648 |
+
|
649 |
+
# If not returning loss, return the logits
|
650 |
+
if not return_loss:
|
651 |
+
return logits
|
652 |
+
|
653 |
+
# Separate labels
|
654 |
+
text, condition, image = labels
|
655 |
+
|
656 |
+
# Add SEP token to end of text label
|
657 |
+
sep_token = torch.tensor(self.sep_token, device=device).repeat(
|
658 |
+
labels.shape[0], 1
|
659 |
+
)
|
660 |
+
labels = torch.cat([labels, sep_token], dim=1)
|
661 |
+
|
662 |
+
# If condition exists and condition vae is defined, add the condition to the labels
|
663 |
+
if exists(condition) and exists(self.condition_vae):
|
664 |
+
offsetted_condition = condition + self.num_text_tokens + 1
|
665 |
+
labels = torch.cat((labels, offsetted_condition), dim=1)
|
666 |
+
|
667 |
+
# Add image to the labels
|
668 |
+
offsetted_image = (
|
669 |
+
image + self.num_text_tokens + 1 + self.num_condition_tokens + 1
|
670 |
+
)
|
671 |
+
labels = torch.cat((labels, offsetted_image), dim=1)
|
672 |
+
|
673 |
+
# Rearrange logits for cross-entropy loss calculation
|
674 |
+
# Logits size: (batch_size, vocab_size, total_seq_len)
|
675 |
+
# Labels size: (batch_size, total_seq_len)
|
676 |
+
logits = rearrange(logits, "b n c -> b c n")
|
677 |
+
|
678 |
+
# Calculate cross-entropy loss for text and image
|
679 |
+
loss_text = F.cross_entropy(
|
680 |
+
logits[:, :, : self.text_seq_len],
|
681 |
+
labels[:, : self.text_seq_len],
|
682 |
+
reduction="none",
|
683 |
+
)[text_mask].mean()
|
684 |
+
|
685 |
+
loss_img = F.cross_entropy(
|
686 |
+
logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :],
|
687 |
+
labels[:, self.text_seq_len + 1 + self.condition_seq_len :],
|
688 |
+
reduction="none",
|
689 |
+
)[image_mask].mean()
|
690 |
+
|
691 |
+
# Calculate total loss
|
692 |
+
loss = (loss_text + self.loss_img_weight * loss_img) / (
|
693 |
+
self.loss_img_weight + 1
|
694 |
+
)
|
695 |
+
|
696 |
+
loss_dict = {
|
697 |
+
"loss_text": loss_text,
|
698 |
+
# "loss_cond": loss_cond,
|
699 |
+
"loss_img": loss_img,
|
700 |
+
"loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0),
|
701 |
+
}
|
702 |
+
|
703 |
+
return loss, loss_dict, None
|
704 |
+
|
705 |
+
def create_tensors(self, text, condition, image):
|
706 |
+
"""
|
707 |
+
This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function.
|
708 |
+
"""
|
709 |
+
device = next(
|
710 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
711 |
+
None,
|
712 |
+
).device
|
713 |
+
|
714 |
+
if not isinstance(text, torch.Tensor):
|
715 |
+
text = (
|
716 |
+
torch.ones(1, self.text_seq_len, device=device, dtype=torch.long)
|
717 |
+
* self.text_mask_token
|
718 |
+
)
|
719 |
+
|
720 |
+
if not isinstance(condition, torch.Tensor):
|
721 |
+
condition = (
|
722 |
+
torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long)
|
723 |
+
* self.cond_mask_token
|
724 |
+
)
|
725 |
+
else:
|
726 |
+
with torch.no_grad():
|
727 |
+
condition = self.condition_vae.get_codebook_indices(condition)
|
728 |
+
|
729 |
+
if not isinstance(image, torch.Tensor):
|
730 |
+
image = (
|
731 |
+
torch.ones(1, self.image_seq_len, device=device, dtype=torch.long)
|
732 |
+
* self.image_mask_token
|
733 |
+
)
|
734 |
+
else:
|
735 |
+
with torch.no_grad():
|
736 |
+
image = self.vae.get_codebook_indices(image)
|
737 |
+
|
738 |
+
return text, condition, image
|
739 |
+
|
740 |
+
@torch.no_grad()
|
741 |
+
@eval_decorator
|
742 |
+
def sample(
|
743 |
+
self,
|
744 |
+
text=None,
|
745 |
+
condition=None,
|
746 |
+
image=None,
|
747 |
+
temperature=1.0,
|
748 |
+
filter_thres=0.9,
|
749 |
+
progress=False,
|
750 |
+
timesteps=1,
|
751 |
+
force_aas=True,
|
752 |
+
):
|
753 |
+
# ensure timesteps is a positive integer
|
754 |
+
assert int(timesteps) > 0
|
755 |
+
# set model and VAEs to evaluation mode
|
756 |
+
self.eval()
|
757 |
+
vae = self.vae.eval()
|
758 |
+
if progress == True:
|
759 |
+
progress = tqdm
|
760 |
+
else:
|
761 |
+
progress = lambda x: x
|
762 |
+
|
763 |
+
|
764 |
+
# ensure that at least one of text, condition, or image is supplied
|
765 |
+
assert (
|
766 |
+
isinstance(text, torch.Tensor)
|
767 |
+
or isinstance(condition, torch.Tensor)
|
768 |
+
or isinstance(image, torch.Tensor)
|
769 |
+
), "some data must be supplied"
|
770 |
+
|
771 |
+
# convert text, condition, and image to tensors if they aren't already
|
772 |
+
text, condition, image = self.create_tensors(text, condition, image)
|
773 |
+
|
774 |
+
# determine the maximum batch size of the input tensors
|
775 |
+
batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
|
776 |
+
|
777 |
+
# match the batch sizes of text, condition, and image
|
778 |
+
text, condition, image = match_batch_size(text, condition, image, batch_size)
|
779 |
+
|
780 |
+
# determine the device of the tensors
|
781 |
+
device = next(
|
782 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
783 |
+
None,
|
784 |
+
).device
|
785 |
+
|
786 |
+
assert text.shape[0] == condition.shape[0] == image.shape[0]
|
787 |
+
|
788 |
+
# Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
|
789 |
+
|
790 |
+
# full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
|
791 |
+
full_text_logits = torch.zeros(
|
792 |
+
batch_size, self.text_seq_len, self.num_text_tokens
|
793 |
+
).to(device)
|
794 |
+
|
795 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
796 |
+
full_text_logits = full_text_logits.scatter_(
|
797 |
+
dim=-1, index=text.unsqueeze(-1), value=1
|
798 |
+
)
|
799 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
800 |
+
full_image_logits = torch.zeros(
|
801 |
+
batch_size, self.image_seq_len, self.num_image_tokens + 1
|
802 |
+
).to(device)
|
803 |
+
|
804 |
+
# Remove the last token from each image sequence by setting full_image_logits to its first num_image_tokens elements
|
805 |
+
full_image_logits = full_image_logits.scatter_(
|
806 |
+
dim=-1, index=image.unsqueeze(-1), value=1
|
807 |
+
)
|
808 |
+
|
809 |
+
# cut off mask token
|
810 |
+
full_image_logits = full_image_logits[:, :, : self.num_image_tokens]
|
811 |
+
|
812 |
+
count = 0
|
813 |
+
|
814 |
+
for timestep in progress(torch.linspace(0, 1, timesteps)):
|
815 |
+
# Create masks for the text, condition, and image tensors
|
816 |
+
text_mask = text == self.text_mask_token
|
817 |
+
cond_mask = condition == self.cond_mask_token
|
818 |
+
image_mask = image == self.image_mask_token
|
819 |
+
|
820 |
+
# Calculate logits and samples using the calculate_logits function
|
821 |
+
logits, sample = calculate_logits(
|
822 |
+
[text, condition, image],
|
823 |
+
[text_mask, cond_mask, image_mask],
|
824 |
+
self.embed_and_transform,
|
825 |
+
filter_thres,
|
826 |
+
temperature,
|
827 |
+
)
|
828 |
+
|
829 |
+
# Calculate the number of masked tokens in the text and image tensors
|
830 |
+
num_masked_text_tokens = torch.sum(text_mask, dim=1)[0]
|
831 |
+
num_masked_image_tokens = torch.sum(image_mask, dim=1)[0]
|
832 |
+
|
833 |
+
# If there are masked text tokens, unmask them using unmask_tokens and fill the full text logits tensor with -inf for unmasked tokens
|
834 |
+
if num_masked_text_tokens.any() > 0:
|
835 |
+
text, full_text_logits = unmask_tokens(
|
836 |
+
text,
|
837 |
+
text_mask,
|
838 |
+
num_masked_text_tokens,
|
839 |
+
logits[:, : self.text_seq_len, : self.num_text_tokens],
|
840 |
+
sample[:, : self.text_seq_len, : self.num_text_tokens],
|
841 |
+
timestep,
|
842 |
+
timesteps,
|
843 |
+
self.gamma,
|
844 |
+
suppress_invalid_text_tokens,
|
845 |
+
self.pad_token,
|
846 |
+
self.text_mask_token,
|
847 |
+
force_aas=force_aas,
|
848 |
+
)
|
849 |
+
full_text_logits = full_text_logits.masked_fill(
|
850 |
+
~text_mask.unsqueeze(-1), -torch.inf
|
851 |
+
)
|
852 |
+
|
853 |
+
# If there are masked image tokens, unmask them using unmask_tokens and fill the full image logits tensor with -inf for unmasked tokens
|
854 |
+
if num_masked_image_tokens > 0:
|
855 |
+
image, full_image_logits = unmask_tokens(
|
856 |
+
image,
|
857 |
+
image_mask,
|
858 |
+
num_masked_image_tokens,
|
859 |
+
logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
|
860 |
+
sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
|
861 |
+
timestep,
|
862 |
+
timesteps,
|
863 |
+
self.gamma,
|
864 |
+
)
|
865 |
+
full_text_logits = full_text_logits.masked_fill(
|
866 |
+
~text_mask.unsqueeze(-1), -torch.inf
|
867 |
+
)
|
868 |
+
|
869 |
+
# Generate heatmap
|
870 |
+
with torch.no_grad():
|
871 |
+
# Normalize full image logits tensor
|
872 |
+
full_image_logits /= torch.max(
|
873 |
+
torch.abs(full_image_logits), dim=-1, keepdim=True
|
874 |
+
).values
|
875 |
+
|
876 |
+
# Apply quantize embedding to full image logits tensor
|
877 |
+
full_image_logits = torch.matmul(
|
878 |
+
full_image_logits, self.vae.model.quantize.embedding.weight
|
879 |
+
)
|
880 |
+
|
881 |
+
# Rearrange full image logits tensor
|
882 |
+
h = int(self.image_seq_len**0.5)
|
883 |
+
full_image_logits = rearrange(
|
884 |
+
full_image_logits, "b (h w) c -> b c h w", h=h
|
885 |
+
)
|
886 |
+
|
887 |
+
# Decode full image logits tensor
|
888 |
+
full_image_logits = self.vae.model.decode(full_image_logits)
|
889 |
+
|
890 |
+
# Add clipping to full image logits tensor
|
891 |
+
max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0]
|
892 |
+
min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0]
|
893 |
+
full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view(
|
894 |
+
batch_size, 1, 1, 1
|
895 |
+
)
|
896 |
+
full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view(
|
897 |
+
batch_size, 1, 1, 1
|
898 |
+
)
|
899 |
+
|
900 |
+
# Clip full image logits tensor values to the range [0, 1]
|
901 |
+
full_image_logits = torch.clip(full_image_logits, 0, 1)
|
902 |
+
|
903 |
+
# Return text tensor, detokenized text tensor, full text logits tensor,
|
904 |
+
# binary image tensor, and full image logits tensor
|
905 |
+
return (
|
906 |
+
text,
|
907 |
+
detokenize_text(self.text_embedding, text),
|
908 |
+
full_text_logits,
|
909 |
+
1.0 * (vae.decode(image) > 0.5),
|
910 |
+
full_image_logits,
|
911 |
+
)
|
912 |
+
|
913 |
+
@torch.no_grad()
|
914 |
+
@eval_decorator
|
915 |
+
def sample_text(
|
916 |
+
self,
|
917 |
+
text=False,
|
918 |
+
condition=False,
|
919 |
+
image=False,
|
920 |
+
temperature=1.0,
|
921 |
+
filter_thres=0.9,
|
922 |
+
progress=False,
|
923 |
+
n_unmask=1,
|
924 |
+
place_amino=True,
|
925 |
+
force_aas=False,
|
926 |
+
):
|
927 |
+
# set model and VAEs to evaluation mode
|
928 |
+
self.eval()
|
929 |
+
|
930 |
+
# ensure that at least one of text, condition, or image is supplied
|
931 |
+
assert (
|
932 |
+
isinstance(text, torch.Tensor)
|
933 |
+
or isinstance(condition, torch.Tensor)
|
934 |
+
or isinstance(image, torch.Tensor)
|
935 |
+
), "some data must be supplied"
|
936 |
+
|
937 |
+
# convert text, condition, and image to tensors if they aren't already
|
938 |
+
text, condition, image = self.create_tensors(text, condition, image)
|
939 |
+
|
940 |
+
# determine the maximum batch size of the input tensors
|
941 |
+
batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
|
942 |
+
|
943 |
+
# match the batch sizes of text, condition, and image
|
944 |
+
text, condition, image = match_batch_size(text, condition, image, batch_size)
|
945 |
+
|
946 |
+
# determine the device of the tensors
|
947 |
+
device = next(
|
948 |
+
filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
|
949 |
+
None,
|
950 |
+
).device
|
951 |
+
|
952 |
+
assert text.shape[0] == condition.shape[0] == image.shape[0]
|
953 |
+
|
954 |
+
# Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
|
955 |
+
|
956 |
+
# full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
|
957 |
+
full_text_logits = torch.zeros(
|
958 |
+
batch_size, self.text_seq_len, self.num_text_tokens
|
959 |
+
).to(device)
|
960 |
+
|
961 |
+
# Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
|
962 |
+
full_text_logits = full_text_logits.scatter_(
|
963 |
+
dim=-1, index=text.unsqueeze(-1), value=1
|
964 |
+
)
|
965 |
+
|
966 |
+
text_mask = text == self.text_mask_token
|
967 |
+
cond_mask = condition == self.cond_mask_token
|
968 |
+
image_mask = image == self.image_mask_token
|
969 |
+
|
970 |
+
mask_indices = text_mask.nonzero()
|
971 |
+
non_mask_indices = (~text_mask).nonzero()
|
972 |
+
|
973 |
+
# figure out the center of the amino acids to determine generation direction
|
974 |
+
central_protein_index = torch.tensor(
|
975 |
+
[
|
976 |
+
torch.median(
|
977 |
+
non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1]
|
978 |
+
)
|
979 |
+
for idx in range(batch_size)
|
980 |
+
]
|
981 |
+
)
|
982 |
+
|
983 |
+
count = 1
|
984 |
+
|
985 |
+
run_mask = text_mask
|
986 |
+
if progress:
|
987 |
+
pbar = progress(total=torch.sum(run_mask).item())
|
988 |
+
while torch.sum(run_mask) > 0:
|
989 |
+
logits, sample = calculate_logits(
|
990 |
+
[text, condition, image],
|
991 |
+
[text_mask, cond_mask, image_mask],
|
992 |
+
self.embed_and_transform,
|
993 |
+
filter_thres,
|
994 |
+
temperature,
|
995 |
+
)
|
996 |
+
|
997 |
+
# sub_sample: [batch_size ,text_seq_len ,num_text_tokens]
|
998 |
+
sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens]
|
999 |
+
sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf)
|
1000 |
+
sub_sample = suppress_invalid_text_tokens(
|
1001 |
+
text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas
|
1002 |
+
)
|
1003 |
+
# calculate % to unmasked
|
1004 |
+
# get most likely token and probability for each position
|
1005 |
+
|
1006 |
+
for idx in range(batch_size):
|
1007 |
+
selected_mask_indices = mask_indices[
|
1008 |
+
torch.where(mask_indices[:, 0] == idx)
|
1009 |
+
][:, -1]
|
1010 |
+
|
1011 |
+
# Generate to the left
|
1012 |
+
if selected_mask_indices[-count] < central_protein_index[idx]:
|
1013 |
+
unmask_index = selected_mask_indices[-count]
|
1014 |
+
left_sample = max(0, (unmask_index + 1) - n_unmask)
|
1015 |
+
right_sample = min(unmask_index + 1, self.text_seq_len - 1)
|
1016 |
+
central_protein_index[idx] = max(
|
1017 |
+
0, central_protein_index[idx] - 0.5 * n_unmask
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
# Generate to the right
|
1021 |
+
elif selected_mask_indices[count - 1] > central_protein_index[idx]:
|
1022 |
+
unmask_index = selected_mask_indices[count - 1]
|
1023 |
+
left_sample = max(0, unmask_index)
|
1024 |
+
right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1)
|
1025 |
+
central_protein_index[idx] = min(
|
1026 |
+
central_protein_index[idx] + 0.5 * n_unmask,
|
1027 |
+
self.text_seq_len - 1,
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
# save logits for relevant position
|
1031 |
+
full_text_logits[
|
1032 |
+
idx, left_sample:right_sample, : self.text_seq_len - 1
|
1033 |
+
] = logits[idx, left_sample:right_sample, : self.num_text_tokens]
|
1034 |
+
|
1035 |
+
run_mask[idx, left_sample:right_sample] = False
|
1036 |
+
|
1037 |
+
# you may want to resample the amion acids or calculate marginal probs
|
1038 |
+
# if so, set place_amino to false
|
1039 |
+
if place_amino:
|
1040 |
+
text[idx, left_sample:right_sample] = torch.where(
|
1041 |
+
text[idx, left_sample:right_sample] == self.text_mask_token,
|
1042 |
+
sub_sample[
|
1043 |
+
idx, left_sample:right_sample, : self.num_text_tokens
|
1044 |
+
].argmax(dim=-1),
|
1045 |
+
text[idx, left_sample:right_sample],
|
1046 |
+
)
|
1047 |
+
|
1048 |
+
text_mask = run_mask
|
1049 |
+
|
1050 |
+
count += n_unmask
|
1051 |
+
|
1052 |
+
if progress:
|
1053 |
+
pbar.update(n_unmask)
|
1054 |
+
if progress:
|
1055 |
+
pbar.close()
|
1056 |
+
|
1057 |
+
return (
|
1058 |
+
text,
|
1059 |
+
detokenize_text(self.text_embedding, text),
|
1060 |
+
full_text_logits,
|
1061 |
+
)
|
celle/utils.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image, ImageSequence
|
4 |
+
from math import pi
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
|
7 |
+
|
8 |
+
# Define helper functions
|
9 |
+
def exists(val):
|
10 |
+
"""Check if a variable exists"""
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
|
14 |
+
def uniq(arr):
|
15 |
+
return {el: True for el in arr}.keys()
|
16 |
+
|
17 |
+
|
18 |
+
def default(val, d):
|
19 |
+
"""If a value exists, return it; otherwise, return a default value"""
|
20 |
+
return val if exists(val) else d
|
21 |
+
|
22 |
+
|
23 |
+
def max_neg_value(t):
|
24 |
+
return -torch.finfo(t.dtype).max
|
25 |
+
|
26 |
+
|
27 |
+
def cast_tuple(val, depth=1):
|
28 |
+
if isinstance(val, list):
|
29 |
+
val = tuple(val)
|
30 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
31 |
+
|
32 |
+
|
33 |
+
def is_empty(t):
|
34 |
+
"""Check if a tensor is empty"""
|
35 |
+
# Return True if the number of elements in the tensor is zero, else False
|
36 |
+
return t.nelement() == 0
|
37 |
+
|
38 |
+
|
39 |
+
def masked_mean(t, mask, dim=1):
|
40 |
+
"""
|
41 |
+
Compute the mean of a tensor, masked by a given mask
|
42 |
+
|
43 |
+
Args:
|
44 |
+
t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim)
|
45 |
+
mask (torch.Tensor): mask tensor of shape (batch_size, seq_len)
|
46 |
+
dim (int): dimension along which to compute the mean (default=1)
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim)
|
50 |
+
"""
|
51 |
+
t = t.masked_fill(~mask[:, :, None], 0.0)
|
52 |
+
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
53 |
+
|
54 |
+
|
55 |
+
def set_requires_grad(model, value):
|
56 |
+
"""
|
57 |
+
Set whether or not the model's parameters require gradients
|
58 |
+
|
59 |
+
Args:
|
60 |
+
model (torch.nn.Module): the PyTorch model to modify
|
61 |
+
value (bool): whether or not to require gradients
|
62 |
+
"""
|
63 |
+
for param in model.parameters():
|
64 |
+
param.requires_grad = value
|
65 |
+
|
66 |
+
|
67 |
+
def eval_decorator(fn):
|
68 |
+
"""
|
69 |
+
Decorator function to evaluate a given function
|
70 |
+
|
71 |
+
Args:
|
72 |
+
fn (callable): function to evaluate
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
callable: the decorated function
|
76 |
+
"""
|
77 |
+
|
78 |
+
def inner(model, *args, **kwargs):
|
79 |
+
was_training = model.training
|
80 |
+
model.eval()
|
81 |
+
out = fn(model, *args, **kwargs)
|
82 |
+
model.train(was_training)
|
83 |
+
return out
|
84 |
+
|
85 |
+
return inner
|
86 |
+
|
87 |
+
|
88 |
+
def log(t, eps=1e-20):
|
89 |
+
"""
|
90 |
+
Compute the natural logarithm of a tensor
|
91 |
+
|
92 |
+
Args:
|
93 |
+
t (torch.Tensor): input tensor
|
94 |
+
eps (float): small value to add to prevent taking the log of 0 (default=1e-20)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
torch.Tensor: the natural logarithm of the input tensor
|
98 |
+
"""
|
99 |
+
return torch.log(t + eps)
|
100 |
+
|
101 |
+
|
102 |
+
def gumbel_noise(t):
|
103 |
+
"""
|
104 |
+
Generate Gumbel noise
|
105 |
+
|
106 |
+
Args:
|
107 |
+
t (torch.Tensor): input tensor
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor
|
111 |
+
"""
|
112 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
113 |
+
return -log(-log(noise))
|
114 |
+
|
115 |
+
|
116 |
+
def gumbel_sample(t, temperature=0.9, dim=-1):
|
117 |
+
"""
|
118 |
+
Sample from a Gumbel-softmax distribution
|
119 |
+
|
120 |
+
Args:
|
121 |
+
t (torch.Tensor): input tensor of shape (batch_size, num_classes)
|
122 |
+
temperature (float): temperature for the Gumbel-softmax distribution (default=0.9)
|
123 |
+
dim (int): dimension along which to sample (default=-1)
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor
|
127 |
+
"""
|
128 |
+
return (t / max(temperature, 1e-10)) + gumbel_noise(t)
|
129 |
+
|
130 |
+
|
131 |
+
def top_k(logits, thres=0.5):
|
132 |
+
"""
|
133 |
+
Return a tensor where all but the top k values are set to negative infinity
|
134 |
+
|
135 |
+
Args:
|
136 |
+
logits (torch.Tensor): input tensor of shape (batch_size, num_classes)
|
137 |
+
thres (float): threshold for the top k values (default=0.5)
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity
|
141 |
+
"""
|
142 |
+
num_logits = logits.shape[-1]
|
143 |
+
k = max(int((1 - thres) * num_logits), 1)
|
144 |
+
val, ind = torch.topk(logits, k)
|
145 |
+
probs = torch.full_like(logits, float("-inf"))
|
146 |
+
probs.scatter_(-1, ind, val)
|
147 |
+
return probs
|
148 |
+
|
149 |
+
|
150 |
+
def gamma_func(mode="cosine", scale=0.15):
|
151 |
+
"""Return a function that takes a single input r and returns a value based on the selected mode"""
|
152 |
+
|
153 |
+
# Define a different function based on the selected mode
|
154 |
+
if mode == "linear":
|
155 |
+
return lambda r: 1 - r
|
156 |
+
elif mode == "cosine":
|
157 |
+
return lambda r: torch.cos(r * pi / 2)
|
158 |
+
elif mode == "square":
|
159 |
+
return lambda r: 1 - r**2
|
160 |
+
elif mode == "cubic":
|
161 |
+
return lambda r: 1 - r**3
|
162 |
+
elif mode == "scaled-cosine":
|
163 |
+
return lambda r: scale * (torch.cos(r * pi / 2))
|
164 |
+
else:
|
165 |
+
# Raise an error if the selected mode is not implemented
|
166 |
+
raise NotImplementedError
|
167 |
+
|
168 |
+
|
169 |
+
class always:
|
170 |
+
"""Helper class to always return a given value"""
|
171 |
+
|
172 |
+
def __init__(self, val):
|
173 |
+
self.val = val
|
174 |
+
|
175 |
+
def __call__(self, x, *args, **kwargs):
|
176 |
+
return self.val
|
177 |
+
|
178 |
+
|
179 |
+
class DivideMax(torch.nn.Module):
|
180 |
+
def __init__(self, dim):
|
181 |
+
super().__init__()
|
182 |
+
self.dim = dim
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
maxes = x.amax(dim=self.dim, keepdim=True).detach()
|
186 |
+
return x / maxes
|
187 |
+
|
188 |
+
def replace_outliers(image, percentile=0.0001):
|
189 |
+
|
190 |
+
lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile(
|
191 |
+
image, 1 - percentile
|
192 |
+
)
|
193 |
+
mask = (image <= upper_bound) & (image >= lower_bound)
|
194 |
+
|
195 |
+
valid_pixels = image[mask]
|
196 |
+
|
197 |
+
image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels))
|
198 |
+
|
199 |
+
return image
|
200 |
+
|
201 |
+
|
202 |
+
def process_image(image, dataset, image_type=None):
|
203 |
+
image = TF.to_tensor(image)[0].unsqueeze(0).unsqueeze(0)
|
204 |
+
image /= image.max()
|
205 |
+
|
206 |
+
if dataset == "HPA":
|
207 |
+
if image_type == 'nucleus':
|
208 |
+
normalize = (0.0655, 0.0650)
|
209 |
+
|
210 |
+
elif image_type == 'protein':
|
211 |
+
normalize = (0.1732, 0.1208)
|
212 |
+
|
213 |
+
elif dataset == "OpenCell":
|
214 |
+
|
215 |
+
if image_type == 'nucleus':
|
216 |
+
normalize = (0.0272, 0.0244)
|
217 |
+
|
218 |
+
elif image_type == 'protein':
|
219 |
+
normalize = (0.0486, 0.0671)
|
220 |
+
|
221 |
+
t_forms = []
|
222 |
+
|
223 |
+
t_forms.append(transforms.RandomCrop(256))
|
224 |
+
|
225 |
+
# t_forms.append(transforms.Normalize(normalize[0],normalize[1]))
|
226 |
+
|
227 |
+
|
228 |
+
image = transforms.Compose(t_forms)(image)
|
229 |
+
|
230 |
+
return image
|
dataloader.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageSequence
|
4 |
+
import json
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision import transforms
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
|
12 |
+
from celle.utils import replace_outliers
|
13 |
+
|
14 |
+
def simple_conversion(seq):
|
15 |
+
"""Create 26-dim embedding"""
|
16 |
+
chars = [
|
17 |
+
"-",
|
18 |
+
"M",
|
19 |
+
"R",
|
20 |
+
"H",
|
21 |
+
"K",
|
22 |
+
"D",
|
23 |
+
"E",
|
24 |
+
"S",
|
25 |
+
"T",
|
26 |
+
"N",
|
27 |
+
"Q",
|
28 |
+
"C",
|
29 |
+
"U",
|
30 |
+
"G",
|
31 |
+
"P",
|
32 |
+
"A",
|
33 |
+
"V",
|
34 |
+
"I",
|
35 |
+
"F",
|
36 |
+
"Y",
|
37 |
+
"W",
|
38 |
+
"L",
|
39 |
+
"O",
|
40 |
+
"X",
|
41 |
+
"Z",
|
42 |
+
"B",
|
43 |
+
"J",
|
44 |
+
]
|
45 |
+
|
46 |
+
nums = range(len(chars))
|
47 |
+
|
48 |
+
seqs_x = np.zeros(len(seq))
|
49 |
+
|
50 |
+
for idx, char in enumerate(seq):
|
51 |
+
|
52 |
+
lui = chars.index(char)
|
53 |
+
|
54 |
+
seqs_x[idx] = nums[lui]
|
55 |
+
|
56 |
+
return torch.tensor([seqs_x]).long()
|
57 |
+
|
58 |
+
|
59 |
+
class CellLoader(Dataset):
|
60 |
+
"""imports mined opencell images with protein sequence"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
data_csv=None,
|
65 |
+
dataset=None,
|
66 |
+
split_key=None,
|
67 |
+
resize=600,
|
68 |
+
crop_size=600,
|
69 |
+
crop_method="random",
|
70 |
+
sequence_mode="simple",
|
71 |
+
vocab="bert",
|
72 |
+
threshold="median",
|
73 |
+
text_seq_len=0,
|
74 |
+
pad_mode="random",
|
75 |
+
):
|
76 |
+
self.data_csv = data_csv
|
77 |
+
self.dataset = dataset
|
78 |
+
self.image_folders = []
|
79 |
+
self.crop_method = crop_method
|
80 |
+
self.resize = resize
|
81 |
+
self.crop_size = crop_size
|
82 |
+
self.sequence_mode = sequence_mode
|
83 |
+
self.threshold = threshold
|
84 |
+
self.text_seq_len = int(text_seq_len)
|
85 |
+
self.vocab = vocab
|
86 |
+
self.pad_mode = pad_mode
|
87 |
+
|
88 |
+
if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
|
89 |
+
|
90 |
+
|
91 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
92 |
+
from esm import Alphabet
|
93 |
+
|
94 |
+
self.tokenizer = Alphabet.from_architecture(
|
95 |
+
"ESM-1b"
|
96 |
+
).get_batch_converter()
|
97 |
+
self.text_seq_len += 2
|
98 |
+
|
99 |
+
if data_csv:
|
100 |
+
|
101 |
+
data = pd.read_csv(data_csv)
|
102 |
+
|
103 |
+
self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
|
104 |
+
|
105 |
+
if split_key == "train":
|
106 |
+
self.data = data[data["split"] == "train"]
|
107 |
+
elif split_key == "val":
|
108 |
+
self.data = data[data["split"] == "val"]
|
109 |
+
else:
|
110 |
+
self.data = data
|
111 |
+
|
112 |
+
self.data = self.data.reset_index(drop=True)
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.data)
|
118 |
+
|
119 |
+
def __getitem__(
|
120 |
+
self,
|
121 |
+
idx,
|
122 |
+
get_sequence=True,
|
123 |
+
get_images=True,
|
124 |
+
):
|
125 |
+
if get_sequence and self.text_seq_len > 0:
|
126 |
+
|
127 |
+
protein_vector = self.get_protein_vector(idx)
|
128 |
+
|
129 |
+
else:
|
130 |
+
protein_vector = torch.zeros((1, 1))
|
131 |
+
|
132 |
+
if get_images:
|
133 |
+
|
134 |
+
nucleus, target, threshold = self.get_images(idx, self.dataset)
|
135 |
+
else:
|
136 |
+
nucleus, target, threshold = torch.zeros((3, 1))
|
137 |
+
|
138 |
+
data_dict = {
|
139 |
+
"nucleus": nucleus.float(),
|
140 |
+
"target": target.float(),
|
141 |
+
"threshold": threshold.float(),
|
142 |
+
"sequence": protein_vector.long(),
|
143 |
+
}
|
144 |
+
|
145 |
+
return data_dict
|
146 |
+
|
147 |
+
def get_protein_vector(self, idx):
|
148 |
+
|
149 |
+
if "protein_sequence" not in self.data.columns:
|
150 |
+
|
151 |
+
metadata = self.retrieve_metadata(idx)
|
152 |
+
protein_sequence = metadata["sequence"]
|
153 |
+
else:
|
154 |
+
protein_sequence = self.data.iloc[idx]["protein_sequence"]
|
155 |
+
|
156 |
+
protein_vector = self.tokenize_sequence(protein_sequence)
|
157 |
+
|
158 |
+
return protein_vector
|
159 |
+
|
160 |
+
def get_images(self, idx, dataset):
|
161 |
+
|
162 |
+
if dataset == "HPA":
|
163 |
+
|
164 |
+
nucleus = Image.open(
|
165 |
+
os.path.join(
|
166 |
+
self.parent_path, self.data.iloc[idx]["nucleus_image_path"]
|
167 |
+
)
|
168 |
+
)
|
169 |
+
|
170 |
+
target = Image.open(
|
171 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["target_image_path"])
|
172 |
+
)
|
173 |
+
|
174 |
+
nucleus = TF.to_tensor(nucleus)[0]
|
175 |
+
target = TF.to_tensor(target)[0]
|
176 |
+
|
177 |
+
image = torch.stack([nucleus, target], axis=0)
|
178 |
+
|
179 |
+
normalize = (0.0655, 0.0650), (0.1732, 0.1208)
|
180 |
+
|
181 |
+
elif dataset == "OpenCell":
|
182 |
+
image = Image.open(
|
183 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["image_path"])
|
184 |
+
)
|
185 |
+
nucleus, target = [page.copy() for page in ImageSequence.Iterator(image)]
|
186 |
+
|
187 |
+
nucleus = replace_outliers(torch.divide(TF.to_tensor(nucleus), 65536))[0]
|
188 |
+
target = replace_outliers(torch.divide(TF.to_tensor(target), 65536))[0]
|
189 |
+
|
190 |
+
image = torch.stack([nucleus, target], axis=0)
|
191 |
+
|
192 |
+
normalize = (
|
193 |
+
(0.0272, 0.0244),
|
194 |
+
(0.0486, 0.0671),
|
195 |
+
)
|
196 |
+
|
197 |
+
# # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
|
198 |
+
|
199 |
+
t_forms = [transforms.Resize(self.resize, antialias=None)]
|
200 |
+
|
201 |
+
if self.crop_method == "random":
|
202 |
+
|
203 |
+
t_forms.append(transforms.RandomCrop(self.crop_size))
|
204 |
+
t_forms.append(transforms.RandomHorizontalFlip(p=0.5))
|
205 |
+
t_forms.append(transforms.RandomVerticalFlip(p=0.5))
|
206 |
+
|
207 |
+
elif self.crop_method == "center":
|
208 |
+
|
209 |
+
t_forms.append(transforms.CenterCrop(self.crop_size))
|
210 |
+
|
211 |
+
t_forms.append(transforms.Normalize(normalize[0], normalize[1]))
|
212 |
+
|
213 |
+
image = transforms.Compose(t_forms)(image)
|
214 |
+
|
215 |
+
nucleus, target = image
|
216 |
+
|
217 |
+
nucleus /= torch.abs(nucleus).max()
|
218 |
+
target -= target.min()
|
219 |
+
target /= target.max()
|
220 |
+
|
221 |
+
nucleus = nucleus.unsqueeze(0)
|
222 |
+
target = target.unsqueeze(0)
|
223 |
+
|
224 |
+
threshold = target
|
225 |
+
|
226 |
+
if self.threshold == "mean":
|
227 |
+
|
228 |
+
threshold = 1.0 * (threshold > (torch.mean(threshold)))
|
229 |
+
|
230 |
+
elif self.threshold == "median":
|
231 |
+
|
232 |
+
threshold = 1.0 * (threshold > (torch.median(threshold)))
|
233 |
+
|
234 |
+
elif self.threshold == "1090_IQR":
|
235 |
+
|
236 |
+
p10 = torch.quantile(threshold, 0.1, None)
|
237 |
+
p90 = torch.quantile(threshold, 0.9, None)
|
238 |
+
threshold = torch.clip(threshold, p10, p90)
|
239 |
+
|
240 |
+
nucleus = torch.nan_to_num(nucleus, 0.0, 1.0, 0.0)
|
241 |
+
target = torch.nan_to_num(target, 0.0, 1.0, 0.0)
|
242 |
+
threshold = torch.nan_to_num(threshold, 0.0, 1.0, 0.0)
|
243 |
+
|
244 |
+
return nucleus, target, threshold
|
245 |
+
|
246 |
+
def retrieve_metadata(self, idx):
|
247 |
+
with open(
|
248 |
+
os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
|
249 |
+
) as f:
|
250 |
+
metadata = json.load(f)
|
251 |
+
|
252 |
+
return metadata
|
253 |
+
|
254 |
+
def tokenize_sequence(self, protein_sequence):
|
255 |
+
|
256 |
+
pad_token = 0
|
257 |
+
|
258 |
+
if self.sequence_mode == "simple":
|
259 |
+
protein_vector = simple_conversion(protein_sequence)
|
260 |
+
|
261 |
+
elif self.sequence_mode == "center":
|
262 |
+
protein_sequence = protein_sequence.center(self.text_seq_length, "-")
|
263 |
+
protein_vector = simple_conversion(protein_sequence)
|
264 |
+
|
265 |
+
elif self.sequence_mode == "alternating":
|
266 |
+
protein_sequence = protein_sequence.center(self.text_seq_length, "-")
|
267 |
+
protein_sequence = protein_sequence[::18]
|
268 |
+
protein_sequence = protein_sequence.center(
|
269 |
+
int(self.text_seq_length / 18) + 1, "-"
|
270 |
+
)
|
271 |
+
protein_vector = simple_conversion(protein_sequence)
|
272 |
+
|
273 |
+
|
274 |
+
elif self.sequence_mode == "embedding":
|
275 |
+
|
276 |
+
if self.vocab == "esm1b" or self.vocab == "esm2":
|
277 |
+
pad_token = 1
|
278 |
+
protein_vector = self.tokenizer([("", protein_sequence)])[-1]
|
279 |
+
|
280 |
+
if protein_vector.shape[-1] < self.text_seq_len:
|
281 |
+
|
282 |
+
diff = self.text_seq_len - protein_vector.shape[-1]
|
283 |
+
|
284 |
+
if self.pad_mode == "end":
|
285 |
+
protein_vector = torch.nn.functional.pad(
|
286 |
+
protein_vector, (0, diff), "constant", pad_token
|
287 |
+
)
|
288 |
+
elif self.pad_mode == "random":
|
289 |
+
split = diff - np.random.randint(0, diff + 1)
|
290 |
+
|
291 |
+
protein_vector = torch.cat(
|
292 |
+
[torch.ones(1, split) * 0, protein_vector], dim=1
|
293 |
+
)
|
294 |
+
|
295 |
+
protein_vector = torch.nn.functional.pad(
|
296 |
+
protein_vector, (0, diff - split), "constant", pad_token
|
297 |
+
)
|
298 |
+
|
299 |
+
elif protein_vector.shape[-1] > self.text_seq_len:
|
300 |
+
start_int = np.random.randint(
|
301 |
+
0, protein_vector.shape[-1] - self.text_seq_len
|
302 |
+
)
|
303 |
+
|
304 |
+
protein_vector = protein_vector[
|
305 |
+
:, start_int : start_int + self.text_seq_len
|
306 |
+
]
|
307 |
+
|
308 |
+
return protein_vector.long()
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
os
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
huggingface_hub
|
5 |
+
gradio
|
6 |
+
OmegaConf
|
7 |
+
axial-positional-embedding
|
8 |
+
einops
|
9 |
+
rotary_embedding_torch
|
10 |
+
fair-esm
|
11 |
+
tqdm
|
12 |
+
importlib
|
13 |
+
pytorch-lightning==1.9.0
|