davidjwen commited on
Commit
9151b3e
·
1 Parent(s): 5488d17

Update geneformer/in_silico_perturber.py

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +17 -17
geneformer/in_silico_perturber.py CHANGED
@@ -604,7 +604,7 @@ class InSilicoPerturber:
604
  "filter_data": {None, dict},
605
  "cell_states_to_model": {None, dict},
606
  "max_ncells": {None, int},
607
- "inds_to_perturb": {"all", dict},
608
  "emb_layer": {-1, 0},
609
  "forward_batch_size": {int},
610
  "nproc": {int},
@@ -623,7 +623,7 @@ class InSilicoPerturber:
623
  filter_data=None,
624
  cell_states_to_model=None,
625
  max_ncells=None,
626
- inds_to_perturb="all",
627
  emb_layer=-1,
628
  forward_batch_size=100,
629
  nproc=4,
@@ -689,9 +689,9 @@ class InSilicoPerturber:
689
  max_ncells : None, int
690
  Maximum number of cells to test.
691
  If None, will test all cells.
692
- inds_to_perturb : "all", list
693
  Default is perturbing each cell in the dataset.
694
- Otherwise, may provide a dict of indices of genes to perturb with keys start_ind and end_ind.
695
  start_ind: the first index to perturb.
696
  end_ind: the last index to perturb (exclusive).
697
  Indices will be selected *after* the filter_data criteria and sorting.
@@ -732,7 +732,7 @@ class InSilicoPerturber:
732
  self.filter_data = filter_data
733
  self.cell_states_to_model = cell_states_to_model
734
  self.max_ncells = max_ncells
735
- self.inds_to_perturb = inds_to_perturb
736
  self.emb_layer = emb_layer
737
  self.forward_batch_size = forward_batch_size
738
  self.nproc = nproc
@@ -908,15 +908,15 @@ class InSilicoPerturber:
908
  "Values in filter_data dict must be lists. " \
909
  f"Changing {key} value to list ([{value}]).")
910
 
911
- if self.inds_to_perturb != "all":
912
- if set(self.inds_to_perturb.keys()) != {"start", "end"}:
913
  logger.error(
914
- "If inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
915
  )
916
  raise
917
- if self.inds_to_perturb["start"] < 0 or self.inds_to_perturb["end"] < 0:
918
  logger.error(
919
- 'inds_to_perturb must be positive.'
920
  )
921
  raise
922
 
@@ -1017,15 +1017,15 @@ class InSilicoPerturber:
1017
  cos_sims_dict = defaultdict(list)
1018
  pickle_batch = -1
1019
  filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
1020
- if self.inds_to_perturb != "all":
1021
- if self.inds_to_perturb["start"] >= len(filtered_input_data):
1022
- logger.error("inds_to_perturb['start'] is larger than the filtered dataset.")
1023
  raise
1024
- if self.inds_to_perturb["end"] > len(filtered_input_data):
1025
- logger.warning("inds_to_perturb['end'] is larger than the filtered dataset. \
1026
  Setting to the end of the filtered dataset.")
1027
- self.inds_to_perturb["end"] = len(filtered_input_data)
1028
- filtered_input_data = filtered_input_data.select([i for i in range(self.inds_to_perturb["start"], self.inds_to_perturb["end"])])
1029
 
1030
  # make perturbation batch w/ single perturbation in multiple cells
1031
  if self.perturb_group == True:
 
604
  "filter_data": {None, dict},
605
  "cell_states_to_model": {None, dict},
606
  "max_ncells": {None, int},
607
+ "cell_inds_to_perturb": {"all", dict},
608
  "emb_layer": {-1, 0},
609
  "forward_batch_size": {int},
610
  "nproc": {int},
 
623
  filter_data=None,
624
  cell_states_to_model=None,
625
  max_ncells=None,
626
+ cell_inds_to_perturb="all",
627
  emb_layer=-1,
628
  forward_batch_size=100,
629
  nproc=4,
 
689
  max_ncells : None, int
690
  Maximum number of cells to test.
691
  If None, will test all cells.
692
+ cell_inds_to_perturb : "all", list
693
  Default is perturbing each cell in the dataset.
694
+ Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
695
  start_ind: the first index to perturb.
696
  end_ind: the last index to perturb (exclusive).
697
  Indices will be selected *after* the filter_data criteria and sorting.
 
732
  self.filter_data = filter_data
733
  self.cell_states_to_model = cell_states_to_model
734
  self.max_ncells = max_ncells
735
+ self.cell_inds_to_perturb = cell_inds_to_perturb
736
  self.emb_layer = emb_layer
737
  self.forward_batch_size = forward_batch_size
738
  self.nproc = nproc
 
908
  "Values in filter_data dict must be lists. " \
909
  f"Changing {key} value to list ([{value}]).")
910
 
911
+ if self.cell_inds_to_perturb != "all":
912
+ if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
913
  logger.error(
914
+ "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
915
  )
916
  raise
917
+ if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
918
  logger.error(
919
+ 'cell_inds_to_perturb must be positive.'
920
  )
921
  raise
922
 
 
1017
  cos_sims_dict = defaultdict(list)
1018
  pickle_batch = -1
1019
  filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
1020
+ if self.cell_inds_to_perturb != "all":
1021
+ if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
1022
+ logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
1023
  raise
1024
+ if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
1025
+ logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
1026
  Setting to the end of the filtered dataset.")
1027
+ self.cell_inds_to_perturb["end"] = len(filtered_input_data)
1028
+ filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
1029
 
1030
  # make perturbation batch w/ single perturbation in multiple cells
1031
  if self.perturb_group == True: