Christina Theodoris commited on
Commit
5fcf2b8
1 Parent(s): c34ead6

Fix filter_data to allow value of None for no filtering

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +16 -14
geneformer/in_silico_perturber.py CHANGED
@@ -484,12 +484,13 @@ class InSilicoPerturber:
484
  "quartile to shift by must be specified.")
485
  raise
486
 
487
- for key,value in self.filter_data.items():
488
- if type(value) != list:
489
- self.filter_data[key] = [value]
490
- logger.warning(
491
- "Values in filter_data dict must be lists. " \
492
- f"Changing {key} value to list ([{value}]).")
 
493
 
494
  def perturb_data(self,
495
  model_directory,
@@ -543,14 +544,15 @@ class InSilicoPerturber:
543
  # load data and filter by defined criteria
544
  def load_and_filter(self, input_data_file):
545
  data = load_from_disk(input_data_file)
546
- for key,value in self.filter_data.items():
547
- def filter_data(example):
548
- return example[key] in value
549
- data = data.filter(filter_data, num_proc=self.nproc)
550
- if len(data) == 0:
551
- logger.error(
552
- "No cells remain after filtering. Check filtering criteria.")
553
- raise
 
554
  data_shuffled = data.shuffle(seed=42)
555
  num_cells = len(data_shuffled)
556
  # if max number of cells is defined, then subsample to this max number
 
484
  "quartile to shift by must be specified.")
485
  raise
486
 
487
+ if self.filter_data is not None:
488
+ for key,value in self.filter_data.items():
489
+ if type(value) != list:
490
+ self.filter_data[key] = [value]
491
+ logger.warning(
492
+ "Values in filter_data dict must be lists. " \
493
+ f"Changing {key} value to list ([{value}]).")
494
 
495
  def perturb_data(self,
496
  model_directory,
 
544
  # load data and filter by defined criteria
545
  def load_and_filter(self, input_data_file):
546
  data = load_from_disk(input_data_file)
547
+ if self.filter_data is not None:
548
+ for key,value in self.filter_data.items():
549
+ def filter_data_by_criteria(example):
550
+ return example[key] in value
551
+ data = data.filter(filter_data_by_criteria, num_proc=self.nproc)
552
+ if len(data) == 0:
553
+ logger.error(
554
+ "No cells remain after filtering. Check filtering criteria.")
555
+ raise
556
  data_shuffled = data.shuffle(seed=42)
557
  num_cells = len(data_shuffled)
558
  # if max number of cells is defined, then subsample to this max number