ctheodoris davidjwen commited on
Commit
f115e8f
1 Parent(s): 78517d8

Added feature to perturb a set of indices to help with debugging and with very large runtimes (#175)

Browse files

- Added feature to perturb a set of indices to help with debugging and with very large runtimes (5488d176961bcee6c66ea1494151429fb570ba9c)
- Update geneformer/in_silico_perturber.py (9151b3e978c1c138039da1090be2147b4631d903)


Co-authored-by: David Wen <davidjwen@users.noreply.huggingface.co>

Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +32 -1
geneformer/in_silico_perturber.py CHANGED
@@ -604,6 +604,7 @@ class InSilicoPerturber:
604
  "filter_data": {None, dict},
605
  "cell_states_to_model": {None, dict},
606
  "max_ncells": {None, int},
 
607
  "emb_layer": {-1, 0},
608
  "forward_batch_size": {int},
609
  "nproc": {int},
@@ -622,6 +623,7 @@ class InSilicoPerturber:
622
  filter_data=None,
623
  cell_states_to_model=None,
624
  max_ncells=None,
 
625
  emb_layer=-1,
626
  forward_batch_size=100,
627
  nproc=4,
@@ -687,6 +689,13 @@ class InSilicoPerturber:
687
  max_ncells : None, int
688
  Maximum number of cells to test.
689
  If None, will test all cells.
 
 
 
 
 
 
 
690
  emb_layer : {-1, 0}
691
  Embedding layer to use for quantification.
692
  -1: 2nd to last layer (recommended for pretrained Geneformer)
@@ -723,6 +732,7 @@ class InSilicoPerturber:
723
  self.filter_data = filter_data
724
  self.cell_states_to_model = cell_states_to_model
725
  self.max_ncells = max_ncells
 
726
  self.emb_layer = emb_layer
727
  self.forward_batch_size = forward_batch_size
728
  self.nproc = nproc
@@ -886,7 +896,7 @@ class InSilicoPerturber:
886
  if self.perturb_type in ["inhibit","activate"]:
887
  if self.perturb_rank_shift is None:
888
  logger.error(
889
- "If perturb type is inhibit or activate then " \
890
  "quartile to shift by must be specified.")
891
  raise
892
 
@@ -897,6 +907,18 @@ class InSilicoPerturber:
897
  logger.warning(
898
  "Values in filter_data dict must be lists. " \
899
  f"Changing {key} value to list ([{value}]).")
 
 
 
 
 
 
 
 
 
 
 
 
900
 
901
  def perturb_data(self,
902
  model_directory,
@@ -995,6 +1017,15 @@ class InSilicoPerturber:
995
  cos_sims_dict = defaultdict(list)
996
  pickle_batch = -1
997
  filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
 
 
 
 
 
 
 
 
 
998
 
999
  # make perturbation batch w/ single perturbation in multiple cells
1000
  if self.perturb_group == True:
 
604
  "filter_data": {None, dict},
605
  "cell_states_to_model": {None, dict},
606
  "max_ncells": {None, int},
607
+ "cell_inds_to_perturb": {"all", dict},
608
  "emb_layer": {-1, 0},
609
  "forward_batch_size": {int},
610
  "nproc": {int},
 
623
  filter_data=None,
624
  cell_states_to_model=None,
625
  max_ncells=None,
626
+ cell_inds_to_perturb="all",
627
  emb_layer=-1,
628
  forward_batch_size=100,
629
  nproc=4,
 
689
  max_ncells : None, int
690
  Maximum number of cells to test.
691
  If None, will test all cells.
692
+ cell_inds_to_perturb : "all", list
693
+ Default is perturbing each cell in the dataset.
694
+ Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
695
+ start_ind: the first index to perturb.
696
+ end_ind: the last index to perturb (exclusive).
697
+ Indices will be selected *after* the filter_data criteria and sorting.
698
+ Useful for splitting extremely large datasets across separate GPUs.
699
  emb_layer : {-1, 0}
700
  Embedding layer to use for quantification.
701
  -1: 2nd to last layer (recommended for pretrained Geneformer)
 
732
  self.filter_data = filter_data
733
  self.cell_states_to_model = cell_states_to_model
734
  self.max_ncells = max_ncells
735
+ self.cell_inds_to_perturb = cell_inds_to_perturb
736
  self.emb_layer = emb_layer
737
  self.forward_batch_size = forward_batch_size
738
  self.nproc = nproc
 
896
  if self.perturb_type in ["inhibit","activate"]:
897
  if self.perturb_rank_shift is None:
898
  logger.error(
899
+ "If perturb_type is inhibit or activate then " \
900
  "quartile to shift by must be specified.")
901
  raise
902
 
 
907
  logger.warning(
908
  "Values in filter_data dict must be lists. " \
909
  f"Changing {key} value to list ([{value}]).")
910
+
911
+ if self.cell_inds_to_perturb != "all":
912
+ if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
913
+ logger.error(
914
+ "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
915
+ )
916
+ raise
917
+ if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
918
+ logger.error(
919
+ 'cell_inds_to_perturb must be positive.'
920
+ )
921
+ raise
922
 
923
  def perturb_data(self,
924
  model_directory,
 
1017
  cos_sims_dict = defaultdict(list)
1018
  pickle_batch = -1
1019
  filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
1020
+ if self.cell_inds_to_perturb != "all":
1021
+ if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
1022
+ logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
1023
+ raise
1024
+ if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
1025
+ logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
1026
+ Setting to the end of the filtered dataset.")
1027
+ self.cell_inds_to_perturb["end"] = len(filtered_input_data)
1028
+ filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
1029
 
1030
  # make perturbation batch w/ single perturbation in multiple cells
1031
  if self.perturb_group == True: