Emaad commited on
Commit
2346297
1 Parent(s): 9276a39

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +1 -23
prediction.py CHANGED
@@ -1,15 +1,9 @@
1
- import os
2
- os.chdir('..')
3
- base_dir = os.getcwd()
4
  from dataloader import CellLoader
5
- from celle_main import instantiate_from_config
6
- from omegaconf import OmegaConf
7
 
8
  def run_image_prediction(
9
  sequence_input,
10
  nucleus_image,
11
- model_ckpt_path,
12
- model_config_path,
13
  device
14
  ):
15
  """
@@ -37,22 +31,6 @@ def run_image_prediction(
37
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
38
  sequence = dataset.tokenize_sequence(sequence_input)
39
 
40
- # Load model config and set ckpt_path if not provided in config
41
- config = OmegaConf.load(model_config_path)
42
- if config["model"]["params"]["ckpt_path"] is None:
43
- config["model"]["params"]["ckpt_path"] = model_ckpt_path
44
-
45
- # Set condition_model_path and vqgan_model_path to None
46
- config["model"]["params"]["condition_model_path"] = None
47
- config["model"]["params"]["vqgan_model_path"] = None
48
-
49
- os.chdir(os.path.dirname(model_ckpt_path))
50
-
51
- # Instantiate model from config and move to device
52
- model = instantiate_from_config(config.model).to(device)
53
-
54
- os.chdir(base_dir)
55
-
56
  # Sample from model using provided sequence and nucleus image
57
  _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
58
  text=sequence.to(device),
 
 
 
 
1
  from dataloader import CellLoader
 
 
2
 
3
  def run_image_prediction(
4
  sequence_input,
5
  nucleus_image,
6
+ model,
 
7
  device
8
  ):
9
  """
 
31
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
32
  sequence = dataset.tokenize_sequence(sequence_input)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Sample from model using provided sequence and nucleus image
35
  _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
36
  text=sequence.to(device),