|
import argparse |
|
import torch |
|
import os |
|
os.chdir('..') |
|
from dataloader import CellLoader |
|
from matplotlib import pyplot as plt |
|
from celle_main import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
from celle.utils import process_image |
|
|
|
def run_model(mode, sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device): |
|
if mode == "image": |
|
run_image_prediction( |
|
sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
) |
|
elif mode == "sequence": |
|
run_sequence_prediction( |
|
sequence, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
) |
|
|
|
def run_sequence_prediction( |
|
sequence_input, |
|
nucleus_image_path, |
|
protein_image_path, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
if len(sequence_input) == 0: |
|
raise ValueError("Sequence must be provided.") |
|
|
|
if "<mask>" not in sequence_input: |
|
print("Warning: Sequence does not contain any masked positions to predict.") |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
if not os.path.exists(nucleus_image_path): |
|
|
|
nucleus_image_path = 'images/nucleus.jpg' |
|
print( |
|
"Warning: No nucleus image provided. Using default nucleus image from dataset." |
|
) |
|
else: |
|
|
|
nucleus_image = process_image(nucleus_image_path) |
|
|
|
|
|
if not os.path.exists(protein_image_path): |
|
|
|
protein_image_path = 'images/protein.jpg' |
|
print( |
|
"Warning: No nucleus image provided. Using default protein image from dataset." |
|
) |
|
else: |
|
|
|
protein_image = process_image(protein_image_path) |
|
protein_image = (protein_image > torch.median(protein_image,dim=0))*1.0 |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
|
|
model = instantiate_from_config(config).to(device) |
|
|
|
|
|
_, predicted_sequence, _ = model.celle.sample_text( |
|
text=sequence, |
|
condition=nucleus_image, |
|
image=protein_image, |
|
force_aas=True, |
|
timesteps=1, |
|
temperature=1, |
|
progress=True, |
|
) |
|
|
|
formatted_predicted_sequence = "" |
|
|
|
for i in range(min(len(predicted_sequence), len(sequence))): |
|
if predicted_sequence[i] != sequence[i]: |
|
formatted_predicted_sequence += f"**{predicted_sequence[i]}**" |
|
else: |
|
formatted_predicted_sequence += predicted_sequence[i] |
|
|
|
if len(predicted_sequence) > len(sequence): |
|
formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**" |
|
|
|
print("predicted_sequence:", formatted_predicted_sequence) |
|
|
|
|
|
def run_image_prediction( |
|
sequence_input, |
|
nucleus_image, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
if len(sequence_input) == 0: |
|
sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK" |
|
|
|
print("Warning: No sequence provided. Using default sequence for GFP.") |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
|
|
model = instantiate_from_config(config).to(device) |
|
|
|
|
|
_, _, _, predicted_threshold, predicted_heatmap = model.celle.sample( |
|
text=sequence, |
|
condition=nucleus_image, |
|
timesteps=1, |
|
temperature=1, |
|
progress=True, |
|
) |
|
|
|
|
|
predicted_threshold = predicted_threshold.cpu()[0, 0] |
|
predicted_heatmap = predicted_heatmap.cpu()[0, 0] |
|
|
|
return predicted_threshold, predicted_heatmap |