Multiprocessing in InSilicoPerturber

#369
by nlapier2 - opened

Hello,

Following the resolution of my previous issue (#368) , I was able to run the in_silico_perturbation example successfully. However, to do so, I had to reduce nproc to 1 in InSilicoPerturber, which of course slowed things down. When I ran it with nproc>1, I got a bunch of warnings:

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

...followed by a fatal error which crashed the program:

line 433, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/connection.py", line 402, in _recv
    raise EOFError
EOFError

Notably, I did not have this issue when using multiprocessing for embedding extraction in the earlier step.

Here is the full traceback of the warning and error:

Traceback (most recent call last):                                                                                      
  File "<string>", line 1, in <module>
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 131, in _main
    prepare(preparation_data)
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 246, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 297, in _fixup_main_from_path
    main_content = runpy.run_path(main_path,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/gpfs/data/xhe-lab/nlapier2/geneformer/Geneformer_test/examples/new_in_silico_perturbation.py", line 57, in <module>
    isp.perturb_data("../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/in_silico_perturber.py", line 445, in perturb_data
    self.isp_perturb_all(
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/in_silico_perturber.py", line 756, in isp_perturb_all
    perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/perturber_utils.py", line 386, in make_perturbation_batch
    perturbation_dataset = perturbation_dataset.map(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3245, in map
    with Pool(len(kwargs_per_job)) as pool:
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/pool.py", line 215, in __init__
    self._repopulate_pool()
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/pool.py", line 306, in _repopulate_pool
    return self._repopulate_pool_static(self._ctx, self.Process,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/pool.py", line 329, in _repopulate_pool_static
    w.start()
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/process.py", line 121, in start
    self._popen = self._Popen(self)
                  ^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/context.py", line 288, in _Popen
    return Popen(process_obj)
           ^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/popen_spawn_posix.py", line 42, in _launch
    prep_data = spawn.get_preparation_data(process_obj._name)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 164, in get_preparation_data
    _check_not_importing_main()
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/spawn.py", line 140, in _check_not_importing_main
    raise RuntimeError('''
RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
        
        To fix this issue, refer to the "Safe importing of main module"
        section in https://docs.python.org/3/library/multiprocessing.html
Traceback (most recent call last):
  File "/gpfs/data/xhe-lab/nlapier2/geneformer/Geneformer_test/examples/new_in_silico_perturbation.py", line 57, in <module>
    isp.perturb_data("../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/in_silico_perturber.py", line 445, in perturb_data
    self.isp_perturb_all(
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/in_silico_perturber.py", line 756, in isp_perturb_all
    perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/geneformer/perturber_utils.py", line 386, in make_perturbation_batch
    perturbation_dataset = perturbation_dataset.map(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3253, in map
    for rank, done, content in iflatmap_unordered(
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/datasets/utils/py_utils.py", line 696, in iflatmap_unordered
    with manager_cls() as manager:
         ^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/context.py", line 57, in Manager
    m.start()
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/managers.py", line 567, in start
    self._address = reader.recv()
                    ^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/connection.py", line 253, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/connection.py", line 433, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/home/nlapier2/project-xhe/miniconda3/envs/geneformer/lib/python3.11/site-packages/multiprocess/connection.py", line 402, in _recv
    raise EOFError
EOFError

Thank you for your question! We essentially always run the in silico perturbation with nproc>1. Could you please provide the code you are using to call the in silico perturber so we can reproduce this? It appears the issue is coming from Hugging Face Datasets's map function.

Here it is -- I modified the example to be a regular python script:

from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor
from datasets import load_dataset, load_from_disk

import torch
with torch.no_grad():
  # first obtain start, goal, and alt embedding positions
  # this function was changed to be separate from perturb_data
  # to avoid repeating calcuations when parallelizing perturb_data
  cell_states_to_model={"state_key": "disease",
                        "start_state": "dcm",
                        "goal_state": "nf",
                        "alt_states": ["hcm"]}

  filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}

  embex = EmbExtractor(model_type="CellClassifier",
                       num_classes=3,
                       filter_data=filter_data_dict,
                       max_ncells=1000,
                       emb_layer=0,
                       summary_stat="exact_mean",
                       forward_batch_size=32,  # 256
                       nproc=16)

  state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                         "../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
                                         "../../Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset",
                                         "test_output",
                                         "output_prefix")

  isp = InSilicoPerturber(perturb_type="delete",
                          perturb_rank_shift=None,
                          genes_to_perturb="all",
                          combos=0,
                          anchor_gene=None,
                          model_type="CellClassifier",
                          num_classes=3,
                          emb_mode="cell",
                          cell_emb_style="mean_pool",
                          filter_data=filter_data_dict,
                          cell_states_to_model=cell_states_to_model,
                          state_embs_dict=state_embs_dict,
                          max_ncells=2000,
                          emb_layer=0,
                          forward_batch_size=32,  # 400,
                          nproc=16)

  # outputs intermediate files from in silico perturbation
  isp.perturb_data("../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224",
                   "../../Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset",
                   "test_output",
                   "output_prefix")

  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                    genes_perturbed="all",
                                    combos=0,
                                    anchor_gene=None,
                                    cell_states_to_model=cell_states_to_model)

  # extracts data from intermediate files and processes stats to output in final .csv
  ispstats.get_stats("test_output",
                     None,
                     "test_output",
                     "output_prefix")

If you are able to consistently reproduce this error, could you please try adding "load_from_cache_file=False" for the three perturbation_dataset.map() occurrences in lines 381-390 in perturber_utils.py?

Also, please let us know what Datasets version you are running. Testing upgrading Datasets may also be helpful. Thank you!

Recent changes should have addressed this issue, but if not please feel free to reopen.

ctheodoris changed discussion status to closed

Sign up or log in to comment