|
import os |
|
os.chdir('..') |
|
from dataloader import CellLoader |
|
from celle_main import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
|
|
def run_image_prediction( |
|
sequence_input, |
|
nucleus_image, |
|
model_ckpt_path, |
|
model_config_path, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
|
|
model = instantiate_from_config(config).to(device) |
|
|
|
|
|
_, _, _, predicted_threshold, predicted_heatmap = model.celle.sample( |
|
text=sequence, |
|
condition=nucleus_image, |
|
timesteps=1, |
|
temperature=1, |
|
progress=True, |
|
) |
|
|
|
|
|
predicted_threshold = predicted_threshold.cpu()[0, 0] |
|
predicted_heatmap = predicted_heatmap.cpu()[0, 0] |
|
|
|
return predicted_threshold, predicted_heatmap |