|
import os |
|
os.chdir('..') |
|
base_dir = os.getcwd() |
|
from dataloader import CellLoader |
|
|
|
|
|
def run_image_prediction( |
|
sequence_input, |
|
nucleus_image, |
|
model, |
|
device |
|
): |
|
""" |
|
Run Celle model with provided inputs and display results. |
|
|
|
:param sequence: Path to sequence file |
|
:param nucleus_image_path: Path to nucleus image |
|
:param protein_image_path: Path to protein image (optional) |
|
:param model_ckpt_path: Path to model checkpoint |
|
:param model_config_path: Path to model config |
|
""" |
|
|
|
dataset = CellLoader( |
|
sequence_mode="embedding", |
|
vocab="esm2", |
|
split_key="val", |
|
crop_method="center", |
|
resize=600, |
|
crop_size=256, |
|
text_seq_len=1000, |
|
pad_mode="end", |
|
threshold="median", |
|
) |
|
|
|
|
|
sequence = dataset.tokenize_sequence(sequence_input) |
|
|
|
|
|
_, _, _, predicted_threshold, predicted_heatmap = model.celle.sample( |
|
text=sequence.to(device), |
|
condition=nucleus_image.to(device), |
|
timesteps=1, |
|
temperature=1, |
|
progress=False, |
|
) |
|
|
|
|
|
predicted_threshold = predicted_threshold.cpu()[0, 0] |
|
predicted_heatmap = predicted_heatmap.cpu()[0, 0] |
|
|
|
return predicted_threshold, predicted_heatmap |