File size: 1,493 Bytes
dacc4bf
 
 
5d2263b
 
dacc4bf
5d2263b
 
5cbd5ac
2346297
5d2263b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860c3d7
 
5d2263b
 
860c3d7
5d2263b
 
 
 
 
 
5cbd5ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
os.chdir('..')
base_dir = os.getcwd()
from dataloader import CellLoader


def run_image_prediction(
    sequence_input,
    nucleus_image,
    model,
    device
):
    """
    Run Celle model with provided inputs and display results.

    :param sequence: Path to sequence file
    :param nucleus_image_path: Path to nucleus image
    :param protein_image_path: Path to protein image (optional)
    :param model_ckpt_path: Path to model checkpoint
    :param model_config_path: Path to model config
    """
    # Instantiate dataset object
    dataset = CellLoader(
        sequence_mode="embedding",
        vocab="esm2",
        split_key="val",
        crop_method="center",
        resize=600,
        crop_size=256,
        text_seq_len=1000,
        pad_mode="end",
        threshold="median",
    )

    # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
    sequence = dataset.tokenize_sequence(sequence_input)

    # Sample from model using provided sequence and nucleus image
    _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
        text=sequence.to(device),
        condition=nucleus_image.to(device),
        timesteps=1,
        temperature=1,
        progress=False,
    )

    # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
    predicted_threshold = predicted_threshold.cpu()[0, 0]
    predicted_heatmap = predicted_heatmap.cpu()[0, 0]

    return predicted_threshold, predicted_heatmap