Alejandro Velez commited on
Commit
47990ca
·
1 Parent(s): 5f1a697

tdc geneformer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +25 -0
  2. .gitignore +160 -0
  3. .pre-commit-config.yaml +26 -0
  4. .readthedocs.yaml +19 -0
  5. MANIFEST.in +4 -0
  6. README.md +96 -0
  7. config.json +24 -0
  8. docs/Makefile +20 -0
  9. docs/make.bat +35 -0
  10. docs/requirements.txt +3 -0
  11. docs/source/_static/css/custom.css +40 -0
  12. docs/source/_static/gf_logo.png +0 -0
  13. docs/source/about.rst +49 -0
  14. docs/source/api.rst +51 -0
  15. docs/source/conf.py +80 -0
  16. docs/source/geneformer.classifier.rst +10 -0
  17. docs/source/geneformer.emb_extractor.rst +26 -0
  18. docs/source/geneformer.in_silico_perturber.rst +8 -0
  19. docs/source/geneformer.in_silico_perturber_stats.rst +25 -0
  20. docs/source/geneformer.mtl_classifier.rst +11 -0
  21. docs/source/geneformer.tokenizer.rst +15 -0
  22. docs/source/getstarted.rst +36 -0
  23. docs/source/index.rst +16 -0
  24. examples/cell_classification.ipynb +0 -0
  25. examples/extract_and_plot_cell_embeddings.ipynb +0 -0
  26. examples/gene_classification.ipynb +0 -0
  27. examples/in_silico_perturbation.ipynb +159 -0
  28. examples/multitask_cell_classification.ipynb +420 -0
  29. examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb +365 -0
  30. examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +167 -0
  31. examples/tokenizing_scRNAseq_data.ipynb +91 -0
  32. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +24 -0
  33. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json +35 -0
  34. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json +150 -0
  35. geneformer/__init__.py +34 -0
  36. geneformer/classifier.py +1563 -0
  37. geneformer/classifier_utils.py +648 -0
  38. geneformer/collator_for_classification.py +667 -0
  39. geneformer/emb_extractor.py +863 -0
  40. geneformer/evaluation_utils.py +287 -0
  41. geneformer/in_silico_perturber.py +1579 -0
  42. geneformer/in_silico_perturber_stats.py +1104 -0
  43. geneformer/mtl/__init__.py +1 -0
  44. geneformer/mtl/collators.py +76 -0
  45. geneformer/mtl/data.py +162 -0
  46. geneformer/mtl/eval_utils.py +88 -0
  47. geneformer/mtl/imports.py +43 -0
  48. geneformer/mtl/model.py +121 -0
  49. geneformer/mtl/optuna_utils.py +27 -0
  50. geneformer/mtl/train.py +380 -0
.gitattributes CHANGED
@@ -1,27 +1,41 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.npy filter=lfs diff=lfs merge=lfs -text
15
  *.npz filter=lfs diff=lfs merge=lfs -text
 
 
 
 
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
 
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
 
 
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +47,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ <<<<<<< HEAD
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
6
  *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ =======
8
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
9
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
10
+ >>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
11
  *.ftz filter=lfs diff=lfs merge=lfs -text
12
  *.gz filter=lfs diff=lfs merge=lfs -text
13
  *.h5 filter=lfs diff=lfs merge=lfs -text
14
  *.joblib filter=lfs diff=lfs merge=lfs -text
15
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
16
+ <<<<<<< HEAD
17
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
18
  *.model filter=lfs diff=lfs merge=lfs -text
19
  *.msgpack filter=lfs diff=lfs merge=lfs -text
20
  *.npy filter=lfs diff=lfs merge=lfs -text
21
  *.npz filter=lfs diff=lfs merge=lfs -text
22
+ =======
23
+ *.model filter=lfs diff=lfs merge=lfs -text
24
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
25
+ >>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
26
  *.onnx filter=lfs diff=lfs merge=lfs -text
27
  *.ot filter=lfs diff=lfs merge=lfs -text
28
  *.parquet filter=lfs diff=lfs merge=lfs -text
29
  *.pb filter=lfs diff=lfs merge=lfs -text
30
+ <<<<<<< HEAD
31
  *.pickle filter=lfs diff=lfs merge=lfs -text
32
+ =======
33
+ >>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
34
  *.pkl filter=lfs diff=lfs merge=lfs -text
35
  *.pt filter=lfs diff=lfs merge=lfs -text
36
  *.pth filter=lfs diff=lfs merge=lfs -text
37
  *.rar filter=lfs diff=lfs merge=lfs -text
38
+ <<<<<<< HEAD
39
  *.safetensors filter=lfs diff=lfs merge=lfs -text
40
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
41
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
47
  *.zip filter=lfs diff=lfs merge=lfs -text
48
  *.zst filter=lfs diff=lfs merge=lfs -text
49
  *tfevents* filter=lfs diff=lfs merge=lfs -text
50
+ =======
51
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
52
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
53
+ *.tflite filter=lfs diff=lfs merge=lfs -text
54
+ *.tgz filter=lfs diff=lfs merge=lfs -text
55
+ *.xz filter=lfs diff=lfs merge=lfs -text
56
+ *.zip filter=lfs diff=lfs merge=lfs -text
57
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
58
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
59
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
60
+ >>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.2.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ - id: check-added-large-files
11
+ - id: check-merge-conflict
12
+ - id: mixed-line-ending
13
+ - id: check-docstring-first
14
+ - repo: https://github.com/pycqa/isort
15
+ rev: 5.12.0
16
+ hooks:
17
+ - id: isort
18
+ args: ["--profile", "black"]
19
+ - repo: https://github.com/astral-sh/ruff-pre-commit
20
+ # Ruff version.
21
+ rev: v0.1.4
22
+ hooks:
23
+ # Run the Ruff linter.
24
+ - id: ruff
25
+ # Run the Ruff formatter.
26
+ - id: ruff-format
.readthedocs.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the Docs configuration file
2
+
3
+ # Required
4
+ version: 2
5
+
6
+ # Set the OS, Python version and other tools you might need
7
+ build:
8
+ os: ubuntu-22.04
9
+ tools:
10
+ python: "3.10"
11
+
12
+ # Build documentation in the "docs/" directory with Sphinx
13
+ sphinx:
14
+ configuration: docs/source/conf.py
15
+
16
+ # Python requirements required build your documentation
17
+ python:
18
+ install:
19
+ - requirements: docs/requirements.txt
MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include geneformer/gene_median_dictionary_gc95M.pkl
2
+ include geneformer/gene_name_id_dict_gc95M.pkl
3
+ include geneformer/ensembl_mapping_dict_gc95M.pkl
4
+ include geneformer/token_dictionary_gc95M.pkl
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  license: apache-2.0
3
  datasets:
4
  - ctheodoris/Genecorpus-30M
@@ -20,3 +21,98 @@ model = AutoModelForMaskedLM.from_pretrained("ctheodoris/Geneformer")
20
  ```
21
 
22
  For further details see: https://huggingface.co/ctheodoris/Geneformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ <<<<<<< HEAD
3
  license: apache-2.0
4
  datasets:
5
  - ctheodoris/Genecorpus-30M
 
21
  ```
22
 
23
  For further details see: https://huggingface.co/ctheodoris/Geneformer
24
+ =======
25
+ datasets: ctheodoris/Genecorpus-30M
26
+ license: apache-2.0
27
+ tags:
28
+ - single-cell
29
+ - genomics
30
+ ---
31
+ # Geneformer
32
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
33
+
34
+ - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
35
+ - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
36
+ - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
37
+
38
+ # Model Description
39
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
40
+
41
+ Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
42
+
43
+ The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
44
+
45
+ We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
46
+
47
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
48
+
49
+ The repository includes the following pretrained models:
50
+
51
+ L=layers\
52
+ M=millions of cells used for pretraining\
53
+ i=input size\
54
+ (pretraining date)
55
+
56
+ - GF-6L-30M-i2048 (June 2021)
57
+ - GF-12L-30M-i2048 (June 2021)
58
+ - GF-12L-95M-i4096 (April 2024)
59
+ - GF-20L-95M-i4096 (April 2024)
60
+
61
+ The current default model in the main directory of the repository is GF-12L-95M-i4096.
62
+
63
+ The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
64
+
65
+ # Application
66
+ The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
67
+
68
+ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
69
+
70
+ *Fine-tuning*:
71
+ - transcription factor dosage sensitivity
72
+ - chromatin dynamics (bivalently marked promoters)
73
+ - transcription factor regulatory range
74
+ - gene network centrality
75
+ - transcription factor targets
76
+ - cell type annotation
77
+ - batch integration
78
+ - cell state classification across differentiation
79
+ - disease classification
80
+ - in silico perturbation to determine disease-driving genes
81
+ - in silico treatment to determine candidate therapeutic targets
82
+
83
+ *Zero-shot learning*:
84
+ - batch integration
85
+ - gene context specificity
86
+ - in silico reprogramming
87
+ - in silico differentiation
88
+ - in silico perturbation to determine impact on cell state
89
+ - in silico perturbation to determine transcription factor targets
90
+ - in silico perturbation to determine transcription factor cooperativity
91
+
92
+ # Installation
93
+ In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install (~20s):
94
+
95
+ ```bash
96
+ # Make sure you have git-lfs installed (https://git-lfs.com)
97
+ git lfs install
98
+ git clone https://huggingface.co/ctheodoris/Geneformer
99
+ cd Geneformer
100
+ pip install .
101
+ ```
102
+
103
+ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main/examples) for:
104
+ - tokenizing transcriptomes
105
+ - pretraining
106
+ - hyperparameter tuning
107
+ - fine-tuning
108
+ - extracting and plotting cell embeddings
109
+ - in silico perturbation
110
+
111
+ Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
112
+
113
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
114
+
115
+ # Citations
116
+ - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
117
+ - H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
118
+ >>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
docs/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .
2
+ sphinx_rtd_theme==2.0.0
3
+ nbsphinx==0.9.3
docs/source/_static/css/custom.css ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* top left logo */
2
+ .wy-side-nav-search, .wy-nav-top {
3
+ background: linear-gradient(15deg, #13547a 0%, #80d0c7 100%);
4
+ }
5
+
6
+
7
+ /* unvisited link */
8
+ .wy-nav-content a:link {
9
+ color: #067abd;
10
+ }
11
+
12
+ /* visited link */
13
+ .wy-nav-content a:visited {
14
+ color: #4b827c;
15
+ }
16
+
17
+ /* mouse over link */
18
+ .wy-nav-content a:hover {
19
+ color: #80d0c7;
20
+ }
21
+
22
+ /* selected link */
23
+ .wy-nav-content a:active {
24
+ color: #4b827c;
25
+ }
26
+
27
+ /* class object */
28
+ .sig.sig-object {
29
+ padding: 5px 5px 5px 5px;
30
+ background-color: #ececec;
31
+ border-style: solid;
32
+ border-color: black;
33
+ border-width: 1px 0;
34
+ }
35
+
36
+ /* parameter object */
37
+ dt {
38
+ padding: 5px 5px 5px 5px;
39
+ background-color: #ececec;
40
+ }
docs/source/_static/gf_logo.png ADDED
docs/source/about.rst ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ About
2
+ =====
3
+
4
+ Model Description
5
+ -----------------
6
+
7
+ **Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus of human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
8
+
9
+ In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the original 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
10
+
11
+ Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-12L-30M-i2048/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
12
+
13
+ Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
14
+
15
+ Application
16
+ -----------
17
+
18
+ The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
19
+
20
+ Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ include:
21
+
22
+ | *Fine-tuning*:
23
+ | - transcription factor dosage sensitivity
24
+ | - chromatin dynamics (bivalently marked promoters)
25
+ | - transcription factor regulatory range
26
+ | - gene network centrality
27
+ | - transcription factor targets
28
+ | - cell type annotation
29
+ | - batch integration
30
+ | - cell state classification across differentiation
31
+ | - disease classification
32
+ | - in silico perturbation to determine disease-driving genes
33
+ | - in silico treatment to determine candidate therapeutic targets
34
+
35
+ | *Zero-shot learning*:
36
+ | - batch integration
37
+ | - gene context specificity
38
+ | - in silico reprogramming
39
+ | - in silico differentiation
40
+ | - in silico perturbation to determine impact on cell state
41
+ | - in silico perturbation to determine transcription factor targets
42
+ | - in silico perturbation to determine transcription factor cooperativity
43
+
44
+ Citations
45
+ ---------
46
+
47
+ | C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
48
+
49
+ | H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
docs/source/api.rst ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ API
2
+ ===
3
+
4
+ Tokenizer
5
+ ---------
6
+
7
+ .. toctree::
8
+ :maxdepth: 1
9
+
10
+ geneformer.tokenizer
11
+
12
+ Classifier
13
+ ----------
14
+
15
+ .. toctree::
16
+ :maxdepth: 1
17
+
18
+ geneformer.classifier
19
+
20
+ Multitask Classifier
21
+ --------------------
22
+
23
+ .. toctree::
24
+ :maxdepth: 1
25
+
26
+ geneformer.mtl_classifier
27
+
28
+ Embedding Extractor
29
+ -------------------
30
+
31
+ .. toctree::
32
+ :maxdepth: 1
33
+
34
+ geneformer.emb_extractor
35
+
36
+ In Silico Perturber
37
+ -------------------
38
+
39
+ .. toctree::
40
+ :maxdepth: 1
41
+
42
+ geneformer.in_silico_perturber
43
+
44
+
45
+ In Silico Perturber Stats
46
+ -------------------------
47
+
48
+ .. toctree::
49
+ :maxdepth: 1
50
+
51
+ geneformer.in_silico_perturber_stats
docs/source/conf.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # For the full list of built-in configuration values, see the documentation:
4
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
+
6
+ import pathlib
7
+ import re
8
+ import sys
9
+
10
+ from sphinx.ext import autodoc
11
+
12
+ sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix())
13
+
14
+
15
+ # -- Project information -----------------------------------------------------
16
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
17
+
18
+ project = "geneformer"
19
+ copyright = "2024, Christina Theodoris"
20
+ author = "Christina Theodoris"
21
+ release = "0.1.0"
22
+ repository_url = "https://huggingface.co/ctheodoris/Geneformer"
23
+
24
+ # -- General configuration ---------------------------------------------------
25
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
26
+
27
+ extensions = [
28
+ "sphinx.ext.autodoc",
29
+ "sphinx.ext.autosummary",
30
+ "nbsphinx",
31
+ "sphinx.ext.viewcode",
32
+ "sphinx.ext.doctest",
33
+ ]
34
+
35
+ templates_path = ["_templates"]
36
+ exclude_patterns = [
37
+ "**.ipynb_checkpoints",
38
+ ]
39
+ autoclass_content = "both"
40
+
41
+
42
+ class MockedClassDocumenter(autodoc.ClassDocumenter):
43
+ def add_line(self, line: str, source: str, *lineno: int) -> None:
44
+ if line == " Bases: :py:class:`object`":
45
+ return
46
+ super().add_line(line, source, *lineno)
47
+
48
+
49
+ autodoc.ClassDocumenter = MockedClassDocumenter
50
+ add_module_names = False
51
+
52
+
53
+ def process_signature(app, what, name, obj, options, signature, return_annotation):
54
+ # loop through each line in the docstring and replace path with
55
+ # the generic path text
56
+ signature = re.sub(r"PosixPath\(.*?\)", "FILEPATH", signature)
57
+ return (signature, None)
58
+
59
+
60
+ def setup(app):
61
+ app.connect("autodoc-process-signature", process_signature)
62
+
63
+
64
+ # -- Options for HTML output -------------------------------------------------
65
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
66
+
67
+ html_theme = "sphinx_rtd_theme"
68
+ html_show_sphinx = False
69
+ html_static_path = ["_static"]
70
+ html_logo = "_static/gf_logo.png"
71
+ html_theme_options = {
72
+ "collapse_navigation": False,
73
+ "sticky_navigation": True,
74
+ "navigation_depth": 3,
75
+ "logo_only": True,
76
+ }
77
+ html_css_files = [
78
+ "css/custom.css",
79
+ ]
80
+ html_show_sourcelink = False
docs/source/geneformer.classifier.rst ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.classifier
2
+ =====================
3
+
4
+ .. automodule:: geneformer.classifier
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ valid_option_dict,
10
+ validate_options
docs/source/geneformer.emb_extractor.rst ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.emb\_extractor
2
+ =========================
3
+
4
+ .. automodule:: geneformer.emb_extractor
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ accumulate_tdigests,
10
+ gen_heatmap_class_colors,
11
+ gen_heatmap_class_dict,
12
+ get_embs,
13
+ label_cell_embs,
14
+ label_gene_embs,
15
+ make_colorbar,
16
+ plot_heatmap,
17
+ plot_umap,
18
+ summarize_gene_embs,
19
+ tdigest_mean,
20
+ tdigest_median,
21
+ test_emb,
22
+ update_tdigest_dict,
23
+ update_tdigest_dict_mean,
24
+ update_tdigest_dict_median,
25
+ valid_option_dict,
26
+ validate_options
docs/source/geneformer.in_silico_perturber.rst ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ geneformer.in\_silico\_perturber
2
+ =======================================
3
+
4
+ .. automodule:: geneformer.in_silico_perturber
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
docs/source/geneformer.in_silico_perturber_stats.rst ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.in\_silico\_perturber\_stats
2
+ ==============================================
3
+
4
+ .. automodule:: geneformer.in_silico_perturber_stats
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ find,
10
+ get_fdr,
11
+ get_gene_list,
12
+ get_impact_component,
13
+ invert_dict,
14
+ isp_aggregate_gene_shifts,
15
+ isp_aggregate_grouped_perturb,
16
+ isp_stats_mixture_model,
17
+ isp_stats_to_goal_state,
18
+ isp_stats_vs_null,
19
+ n_detections,
20
+ read_dict,
21
+ read_dictionaries,
22
+ token_to_gene_name,
23
+ token_tuple_to_ensembl_ids,
24
+ valid_option_dict,
25
+ validate_options
docs/source/geneformer.mtl_classifier.rst ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.mtl\_classifier
2
+ ==========================
3
+
4
+ .. automodule:: geneformer.mtl_classifier
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ valid_option_dict,
10
+ validate_options,
11
+ validate_additional_options
docs/source/geneformer.tokenizer.rst ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.tokenizer
2
+ ====================
3
+
4
+ .. automodule:: geneformer.tokenizer
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ create_dataset,
10
+ tokenize_anndata,
11
+ tokenize_files,
12
+ tokenize_loom,
13
+ rank_genes,
14
+ tokenize_cell,
15
+ sum_ensembl_ids
docs/source/getstarted.rst ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Getting Started
2
+ ===============
3
+
4
+ Installation
5
+ ------------
6
+
7
+ Geneformer installation instructions.
8
+
9
+ Make sure you have git-lfs installed (https://git-lfs.com).
10
+
11
+ .. code-block:: bash
12
+
13
+ git lfs install
14
+ git clone https://huggingface.co/ctheodoris/Geneformer
15
+ cd Geneformer
16
+ pip install .
17
+
18
+
19
+ Tutorials
20
+ ---------
21
+
22
+ | See `examples <https://huggingface.co/ctheodoris/Geneformer/tree/main/examples>`_ for:
23
+ | - tokenizing transcriptomes
24
+ | - pretraining
25
+ | - hyperparameter tuning
26
+ | - fine-tuning
27
+ | - extracting and plotting cell embeddings
28
+ | - in silico perturbation
29
+
30
+ Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the `example_input_files directory <https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files>`_ in the dataset repository, but these only represent a few example fine-tuning applications.
31
+
32
+
33
+ Tips
34
+ ----
35
+
36
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
docs/source/index.rst ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Geneformer
2
+ ==========
3
+
4
+ Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
5
+
6
+ See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
7
+
8
+ Table of Contents
9
+ -----------------
10
+
11
+ .. toctree::
12
+ :maxdepth: 2
13
+
14
+ about
15
+ getstarted
16
+ api
examples/cell_classification.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/extract_and_plot_cell_embeddings.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/gene_classification.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/in_silico_perturbation.ipynb ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e10ac0c9-40ce-41fb-b6fa-3d62b76f2e57",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from geneformer import InSilicoPerturber\n",
11
+ "from geneformer import InSilicoPerturberStats\n",
12
+ "from geneformer import EmbExtractor"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "cbd6851c-060e-4967-b816-e605ffe58b23",
18
+ "metadata": {
19
+ "tags": []
20
+ },
21
+ "source": [
22
+ "### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "c53e98cd-c603-4878-82ba-db471181bb55",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# first obtain start, goal, and alt embedding positions\n",
33
+ "# this function was changed to be separate from perturb_data\n",
34
+ "# to avoid repeating calcuations when parallelizing perturb_data\n",
35
+ "cell_states_to_model={\"state_key\": \"disease\", \n",
36
+ " \"start_state\": \"dcm\", \n",
37
+ " \"goal_state\": \"nf\", \n",
38
+ " \"alt_states\": [\"hcm\"]}\n",
39
+ "\n",
40
+ "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
+ "\n",
42
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
+ "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
+ "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
+ " num_classes=3,\n",
47
+ " filter_data=filter_data_dict,\n",
48
+ " max_ncells=1000,\n",
49
+ " emb_layer=0,\n",
50
+ " summary_stat=\"exact_mean\",\n",
51
+ " forward_batch_size=256,\n",
52
+ " nproc=16)\n",
53
+ "\n",
54
+ "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
55
+ " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
56
+ " \"path/to/input_data\",\n",
57
+ " \"path/to/output_directory\",\n",
58
+ " \"output_prefix\")"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
65
+ "metadata": {
66
+ "tags": []
67
+ },
68
+ "outputs": [],
69
+ "source": [
70
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
+ "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
+ "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
+ " perturb_rank_shift=None,\n",
75
+ " genes_to_perturb=\"all\",\n",
76
+ " combos=0,\n",
77
+ " anchor_gene=None,\n",
78
+ " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
+ " num_classes=3,\n",
80
+ " emb_mode=\"cell\",\n",
81
+ " cell_emb_style=\"mean_pool\",\n",
82
+ " filter_data=filter_data_dict,\n",
83
+ " cell_states_to_model=cell_states_to_model,\n",
84
+ " state_embs_dict=state_embs_dict,\n",
85
+ " max_ncells=2000,\n",
86
+ " emb_layer=0,\n",
87
+ " forward_batch_size=400,\n",
88
+ " nproc=16)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "0525a663-871a-4ce0-a135-cc203817ffa9",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# outputs intermediate files from in silico perturbation\n",
99
+ "\n",
100
+ "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
+ " \"path/to/input_data\",\n",
102
+ " \"path/to/isp_output_directory\",\n",
103
+ " \"output_prefix\")"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
+ "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
+ "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
+ " genes_perturbed=\"all\",\n",
118
+ " combos=0,\n",
119
+ " anchor_gene=None,\n",
120
+ " cell_states_to_model=cell_states_to_model)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "ffecfae6-e737-43e3-99e9-fa37ff46610b",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# extracts data from intermediate files and processes stats to output in final .csv\n",
131
+ "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
132
+ " None,\n",
133
+ " \"path/to/isp_stats_output_directory\",\n",
134
+ " \"output_prefix\")"
135
+ ]
136
+ }
137
+ ],
138
+ "metadata": {
139
+ "kernelspec": {
140
+ "display_name": "Python 3 (ipykernel)",
141
+ "language": "python",
142
+ "name": "python3"
143
+ },
144
+ "language_info": {
145
+ "codemirror_mode": {
146
+ "name": "ipython",
147
+ "version": 3
148
+ },
149
+ "file_extension": ".py",
150
+ "mimetype": "text/x-python",
151
+ "name": "python",
152
+ "nbconvert_exporter": "python",
153
+ "pygments_lexer": "ipython3",
154
+ "version": "3.10.15"
155
+ }
156
+ },
157
+ "nbformat": 4,
158
+ "nbformat_minor": 5
159
+ }
examples/multitask_cell_classification.ipynb ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "866f100c-e11a-4e7b-a37c-831775d845a7",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Geneformer Multi-Task Cell Classifier Tutorial\n",
9
+ "\n",
10
+ "This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
16
+ "metadata": {},
17
+ "source": [
18
+ "## 1. Installation and Imports\n",
19
+ "\n",
20
+ "First import the necessary modules."
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 3,
26
+ "id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "from geneformer import MTLClassifier"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
36
+ "metadata": {},
37
+ "source": [
38
+ "## 2. Set up Paths and Parameters\n",
39
+ "\n",
40
+ "Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "04a04197-8e45-47f8-a86f-202209ea10ae",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "# Define paths\n",
51
+ "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
52
+ "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
53
+ "train_path = \"/path/to/train/data.dataset\"\n",
54
+ "val_path = \"/path/to/val/data.dataset\"\n",
55
+ "test_path = \"/path/to/test/data.dataset\"\n",
56
+ "results_dir = \"/path/to/results/directory\"\n",
57
+ "model_save_path = \"/path/to/model/save/path\"\n",
58
+ "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
59
+ "\n",
60
+ "# Define tasks and hyperparameters\n",
61
+ "# task_columns should be a list of column names from your dataset\n",
62
+ "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
63
+ "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
64
+ "\n",
65
+ "hyperparameters = {\n",
66
+ " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
67
+ " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
68
+ " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
69
+ " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
70
+ " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
71
+ " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
72
+ "}"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "31857690-a739-435a-aefd-f171fafc1b78",
78
+ "metadata": {},
79
+ "source": [
80
+ "In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
81
+ "1. Identifying the cell type\n",
82
+ "2. Determining the disease state\n",
83
+ "3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
84
+ "\n",
85
+ "These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
86
+ "\n",
87
+ "For example, your dataset might look something like this:\n",
88
+ "\n",
89
+ " | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
90
+ " |----------------|-----------|-----|-----------|---------------|\n",
91
+ " | cell1 | ... | ... | neuron | healthy |\n",
92
+ " | cell2 | ... | ... | astrocyte | diseased |\n",
93
+ " | ... | ... | ... | ... | ... |\n",
94
+ "The model will learn to predict classes within 'cell_type' and 'disease_state' "
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 3. Initialize the MTLClassifier\n",
103
+ "\n",
104
+ "Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "e27caac9-670c-409d-9313-50201c665cb9",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "mc = MTLClassifier(\n",
115
+ " task_columns=task_columns, # Our defined classification tasks\n",
116
+ " study_name=\"MTLClassifier_example\",\n",
117
+ " pretrained_path=pretrained_path,\n",
118
+ " train_path=train_path,\n",
119
+ " val_path=val_path,\n",
120
+ " test_path=test_path,\n",
121
+ " model_save_path=model_save_path,\n",
122
+ " results_dir=results_dir,\n",
123
+ " tensorboard_log_dir=tensorboard_log_dir,\n",
124
+ " hyperparameters=hyperparameters,\n",
125
+ " n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
126
+ " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
127
+ " batch_size=8, # Adjust based on available GPU memory\n",
128
+ " seed=42\n",
129
+ ")"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "id": "0d729444-e3ad-4584-9659-0c464ac97462",
135
+ "metadata": {},
136
+ "source": [
137
+ "## 4. Run Hyperparameter Optimization\n",
138
+ "\n",
139
+ "Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "mc.run_optuna_study()"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
155
+ "metadata": {},
156
+ "source": [
157
+ "## 5. Evaluate the Model on Test Data\n",
158
+ "\n",
159
+ "After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "mc.load_and_evaluate_test_model()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
175
+ "metadata": {},
176
+ "source": [
177
+ "## 6. (Optional) Manual Hyperparameter Tuning\n",
178
+ "\n",
179
+ "If you prefer to set hyperparameters manually, you can use the following approach:"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "manual_hyperparameters = {\n",
190
+ " \"learning_rate\": 0.001,\n",
191
+ " \"warmup_ratio\": 0.01,\n",
192
+ " \"weight_decay\": 0.1,\n",
193
+ " \"dropout_rate\": 0.1,\n",
194
+ " \"lr_scheduler_type\": \"cosine\",\n",
195
+ " \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
196
+ " \"max_layers_to_freeze\": 2\n",
197
+ "}\n",
198
+ "\n",
199
+ "mc_manual = MTLClassifier(\n",
200
+ " task_columns=task_columns,\n",
201
+ " study_name=\"mtl_manual\",\n",
202
+ " pretrained_path=pretrained_path,\n",
203
+ " train_path=train_path,\n",
204
+ " val_path=val_path,\n",
205
+ " test_path=test_path,\n",
206
+ " model_save_path=model_save_path,\n",
207
+ " results_dir=results_dir,\n",
208
+ " tensorboard_log_dir=tensorboard_log_dir,\n",
209
+ " manual_hyperparameters=manual_hyperparameters,\n",
210
+ " use_manual_hyperparameters=True,\n",
211
+ " epochs=10,\n",
212
+ " batch_size=32,\n",
213
+ " seed=42\n",
214
+ ")\n",
215
+ "\n",
216
+ "mc_manual.run_manual_tuning()"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
222
+ "metadata": {},
223
+ "source": [
224
+ "# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
225
+ "This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "# Define paths\n",
246
+ "model_directory = \"/path/to/model/save/path\"\n",
247
+ "input_data_file = \"/path/to/input/data.dataset\"\n",
248
+ "output_directory = \"/path/to/output/directory\"\n",
249
+ "output_prefix = \"mtl_quantized_perturbation\"\n",
250
+ "\n",
251
+ "# Define parameters\n",
252
+ "perturb_type = \"delete\" # or \"overexpress\"\n",
253
+ "\n",
254
+ "# Define cell states to model\n",
255
+ "cell_states_to_model = {\n",
256
+ " \"state_key\": \"disease_state\", \n",
257
+ " \"start_state\": \"disease\", \n",
258
+ " \"goal_state\": \"control\"\n",
259
+ "}\n",
260
+ "\n",
261
+ "# Define filter data\n",
262
+ "filter_data_dict = {\n",
263
+ " \"cell_type\": [\"Fibroblast\"]\n",
264
+ "}"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "markdown",
269
+ "id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
270
+ "metadata": {},
271
+ "source": [
272
+ "## 3. Extract State Embeddings\n",
273
+ "\n",
274
+ "Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "# Initialize EmbExtractor\n",
285
+ "embex = EmbExtractor(\n",
286
+ " filter_data_dict=filter_data_dict,\n",
287
+ " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
+ " emb_layer=0, # Use the second to last layer\n",
289
+ " emb_mode = \"cls\",\n",
290
+ " summary_stat=\"exact_mean\",\n",
291
+ " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
+ " nproc=4\n",
293
+ ")\n",
294
+ "\n",
295
+ "# Extract state embeddings\n",
296
+ "state_embs_dict = embex.get_state_embs(\n",
297
+ " cell_states_to_model,\n",
298
+ " model_directory=model_directory,\n",
299
+ " input_data_file=input_data_file,\n",
300
+ " output_directory=output_directory,\n",
301
+ " output_prefix=output_prefix\n",
302
+ ")"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "markdown",
307
+ "id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
308
+ "metadata": {},
309
+ "source": [
310
+ "## 4. Initialize the InSilicoPerturber\n",
311
+ "\n",
312
+ "Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "id": "09f985a1-91bc-4e8d-8001-a3663531b570",
319
+ "metadata": {},
320
+ "outputs": [],
321
+ "source": [
322
+ "# Initialize InSilicoPerturber\n",
323
+ "isp = InSilicoPerturber(\n",
324
+ " perturb_type=perturb_type,\n",
325
+ " genes_to_perturb=\"all\", # Perturb all genes\n",
326
+ " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
+ " emb_mode=\"cls\", # Use CLS token embedding\n",
328
+ " cell_states_to_model=cell_states_to_model,\n",
329
+ " state_embs_dict=state_embs_dict,\n",
330
+ " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
331
+ " emb_layer=0, \n",
332
+ " forward_batch_size=8, # Adjust based on available GPU memory\n",
333
+ " nproc=1\n",
334
+ ")"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
340
+ "metadata": {},
341
+ "source": [
342
+ "## 5. Run In Silico Perturbation\n",
343
+ "\n",
344
+ "Run the in silico perturbation on the dataset."
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "# Run perturbation and output intermediate files\n",
355
+ "isp.perturb_data(\n",
356
+ " model_directory=model_directory,\n",
357
+ " input_data_file=input_data_file,\n",
358
+ " output_directory=output_directory,\n",
359
+ " output_prefix=output_prefix\n",
360
+ ")"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "markdown",
365
+ "id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
366
+ "metadata": {},
367
+ "source": [
368
+ "## 6. Process Results with InSilicoPerturberStats\n",
369
+ "\n",
370
+ "After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "# Initialize InSilicoPerturberStats\n",
381
+ "ispstats = InSilicoPerturberStats(\n",
382
+ " mode=\"goal_state_shift\",\n",
383
+ " genes_perturbed=\"all\",\n",
384
+ " combos=0,\n",
385
+ " anchor_gene=None,\n",
386
+ " cell_states_to_model=cell_states_to_model\n",
387
+ ")\n",
388
+ "\n",
389
+ "# Process stats and output final .csv\n",
390
+ "ispstats.get_stats(\n",
391
+ " input_data_file,\n",
392
+ " None,\n",
393
+ " output_directory,\n",
394
+ " output_prefix\n",
395
+ ")"
396
+ ]
397
+ }
398
+ ],
399
+ "metadata": {
400
+ "kernelspec": {
401
+ "display_name": "Python 3 (ipykernel)",
402
+ "language": "python",
403
+ "name": "python3"
404
+ },
405
+ "language_info": {
406
+ "codemirror_mode": {
407
+ "name": "ipython",
408
+ "version": 3
409
+ },
410
+ "file_extension": ".py",
411
+ "mimetype": "text/x-python",
412
+ "name": "python",
413
+ "nbconvert_exporter": "python",
414
+ "pygments_lexer": "ipython3",
415
+ "version": "3.11.5"
416
+ }
417
+ },
418
+ "nbformat": 4,
419
+ "nbformat_minor": 5
420
+ }
examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "charged-worcester",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Obtain non-zero median expression value of each gene across Genecorpus-30M"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "28e87f2a-a33e-4fe3-81af-ad4cd62fcc1b",
14
+ "metadata": {},
15
+ "source": [
16
+ "#### Upon request, we are providing the code that we used for obtaining the non-zero median expression value of each gene across the broad range of cell types represented in Genecorpus-30M that we use as a normalization factor to prioritize genes that uniquely distinguish cell state.\n",
17
+ "\n",
18
+ "#### Please read the important information below before using this code.\n",
19
+ "\n",
20
+ "#### If using Geneformer, to ensure consistency of the normalization factor used for each gene for all future datasets, <ins>**users should use the Geneformer transcriptome tokenizer to tokenize their datasets and should not re-calculate this normalization factor for their individual dataset** </ins>. This code for re-calculating the normalization factor should only be used by users who are pretraining a new model from scratch with a new pretraining corpus other than Genecorpus-30M.\n",
21
+ "\n",
22
+ "#### It is critical that this calculation is performed on a large-scale pretraining corpus that has tens of millions of cells from a broad range of human tissues. <ins>**The richness of variable cell states in the pretraining corpus is what allows this normalization factor to accomplish the goal of prioritizing genes that uniquely distinguish cell states.** </ins> This normalization factor for each gene is calculated once from the large-scale pretraining corpus and is used for all future datasets presented to the model. \n",
23
+ "\n",
24
+ "#### Of note, as discussed in the Methods, we only included droplet-based sequencing platforms in the pretraining corpus to assure expression value unit comparability for the calculation of this normalization factor. Users wishing to pretrain a new model from scratch with a new pretraining corpus should choose either droplet-based or plate-based platforms for calculating this normalization factor, or they should exercise caution that including both platforms may cause unintended effects on the results. Once the normalization factor is calculated however, data from any platform can be used with the model because the expression value units will be consistent within each individual cell.\n",
25
+ "\n",
26
+ "#### Please see the Methods in the manuscript for a description of the procedure enacted by this code, an excerpt of which is below for convenience:\n",
27
+ "\n",
28
+ "#### \"To accomplish this, we first calculated the non-zero median value of expression of each detected gene across all cells passing quality filtering from the entire Genecorpus-30M. We aggregated the transcript count distribution for each gene in a memory-efficient manner by scanning through chunks of .loom data using loompy, normalizing the gene transcript counts in each cell by the total transcript count of that cell to account for varying sequencing depth and updating the normalized count distribution of the gene within the t-digest data structure developed for accurate online accumulation of rank-based statistics. We then normalized the genes in each single-cell transcriptome by the non-zero median value of expression of that gene across Genecorpus-30M and ordered the genes by the rank of their normalized expression in that specific cell. Of note, we opted to use the non-zero median value of expression rather than include zeros in the distribution so as not to weight the value by tissue representation within Genecorpus-30M, assuming that a representative range of transcript values would be observed within the cells in which each gene was detected. This normalization factor for each gene is calculated once from the pretraining corpus and is used for all future datasets presented to the model. The provided tokenizer code includes this normalization procedure and should be used for tokenizing new datasets presented to Geneformer to ensure consistency of the normalization factor used for each gene.\""
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 1,
34
+ "id": "textile-destruction",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "import numpy as np\n",
40
+ "import loompy as lp\n",
41
+ "import pandas as pd\n",
42
+ "import crick\n",
43
+ "import pickle\n",
44
+ "import math\n",
45
+ "from tqdm.notebook import tqdm"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "id": "4af8cfef-05f2-47e0-b8d2-71ca025059c7",
51
+ "metadata": {
52
+ "tags": []
53
+ },
54
+ "source": [
55
+ "### The following code is an example of how the nonzero median expression values are obtained for a single input file. This calculation should be run as a script to be parallelized for all dataset files."
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 30,
61
+ "id": "physical-intro",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "input_file = \"study1.loom\"\n",
66
+ "current_database = \"database1\"\n",
67
+ "\n",
68
+ "rootdir = f\"/path/to/{current_database}/data/\"\n",
69
+ "output_file = input_file.replace(\".loom\", \".gene_median_digest_dict.pickle\")\n",
70
+ "outdir = rootdir.replace(\"/data/\", \"/tdigest/\")\n",
71
+ "\n",
72
+ "with lp.connect(f\"{rootdir}{input_file}\") as data:\n",
73
+ " # define coordinates of protein-coding or miRNA genes\n",
74
+ " coding_miRNA_loc = np.where((data.ra.gene_type == \"protein_coding\") | (data.ra.gene_type == \"miRNA\"))[0]\n",
75
+ " coding_miRNA_genes = data.ra[\"ensembl_id\"][coding_miRNA_loc]\n",
76
+ " \n",
77
+ " # initiate tdigests\n",
78
+ " median_digests = [crick.tdigest.TDigest() for _ in range(len(coding_miRNA_loc))]\n",
79
+ " \n",
80
+ " # initiate progress meters\n",
81
+ " progress = tqdm(total=len(coding_miRNA_loc))\n",
82
+ " last_view_row = 0\n",
83
+ " progress.update(0)\n",
84
+ " \n",
85
+ " for (ix, selection, view) in data.scan(items=coding_miRNA_loc, axis=0):\n",
86
+ " # define coordinates of cells passing filter\n",
87
+ " filter_passed_loc = np.where(view.ca.filter_pass == 1)[0]\n",
88
+ " subview = view.view[:, filter_passed_loc]\n",
89
+ " # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision\n",
90
+ " subview_norm_array = subview[:,:]/subview.ca.n_counts*10_000\n",
91
+ " # if integer, convert to float to prevent error with filling with nan\n",
92
+ " if np.issubdtype(subview_norm_array.dtype, np.integer):\n",
93
+ " subview_norm_array = subview_norm_array.astype(np.float32)\n",
94
+ " # mask zeroes from distribution tdigest by filling with nan\n",
95
+ " nonzero_data = np.ma.masked_equal(subview_norm_array, 0.0).filled(np.nan)\n",
96
+ " # update tdigests\n",
97
+ " [median_digests[i+last_view_row].update(nonzero_data[i,:]) for i in range(nonzero_data.shape[0])]\n",
98
+ " # update progress meters\n",
99
+ " progress.update(view.shape[0])\n",
100
+ " last_view_row = last_view_row + view.shape[0]\n",
101
+ " \n",
102
+ "median_digest_dict = dict(zip(coding_miRNA_genes, median_digests))\n",
103
+ "with open(f\"{outdir}{output_file}\", \"wb\") as fp:\n",
104
+ " pickle.dump(median_digest_dict, fp)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "markdown",
109
+ "id": "190a3754-aafa-4ccf-ba97-951c94ea3030",
110
+ "metadata": {
111
+ "tags": []
112
+ },
113
+ "source": [
114
+ "### After the above code is run as a script in parallel for all datasets to obtain the nonzero median tdigests for their contained genes, the following code can be run to merge the tdigests across all datasets."
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 2,
120
+ "id": "distributed-riding",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "# merge new tdigests into total tdigest dict\n",
125
+ "def merge_digest(dict_key_ensembl_id, dict_value_tdigest, new_tdigest_dict):\n",
126
+ " new_gene_tdigest = new_tdigest_dict.get(dict_key_ensembl_id)\n",
127
+ " if new_gene_tdigest is not None:\n",
128
+ " dict_value_tdigest.merge(new_gene_tdigest)\n",
129
+ " return dict_value_tdigest\n",
130
+ " elif new_gene_tdigest is None:\n",
131
+ " return dict_value_tdigest"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "distinct-library",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "# use tdigest1.merge(tdigest2) to merge tdigest1, tdigest2, ...tdigestn\n",
142
+ "# then, extract median by tdigest1.quantile(0.5)\n",
143
+ "\n",
144
+ "databases = [\"database1\", \"database2\", \"...databaseN\"]\n",
145
+ "\n",
146
+ "# obtain gene list\n",
147
+ "gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n",
148
+ "func_gene_list = [i for i in gene_info[(gene_info[\"gene_type\"] == \"protein_coding\") | (gene_info[\"gene_type\"] == \"miRNA\")][\"ensembl_id\"]]\n",
149
+ "\n",
150
+ "# initiate tdigests\n",
151
+ "median_digests = [crick.tdigest.TDigest() for _ in range(len(func_gene_list))]\n",
152
+ "total_tdigest_dict = dict(zip(func_gene_list, median_digests))\n",
153
+ "\n",
154
+ "# merge tdigests\n",
155
+ "for current_database in databases:\n",
156
+ " rootdir = f\"/path/to/{current_database}/tdigest/\"\n",
157
+ " \n",
158
+ " for subdir, dirs, files in os.walk(rootdir):\t\n",
159
+ " for file in files:\n",
160
+ " if file.endswith(\".gene_median_digest_dict.pickle\"):\n",
161
+ " with open(f\"{rootdir}{file}\", \"rb\") as fp:\n",
162
+ " tdigest_dict = pickle.load(fp)\n",
163
+ " total_tdigest_dict = {k: merge_digest(k,v,tdigest_dict) for k, v in total_tdigest_dict.items()}\n",
164
+ "\n",
165
+ "# save dict of merged tdigests\n",
166
+ "with open(f\"/path/to/total_gene_tdigest_dict.pickle\", \"wb\") as fp:\n",
167
+ " pickle.dump(total_tdigest_dict, fp)\n",
168
+ "\n",
169
+ "# extract medians and save dict\n",
170
+ "total_median_dict = {k: v.quantile(0.5) for k, v in total_tdigest_dict.items()}\n",
171
+ "with open(f\"/path/to/total_gene_median_dict.pickle\", \"wb\") as fp:\n",
172
+ " pickle.dump(total_median_dict, fp)\n",
173
+ "\n",
174
+ "# save dict of only detected genes' medians \n",
175
+ "detected_median_dict = {k: v for k, v in total_median_dict.items() if not math.isnan(v)}\n",
176
+ "with open(f\"/path/to/detected_gene_median_dict.pickle\", \"wb\") as fp:\n",
177
+ " pickle.dump(detected_median_dict, fp)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "id": "e8e17ad6-79ac-4f34-aa0c-1eaa1bace2e5",
183
+ "metadata": {
184
+ "tags": []
185
+ },
186
+ "source": [
187
+ "### The below code displays some characteristics of the genes detected in the pretraining corpus."
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 38,
193
+ "id": "decent-switzerland",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "gene_detection_counts_dict = {k: v.size() for k, v in total_tdigest_dict.items()}"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 44,
203
+ "id": "polished-innocent",
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stderr",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "/home1/ct68/miniconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
211
+ " warnings.warn(msg, FutureWarning)\n"
212
+ ]
213
+ },
214
+ {
215
+ "data": {
216
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAMRCAYAAABlG8GWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABcSAAAXEgFnn9JSAAC/KUlEQVR4nOzdd5hjZ3X48e/Z7l2vK240G0wzBgOmmmp6NT/TQmjBlCS0ACGE3gklJARCLyGYGgi9hRqwgYBpxnRMMTZgsI1x2+Lt5/fHe8d7dUfSSBpdaWb2+3kePaN7dcs7M1ea0dF5z4nMRJIkSZIkSZLGbdm0ByBJkiRJkiRpaTL4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWrFi2gOQJEmTERFXAY4GDgf2A1YDG4BLgPOA0zPzT9Ma3yAi4mTgkbVVj8rMk/tsfwTwm9qqczLziDbGJqm3iDib8toz4xqZefZ0RqM9XUQcD3yltsq/DZLUIoOPkiQtYRFxLHAScF863/j32v4c4AvAe4GvZWa2OkANLCJOAe4wzmNmZozzeJImw9cDSdJi4rRrSZL6iIjlEXFpRGR1e/uA271m0mNtjOeGEfFF4HvA3zFA4LFyOPDXwKnAryLiERHh/wuSJEmSRuKbCUmS+rsJsE9t+Ss9trtZY7tT2hpQP1E8AzgduEufTRO4mDLtuld24zWBdwPfGOsgJUmSJO0xnHYtSVJ/zWltp/TY7k61+7uAr7Yymj6qDMX/AB7V5eE/AB8DPgt8F7gwM3dW+60GrgPcFrgf5XtZXtv3ei0OW6M7DXjXtAchaUHw9UCStGAZfJQkqb/ja/fPzMw/9NiuHnz8QWZe3N6QenoDswOPG4FXAK/JzMu77ZSZW4EfVbc3R8SRwHMptSKtAbZwnZmZb5n2IKRB2Myjdb4eSJIWLKddS5LUQ5VJeNvaqq5TriNiFXDrubZrU0Q8Gnh8Y/X5wO0y8+W9Ao/dZOavM/PRlO/prDEOU5IkSdIexuCjJEm93RjYr7bcK6h4K2BtbfmUdobTXURcGWg2uLkEuG1mnjHqcTPzNErNyy+MPDhJkiRJezSDj5Ik9bZY6j0+n85mNwBPycxfzffAmXkZ8JfzPY4kSZKkPZM1HyVJ6u342v2fZOYFPbarBx+/n5mXtjekThFxILPrPH41M989rnNk5q5R9ouI/SlZoYcAB1G6av8J+A1wWmZuG9cY2xQRhwI3Ao4A9gVWAZcDlwG/BX49jkCveouIFcAtgBsABwJbKE2UvtfWzz4i1lGu3+sA+1M+WPhjZg7U1CMiDgeOpVz7BwKbgQuAnwI/zMxeXeYHHV8ARwLHAIdRPoCI6jwXAedQ6gCeN+Lx96mOfR1KBvhaYCuwCTgXOBv4aWZun8/3MR8RsS9wG+DawN7ApZTr4quZeeGYzrEGuD1wdeBgys/gt8C3MvO34zjHYlS7/q5L+dnsQ0lsuRi4kPLc/E1L574ecDRwJeAAYAfld/8r4EeZ+acxnWcl5TXgBpTXgMspz+Fv+ZovScMx+ChJUhdVvcfb1Vad0mO7vYBbzrVdix4GrG6se9OEx3CFiFgOPBL4a+DmdHbNrtsYEZ8BXpyZP5vU+AZVBbseW91uOsD2FwFfAz4MvH/UgO1CEBF3BT5H5wyZl2TmC4c4xq2BU+n8X/M1mfm0Hts3A3HXyMyzI2It8EzgiZQAXrd9Twf+KTM/NsT4TgLeWVt1amYeXz12beBFwAOY/dyCPh2FI2I/4O8p2cLX6TOE8yLivcArMvOiQcddnWN/4OnAwylBn7m2/y3wReA9mXnqANvfHXgycDfmfq+wJSK+A3wMOLlfo62IOBs4vLbqGpl5dp/tXwTUr7l3ZeZJ1WPXBP4JeCCwssvuGRFfAp6VmafP8T30Ov9VgJdRroO9e2zzDcpr2Bd6jPnFmfmiUc6/EFXX3v2Be1MCsl2fk7XtzwXeAbx+vsHgqhHaPwInAFfus2lGxA+BjwLvyMxzRzjXeuA5wOPoLL1S3+ZnwAsy88PDHl+S9kROu5Yk7bEiInvdgJ2UTIcZT+yx3WZKJtyMf+hz3ONb+Dbu21i+iBIImLiIuC0lq+sdlGyRXoFHKG/mHwz8KCJeXmXRLAhVxtrpwJsZIPBYOQD4f8B7mD0FflHJzC8CL2msfl4VlJpTRFwJ+CCdgatvUoKIA4uIawDfA15A/yDHscBHI+JDVZbayKrGTT8CHkr3wGO/ff+W0qDpBfQPPAIcSgkgnhURDxziHHcDfkkJjMwZeKxcHXgM8O9zHHtNRHyQEni+F4MlKayhfEjzb8wuU9GKiHgI8GPgIXQPPELJAL0r8K2I+KsRz/FzygcpXQOPlVsDn4+If68+sFqyIuIY4DzgP4D7MUfgsXIVyvPhVxFxnxHPuyYi3kT5ffwt/QOPUH73NwJeDHx6hPPdiPIa8Cx6BB4rRwEfiog3L/XfvSSNgy+UkiQtUlV23m0aq78xjenMEfEI4H/pHnRJyhTljV0eWw48G/ivhfAGLiIOoGQw3rDHJpspUwo3TWxQ0/FS4PO15WXAeyPiqv12qn6H7wPq210I/MWQ03OvRLmertdYv5Ey9bGbBwIfj4ihgoYzIuKRlMB5c/9LgJ5jj4jlEfF64C10fmAxYydlKurWLo/tC/x3RPzdAOO7DfApugd9EthA+Vl3O88gPgz8RY/HtgJ/pjyPp5bVGxEPp1xfe9VW76L8fLu97q0ATo6IOw5xjkcC76V70HHmXDsb65/M7KZfS81aOj9oq9tOuT66vcZDuc4/GREPHuaEVTO1rwKPp3cw/DLKtd/1EEOe7waUxnKHNx66jN6v+Y+jBFglSX1M/Z98SZI0sqMomUd135n0IKqMlnfR+cb0YuBVwHHAmszcNzPXUzJJHgB8o3GYB1Ma50zby4Cr1ZYTeDdlCur+mbkuMw/KzL0p3+/RlAyskyn1LJeEatr4w4Hf1VZfiRIo65VtBuV3eLf6oYCHZ+bvhxzCG4BrVPd/TalreqXMXJ+ZaykZVU+m1F+ruzvwyiHPBSU7cKZcwS7KlOw7AKszc39KQPIISjZU0yuBJzXW/Yoy/fr6wMrMPCAz11TneQKlHuOMAF4bEXeiv7fR+Ry7DHg5pbzBuszcp7o211ACZzenBG0+Se+AbRlACQrdu7H6a5Tp41fJzDWZeaXM3JcSBDoCuA/wauDMOcY9LjegZN0Fpebnayh1QFdVP9/V1TZvoDNAGsB/VCUh+qqy3t5O53ukXcBbKdncqzPzAMrv4UaU17iZYO+TgXuM/N0tHpcDn6Fc87cB9svMVdX1sR5YR8kIfTWdwciZ38ORg5ykKmnyGcp1XHcxJTP7lpTfx76ZuQ/ld3ITyvPrS8wOEM9lL8qsgZkPED5G+X2uq86xN+V15x8oH0jUPSci5sp2lqQ9Wsyz1rUkSYtWRDyux0PLgNexe9rwl4EPddluNfDa2vIngc/2OeUnM/MPQw6zp4i4H6WuVd39MvPj4zrHAGM4HPg+nRlfXwQekZnnz7Hv8ygZdjN2ATfLzO/32edkylTIGY/KzJP7bH8EpcHNjHMy84ge266iBLP2ra3+y8z8YK/jd9n/AcDHM7NvsGcUEXEKnVNbr6iB15aIuBUl86gecPz3zHxql23vQsmWrAduXpqZc2YFdan5OONTlN/B5h77HVidsz49fhdw28z8Zp/znURnzccZG4D7ZuYpc425Os59gU80Vv8r8Nx+GcgRsTclg69eNuEPwJGZuaXL9rcAvlVbdQlwq8wcKPBXZfTeOTO7vY4REf8D3LO26s3AEwdtihMRtwf+1K9+6xhqPs74DXCvzPx5n30fQfnQoO7/ZeYn++yzjDLN/8a11RuBe2fmV/vsdzQl2HVol4dbq/k46deDiLguJUD9jkGbqlV1Mz9JKY0w4z8z8zED7Hsyna/1AB+nvOZfMsD+R1Cey6/r8fjxlCzHps2UD0x6li+pMiT/j84SGz1r2kqSgMz05s2bN2/evNVuwM0oGVszt4f32O4Oje3uOeFxPqFx/gRuN+Ex/Gfj/F+lZKMMuv8bG/t/YI7tT25sf9Ic2x/R2P7sPtter7Htt6d9LTbGd0qX3/d8bicOeN6ndNn3AY1trkIJ3Na3+RKwbMBzdBvfDylZs3PteyXg/Ma+n5ljn5N6nPM+Q/w+llHqL9b3f9UQ+6+hBLvq+z+ux7aPG/U8A47lvNqxtwH7tnD9nt34Ho6YY/sXdfn9XAZce8DzfbKx77vn2P5eo14PlIDl9i77v2jcP8faOafyejDCOA+mTMmeOc8WShZ5v32OoXyIUB/fhwd9PRlwXMf3+Dn85YD7P62x32/b+l178+bN21K4Oe1akqTZmtMfv9xju+Nr93cCX29lNL2t77JuoIyUcajq/z28tmoH8NjMHKbm3HMpAYUZD6yy2abhgMbyr6YyigUmM/+d8sa/7j+jdIWeqT36AeCg2uN/AB6a8+v6/eTskgXYZXwXUq6juntUDWuG8enMHKZBxQOBa9WWfwU8b9Cdq+/tHxure2Vjt31t1o9/YQ6Y2TYFr8zMXw647dsay83pu03Nn/0nBr0eMvMMdk/bV01mXkCppTpjNWVadj/PprNe4x8pf1varjf6xcz8wIDbvpPyN2/G1SLikBbGJElLgsFHSZJmqwcfz8zeU6WPr93/Xmb2Knrflm6NNSbZCOVBdE7H/Vxm/mKYA2SZPve52qrllO6503BJY/kmC6EJzgLxaKD+u90H+HBVl+2VwG1rj+2gZA816zEO46c54NTnynvpDGIvY3YNw7k0g1VzeVhj+S05ZLOnzPwyJetwxjHVFOmmSxrLg3ZhH1T9+IfM1VhoSnYx3O+oWVf2Or2ez1UA/c6N1cMGE9885PZ7ktMay7fqtWFVU/a+jdWvywGmWo/BwL/DzLwYaJYZaDbIkiRVenUNkyRpj1S98akHUrpmPVYddetvoE5pcVi9dMswXDfB89+hsfy5rlvN7Xt0dtk9jlLba9LOpDQzmKlfeT3grRHxtCkElgdxGqXRz6jOGHTDzNwQEQ+k1B2c6TR8DKUj9XGNzZ+bmV+bx7hgdh3Fuca3JSK+QMlGnHErSvORgQ4BnDro+aogVjNIPur1/31211sMSiONZu3YZvDmMRFxBvDWMWWDnQacUN1fRgks/2X2qck4BT+uslwHkpkXRcSl7K7huoySLd4tq/MYSjfnGVvonfHe63w/j4izgGsOs98YTez1oK7qSH09SjOx9ZRyAs0u081mLFejt1vS+buA8uHCJAz8GlA5C7hhbXm/8Q1FkpYWg4+SJHW6BZ0BvK/02O5WdHaaPqWtAfWxscu6fbusa0sze+V6fZr49HNMY/mwEcczL5m5MyLeSmdH48cCD4qIj1A6r34tMxdKV+szM/MtkzpZZv4oIh5Pqbs5oxl4/BTwL2M43ekj7lMPPt5oiH3PyczL5t7sCtehs8kSwJ0iYpSs3Ss1lmdd/5l5ekScxu7n3HJKZt4zIuK/KYHPb2WPxjwDeCO7g49QAkC/jIjPUgLBp2Tmr0c89ricPcI+G+h8TdyH7sHHZsbajzNzR5ft5vJ9phd8nNjrQUTcjZL5e29glDIZzedOXTOr95zM/P0I5xjWZZl50ZD7ND+U2qfrVpIkg4+SJDXUp1wnvYOPx9fuT6PeI5Q6WE3dpmyOXZX5dVBj9ZPGdPhp1XwEeAlwezprku1LmXb8aICI+AWl0+nXgC9n5jmTHuS0ZOa7qgBbt261ZwOPzMwcw6lG+Zk29xnmOvrzkOfq1tm4a1fdEfQa9yMoU4nrz7sjgGdUtx0R8X3KtflVSsDw4kFOmJmfj4jXAH9fW72CEpA8ASAizqvO/zXg1OzTlb4ll4ywz87G8vIe2zWDYb1Kbcyl22vykhER16JMfb/jPA/VrV7xjObflUnV3r1khH0Gvb4kaY9n8FGStEeJiJtRuln3Us+cuojSAKXbdvev3f8T8LAe2/0hMz857DgH1C0T6RiGnLI6ov1pr3Z0c8rdxGTm5RFxZ+AVlG7iq7psdp3q9iiAiPg28HbgXZm5fVJjnaLnU7pFN99oP2bQYNcAhslCnNHMaOuXXdXULYu4nzYD5F2v/8z8VUTclFKXrls9yxWUpio3B54KbK+mor82M78010kz82kR8XPgZczOxoQScL1/dSMizgbeTanHN2zwdhTjCGr3sl9jedQyC6Nct4tCRNyA0sF+HE1V+v3taD63LhnD+QbR5vUlSXs8g4+SpD3NfYAXDrjtgQxWgP7QPtudCrQVfPwppe5jvfHMXB1dx6VbUG5cukZxJ6XqQvz3EfFq4K+AE4Fj6Z3Vcovq9oyIeEhmfm8iA52CKBH2t9D9Z/F4hqyT18cogYBJXjdTuf4z83fAfSLiWMq1eS/g2j02X0kJUt47Ij5HyUrt2wQoM98WEe8HHkxpKHVbeteRPQJ4AfDUiHhiZk6qLl8bmvVzR/39tnldTE1VC/kDzA48/gD4KPAdSubxecDlwNZ6LdKIOJ7eswjmYlBQkpYAg4+SJC1Smbk9Ir5B5xS420TEqmG77o6gW6bTUZn585bPOzFVnbGXAy+PiPWUenvHAbehBGWaGWrXBr4cEbfNzB9NdLCT84/M7kQ744ER8ZTM/PcxnGeU2qXNemvjysLspnn9n5+Z3aZityIzT6fUuHxqRBxGKRNwa8p1eVNmB4fvAXwpIm6dmX2zPKvH3wG8owo63YRy3d+WUpLg4MYu+wDviYgVmXnyvL6x6WleK/uNeJxR91voHgYcXVveATxqiIDz3kOcq9lUaJgMZknSAtXWdClJkjQZzazKA4D7tX3SKrjZnGJ4rbbPOy2ZuSEzv5iZL8nMu1N+zicAn29sug+Dd1heVKpajy9rrD6rsfwvEXHLMZzu8DHs0+ZU4GbToUOqAPXEZeYfM/MjmfkPmXlLSnDwr4GfNTa9ISV4PMyxt2fmtzPz3zPzQZQs71tS6v41G7K8NiIWa6CoWavxqBGPM+p+C939G8uvHDLTtVnHsZ/mc2vJ/l2RpD2JwUdJ0h4lM1+UmdHtBvxPbdMz+mz31dp2p/Xarrod3/K39D6gmeX4hJbPOaPZcOL4CZ136jJza2Z+OjPvQWd3bIDbR8TVpzGutkTEwZRpl/VZM6dSaozWO1OvBP47Iubb+OjYMezzg3mOoZ+fAVsa6+7Q4vkGlpkXZeZ/ULp9f6rx8CPmeeysgpF/S8m4rgcg96WzY/Zi8t3G8lUj4irDHCAiVgM3HtuIFpYbN5bfPeT+txhi2+bv4vCIuOqQ55MkLTAGHyVJAiJiOXC72qpTemy3hpL5M2PUOlZjkZl/At7VWH37iPircZ2j6mzdzRcbyw+IiD2xpMurmJ05daNpDKQN1e//fcCVa6vPBx6SmZsoTZouqT12deDd0aMD04D+35BjXAPcrbH6tHmcv6+qLmizw/2D2zrfKKrmR89orL7GuDI0M/PrwEcaqxfldV/Vwmx2VX7YkIe5H73rYy52zan2A3ejr/623muIc30H2NRY9/Ah9pckLUAGHyVJKm4O1N+U9woqHkdng5epBh8rL2Z2d9Z/j4h5T1eLiH2A/+rx8EeAXbXlI4DHzPeci01mJrPfjC+lIMQLgbvUlncBD83MPwJk5m+oOn/X3Bt45jzOef2IGCaT8OF01nzcBXxmHucfxH83lh8SEddv+ZzD+k2XdeO8NpvHX8zX/fsay0+NiIFqj1Yfujx3/ENaMJrZ9fsNse9DKR9IDKQKmn+8sfrvBv1dSJIWJoOPkiQV9aYtu+icWt1ru+3A/7U2ogFl5rnAPzRW7wd8PSJGzkSKiFtRptTevcd5f06Zilv3rxFxk3mcc2qdrkfN2qyacjQDvefNf0TTFxF3A57XWP2izOzoap2ZHwde3djunyLi9vM4/eurqaxzjfFKzK5F+fkqKNqmk4Gza8vLgQ9FxH6jHrDX9T+PjOJmMHQnjZp688xWbh5/MV/3b6O8ps84DPiPKnNvLv8C3KCVUS0Mv28sDzS9vio/8doRzvdKOrtcX5nSAMn3rpK0SPkCLklSUQ8qnpGZl/TY7vja/W9l5ubWRjSEzHw75c1z3SHA1yLi2RGx16DHiohrRsQ7KIHVI+fY/PnApbXlvYH/jYgHDHq+6pyHRcQLmV2jbpKeEBGfjYh7DPkm9xXAlWrLGylTBxe1qs7a++j8f/HzzA70zXgWncH45cAHqnqRo7ghJZjX89qNiAOBz9E5LTT7jHFsqgytpzdWX58S9B8qEBURx0TE2ylZzN28OyLeHhHHDHHMdcwO/HwtM3c21t0gIn4YEY+p9hn0+P8PuE9j9ULIBB9JZv6BEvSqeyDwyYi4Wrd9IuLAiDgZeGq1qlkHdKn4cmP5ZRHR929DlQX8VUpzrqFk5o+BdzZWPwD48BDZqEdExJOHPbckqR17Yl0mSZI6RMQq4Da1Vaf02G4vFlC9xy6eAOxFZ1OJ9cDLgSdFxEeBzwLfAy6cCUJU2WXXpvwM7g/cmRI4mlNmnhURD6ZMcZ3ZZ3/Km8TTgP+gvAH9dWbuqs4XlGDRDYGbUrJojqMEub430nc+HsuAe1S3P0XEJyi/4zOAX1UdvgGIiEOA2wNPqr7Wvb2qhdi260bE4+Z5jK9k5pnNlVU23AfpDKr+Hnj4zO+xKTN3VNfC99nd3fYw4P0Rcbde+/XwLcpz7QTgRxHxT8AnM/OianyHUQJDz2N2Pbo3ZOZEMpIz8yPV2OrZoUcDZ0TExyglC76RmVdkBFaZdIdTmvUcR6lved3q4Wb26Iy1wEOAx0bEmcDHgG9Srs3f155by4BrUK7hpwHXbBznNT2Of0PKc/X1EfF5SpD5dOAn9Wu5qhd5c+CvKK8z9cD09+idMb5YvBS4J3Cz2rp7Ab+KiC9RmqFcSHmNuxElK3wmYPt7SimKp9T2rWfvta211wPgLcDj2f37PgT4bkS8DPjvzPwtXPG6cTPKVOu/BVZV25/C8A3JnkRpInXj2rr7AXeIiNcDnwZ+UH0IMJOBfn3Kc+r+wJ2AHwOvG/K8kqQWGHyUJKkEOdbWlnsFFW/N7jdT/babiszcGRGPBH4OvITOAOKVKW/mnjSzeURcXG2znv6zIZpdrZvn/XwVdHonnXUzb1XdAHZFxKXVeeY630JwEPDY6gZARGwGNlOCDb2y8b7L5Gq/1X++o3oU0C3Y8M+U633GDuDBmXlhv4Nl5rkR8TBKNuLM7/jOlLqRLxxiXE+i1FS8BiX79p0AEbGBcs2u7bHfl5jdZKVtL6CM6VnAzLTp5ZTg6AMBImIHJUN4DfOvi3hdOjusZ/Vz2U6pe7myx35vzMxPznHsvYATqxsAEbENuIxS67ZXs5o/UwLTkwy2jV1mbo+IuwNfoHwwMmMVJQjZq3HKRZRA+f0a6yeZCdna60Fm/jgiXkNneY/9KNPN/yUiNgFbKUHZZumAzwP/ypDBx8y8PCLuTcmGr3eyP4Da60n1dyUo1+bUynZIkvpb6P/4S5I0CfUp1zuBrw2w3VZK5tGCksXLKW+c+wVHg/Imbl96/z/wE+ABmXnHHo/Xz/sR4Bb07jC8jPLGtN/5dgE/mOtcLZorcLKWkgnYK/D438AdM/PysY5qwiLiRErWXN2zMvMbg+yfmV+kZJDVPa+qHzmoCylBy2YgZD29A48fA+5bdaKemOo59xxKBuNve2y2AjiQ/oHHzXQPBEP/azMoQccD6R543Ao8PzOf1OWxuY4NJfB2JXoHHs8AjqtqwC56VXbtHSkZc80p6t18i/L9n0Fn0yPo7AK/2D0DeEePx9ZR/p40g38fomQh7hjlhNVU+NtRPnzolTm9L+Xn3i3wOEy2tSSpRQYfJUmaXe/x0h7bHV+7f9qkgxzDyMwfZOadKFMk3wj8bsBdzwbeRHkzfYPM/OgQ5/x5Zh5HCRp9nM5akL1spmTGPB04PDOn2S37DcBtKTUcv0kJ2sxlM+UN9h0y88GZubHF8bUuIq5JaaRS94nM7DUduJeXULIQZywD3hsRVxn0AFXDmGOrY/25z6Y/AB6YmfefZuA3Mz9FaTz0KODrzO4Q3M0FlOntJwGHVrVbu3kYJbj5VuCnDDad93zKNX39zPynPuP+AXAU8I+U5+IlAxx7F2Uq7SOBm2bmLwfYZ9HIzA2Z+RTKNN4XUj5U+QMlu3QTJbv8XZRMyOMy8xfVrs0SABdPZsTty8xdmflY4EHM/SHRtykfXP3FfOsiZ+bmzHw0pTTAe+j/WgDl2vwW5Xoe5gMPSVKLYpHPjpAkSQOqGojcALg6ZcrcKkpzlIuBPwLfy8y53tgNc77llHpd16JkZe1PCchspLyRP5NSC3J7r2NMU1UL9HqUab9XpmR+LaeM/8+UINBPMnOQIKX6iIjmP6TXyMyza4+voGTW3pByLW2hXEPfW6iBr1qN2KtSxryeEqy+DDiHcv3/bpSpylXTjetRajoeTMk8S2AD5bn8I0qd0qEzv6qarNeqblejZJatrsZ+KfAL4Id9PqTZY0XELyk/txk3rJqnLDkRcS3K9X0oJRt8I+W6/nZmntvieZdR/q5ch5KRux9wOeXv2C+BH/VpGCdJmhKDj5IkSZqquYKP0kJXdTj/UW3VRmDfUQLAkiQtNU67liRJkqT5eX5j+csGHiVJKgw+SpIkSRIQEatH2OepwF80Vr9pLAOSJGkJMPgoSZIkScXLI+JjEXGPqu5rTxFxREScDLym8dC3gS+0NUBJkhabFdMegCRJkiQtEMuBE6vbhoj4FqWW4/mUOo57A4dRmq3cvNq+bgPw0FEaCUmStFQZfJQkSZKk2dYDd6lugzgPuH9m/rq9IUmStPg47VqSJEmSirOArUPusx14F3CzzPzm+IckSdLiFs4IkCRJ0jRFRPMf0mtk5tnTGIsUEeuBuwG3Bm4EHA4cBKwFErgE+DPwQ+BrwCcz83dTGawkSYuAwUdJkiRJkiRJrXDatSRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa1YMe0BSJKk3iLiFOAOtVV3zMxTpjOa6YqIVcAxwDWBw4B1wE7gEuBi4EzgJ5m5Y1pjnISIyPpyZsYc258MPLK26lGZefL4RyZNX0QcD3ylturUzDx+KoPRghAR1wKuC1wN2AdYBWxk99+OnwC/zszsdQxJ0vwYfJSkJaRLkKGXncBllH+8fwl8G/hcZv5fa4OTRhARa4GHAX8B3A5YPccul0fEd4H/Bj6YmX9qeYiSFrkh/3ZuAC4FzgFOB04FPpOZ20c894uAF3Z56NOZecKIx2wG0e6amV8a5Vi1Y34CuG9j9ecy857zOW4bImIZcB/gwcA9gAMG2G1DRJwOfBL4SGae0+IQJWmP47RrSdozLQf2B64B3A14HvD1iPhBRNx7qiNrSUTcOCJeVLudNO0xqbeIWBkRTwd+B7wNuAtzBx4B9qIEKV8P/CEi3hURh7c3Ukl7kOXAfsDhwO2BpwIfA34fEU+PiOVjPNd9IuLWYzzeyCLiYOBeXR66W0RcZdLj6SciHgL8CvgE8FAGCzwCrKfMMng1cHZEfDUi7t7OKCVpz2PwUZJUdwzw6Yh47bQH0oIbU7JLZm4nTXMw6i0ijqRk4/4L/d84bgH+DGzt8fgK4K+AMyPi/411kJK028GU16tTI2L9GI/7ijEeaz4eQfcZc8sor7FTFxEHRcRngfdTPljtJSmzPi6lZLL2cjvgcxHxubENUpL2YE67lqSl7ZfAv3VZv4IS1LkJJaNs78bjT4mIyzLzBS2PT+oQETcDPs/soGMCXwI+U339bWZuqO13KOV6vitlqt2Va/uupv+bUUmq6/W3cybz8SjgzsChjcdvA3w8Iu6ambvGMI7bR8Q9MnPaAbCT+jz2KKYcJK0+sPoi3V/nzwA+DXwZ+DFwUWburPZbARwB3ILyv9D9KL/fuhu0MWZJ2tMYfJSkpe0PmfmWfhtExAGUNw5/03joeRHxscz8fmuj05z2pEYJEXFd4AuUkgB1XwWenpnf6bVvZp4HfBb4bEQ8A/hL4KWUN5Z7tMw8CTN9tYeoGnL1bcI0gEH+dq4CHgf8M7Cm9tCdgIcD757nGGa8PCI+P61mKBFxczoDcDPjmPkZXzsibjOtmtHVB09fBq7eeOhHwHMz81O99q2ak/2qur0/Ip5AyfJ8Nn5gJUlj5bRrSdrDZeZFmfm3lCljdUH3IvjS2EXEXpQaXc3A45spHb57Bh6bMnNHZr6Xkp30Gna/WZakscjMbZn5OkpdwaZnz+PQ59P5mnUT4EHzON58Paqx/GU6u4l322YiqsYyH2Z24PEjwC36BR67ycwtmfl2yt+O5wHbxjJQSZLBR0nSFZ4D/Kax7h5VUEhq20uB6zbWvTEznzDq9MXqjeTTgAcCm+c7QElqysyPUbKu664XEc2A2KB+BXygse6lY25mM5CIWAM8pLH6XdWt7i8iYt1kRtXhKZSp7nX/DfxFZm4Z9aCZuTUzXwbcitn/F0mSRmDwUZIEXDH96B2N1auBBdFtU0tX9Sb97xqrfwY8fRzHz8yPAv8xjmNJUhf/3WXdLedxvBcAO2rL12E62YUn0lkDcSPwUUq24Yba+vXAAyY2KiAi9qVkJ9adCzxuTPU2qcrO3GMcx5KkPZ01HyVJdd/osu7wUQ4UEYdQ3nxdg/LGZAvwg8z84gD7XplSAP5g4EBgE/An4BfA6dOqfTWXqvbUzSnjPojShflPlOYF350pcj9tVTbrbYDrAftS3kSeD/xfZv5+CkN6CrCqse6J88lcaRrlzegkr8OIuAbl2rkysBa4CPgpcFpmLripfxGxkpIVdAPKVPnLgQuAb2Xmr8Z0jqD8/K8DHFatPh84IzN/MI5zjFPV5fjWlN/hQZROun8Cfgd8c5zXc+O8y4GbAkdTrtUVlC7wH8nMP7Vxzi5jWEv53q9LCVZdDpwFfC0z/zzA/uuB46r996F0Iz4H+Epmbmpn1GP1oy7rDh71YJn5q4h4B/C3tdUvjIj3tnUd9dAMeH545vcRER9uPP4oxlfnchCPZnZjsmdk5sXjPMl8r78qI/Q4dr8uLKe8Lvye8je3laz8iDiI8nf+msBelL8pvwO+mpmXjfE8AdyI8jp9EOV/iospfw++nZm/G9e5JC1ymenNmzdv3pbIDTiZUitq5nbKkPsf1dg/Kf/Mz3WeF9UeuytwCrCry7F6jgdYCTwJ+GGX/eq3C4C3AFcb4Ps5fo5jzXU7YoBz7AU8jdJRs9v3PHP7M/C2QcbdOP4pjeMcP8f2J/X6mQOHUGoobuozztOAO03wmp15I1Yfw0+n+Bwa+3U4x/nuDXy7z3kuBV4N7Ffbp2ObAc5xcmOfk4Z83pxde2w9pUHVxX3G/FPggfP4mewFvAj4Q59znAU8AVjWY8ynjHr+EcZ7L0oNvG19xrsZ+CRwyyGPfUSv3zflTf4rgQt7nPP4MX1//a6HQ6rnweYeY9hKyTq+Uo9jX4MSsLq8x/6XA68F9p3HeOe8Fro8R4a6foBrdxn7cwfc90WN/b5erb9yl5/r0wY8ZnMsdxnh935VSgC96zUF3KHx2C7gmhN83jVfoy8AVk3q/AOM736U+phb+7wubAE+M8LrQvN6fVHtsZtQuns3f3czt+2UTN1rz/P7uxZltsz5fb6/mb8HTwBWTvt34s2bt+nenHYtSarr1iE0B9oxYkVEvJnSrfgOPY7Va99bAmcCrwduOMfmB1GyQX4REc8a9BxtiIgTKfW5Xk355L/f93wA8NeUcTenGLcuIu5EeRPwOEpmXS+3BP43Ip47kYGVjJArNdY1p/9PxCSvw4hYHRHvo7xJvHmfTfehBLd/FBE3GvY841Sd/0fAs+icitl0FPChiHhz1RBimHMcA/yE0uzqsD6bXgN4I/CViGhmP01ERBwcEV+mBA+OpwSue9kLOAE4LSLeX2UKzufct6A8n59JycqduIi4QzWGv6V8f92sAh4DfCcirtXY/4GUANIj6OwWXbeGkhn9zSqzfKHap8u6eWXMZeYfgDc0Vj+7yhKdhJPoLNF1DnBqbfmrdNZDDOCR7Q8LIuJwZr9GvzcXQJZ4RBwVEd+lTE+/I7Oz+utWUz68OC0i/rPqoD6fc/8D8B3Kh1q9XntXUBoYfT8i7jbCOVZFxOsppVEezdwZvkdRXqt/EhHXH/Z8kpYOg4/SFETEdSLiARHxxIh4dkQ8NiJOiIjrD/tGTRqzbm/uLhxw37dRAlt1OykZUj2nvEbECZTsgGv02OQSOmtfzVgDvCIi/mMaz5sq4PRRSnZK0y7KuLtNp1oDvC4imt3FW1MFHv+H2VPULqFkXnTzTxFxUovDmnGHLutOmcB5O0zyOqzeYH6E7l1yoWR7bWysuyrwpYg4cpBzjFtE3ICS3Xd446HL6B1keRyldt2g57gx/X8Hl1KyiOpuT7m2ewWvWhER16ZkCd+xxyYbKb/Hbh5CCZo2g+6DnvuGwBeZ/dqziXkGvIYYwy2Y/ZqyizK1s9vz5AjgMzNB1yrw+AFg79o2/f5eHAV8bAH/j9Ttg4HfdFk3rFdSrvsZVwL+YQzHHcRJjeV3Z+YVH0ZW95vTrB9ZTcNtW7e/G1+dwHn7qoJ536SUQehmI52/z7pHAV8c9YOJ6gPDf6XMJpixk97PyXXAJyLiekOc40DgS5TZAd3Kt23rc75rA9+IiGaDIEl7iIX6B1wLUEQsi4ijI+KREfH6iPhmRGyOiKzdjp/2OBeqiFgZEX8fET+iZNZ8mPKJ9suBt1OmY/0E+HNE/JcdhjUl3ZrLnDPAfg9kd+2nDcCLKbXgVmXmAZTAwE1ovFGJiKMob0Cb/2x/ilLkfU1m7k/JHDiK0hG5+Yb+McCze4zrF8Djq1vzTdIva4/1unWtVRYRT6RMPa2/yfoD8Hzg2Or73j8z11GyAv4K+HHjME+PiElkiRxKmWK1mvJG5D8pb9xWV2PcCziS8rNtBiJfExH7tzy+YxvLWynZUBMzgeuw6RWUzJS631OCdYdm5trMXE/JaHsk5W8GlMDDewc8xzjtBXyMUtuR6v49gHWZuW9m7g1chRIUuaSx73Mi4jpznaCqi/ZRZmfxfZGSMbguM/fLzDXA1YEnA+dV29ySkik5EdXf548zO0j6Y0oW3wGZuT4z11KyNx9PqbVWdwvgfSMGav6L3Zl2pwL3p0xL3rv6XRxICRz9cYRjD2Iv4IOU58sOyrTrW1Je9w6kvN7fjhKsrrsO8Mwq2HEyJUiymRJgO4YyLXPm78U9gWZdz1tRMq0Wogc3lndSgtPzkpkXUQJKdU8bNXA9qIi4PeXvQl23eo7vonN2xOHAndoaV82Nu6z77gTO21NEHEv5X37f2urNlOfH8ZTXsPWZuR8l8HcPStZ03e0pWYLDuhvl7xKU4OaL2P2cOpDyt+vmlNeOujXAWwc5QVXn95OU53bdqcDDgKtk5ura+W5EeY9Tb0y0L/Dhqia4pD3NtOd9e1scN0qGxkb61/QYW32hpXajfAL64wF+fvVb1/pI3rz1uzGPulWUT7HPauy/BdhrgPPUa/tcdcDzLaPUSazvvxN49Bz7XZsSEG3WMLrZHPudNOrPpnGcmzG7htN7gfUD/Hzf0thvE3DYHPudMszrbJfvc+b2J+A2c+x7x+pnWd/vyS1fsz9onO+MCT9nJn0d3orZtbi+0O/6oQSO39fj95oDfI/N5+tJc2x/fI9zbQLuN8e+N6C8+a3v928DjPHfu5yvb307yhvZr/UY6yktXjOv73K+t9GnphmlVubnu+z31DnOdUSv3zvwDxN6jvS6Hi4GbttnvxWU4Ep9nwuBr1f3zwau12f/vYHvN/b//gjjnfNa6PIcGfj6oXz41vzZfHqI/V/U2PfrXX4Ozbp6fZ9TXcYzVM1H4J39xtTY9tTGtu+dwDX5iea1OInnQp/x7Mvs/5++Bxw5wL6PogTx6/ved8jrtX7Ouf6neH6X/Y4ZYJyvbuyzGXjoAPtdi/IBWn3fj07z9+XNm7fp3Mx81KBuSvmUTkOKiOMo08iOrq3+HSUI8Q+Ufzr+jvLG6zT6TE+VWvYyZmfyfDYze00dbLoUuGsO3i35RGZPVXtmZv5nv50y85fAXeic0rwCmFSNwlfRWcPpg8AjMnNDj+0ByMwdlAyoT9dWr6XUM2vbDkrQ6P/6bZSZX6E0pKl7YGujKpp1/ebsjDtmJzLZ6/D5dM48+TlwYr/rJzO3UjIgpz2t8DGZ+bF+G2TmjymZz3V9r6GqZuPfNFb/a2b+2xznupRSL+3sftuNU1V3sDnWTwF/m5nbe+1X/X5PZHZW77MiYvUIQ/nXzHz1CPuN00Mz8+u9Hqxe855A5/81B1I68G4FTsjMn/fZfyOzXx9vPK3SA01V7bunMDsbeRuDZ0HPqfo5vKyx+gkRcbVxnaMuIvZm9nP2XX12ObmxfP+I2Hesg5qtWXLg4pbPN5en0Pn/05mUgO+v59oxM9/J7OvlOSOM4ffA3TNzroznlzE7q3iu1+hr0vlcTOBBmfn+uQaVmb+ivE7XO2yfWM04kLQHMfioUWylFDN+C9OZ/rVoVMXVP8/u6VGXUQqzH5GZj8/Mf8vMkzPzDZn51Mw8jvIP1XMp/7xKrYuI/SPiLcAzGg8ls4MI/bw0M88dYvvmm8ofAK8ZZMcq8PNPjdX3jYheteLGoqpzVq/xdjHwxMzMQfavtnsanW/G/3oCdcze2S9I0PC2xvKxLY+v2TyhVz2stkzsOoyIIyhT7er+LjO71QZtnmsmeD2tD6i+mJkfGHDbd9JZ8+tqc0yzO4nOmo1/ZMBp1FVQ72kDjmscHk/nhw+XA08Y5DWg+iCnWRf3EOAvhxzDnxiilmZLPpmZn51ro8w8h5Lp2PTmzPzRAPt/FfhtY/XNBhvivFw5Ih7X5fbEiHhORLynGtdrKZnJM3ZSgvRzfm9DegudJVBW016pgQfRWYtzC6V0Ry8forPW6F7MnoY+bns3li9p+Xw9VTUam03kHp+ZwwREX0PJnJxxy6oG7jCemZlz1ujOzF2UDvR1/ZqeATydzlqS78nM5pTxfuf8NSXJYkZQ3g9J2oMYfNSg3k35pP+mlKlht8jMxwP/O91hLVxVHad3sPuN9Qbgbpn5tuoPf1eZeX5mvjwzL+u1jTSEfm+gnh8RH6W8ger2T+DLMvOMAc+zndnZDz1FxD7AbRurX5+ZOwc9BiVDr16jcBmzAzvj9rDG8vsyc6hMvSpgVa9NdQClNlObmtmMPVWZa/XXn3VAKxk2lWbW10QaZsBUrsP70Pm/15mZ+aVBT5SZP6Vk0k/DMNfQxZROqHX9mhrcvbF88iAB2ZpPAsN88DEf92osf2SIbG8y85vAt+Y45lzeM0RGelvePsS23+myrhn86KdZy2/gBhnzcG3KNd+8vYGSOfZwSuC47kxK9v/YP5jP0sX5RY3VJ0XEdcd9LnbXb57x8SrLuKsqM/Ojcxxj3Ob9dyMift+oW9/vdkqfQ92dUo93xo+rWQQDqz5c+nBj9fFDHOIi+geIm77RWO75nKo+fGx+QPK6Ic41o1lv8vgRjiFpETP4qIFk5gsy8+2ZeXq/aUXjEMVNI+IREfEPEfH06v7Rc++9oDycUjh6xjMzs/mGQ2pbvzdQLwHux+wMAoB/z8znD3GeHw4ZhLsVnX+DktlvXvrKzEsozSjq2u6ieIfG8udGPM73GsvHjXicQVzC7ClWc/lNY3m/sYyku2b34kmW+Jj0ddj8PX98mHNVhhrfGJ065PZnNZb367ZR9UHdLRqr/2eYE1XB4s8Ps88oqgynGzdWf2SEQzWDBMO+bg0V2GhB0j2bsZdm5uJFlPrAo+6/3xD7TkJSyikcPWzQaUjvoTOov5zdTUbGopqx02wo0m/Kda9tbtXytNrm7KBploZaCP8XfL0KYA5qoNfnyjHsbjYGcGFmNsc6p8z8GZ2N2m5YTfGXtIdYMe0BSDMiYj3wTOCxzP40eWabXwIvzMzmp2cL0ZNq93/FgN3kpCn7EfDsYabT1PYbxg0by78ecorSjO9SOuHOaC2DsAo8NMd98xHrbjWn5TbrHo7Tb/tlW/fQrD+4T9etxmMjnVNu264VVjfp67BZW3LoN3Aj7jNfl2XpujuMQa+hw+h845uUBkDD+v4I+wzrKGb/79wtq28uzUy+q0TEAUP8jMc9pXdYl1ZB90E1s9J+O2ipisrGxnKbr0ejCEqJkjXA89o6SWbujIjn0RnwfmBEHJuZp4/pNCc1lv/I7A9XuvkyJUh89caxnjmWUc3WvCYm+Xej6VaN5atERLO8wiCawdph/i84e8hzDfM3vvn9bRzx+4MSNN6rur8MOJjZv0tJS5TBRy0IEXErSgZIv5pQULK43h8R9wMe1nYW5qgi4hg6MzneMcKbf6lNOylTay8Bfgl8G/jcXA1J+hi2SciBjeVmpt2gmsXcm8cdp4OZPWNgXDW32hz3JSPs05x2vLzbRhFxR2CYaX+fzMw/NNb9kc4pawcMcbz5mvR12Fx/9gjnGmWf+bpkhH0GuobozKgB2FBN4xzWXE0WxqH5+9s+ZJ3bGd2aUBxIyQgcxKSbMjUNWxameS3Md/9e19I4nZqZx9dXVFm6ewNHAnej1IudaXyyDHhuRKzKzGb95LHJzI9GxHfYXaMvgJczhpIj1fTav2qsfu8gZSgyM6s6mPWGW4+IiOcMWcZiUH+glIKaMcrfjWfRfeYHlAZfzaBbL4c2lh9S3eZrmP8LLhnmwFUgu76q32zI5vd3BEOU4ZjDgczOwpS0RBl81NRVb14/Ten6OuPMat2vKUXrrwv8Bbvrjj2Ikh3RdkHrUd2tsTzUFDJpjGa9gWrJsMGCZsBh1CYjzf3aDFy1GSBcO/cmIxsmw2hYj6xug/o55U1j3a/pzEA8KiJWTujDpUlfh83zjVLbd9INeaDda2i/xnLfrvF9TKJOclvXCwzx2jVicHac5ns9tHk9tabK1txAycw9IyLeTCmDcJfaZv8YEd/NzGHq7w3rOXRmI949Iu6QmcOWRmi6C7Pr+w4y5XrGyXQGHw+jBEWHnUkxiFnThiPiKsN8GNCvNmeVFDFo8LGt/w2G+b+gzefUYv3fR9ICY81HTVVEHEwpQDzzx2cL8BjgqMx8ema+uao1+XRKALI+dfkvIuIRkx3xwOpZjxuAHwNExLER8YaI+ElEXBYRGyPiNxHxsYj4m4jYq/vhpCUnGsvj+se5zX/AV829yciaP489SXO64GpmT4duy7Svw0UZhBmzZs3PUZ9nbT4/Z7R1vYz7WJqAqtP6/Zldv/JNEXGlLruM67xfYnbjqVeM4dCP7rLux4M2ZaHMomhqq/HMGV3WzdWxuS1tvfYslP8L/N9H0lgYfNS0vZLdU613AffLzP/sVgsoMy/PzMfRWevmpdU0kYXmJrX7vwTWRMQbKXWenghcn9IFex1l+sKJlMDqWRFx/4mOVJqO5vTC/UY8TrPO0yj1+gbVnOqYwLrMjDHcTmpx3Atdt2ydO07o3JO+DpvrR6lTNs3aZm2Y9TOJxnzAAe03hrHMpXm9jPq76LZfm69dakkVgHwU5X/YGQcy5kYwXTy7sXxcRJzQdcsBRMR+wP+b14i6OyEi2sic+2qXdbfvsm4Smv8b3HNM/xccMY1vpovm9/fBMX1/kZmnTOMbkjQdCzFooz1ERBwKPKy26j8yc5AOcU8GZqbjHQ7ca9xjG4ODavcvAD4EPIHdn/BtA37P7BothwIfjointD1Aacqa/8weMeJxrjnHccfpT43l6HL+PUpmnjSGNxrfYPbv7THtjx66nPeIEY8z6HU4jvONss9Cdj6dWX+rGO151WZn3RnN39+qiLhy1y376/b9TbuOo0aUmd+mdKKue2zVObrNc368sfpl8/hA/qF0Nv4al1V0/q8/Fpl5FrMzTh8eEZPIgG5q/m/Q2u99Spb69ydpQgw+apoeSGcq/2sG2alqVvCl2qq7jnNQ81VlbKyvrbozuwOkZwL3A/bJzKtl5v7A9YB31g8B/FtE3HkS45WmpNmt9VpV5sWwbtZY/sFow5lb1QX5nMbq49s6354iM3cw+437UVU94LZN+jpsrr9p1636G2WfBauqX/jzxupBa63Nd59h/YxSh7qu+bsfRHOf34/QTVwLywspHyzPWAE8v+VzPo/OjMsbMnqjk+b06C8Ajx/x9ok5jj0u72wsH0SpCT9p328sHz+FMbSp+f3daMS/k5L2cDac0TTdrnb/rMxsvvno59vAPav7t+y1UURcdZSBDejSarpN0zo6A/srq6+nA3fKzI5C85l5JvDoiPgZ8Kpq9TLgtRFxTLcp6NIScBrlTdPMcyUogfnmm4meImJfZn/48I0+uzSDBqN0TP0i8Nja8oOBN4xwHHX6d0p2eP0DqTdGxLGZuWUcJ4iIZZm5q7F60tfhNykZRjNOpHRcHcZSLM3xf3RmLj4MeN+gO1fZh8ePeUyzZObmiDiDzuDhA4BPDnmov2gs93vd0iKQmedExLuAv66tflhEvDQzf9XSOX8SEe+ls0P1SyJiqGY3EXEDZgfE/zkzm3UlBz3ed+icwn3jiLhxZp4xyvH6eAelwc1+tXWviojPTjiY/0U6G6/dIyL2bf6/v4h9A9hEeX8DJX7wAMrPX5IGZuajpulGtfs/GXLf82v3+wUYf9fi7Yk9znl5l3W7gIf3+0ckM/+FzgLiN6Czg6K0ZGTmZcDXGqufNOSUsccB9SZNu4B+pRuaHxaMUq+t+abuthFx9xGOo5rMPBt4Y2P1UcC/jOP4EXE/OoPGM+ed9HX4aTozla4bEQO/zkfE9YE7DTG2xaIZaLx71W12UC9gtA8TRtHs3PvAYaZeR8QtmZ2l2UY3YE3ey9ldFgjKNdl29mMz4/KadHmtm0MzM/E84JRRB5SZ32N285mxZz9WsxFe1lh9ZeBtE64H/1mg3oF+HcN/qLRgZeY2ZmezPj8iVk9jPJIWL4OPmqZ6AeoTBu2mV3XUe1Nt3/0nPO6+MnMnpWt33Rcy82cD7P7axvKCmlIujdnrGsvHUmq6zikijmT2m7pPZOZv+uz2x8bytYatD5WZX2R2ltI7I+LqwxynbsTmGkvR85j9hvVJEfHGUd9IRsSaiHg1pVHZ2h6bTew6rIKszcDk6yKi19jq51oBvJkl+L9bVQu0/iHkMuDkiDh4rn2rwPLftDS0bt5CZ7BnLSVLd87ncUSsqfavOx/4wPiGp2mpnt/vbqx+WMu1H88G3tZYPXDAs3pdeXhj9Qe7ZIkP678ayw9rqR7jaygZ7HUPAD4SEXt12X7sqizL5t+Rf4yIe3bbfhAL8P+ClwI7a8uHM8/MxwX4PUpq2ZL7B1aLyn5jOs6cb9qm4LLG8lcG3O9UOgvvHzue4UgL0seZXQPvXyPiEf12qgI+X2L3FCAoU6qbGRBNP6Jz6vVeDJ8hAvAPdGa3HAb8X0QM1WkzIq5ZBcbePsIYlpzM3EyZqtfMEH8C8JWIGLi2XkSsiIiHURoSPI3dzb66+TiTvQ7/ic7sx6OAj0fE3n3OtQp4F9Pr5joJT6bz7991ga9GxK27bVz9jp8OfJDy+x3L9Py5ZOZ5zA72nEgJQPYsZ1T9fj8K3Ljx0CuqzCItDS+n8+/McsoHK236J8q02BmHDbHvvYFmkL8ZOBxF8xgHAiN34+6l+sD/AZQmjnUnAt+OiPsOe8yIuBFwkyF3+1fgt7Xl5cBHI+LxQ557v4j4e+BbQ56/VVVprDc1Vj8sIj4aEQcMepyIWBYR94iI/2F3+SxJewiDj5qmzbX7FwO/nsetqyG7sA57e2Wf7605pt923Wr2eC+rfhYzDuq1rbTYVZkVD6HztWA58O6I+FhE3GVmWk8U142IFwM/ZHbH3xdWU736ne9y4PON1W+MiC9ExIsj4kkR8bjGbX2X45xGKapfd1Xg1Ij4fEQ8NCIOr3+qX/3DfdWIuFdEvCgiTqe8TjyN8X0Qs+hVGeJ3p/N1EErQ7dvVz/fJEXH9ZrAuIg6u3tS8GjgbeC9wjQHOOenr8JvA6xur7wr8NCL+pp7tFxH7V0HQH7C7VmQzy2dJqOrLNX8u16UE9r8dEa+MiKdExLMi4j8o5U/+hVJXeQfwkuYhWxzuM5jdaffxwHer5/9+Mysj4pCI+BtKZmfzzfYXmJ0xpUWs6sL83sbqh7ec/Xg+pW7uKJrToX+TmfMOfFXBqjPmONdYVI0o78TshnA3AD4REWdExEsj4viIuFJEdJRoiIgDIuLWEfEPEXFqNe5jhhzDxcB96QwCrwHeFBE/joi/i4gbdjn3gRFxh4h4akR8AbgA+DfK9PmF5mnMTqa4H/CbiHhtRNw5IvapP1jNPjgmIh5WvW7/kTJN/Z4Yh5D2ODac0TRdCMz8kfpQZv7tNAczZj8BjqstD5ORUd92zXiGIy1MmfmziPhLSvZSfYrUidWNiLiEkl22ku7eAfT7MKDuZZTgVv3v313pXeLgc8yuFUlmvqOa0vVvjXHdrboB7IyIS6vH96Z/9p0qmfmtKPX+PkTnG8Cg8+dLRGyh1Nram/6vl5uAX/Q556Svw2cC16NcizOuBrwVeGtEbKZMcWsGvy+kTJFspYHFAvD3lN/loxvrb17dutlFmXZ9dmN9a5mQmXl5RJxI+TCjHuC+EVX9yojYQAli95qd8R3gYTaVW5JeBjyC3XVIZ7IfT2rxnK+iBMAHLkVUfdBxr8bqcWQ91o9149ryPSLisMxslkCZt8z8ZZR6qu9h9t/zG1W3mQzUrF7Pg/J6M9f74a9QZjzMNYYfRMS9KLWhD6k9dDS7P2TIiLiM8rq1D5OrVTtvmbkjIu4PvJ/OD1L2AZ5S3Yb5uyxpD+MnDpqmenfro6c2inac0VgeaEpClSlV/8fxz+MakLRQZeangDsCveo17kf3gM8W4NmZ+dhB61NVWWePYHZphKFl5huAOzA7A2rGcspzfz29A4/bGb7h1pKXmb+gBJueDVzSZ9M1wJXo/QZnK6VO4rUy83/mOOckr8OtlKBmr660a5kdePw9cNfM7Jntv9hVP7/HUhq6XTLALn8ATsjMd7L7w8wZg+w/ssz8JeVDxl5lVdbTO/D4X8DxmXlhG2PTdFXdrd/fWN129uOlwD8PudvDmf2aNs7g4wfozEBeTvn724oqA/TulO7fZ/fZdOZ/7f3oHXhM4OvA/8vMO2Xm9wccw1cpU7Z7/b0JSrO7/ekfeBzofJOWmZdQpuo/h84mO3Vz/V2G0tRo7EFoSQubwUdNU/0f9ltFxJWmNpLx+1Rj+cYD7nddOrNuzhrLaKQFrprmdT1K3bcfz7H5nyg1164zR/mDXuf6AHAk8CTgY5SMuEvorOM46LG+SZnadX/gi3RO3e3lUkrnyCcAV87MFw573j1BZm6rfr9Xo3SU/jKdjT562Uzp1Po44NDMfEJVp2+Qc07yOtySmQ+m1Ln8bp9NL6M0VbhhZp4x7HkWmyzeBFyL8hz5AiWQsIUSTD6H8vx5DHBkLajcrFvXnLrfxljPz8w7AfehXHP9XkMup3Q7Py4zH1rVONXS9U90NuiYRO3H11EC8oNqToP+SWbO9bo3sMz8LbMbtLUy9bp2zszM91BeP+5HCYAO+lqwkRJwfAHlteV2mfnJEcbwx8y8N3BTSib0nwbYbRvw1erc183MBdtwsvoZvwK4OmW8P2GwMhe/pDTcuidw1blKlEhaesLZHpqPiDgJeGdt1R2zdK0cZN8jKH+IZj51fFVmPnOc45umiPgWcItq8VzgiMzc0WcXIuKFwItqqx6Tmf/ZzgilhSsirkJ5/hxMKVS/ifIP/JnA6Qt1qmKUxiA3o3SCPJCSWbGFMnX7d5Tx/2bQDDl1qmovHkMJHh9KmQa9k/Lm8mJKRv1PsjQhGMf5JnYdRsQ1qnNdmfIh1MWUrNpvpg1J5hQRb6ezgdTfVdnJkxzDeuA2lN/hQZRr80+Uus/fzMyJNMWRtFs1q+jalA/4r0bJSl5JCTZeQnmt/SXw8zb+Nlfnv351O6C67aL8X3A+5QPQX1QZ8YtSlUBS/1u5ht0/319RfrbO5pL2cAYfNS/zCT5W+7+b3VMwdgD3zswvDLF/ACsX4huziHgQnVPqnpmZr+qz/dUo3Xj3rVZdRglYtp69IUnSYlUF/X9DCfrNuGVmfntKQ5IkSVKN0641bc9gd82PFcCnqm5zfQsUR8RhEfF3lCyXY1se40gy80PAN2urXh4RXZvqVHWAvsjuwCPAqw08SpI0p8fQGXi8kNm1lyVJkjQlZj5qIFV3s25Ze+vprLP0B0pdo6ZnZOZHexz7OEpH2Xqx+AspXSTPAC6i1MrZD7gOJdh4E3Y3cDguM08b8FuZqGpq+Tcp0wNnfJ9Sr+p3lE5wt6LUi1td2+Z/gbuPa+qgJEkLXUSsGnYmQ/U/xP/SWS/5lZn57LEOTpIkSSMz+KiBdJlePaxHZebJfY5/FPBxSnBxWLfIzO+MOK7WRcSNKN/bEQPu8lHgrzJzU1tjkiRpoYmIE4HnAm8APtkv+z8i9qF0xX4RsKr20KXA0Zl5bnsjlSRJ0jBWzL2J1L7M/FlE3AB4NKXL6PXn2OWnwGeB9y707p+Z+YOIuCHlDdIjgV5dvX8MvAz44EJtpiFJUstuBpwM7IiI7wI/pHS4vowyQ+BAygyI21IaDjX9jYFHSZKkhcXMRy1IVYfRWwGHAPsD2yjd6H4N/Dgz/zTF4Y0sIlZQOmFek/K9baV0uvtmZv5mmmOTJGmaqszHj424+3ZKh+u3jm9EkiRJGgeDj5Wqa/KRwA2Aq1HqD26m1Bv8AfCjSdffi4hlwK2rcR1GmUp0LvA1G5FIkqSlpKrf+ClKduMwvgI8Z6HWf5YkSdrT7dHBx4hYD5wA3Be4E3BQn80vptQ8/NfM/GOf7cYxrhXAM4En0Nm9ccY2yj/nT8/Ms9sciyRJ0qRU/wPdHrgdcFPgGpT/hdZRygVdSvlg+FfA14DPZebp0xmtJEmSBrHHBh+rwOMFwJohd70IeGxmjjotqK+IOAT4NKXm0VwuozQm+UQbY5EkSZIkSZLmY08OPu5HyWasOws4FTgTuJASmLwh8AA6m4TsBB407gBkROxFmTp0y9rqc4H3UmodHgjck5IRMGMLcKfM/OY4xyJJkiRJkiTNl8HHkj34TuA/M/OHPbZdC7wW+Ova6ouB62TmhWMc078AT6+t+jDw8Mzc2tjuoZROkCurVb+rxrJlTOM4D1hbHVeSJEmSJEl7rqsBmzPz0FF23pODj3sDzwX+JTMvGnCf9wEPra16YWa+ZEzjuSrwS3ZPA/8hcLPM3N5j+2cBr6itenpmvnpMY7ls9erV64888shxHE6SJEmSJEmL1K9//Wu2bt26ITP3GWX/PTb4OIqIuDLweyCqVd/JzFuM6dgvA55TW3WPzPx8n+1XAGcDV6lW/T4zrzamsfzk+te//vV/8pOfjONwkiRJkiRJWqSOPvpofvrTn/40M48eZf9l4x7QUpaZfwB+Vls1ztTA+9XunwN8YY6x7KBMF59x1YgYpEmNJEmSJEmSNBEGH4e3sXZ/3TgOGBHXAI6qrfpSDpaS+sXG8n3GMR5JkiRJkiRpHAw+Du+I2v3zxnTMGzWWTxtwv28DO2rLx4xnOJIkSZIkSdL8GXwcQkTcFji4tuqbYzr0UY3lXw2yU9Xd+g+1Vdcf03gkSZIkSZKkeTP4OJxnNJb/e0zHvWZj+bdD7FvftnkcSZIkSZIkaWoMPg4oIh4CnFBbdQbwiTEdvtmq/KIh9r24dn9lRKwew3gkSZIkSZKkeVsx7QEsBhFxNPC22qodwF9n5q4xnWLvxvKWIfa9vMuxtg6yY0T8pMdD4+ziLUmSJEmSpD2UmY9ziIjDgM/QGSB8VmZ+d4ynWdNY3jbEvs1A417zHIskSZIkSZI0FmY+9hERBwCfBw6vrX5bZr56zKdqZjqu6rKul+Y062YmZE+ZeXS39VVGpM1rJEmSJEmSNC9mPvYQEfsAnwNuWFv9PuDxLZxuY2O5mQnZTzPTsXksSZIkSZIkaSoMPnYREXsDnwVuXlv9YeCRY6zzWHdZY3n/Ifbdr3Z/e2YOVO9RkiRJkiRJapvBx4aIWEup8Xjr2upPAg/NzJ0tnfY3jeWrD7FvfUr4WWMYiyRJkiRJkjQWBh9rImIv4FPA7WurPws8KDO3t3jqnzaWrzXIThGxBrhyn+NIkiRJkiRJU2PwsRIRq4GPA3eqrf4ScP/MHKb79Ch+0Fg+bsD9bkFn06AfjWc4kiRJkiRJ0vwZfAQiYhXwEeButdVfAe6bmYN2nR5ZZv4G+Hlt1V0iIgbY9a6N5U+Pb1SSJEmSJEnS/OzxwceIWAF8ALh3bfXXgBMy8/IJDuVjtfuH0xkInaUa96Nqq84FvtvCuCRJkiRJkqSR7NHBx4hYDrwXuF9t9TeAe2Xmpnke+4iIyNrtlDl2eTNQ71T9qohY2Wf7pwNXqS2/NjNzxOFKkiRJkiRJY7fHBh+rac3vAB5cW30acI/M3Djp8WTm74A31lYdA7yvqkXZISIeAry4tupc4A3tjlCSJEmSJEkazoq5N1mybgs8srHu6sD3Byu3eIU7ZOa5YxrT8ymdtm9WLT8IuHVEvAc4C9gfuBdwh9o+W4G/nERtSkmSJEmSJGkYe3LwcXmXdVce4Tj9pkYPJTM3R8QJwGeAY6vVVwGe1WOXDcAjM/Pr4xqDJEmSJEmSNC577LTrhSozzwNuBbwAOK/HZtsoDWpulJkf67GNJEmSJEmSNFV7bOZjZp4CDDW/esjjnz3q8TNzO/DSiHg5cGvgWsAhlEzH3wNfy8yLxjRUSZIkSZIkqRV7bPBxMcjMncDXqpskSZIkSZK0qDjtWpIkSZIkSVIrDD5KkiRJkiRJaoXTrqUWvf9bvx3r8R56y6uP9XiSJEmSJEltMvNRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRV7fPAxIpZFxNER8ciIeH1EfDMiNkdE1m7HtzyG4xvnG+Z2szbHJkmSJEmSJI1qxbQHME0R8RHg7sC6aY9FkiRJkiRJWmr26OAjcFMWZuDxHGDHgNtuaXMgkiRJkiRJ0qj29OBj3Vbgh8D3gL2Bh09xLMdn5tlTPL8kSZIkSZI0b3t68PHdwO8oAccfZeZ2gIg4iekGHyVJkiRJkqRFb48OPmbmC6Y9BkmSJEmSJGmp2uO7XUuSJEmSJElqh8FHSZIkSZIkSa0w+ChJkiRJkiSpFXt0zccF7OURcX3gcGAdcAlwHvBN4PPAJzJz5/SGJ0mSJEmSJM3N4OPC9JDG8kHV7YbA3wBnRcTTMvMTEx+ZJEmSJEmSNCCDjwvXxcBllMzHA+icIn9N4OMR8fLMfO6oJ4iIn/R46MhRjylJkiRJkiTNsObjwvFn4PXAPYADM/OAzDwiMw+iBB/vD/xfY5/nRMRTJjxOSZIkSZIkaSBmPi4M3wOumplbuj2YmZcCH4uIjwPPBV5ae/ifI+Kjmfm7YU+amUd3W19lRF5/2ONJkiRJkiRJdWY+LgCZuaFX4LGxXWbmPwFvqa1eDTyjtcFJkiRJkiRJIzL4uDg9D7i8tnzCtAYiSZIkSZIk9WLwcRHKzD8Dp9ZWHR4Rh01rPJIkSZIkSVI3Bh8XrzMbywdPZRSSJEmSJElSDwYfF6/LG8trpzIKSZIkSZIkqQeDj4vXIY3lC6cyCkmSJEmSJKkHg4+L1+1q97cD505rIJIkSZIkSVI3Bh8XoYi4J3Ct2qr/y8zN0xqPJEmSJEmS1I3BxxZExBERkbXbKX223WvIYx8GvLWx+uThRylJkiRJkiS1y+Dj9D04Ik6NiPtGxKp+G0bEXYBvAVerrf4B8J42ByhJkiRJkiSNYsW0BzBNEXF/4FVdHlrfWH5fRDS7SwM8IzM/Ooah3L66XRIR/wf8EPgjsIHSxfoawF2BGzX2Ow84MTN3jWEMkiRJkiRJ0ljt0cFHYB/gyAG2u3Kf/cdpP+De1W0upwEPz8yzxzwGSZIkSZIkaSycdj193wXeCfwMyDm2TeAbwMOB22bmr1semyRJkiRJkjSyPTrzMTNPpoVmLVU2Ygy47Y+BRwNExH7ATYCrA1cC9gK2ApcAZwPfzsxLxz1eSZIkSZIkqQ17dPBxocnMS4CvTHsckiRJkiRJ0jg47VqSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkot+945F/GpH/yBizdvm/ZQJEmSJEmSJmrFtAcgLWV/uORyPnL6uQBcsnkbjzjuiOkOSJIkSZIkaYLMfJRadO4ll19x/6wLN7Erc4qjkSRJkiRJmiyDj1KLLtuy/Yr7W3fs4pLN2/tsLUmSJEmStLQYfJRatGHLjo7lP156eY8tJUmSJEmSlh6Dj1KLNlzemen4x0u3TGkkkiRJkiRJk2fwUWrRZbMyHw0+SpIkSZKkPYfBR6lFG7Z0Zj6e57RrSZIkSZK0BzH4KLVk566cVfPx4s3buXzbzimNSJIkSZIkabIMPkot+fOmrWSX9edd5tRrSZIkSZK0ZzD4KLXkgsu2dl1vx2tJkiRJkrSnMPgoteT8HhmO59l0RpIkSZIk7SEMPkotOb9n5qPBR0mSJEmStGcw+Ci1pJ75eNDeqzvW79zVrRqkJEmSJEnS0mLwUWrJBRt2Zz4eefDeLItyf8eu5MKN3bMiJUmSJEmSlhKDj1JLLqhlPh6wdiVXqmU/OvVakiRJkiTtCQw+Si05f8PuAOP6vVZy2L5rrlg+z47XkiRJkiRpD2DwUWpJveHMPmtWcti+e12xbOajJEmSJEnaExh8lFqwY+cu/ryxHnxc0ZH5aPBRkiRJkiTtCVZMewDSUvTnTduoN7Rev2Ylq1bsjvVv3LqDDVu2s37NyimMTpIkSZIkaTLMfJRacH6t2czqFctYtWIZ69esZO/Vu+P955n9KEmSJEmSljiDj1ILmvUeZzj1WpIkSZIk7UkMPkotqGc+rt9rd7ZjZ/DRjteSJEmSJGlpM/goteCCDd0zHw+147UkSZIkSdqDGHyUWnBBLfNxnzXdMx8v3LiV7Tt3TXRckiRJkiRJk2TwUWpBx7TrWubjlfZezYplAcCuhAtqtSElSZIkSZKWGoOPUgvqDWfW1zIfly8LDtnHuo+SJEmSJGnPYPBRasEFG+rTrld2PHZovenMZdZ9lCRJkiRJS5fBR2nMtu/cxZ83bbtieZ+9OoOPHR2vLzH4KEmSJEmSli6Dj9KYXbhxK5m7l+vTrgEOq3W8Pu+yy8n6xpIkSZIkSUuIwUdpzOr1HvdauZyVyzufZofWaj5u2b6Ly7bsmNjYJEmSJEmSJsngozRmnZ2uV8x6fK9Vy1lVC0hu3mbwUZIkSZIkLU0GH6Uxu+Cy3s1mZuy1avkV9y/ftrP1MUmSJEmSJE2DwUdpzC7YsHvadbfMRyjTsWdsNvgoSZIkSZKWKIOP0pjVp103O13PqGc+btlu8FGSJEmSJC1NBh+lMas3nDHzUZIkSZIk7ckMPkpj1tlwZoCaj2Y+SpIkSZKkJcrgozRm9ZqP+/TIfFy70oYzkiRJkiRp6TP4KI3Rth27uGjTtiuWB+p2beajJEmSJElaogw+SmP0p41bO5Z71nw0+ChJkiRJkvYABh+lMarXe9x/7UpWLO/+FNvLadeSJEmSJGkPYPBRGqMLasHHg9ev6bmdmY+SJEmSJGlPYPBRGqPzL9s97frgfVb33K6e+bh5245WxyRJkiRJkjQtBh+lMbpgw+7Mx0P26Z35uHbV7lqQW7fvYldmq+OSJEmSJEmahokEHyNi30mcR5q2eubjIQNmPiawxanXkiRJkiRpCZpU5uMfIuLdEXH7CZ1Pmop6w5l+mY+rVy4jass2nZEkSZIkSUvRpIKPewEPA74SEWdGxNMj4qAJnVuamAvqNR/7NJxZFsGalTadkSRJkiRJS9ukaz4GcG3gn4HfR8SHIuIeEx6D1JrzazUf+zWcgUbHazMfJUmSJEnSEjSp4OMrgT821q0E7g98JiLOjogXRMTVJjQeaey27tjJJZu3X7Hcb9o1dNZ9NPNRkiRJkiQtRRMJPmbmc4CrAycCnwRmIi1R3a4OvBA4KyL+JyJOjIjl3Y4lLVT1KdcAB+09eObjZjMfJUmSJEnSEjSxadeZuSszP5mZJ1KCjc8FftXYbDlwd+AjlGnZr4iIa09qjNJ8XFCbcn3gulWsWtH/6VXPfLTbtSRJkiRJWoomXfMRgMw8LzNfkZnXAe4IvB+YSRubyYY8BHgG8POI+EpEPDQi+qeSSVN0fr3ZzBxTrsHMR0mSJEmStPRNJfhYl5mnZubDgcOAJwNnNDYJ4PbAe4A/RMS/R8Qxkx2lNLfzL6s1m1k/d5x8rTUfJUmSJEnSEjf14OOMzLw0M9+QmccCNwPeClxWPTyTDbk/8CTg+xHxrYh4TESsm86IpU4XbNid+XjIHJ2uwW7XkiRJkiRp6Vswwce6zDw9Mx9PyYY8CTgfyOo2E4i8GfA24NyIeJ2dsjVtf6oFHw9eP8C0azMfJUmSJEnSErcgg48AEXEo8FTgecDBtYdyZpPq6z7AE4FfRMQ/RcTKiQ1Sqtm4ZccV9/fZa8Wc25v5KEmSJEmSlrq5IyQTFBHLgHsDjwXuSel+fcXD1dc/Au8DbgjctbZ+NfBs4OYRcc/M3DWRQUuVTdt2Bx/Xrhog+GjmoyRJkiRJWuIWRPAxIq4FPBp4JHDozOraJruAL1DqQH4qM3dW+10NeBzweGC/ap+7AE8A3jCJsUsz6h2r161e3mfLorPb9Y4+W0qSJEmSJC1OU5t2HRGrI+LhEfEV4EzgmZQajzM1HaFkOb4MODIz75mZH58JPAJk5u8y87nAtYEv1Q7/iIl8E1LNpq27A4jrhsx83L4z2bHLZF1JkiRJkrS0TDzzMSJuQplW/RBg35nVtU12AV+kkeXYT2b+OSJOAs6hTNU+apxjlgZRn3a9bvXcT63m1OzLt+1k/ZoFW4ZVkiRJkiRpaBMJPkbEvsDDKEHHG82sZnf3aihZju8E3p6Z5wx7jsz8Q0ScA1wTWDfvQUtD2rx1d5x87aq5p12vXB4sj2Bnlh5KJfhovyRJkiRJkrR0TCrz8Y+UhjCwO+hI9fULwNuATw6S5TiHDfPcXxrZxtq0670HyHyMCPZatfyK/Ww6I0mSJEmSlppJBR/X0JnleB7zyHLs44+UxjPSRO3YuYutO3bXbFw7QPARSt3HK4KP2ww+SpIkSZKkpWWSNR/HneU4+wSZ9xr3MaVBbG5kLa4bYNo1dHa8NvNRkiRJkiQtNZMKPr6c8Wc5SgtGvdM1zG4m00u94/VmMx8lSZIkSdISM5HgY2Y+bxLnkaZlU63ZzKrly1i1YrCu1WY+SpIkSZKkpWxS3a5vX929PDO/M4/jHAvsDZCZXx3H2KRx2Lxtd+bj2tWDTbmGRvDRzEdJkiRJkrTETGra9SmUmo+/Aq47j+O8AzimOtYk61VKfdU7Xa8bcMo1dE67NvNRkiRJkiQtNZMM4AW7u13P9zjSgrK5Nu163RCZj2vNfJQkSZIkSUvYYIXpxiMneC5pojbVp12b+ShJkiRJkgRMNvg4DjORmh19t5ImbNOImY/WfJQkSZIkSUvZYgs+HlZ93TjVUUgN9YYzo9Z83GzmoyRJkiRJWmIWTfAxIu4IHEiZvn3OlIcjdejMfBwi+FjLfNyybSeZVieQJEmSJElLx1gbzkTEMcCN+2yyPiL+aohDLgP2BW4APLi2/rThRye1p7Pm4xDTrmuZjzsz2bZzF6tXDL6/JEmSJEnSQjbubtf3A17Q47EADgbeOeKxZ7pcJ/CfIx5DasWmrbuDj3uPmPkIpe6jwUdJkiRJkrRUtDHtOubeZOjj1QOPL8zM7475HNK8bK41ixmm2/WKZctYtXz309CO15IkSZIkaSkZd+bjjF4ByGEDkzsozWXOpky1fmdmfmce45JasbGW+ThMt2so2Y/bLt8F2PFakiRJkiQtLWMNPmbmi4EXN9dHxC5K1uKvM/M64zyntBB0dLseYto1lLqPl16+HTDzUZIkSZIkLS2T7HY97unY0oJR73Y9TMMZ6Kz7aOajJEmSJElaStqadt00kw150YTOJ01UveHMuiFqPkJnx2szHyVJkiRJ0lIykeBjNR1bWrLqDWeGnnZdy3zcbOajJEmSJElaQiY57VpasjZtG73hzFozHyVJkiRJ0hJl8FEag/q067XDTru25qMkSZIkSVqiDD5K87Rtxy6278wrlveex7RrMx8lSZIkSdJSMraajxFRj5pkZq7o8dg4dBxfmqbNtSnXAGuHnHbd0XDGzEdJkiRJkrSEjDOAF0BWX4d5TFrUNm5tBB9XDhl8NPNRkiRJkiQtUeOedt0vuGjgUUtSvUP1mpXLWLF8uKdVPfOxmUUpSZIkSZK0mI0z8/FRIz4mLWr1ZjPrhmw2A53Bx63bd7Erk2VhrF6SJEmSJC1+Yws+Zua7RnlMWuw2bd2d+ThsvUfo7I6dlABkfSq2JEmSJEnSYmW3a2meNm2bX+bj6pXLOmoSOPVakiRJkiQtFQYfpXmqBwvXrR4++LgsgjUrbTojSZIkSZKWHoOP0jxtrE+7HnG6dEfH620GHyVJkiRJ0tIwzoYzYxURa4FrAyuBczLzT1MektTV5lrDmb1HyHyEzqYzZj5KkiRJkqSlYiKZjxGxd0Rcs7pdZY5trxwRHwQuBk4HvgWcFxFfi4hbTGK80jA2batnPo4YfKxlPm4281GSJEmSJC0Rk5p2/a/AL6vbP/baKCIOBU4DHkjJeIza7TbA1yPixLYHKw1j09Z6zccRp13XMh+3mPkoSZIkSZKWiEkFH+8LVzT0fWOf7d4EXLW6n43HkjJN/L0RcfXxDk8a3XwbzoCZj5IkSZIkaWlqPfgYEUcAh1KChz/LzF/22O5o4ER2Bx0vBP4euBfwHGBj9dhewIvaHLM0jE21hjPrRm04Y81HSZIkSZK0BE2i4cz1a/e/2We7R1RfA7gcOC4zz6rWfS4iTgO+XC0/KCKekJlbxjtUaXj1adej1nxca7drSZIkSZK0BE1i2nV9ivTP+mx3z+prAh+sBR7LysxTgFOqxbXAsWManzQvm7bZ7VqSJEmSJKmbSQQf96ndv7jbBhFxJeAGtVX/3eNYp9TuHzW/YUnjUa/RuHbUhjNmPkqSJEmSpCVoEtOuV9buN5vIzLgNuxvSbAdO7bHd72v395/nuBa0iFgG3Bo4EjgMuBQ4F/haZnYN4mo6Nta7XY847drMR0mSJEmStBRNIvi4oXb/wB7b3KH6msD3MvPyHtvVg5er5jswuCLIdxRws9rtRpTGNjPuWE37bl1ErACeCTwBuHKXTbZFxKeAp2fm2ZMYk/rbXG84M4Zu12Y+SpIkSZKkpWISwcdza/d71Wk8oXb/632OdUDt/saRR1SJiI8AdwfWzfdY4xARhwCfpgRAe1kFPAC4a0T8VWZ+YiKDU0/1mo9rx9DtetvOXezYtYsVyyZRFUGSJEmSJKk9kwg+nl59DeCEiDgwM/8882BE3IUytXjG//Y51nVq9/84hrHdlIUTeNwL+ASdgcdzgfcCv6Zkjd4TuH312D7AByLiTpnZr4u4WpSZHd2uR818bHbJvnzbTtavMfgoSZIkSZIWt9ajG5n5G+D7lCnT64BPRcQNImJ1RNwReCe7p1NfSP/g4y1r93855qFuBb4DvIUS8Ju0l9D5/X0YODIzn5WZb8/MV2bmHYCHUepiAqwBPhgRayY8VlW27tjFrloxgHUjNpxZuTxYHnHFslOvJUmSJEnSUjCp1KqXs7uhzC2BHwCbgS8BV6keS+C1mdk16hIRhwPHVItbgB+PYVzvBv6GkgG5PjNvkZmPp38AdOwi4qrAk2qrfgg8NDO3NrfNzPcDL6ituhrwxHZHqF7qWY8wesOZiOis+2jTGUmSJEmStARMJPiYmR+hZBTOBCCjdpvJG/s28Oo+h3nYzOGAb2fmjj7bDjquF1RZhadn5va592jN4ylZjDOeMcd4/pXOWppPbWNQmtumrZ1BwnrtxmHZ8VqSJEmSJC01Eysql5lPoATZzmk8tAV4M3CXzNzWbd+IWMXuzMAA/qetcU7J/Wr3zwG+0G/jKvD6ztqqq0ZEvyY1akm92cy6VctZtiz6bN2fHa8lSZIkSdJSM4mGM1fIzLcCb42IawCHUqZe/6xX0LFmf+DZteXPtTTEiat+FkfVVn0pM7PX9jVfBJ5XW74P8N1xjk1z21zvdD1is5kZZj5KkiRJkqSlZqLBxxlVE5rfDLH9+cC72hvRVN2osXzagPt9G9jB7t/hMX22VUs21qZdr1s1+pRr6Mx83GzmoyRJkiRJWgImNu1aPR3VWP7VIDtl5hbgD7VV1x/biDSwzbWGM+vmm/lowxlJkiRJkrTEGHycvms2ln87xL71bZvH0QRs2lbPfBzjtGszHyVJkiRJ0hJg8HH69mksXzTEvhfX7q+MiNVjGI+GsGlrvebj/KZdr7XhjCRJkiRJWmKmUvMxItZSah0eBewHrKN0sR5YZr5k/CObir0by1uG2PfyLsfaOujOEfGTHg8dOcQY9mgd3a5tOCNJkiRJktRhosHHiLgBpUPzfYH5ZuktleDjmsbyXJ2/65qBxr3mORYNaXNLDWfMfJQkSZIkSUvBxIKPEfF44LXVOWeyHJMhMx5r+y0VzUzHVV3W9dIM4DYzIfvKzKO7ra8yIm1gM4CN9WnXY6z5uNnMR0mSJEmStARMJPgYEScCb6wW64HDpNQ43DiJcSxQze99DYMHH5uZjnvyz3EqNtemXe89xm7XW7btJDOJGCU2L0mSJEmStDC0HnyMEj35t2pxJtPxv4C3At/OzGFqHC5FlzWW9wcuGXDf/Wr3t2fmwPUeNR71btfzbThTz3zcmcn2ncmqFQYfJUmSJEnS4jWJzMebA0ewO+PxMZn5zgmcd7H4TWP56l3W9XJ47f5Z4xmOhlHvdr1uvtOuGzUjN2/bwaoVq+Z1TEmSJEmSpGlaNoFz3Lh2/38NPM7y08bytQbZKSLWAFfucxxNQEfDmXlOu16xbBmrlu9+StrxWpIkSZIkLXaTCD4eULv/2Qmcb7H5QWP5uAH3uwWdmas/Gs9wNIxN2+qZj/Obdg12vJYkSZIkSUvLJIKPf67dv3gC51tUMvM3wM9rq+4Sg3UZuWtj+dPjG5UGVZ92vXaemY/QWffRzEdJkiRJkrTYTSL4eHbt/kETON9i9LHa/cOBu/XbOCJWAI+qrToX+G4L49Ic6g1n9p5nwxkw81GSJEmSJC0tkwg+nsLu7Mc7TuB8UxcRR0RE1m6nzLHLm4F6p+pXRcTKPts/HbhKbfm1mZm9NlZ7NtczH+fZcAbMfJQkSZIkSUtL68HHzNwOvBEI4G4RceO2z7nYZObvKD+jGccA74uI1c1tI+IhwItrq84F3tDuCNXNrl3Zkfk4327X0Jn5uNnMR0mSJEmStMjNP1oymJdSahQeB3wkIu6Ymb+d0Ll7ioj7A6/q8tD6xvL7IuLyLts9IzM/OqbhPB+4PXCzavlBwK0j4j3AWcD+wL2AO9T22Qr8ZWZuGdMYNIRmZuK6MUy7XmvmoyRJkiRJWkImEnzMzJ0RcU/gv4B7Aj+IiFcA78rM8ycxhh72AY4cYLsr99l/LDJzc0ScAHwGOLZafRXgWT122QA8MjO/Pq4xaDj1TtcA68bRcMaaj5IkSZIkaQmZSPAxIr5c3V0G7AL2BV4BvCIizgHOA4bJ3svMvPN4Rzl9mXleRNyKEnB8AnBol822UQKU/1B1ytaUbNq6Ozi4LGD1ivlXMegIPpr5KEmSJEmSFrlJTbs+Hqg3RElKDUiAIygdngcVjWONLDNPBk4ex7Eaxz2b3d/fsPtuB14aES8Hbg1cCziEkun4e+BrmXnRmIaqedhUazazbvUKIkb6lXfoaDhj5qMkSZIkSVrkJhV8hP7BuPlHbZaYzNwJfK26aQHaPOZmM2C3a0mSJEmStLRMKvj4rgmdR5qYeubj2jE0mwFrPkqSJEmSpKVlUg1nHjWJ80iTVG8400bm45btO9mVybIxTOeWJEmSJEmahvl3yJD2UJtrDWfWjSnzcW0tiJnA1u27xnJcSZIkSZKkaTD4KI1o49bxZz6uXrmsowDq5lp2pSRJkiRJ0mJj8FEaUT0wuHb1eIKPyyJYY9MZSZIkSZK0RBh8lEa0qdYQZu8xTbuGRtMZg4+SJEmSJGkRm1S361ki4l7AXYFbAlcF9gfWAr/KzOs2tl0J3KRa3JmZ35vkWKVuNte7XY9p2jV0Np2x47UkSZIkSVrMJh58jIiHAS8FDq+v7nEfgMzcHhHvBq5dHeMmmfnDVgcqzWFjveHMKjMfJUmSJEmSmiY27ToiVkTEfwHvpgQeo3aD0ty3nzfWtn14K4OUhlCv+bhuTDUfwcxHSZIkSZK0dEyy5uP7gAezO+C4Hfgs8GLgCdW6fgHID9Uev2d7w5QGU6/5OK6GM9DIfDT4KEmSJEmSFrGJTLuOiL8EHkQJHgbwEeDJmfnH2jZv6neMzDwvIr4PHAtcPyIOzMw/tzhsqa9NtZqPY512Xct83Oy0a0mSJEmStIhNKvPxxbX7b8nMB9UDj0M4vXb/BvMckzQvHcHHMWY+rjXzUZIkSZIkLRGtBx8j4vqURjEJ/A74+3kc7he1+0fOZ1zSfG3eVm8401LNRzMfJUmSJEnSIjaJzMdja/f/OzO3zuNYl9Tu7z+P40jzVs98XLu6pW7XZj5KkiRJkqRFbBLBx0Nq98+c57HqkZhV8zyWNC+bat2u926r27WZj5IkSZIkaRGbRPCx3sF6vhGaA2r3L57nsaSR7dyVbNm+64rlteNsOGPmoyRJkiRJWiImEXy8oHZ/vnUab1y7f/48jyWNrJ71CO3VfNy2cxc7du3qs7UkSZIkSdLCNYng409r9+8z6kEiYjVw99qq00YekTRPm7d2ZiSOt9t157HMfpQkSZIkSYtV68HHzDydkqUYwHUj4pEjHuqJwJUo07h/mpl/HNMQpaHVMx9XLg9WrRjfU2nl8mB5xBXL1n2UJEmSJEmL1SQyHwHeUX0N4E0Rcddhdq62f3lt1evGNTBpFB2drsc45RogIlhj3UdJkiRJkrQETCr4+EpK7ccE9gL+JyLeFBHX6bdTRBwQEa8APk3pbp3AL4D/bHm8Ul+batOux9npesZaO15LkiRJkqQlYPxRky4yc2NEnAh8iRJ8XA78LfC3EXEW8JPa5gdExJuBo4FbVdvOzEG9DDgxM43GaKo2b6tnPo6v0/UMO15LkiRJkqSlYFKZj2TmacAJdHa/DkoH7BMoWY0A+wN/A9yGzuDoH4F7ZuaZ7Y9W6m9jfdp1C5mPe5n5KEmSJEmSloCJBR8BMvMrwDHAycD22kPR2DRq63YC7wGOrQKY0tRt3lafdt1u5uNmMx8lSZIkSdIiNZFp13WZ+Sfg0RHxbOCBwO2AGwEHAvsBm4ELgZ8DXwE+nJnnTHqcUj9tNpyBxrRrMx8lSZIkSdIiNfHg44zMPB94Y3WTFpV6w5l1bdR8rE273mLmoyRJkiRJWqQmOu1aWirqDWfWtdHt2mnXkiRJkiRpCTD4KI1gU8vBRxvOSJIkSZKkpWCi064jIoBjq9uVgAOAfYBLgYsotR6/m5lnTHJc0rDq067XtjHtul7z0cxHSZIkSZK0SE0k+BgRtwCeAdyZEmyca/tLgC8C/5KZ32t3dNLw6g1n9jbzUZIkSZIkqatWp11HxCER8Wngm8D9gH2BqG5dd6lu+wMPAr4dER+PiIPaHKc0rHodxla6Xa/szHzMzLGfQ5IkSZIkqW2tBR8j4ijgNOCe7A421iMo0eVGY7sATgC+GRHXaWus0rA2bq3XfGx32vXOTLbvNPgoSZIkSZIWn1amXUfE1YCvAgdSAolJCSRuomRBfgM4G7gY2AisB/YDrgkcB9wKWMfuIOQ1gVMj4maZeW4bY5aG0dHtuo3Mx0Ydyc3bdrBqxaqxn0eSJEmSJKlNbdV8/E92Bx4D+BPwz8DbM3PDXDtHxD7A3wL/SGlMk8AhwH9QMimlqepoONNC5uOKZctYtXwZ23buAkrdx/3GfhZJkiRJkqR2jX3adUTcndJYZiZr8TvAsZn5b4MEHgEy87LM/BdKV+zvsntK9t0i4s7jHrM0rE0tZz5Co+O1TWckSZIkSdIi1EbNx6dUXwP4HXD3UadKZ+bvgXtUx5kJZj51vgOU5mtzLfNxXQvdrmF20xlJkiRJkqTFZqzBx4g4ALhLtZjAYzPzkvkcMzMvAh7L7qY0d4uI/eZzTGk+tu3YdcV0aGin4Qw0Mh8NPkqSJEmSpEVo3JmPd6TUkUzgR5n5pXEcNDO/CPyoWlwB3Gkcx5VGUW82A7C2rWnXK512LUmSJEmSFrdxBx9vVbv/7jEfu368W/XcSmrZpkYW4rpVZj5KkiRJkiR1M+7g4/Vq97815mOfVrt//TEfWxrY5q27Mx9Xr1jGiuVtlE6FtbXMx81mPkqSJEmSpEVo3FGTI2r3vzfmY59eu3/4mI8tDWxjLfjYVrMZMPNRkiRJkiQtfuMOPh5cfb08M7eM88CZeTmwmdJ05uA5Npdas3lbvdN1O1OuAdZY81GSJEmSJC1y4w4+7k1pNnPJmI87Y+a461s6vjSnTfXMx5aazQCsNfNRkiRJkiQtcuMOPq6uvm4e83FnXF59XdXS8aU5bap1u17bUrMZsNu1JEmSJEla/MYdfGyn88ZsMaHzSLNs2lqfdm3NR0mSJEmSpF4mFSyUlozN2yYz7bqe+bhl+052ZbZ2LkmSJEmSpDYYfJSGtLGW+bi2xYYza2uBzQS2bt/V2rkkSZIkSZLa0Fba1vqI+Ks2jtvCMaWhbK41nNm7xWnXq1cuIyiBR7DuoyRJkiRJWnzaipwcDLyzpWNLU7WpVn9xbYvTrpdFsGbl8iuCjvXp3pIkSZIkSYtBe5GTdprCWPROU7dpa73mY3vTrqE0nZkJPpr5KEmSJEmSFps2go9tdqK2y7WmrqPhTIvTrqGz6YwdryVJkiRJ0mIz7sjJo8Z8PGnB2VRrOLOuxYYzUDIfZ5j5KEmSJEmSFpuxBh8z813jPJ60EG2qZT62WfMRzHyUJEmSJEmL27JpD0BabDZNqNs1NDIfDT5KkiRJkqRFxuCjNKTNHd2uW552vdJp15IkSZIkafEy+CgNaePWyTWcqQc3N5v5KEmSJEmSFhmDj9IQMrMjCDjJbtf1LtuSJEmSJEmLgcFHaQhbd+xi5668Ynldy9Ou16/ZHdzcsMXgoyRJkiRJWlwMPkpDqDebAVjbcubj+jUrr7i/YcsOMrPP1pIkSZIkSQuLwUdpCM26i2tXtpv5uM9eu4OP23bu6qg3KUmSJEmStNAZfJSGsKlWd3HtquUsWxatnm/tquXUT3HBhq2tnk+SJEmSJGmcDD5KQ6hPu167qt0p1wDLIjqmXp9/2ZbWzylJkiRJkjQuBh+lIWzaunva9d6r251yPaPedOaCy8x8lCRJkiRJi4fBR2kIm7dNNvMROpvOXLDBzEdJkiRJkrR4GHyUhrCxlvm4bkKZj/vUMh/PN/NRkiRJkiQtIgYfpSFMJ/OxNu3ahjOSJEmSJGkRMfgoDaGz5uNkgo/72HBGkiRJkiQtUgYfpSF0drueRsMZg4+SJEmSJGnxMPgoDWFTbdr1ugllPnY2nNlKZk7kvJIkSZIkSfNl8FEawuZpNJzZa3fwcfO2nWysZV9KkiRJkiQtZAYfpSFsnELDmbWrlrMsdi/bdEaSJEmSJC0WBh+lIWyuZR2um1DNx2URHVOvbTojSZIkSZIWC4OP0hA2batPu55M5iM0m86Y+ShJkiRJkhYHg4/SEOrdricbfKw3nTHzUZIkSZIkLQ4GH6UhbK5lPq6d0LRrgH1qmY/nm/koSZIkSZIWCYOP0hDqmY97T2natTUfJUmSJEnSYmHwURpCPfg4qW7XAPt0TLs281GSJEmSJC0OBh+lAe3alWzeXm84M7lp1x01H818lCRJkiRJi4TBR2lAW3bsJHP38tS6XW/YStYHIkmSJEmStEAZfJQGtLE25Rpg3SSnXe+1O/Nx87ads8YiSZIkSZK0EBl8lAa0eevuKdfLAtasnNzTZ+2q5SyL3ct2vJYkSZIkSYuBwUdpQJu27c42XLdqBRHRZ+vxWhbRWfdxg3UfJUmSJEnSwmfwURrQplrm49oJNpuZ0VH30cxHSZIkSZK0CBh8lAbUzHycNDMfJUmSJEnSYmPwURpQvebjJDtdz9inlvlozUdJkiRJkrQYGHyUBrSp1mF67arpTrs+/zIzHyVJkiRJ0sJn8FEaUMe066lkPtanXZv5KEmSJEmSFj6Dj9KANm+b7rTrjpqPZj5KkiRJkqRFwOCjNKCNW+sNZ6bc7XrDVjJz4mOQJEmSJEkahsFHaUCbO2o+TmHa9V67Mx83b9vZEQyVJEmSJElaiAw+SgPaVJt2vffqyWc+rl21nBXL4oplO15LkiRJkqSFzuCjNKCObtdTqPm4LIKD1q++YvmCDdZ9lCRJkiRJC5vBR2lA9czHadR8BDh4nzVX3L/AzEdJkiRJkrTAGXyUBlSv+TiNbtcAB9cyH8+347UkSZIkSVrgDD5KA9o45YYzAIfsU592beajJEmSJEla2Aw+SgPaXJ92PYWGMwAHr9897drMR0mSJEmStNAZfJQGtHnb9Kddm/koSZIkSZIWE4OP0oDq067XTWnadWfDGTMfJUmSJEnSwmbwURrAzl3Jlu27rlheO61u1x0NZ7aSmVMZhyRJkiRJ0iAMPkoDqE+5Bth7atOud2c+Xr59Z0c2piRJkiRJ0kJj8FEawKatOzuW106p4cwBa1exYllcsXz+ZdZ9lCRJkiRJC5fBR2kAm2qZjyuWBauWT+eps2xZcND6etMZ6z5KkiRJkqSFy+CjNIDNtczHdatXEBF9tm5XZ9MZMx8lSZIkSdLCZfBRGkBnp+vpTLme0dl0xsxHSZIkSZK0cBl8lAZQbzizdkrNZmYcsk992rWZj5IkSZIkaeEy+CgNYNO2zmnX03TI+t3Trs18lCRJkiRJC5nBR2kAmxbStGszHyVJkiRJ0iJh8FEaQD34uHbVdDMfOxvOmPkoSZIkSZIWLoOP0gA216Zd7716ITWc2UpmTnE0kiRJkiRJvU03hWuBioijgWOAKwM7gXOB72bmb6Y6ME1NR+bjtGs+1jIfL9++k0s2b2f/daumOCJJkiRJkqTuDD7WRMQDgedTAo/dHv8G8NzMPKWFcx8PfGXE3W+emd8d32jUtGnbwqn5eOC6Vey710ouvXw7AGeev4FbXfPAqY5JkiRJkiSpG6ddAxGxPCLeCXyIHoHHyq2B/42Il05mZFooNm9dON2uI4LrHrr+iuUzz9swxdFIkiRJkiT1ZuZj8RrgpNryZuB9wBnAKuCWwAOAlZSA7fMi4qLMfE2LYzoH2DHnVoVdR1q2saPb9fSfNtc7dD3f/s1FAPzc4KMkSZIkSVqgph9FmbKIuDfwd7VVPwXukZm/a2x3I+B/KHUgAf41Ir6UmT9qaWjHZ+bZLR1bQ6o3nFk75YYzANc5ZHfm4y/ON/goSZIkSZIWpj162nVELANeXlu1GTihGXgEyMwfAA8CdlWrmvtqCavXfNx7ytOuoWQ+zvjFeRvseC1JkiRJkhakPTr4CNyZzhqPr8vMs3ptnJnfoNSFnHGfiLhWW4PTwtHR7XoBTLu+Ti34uGHrDs695PIpjkaSJEmSJKm7PT34eL/G8n8MsM/bG8snjmcoWsg21RvOTLnbNcA+a1Zylf32umLZpjOSJEmSJGkh2tODj/eu3f91Zv56gH2+RmeDl/uMd0haiDbXpl1Pu9v1jHrHa5vOSJIkSZKkhWiPDT5GxH7A1WurThtkv8zcBnyvtuqYXttq6ejIfFwADWegM/ho5qMkSf+/vfuOj+sq8z/+fWbUJcuSe4m705wGCYnTkw0JLQtLC7AhhLIsgVAWtoT9hYXfssvSFn6EXTphwyYQ6kJCykIIgfQKpOEUd8d27LhILurl+f1xp9y51sgz1ozmSvN5v17z8j1nzr33kXQsXT06BQAAAHFUtclHSUdHymuKODc8QrLdzOaUIJ6oT5vZo2bWYWb9ZvaCmT1uZt80s9ebWTwyYFWgf3BY/UPDmXIc1nyUIpvOsOM1AAAAAACIoWpOPi6NlDcVcW60bfRapfCXkk6Q1CapVtJMScdJeo+k/5H0rJn9RRnui4ie/qGcchynXa/dsV8DoQQpAAAAAABAHFRz8rE1Ut5dxLkdkfKUEVuNXYekjZJ2SopmlpZKusHM/q1M90bK/tB6j5LUFIMNZyRp6YwW1SRMkjQw5Fq3o6vCEQEAAAAAAOSq5uRjS6TcO2KrkfUc5FqHapek/5T0CknT3X2auy9295mSpkl6vaR7I+dcaWZ/cyg3M7M/jfSStGwsH8Rk092XTT7W1SRUm4zHf5u6moSWzmzOlJ/etreC0QAAAAAAABwoHlmUymiIlPuLOLcvUm4cYyxSsInNYe7+IXf/lbvnjMR09z3u/nNJZ0n6eOTcz5nZghLEgBF0haZdt8RkynXakXOyA3jZdAYAAAAAAMRNNScfoyMd64o4tz5Sjo6ELJq773P3g46+9MCnJH0jEs8Vh3DPY0Z6KXdDnarXFRr5GJcp12lHseM1AAAAAACIsWpOPu6PlKMjIUcTHekYvdZ4+CflJj1fXYEYqkI4+dgck52u046cnU0+Pk3yEQAAAAAAxEw1Jx+jC+S1F3FuW6Q87lkfd98l6c5Q1SIzmzvecVSD7tC06+b6eI18DO94vaWzR/t6ByoYDQAAAAAAQK5qTj6uj5QXFnHuokh53RhjOVTPRMqzKhLFJNcV2u26OWZrPh7W3pizDuWz2ysxCBcAAAAAAGBk1Zx8XBUpLy/i3PBu0B3uvq0E8RyK6FqTTRWJYpKL85qPZqYjZmc3W2fdRwAAAAAAECdVm3x0905Jm0JVpxVynpnVSTopVPVECcMq1uxIeWdFopjkuvrC067jNfJRyp16/cy26GoCAAAAAAAAlVO1yceUW0PHy8xsaQHnnKXczWluLm1IRTkrdDwgaUulApnMuvvju+GMxKYzAAAAAAAgvqo9+fjzSPmvCzgn2uaG0oRSHDN7pXKnit/r7t2ViGWy2x8a+dgUsw1nJOnIOa2Z42e275O7VzAaAAAAAACArGpPPt4u6clQ+YNmtiRfYzM7TdJFoapb3H11nraLzcxDr9+Nct3GYoJO7Wr9zUj1d4u5BgoX95GPR4WmXXd2D+iFfX0VjAYAAAAAACCrqpOP7j4s6cpQVbOkm8xsQbStmR0v6SfKfs6GJX2sRKG82czuNLPXpNaUzMvMzpf0oKRwjI9Juq5EsSAi7ms+tjfXadaU+kyZTWcAAAAAAEBcxC+TMs7c/SYz+5qky1NVx0h6ysy+L+lRSbWSTpX0xtRx2kfd/bEShnJ26tVpZvdKelzS85L2KdjFeomkCySdEDlvm6TXphKpKIPwbtfNMdvtOu3IOVMyIx6f2bZPZx8xs8IRAQAAAAAAkHxM+5CkKZLelio3S3pPnrYu6bPu/oUyxdIm6cLU62AekHSJu28oUyxQ7rTrphiOfJSCqdd3rw42O2fTGQAAAAAAEBdVPe06zd2H3P1SSW9W7hqQUQ9IOt/drxylzaF4RNI1kp5SkNwcjUu6T9Ilks5097UljgURXf3ZadctMdxwRpKOCO14/cz2vRWMBAAAAAAAICuew7gqxN1/LOnHZnaspOMlzZM0JGmrpIfdfV0R19ogyQps+6Skd0mSmbVJerGkhZJmSGqU1CepU9IGSQ+5+55C48DYhaddN8VwwxlJOiq04/Xq7fs1NOxKJgrqfgAAAAAAAGUTz0xKhaWSgaONgCznvTsl/bYS98bIctd8jOd/mcNntyhh0rBLfYPD2rCrS8tmtlQ6LAAAAAAAUOWYdg2Mwt3V3R/e7Tqe064bapNaPL05U35yC4NjAQAAAABA5ZF8BEbRNzisweHsMpzNMd1wRpJetLAtc/zbp1+oXCAAAAAAAAApJB+BUYRHPUpSU108Rz5K0gVHz84c3/H0CxoYGq5gNAAAAAAAACQfgVGF13uU4rvhjCSdfcRM1dUE/6X39g7q4fW7KxwRAAAAAACodiQfgVF09WeTj421yVjvIN1cX6Mzlk3PlG9btb2C0QAAAAAAAJB8BEbV1Rf/zWbCLlgxJ3P861Xb5e6jtAYAAAAAACgvko/AKPb2DmSOW2K82Uza+UfPyhxv6ezRU8/vq2A0AAAAAACg2pF8BEbR2d2fOW5rqqtgJIWZ1dqgFy1oy5R/zdRrAAAAAABQQSQfgVF0dGVHPrY31VYwksJdsCK76/Wvn9pWwUgAAAAAAEC1I/kIjCI88rG9Of4jHyXpZaHk45Nb9mprZ08FowEAAAAAANWM5CMwio7u8MjHiZF8XD6rRUtmNGfKtz/F1GsAAAAAAFAZJB+BUXSERz5OkGnXZpY79Zp1HwEAAAAAQIWQfARG0Rka+TgRNpxJCycfH1i3K2fXbgAAAAAAgPFC8hEYRe7Ix4mTfDxxYbump9aoHBhy/e6ZHRWOCAAAAAAAVCOSj8AoOrom3rRrSUomTOcdNStTZuo1AAAAAACohJpKBwDEWUfMpl1f/+CmgtvW1yQzx7f9aZuuvX+DahK5f2+4eOXCksUGAAAAAAAQxchHII/egSH1DAxlyu3NE2fkoxTsel2bNElS3+Cw1u/sqnBEAAAAAACg2pB8BPIIbzYjTaw1HyWpriah5TNbMuUnNu+pYDQAAAAAAKAakXwE8ghvNtNQm1BDbXKU1vF0zPypmeM/PtepvT3seg0AAAAAAMYPyUcgj4m603XY8fOnqrUhWNp1aNh192p2vQYAAAAAAOOH5COQR2fMNps5FDXJhM46fGam/NCG3drfN1jBiAAAAAAAQDUh+QjkER75OG2CbTYTdvLiaWquD0Y/Dgy57l2zs8IRAQAAAACAakHyEchjMox8lIKNZ85aPiNTvn/dLnX3M/oRAAAAAACUH8lHII+OrvCajxN35KMkrVwyTY2pDXP6B4d139pdFY4IAAAAAABUA5KPQB4doZGPE3XDmbT62qTOWD49U75v7U71DgxVMCIAAAAAAFANSD4CeXSG1nycyNOu005bOkP1NcF/+d6BYT2wjtGPAAAAAACgvEg+Anns7p48064lqbEuqdOWZUc/3rNmJ2s/AgAAAACAsiL5COTROYmmXaedsWyGapMmSeruH9L1D26qcEQAAAAAAGAyI/kI5NGRM+164o98lKTm+hqtXJId/fj1363N2VgHAAAAAACglEg+AiMYGnbt6Zl8Ix8l6azDs6Mfd3X1619vXlXhiAAAAAAAwGRF8hEYwd6eAblny5Mp+TiloVYXHD07U/7ZH7fojqe3VzAiAAAAAAAwWZF8BEYQnnKdMGlKQ00Foym905fP0IL2xkz5yp89qb29A6OcAQAAAAAAUDySj8AIOkKbzbQ11SmRsApGU3oJM73+xMNUlwy+BWzb26vP3Pp0haMCAAAAAACTDclHYASdoZGP7ZNks5mo2a0N+uB5yzPlHzy0Sfet2VnBiAAAAAAAwGRD8hEYQXjk42Ra7zHqvecu04q5rZnyR3/2uLr7BysYEQAAAAAAmExIPgIjCI98bJvEycfaZEKff+PxSqamlT+3u0f//qtnKhwVAAAAAACYLEg+AiPoqIJp12nHzp+q956zNFP+7n0bdC/TrwEAAAAAQAmQfARGkDPtunnyjnxM++B5h2v5rBZJkrv0Nz/8o7bv7a1wVAAAAAAAYKIj+QiMoKMrPO16co98lKSG2qS+/JYXqa4m+Jawc3+/Pnj9HzU4NFzhyAAAAAAAwERG8hEYQe6068k/8lGSjpk3VZ98zTGZ8kMbdusLtz1bwYgAAAAAAMBER/IRGEFnzm7Xk3/kY9pbTl6g1714fqb8jTvX6jdPba9gRAAAAAAAYCIj+QiMoKNKdruOMjP92+uO1eGp9R8l6W9//Jie291dwagAAAAAAMBERfIRiHD33A1nqij5KElNdTX6+iUnqrE2KUna0zOgD1z/B/UNDlU4MgAAAAAAMNGQfAQiegaG1D+Y3WilmqZdpy2fNUWfef1xmfJjm/fokzetqmBEAAAAAABgIqqpdABA3IRHPUqTe9r19Q9uGvX9UxZP00MbdmfadvUNauWS6XnbX7xyYUnjAwAAAAAAExsjH4GIjq7seo8t9TWqq6ne/yYXHj9XC9obM+WbHtuq9Tu7KhgRAAAAAACYSKo3qwLkEd7puq0Kp1yH1SYTeuvKRZrSEAySHnbp+gc3qjO0IQ8AAAAAAEA+JB+BiPBO19W22cxIWhtrdcnKRUomTJLU1T+k7z2wMWddTAAAAAAAgJGQfAQiwqP6qn3kY9qCaU167YvmZ8pb9/TqZ3/cLHevYFQAAAAAACDuSD4CEbu7stOuGfmYddKidp2+LLvZzOOb9+jOZ3dUMCIAAAAAABB3JB+BiNxp14x8DHvlsXO1dGZzpnzbqu16fHNn5QICAAAAAACxRvIRiMidds3Ix7BkwnTxyQs1rTn7efnp7zdrAztgAwAAAACAEZB8BCI6usPTrhn5GNVUX6N3nLZYjbVJSdLgsOu6BzZq576+CkcGAAAAAADihuQjEBEe+djezMjHkcyYUq9LTs3ugN0zMKTv3r9Bu/aTgAQAAAAAAFkkH4GI8MhHpl3nt2RGs9540mGZ8u6ufr372kfUOzBUwagAAAAAAECckHwEIthwpnAnHNaml6+YnSn/cVOnPvzDRzU07BWMCgAAAAAAxAXJRyBkcGhY+3oHM+V2Rj4e1NlHzNTJi6dlyr/80zZd8dPHNUwCEgAAAACAqkfyEQjp7BnIKbPm48GZmV5zwjwdMbslU/c/f9isf7rxSbmTgAQAAAAAoJqRfARCwpvN1CZNzXXJCkYzcSQTpotPWaSVS7IjIK9/cJP+5eZVJCABAAAAAKhiJB+BkOhmM2ZWwWgmlrqahL7zjpN14sK2TN01927QZ3/5NAlIAAAAAACqFMlHIKSji81mxqKlvkbffdcpOv6wqZm6b965TlfdvrqCUQEAAAAAgEoh+QiEhHe6bmOzmUPS2lCra991io6aMyVT9+XfrNYXfvUMIyABAAAAAKgyJB+BkPC0a0Y+Hrq2pjp9790rtXxWdhOar/x2jT52w5MaYhdsAAAAAACqBslHICQ88rGdkY9jMqOlXte/e6WOnJ0dAXn9g5v0wR/8QX2DQxWMDAAAAAAAjBeSj0BIZ1fuhjMYm1mtDfrRZafqpEXtmbpbn9imd333Ye3vG6xgZAAAAAAAYDyQfARCckc+Mu26FNqa6nTdX52ic4+cmam7d80uXfztB7Rrf18FIwMAAAAAAOVWU+kAgDjpzFnzkZGPxbr+wU1533vpUbPV0dWvxzbvkSQ9vnmPLvjSXbr0tEWaNaVhxHMuXrmwLHECAAAAAIDxwchHICR3t2tGPpZSMmG66CULdNqy6Zm63V39+sada7Xmhf0VjAwAAAAAAJQLyUcgJGe362ZGPpZawkx/ftxcvfyYOZm63oFhffe+9Xpw/a4KRgYAAAAAAMqB5COQ4u7qZLfrsjMznXPETF18ykLVJk2SNOzSjY9u1S2Pb9Wwe4UjBAAAAAAApULyEUjZ3zeoweFs4osNZ8rr2PlT9Z6zlqm1Ibv07L1rd+m6+zeqb2CogpEBAAAAAIBSIfkIpIQ3m5GkqY0kH8ttfnuj3nfucs2bmt1w5pnt+/TNu9blrL8JAAAAAAAmJpKPQEo42dXaUKOaJP89xsPUxlq95+xlWjG3NVO3bW+vvva7tfrDpo4KRgYAAAAAAMaK7AqQsrsrtN4jm82Mq7qahC5euVBnHz4zU9fVN6i3fOsB/eKxrRWMDAAAAAAAjAXJRyAlPO26jc1mxl3CTK84do7ecOJ8JS3YiKZ/cFgf+sEfddXtz8rZiAYAAAAAgAmH5COQ0pGz0zXrPVbKSYum6Z1nLlZjbTJTd9Xtq3XZdb/Xvt6BUc4EAAAAAABxQ/IRSOkIjXxsZ+RjRS2d0aL3nbtMS2c0Z+puW7Vdf/GVe7V6+74KRgYAAAAAAIpB8hFIeWFvb+Z4Gms+VtyMlnr9/PIzdM4R2XUg1+3s0l989V7d8vjzFYwMAAAAAAAUiuQjkLL6hf2Z46Uzm0dpifEytalW//WOk/Wh85Zn6rr7h/T+6/+gT9/6lAaHhisYHQAAAAAAOBiSj4Akd8+Zznv4rCkVjAZhyYTpb192pK6+9CWa0lCTqf/WXet00Tfv1/qdXRWMDgAAAAAAjIbkIyBpx74+7e0dzJQPn9VSwWgwkvNXzNZNHzhTR87OJob/uKlTr/ry3brugY3shg0AAAAAQAyRfASUO+V6Rkud2lnzMZYWz2jWz99/ut5w4mGZup6BIX38hif1jmse1vbQup0AAAAAAKDySD4CElOuJ5Cmuhp98U0n6BuXnKj2ptpM/Z3P7tDLr7pLN/xxC6MgAQAAAACICZKPgHJHPh4+mynXE8Erjp2rX33kbJ131KxMXWf3gD78o0d10Tfu1xOb91QwOgAAAAAAIEk1B28CTH45yUfWe4yN6x/cdNA2Lz1qlqY21OqWJ55Xf2r360c2dug1X7lHJy5q18tWzNaUhmCE5MUrF5Y1XgAAAAAAkIvkI6pedKfr5Uy7nlDMTCcvmaZls1p06xPPa9XzeyVJLun3Gzv05JY9OveImTp12fTKBgoAAAAAQBUi+Yiqt6urXx3dA5ky064npmnNdbrk1EVa88J+3fz4Vr2wr0+S1Dc4rF+t2q671+xUd/+QLj1tUWYkJAAAAAAAKC/WfETVW709O+V6WnOdZrTUVzAajNXyWS364HmH69UnzFNjbTJT390/pH//1TM683O/1VW3P6s9oYQzAAAAAAAoD5KPqHprXghPuWbU42SQTJhOWzpdf3fBETrr8BmqS2a/1e3pGdBVt6/WmZ+7Q5/536e0fW9vBSMFAAAAAGByI/mIqsdmM5NXU32NXnnsXP3Dy4/UuUfOVEt9dqWJfX2D+uad63Tm5+7Q3//kMT0bWvcTAAAAAACUBslHVL3wtGuSj5NTc32NXrZiju796Hn68PmHq7Uhm4QcGHL99Peb9bIv3aV3ffdhPbBul9y9gtECAAAAADB5sOEMql7OyMfZ7HQ9mU1tqtWHzz9C7z5rqX740CZ95571en5Pdtr1HU+/oDuefkEnHDZV7zl7mV5x7BwlE1bBiAEAAAAAmNhIPqKqdXT1a+f+vkyZkY+T2/UPbsocN9XV6PJzl+vxzZ26e/VObQut/fjY5j16//V/0LTmOp25fIZOXNiuupoDB4pfvHLhuMQNAAAAAMBERfIRVS086nFqY61mTmGn62qSTJhevLBdL1rQptUv7Nfdq3do7Y6uzPu7u/r1i8e26vantuvM5TO0csl0NdYlR7kiAAAAAAAII/mIqrY6tNP14bNaZMYU22pkZjpi9hQdMXuKtnT26O7VO/TE5j1Kr/zY3T+k21Zt153P7tCpS6frjOUzcjavAQAAAAAAI+O3Z1S1nM1mZjPlGtL8tka95eSFevmKft2zdqce2bBbA0NBGrJvcFh3PrtD967ZqZMXT9M5R87U/LbGCkcMAAAAAEB8sds1qtqa0LTr5bPYbAZZ7c11evXx8/QPLz9Kf3bkTDXUZr9dDg677l+3S+d8/re64qePad2O/aNcCQAAAACA6sXIR1S16LRrIKqlvkYXrJijsw6fqQfX79Y9a3aqq29QUpCE/PEjm/WT32/Wq46bq8vPXaZj5k2tcMQAAAAAAMQHyUdUrT09A9q+N7TTNdOuMYqG2qTOOWKmTl82XY9s7NDdz+5QZ8+AJMlduuXx53XL48/rjOXT9bZTF+v8o2epJsngcgAAAABAdSP5iKoVnnI9pb5Gc1obKhgNJoraZEKnLZ2uUxZPU2NdUl/73RqtC+2Qfe+aXbp3zS7Nm9qgt566SG8+eYFmtLCLOgAAAACgOjEsB1Vr9fbslOvls9npGsVJJkxvPOkw/foj5+jrbz1Rx85vzXl/655e/fuvntHpn7lDH7j+D/rNU9s1MDRcoWgBAAAAAKgMRj6iaq0OjXxkvUccqmTC9Mrj5uoVx87RIxs7dO39G/W/TzyvweFgh+z+oWHd/Pjzuvnx59XeVKsLj5+r175ovk5a1E7CGwAAAAAw6ZF8RNXKTT6y0zWKd/2Dmw6oO23pdB0zr1UPb9ith9bv1r7ewcx7Hd0D+t4Dm/S9BzapralWK+a2asXcVi2a3qxkwnTxyoXjGT4AAAAAAGVH8hFVa01k2jVQKq0NtXrpUbN17hGz9My2fXp0c6eefn5vZjSkJHV2D+i+tbt039pdaqxN6ui5UzStuU5nHj5DLfV8awYAAAAATA78houqtK93QFv39GbKTLtGOSQTphXzWrViXqt6B4b0p6179OhznVq3o0seatczMKQ/bOrUe7/3e9UkTCctatc5R87U2YfP1Iq5rUokmJ4NAAAAAJiYSD6iKq0N7U7cXJfU/LbGCkaDatBQm9RJi6bppEXTtLd3QE89v1dPPb9Xa3d0aSg0InJw2PXg+t16cP1uff6Xz2hGS73OWD5dpy4NXounN7FWJAAAAABgwiD5iKqUs9P1LHa6xvhqbajVyiXTtXLJdPUODOnZ7fu06vm92rirW3t6BnLa7tzfpxsf3aobH90qSZrT2qBTl07TKUum6+TF7Vo2s4WRkQAAAACA2CL5iKq0JrTZzHI2m0EFNdQmdfxhbTr+sDa9+eQFemxzp+56dofufHaHHnuuU8Oe237b3l7d8OhW3ZBKRrY11eqkhe16yeJpesnidh03f6oaapMV+EgAAAAAADgQyUdUpWdDIx8PZ7MZxMSPHn5OkjRrSoMuOmmBLjxurtbu6NL6nfu1bkeXXtjXd8A5nd0D+s3TL+g3T78gKVhn8rC2Ri2a3qRLT1uskxa1q725blw/DgAAAAAA0kg+jsDMjpF0vKR5koYkbZH0iLuvH+c4EpJOl7RM0lxJe1Kx3O3uHeMZy2Syp2dAj2zMfvrYbAZx1VRXo+PmT9Vx86dKkvb3DWr9zi6t27FfG3d1a/veXkUGRmpo2LVxd7c27u7WXat3SgqWFnjJomB05MmL27VwGutGAgAAAADGB8nHEDN7o6SPK0g8jvT+fZI+5u6/K3McNZI+KulyBQnQqH4zu0nS37v7hnLGMhn91z3rta93UJI0pb5GJy+ZVuGIgMK01OcmI3sHhrRpd7c27urShl3d2tzRrYGhaDoyWGZgzQv79cPUyMoZLfU6eXG7Tl06XSuXTtMRs6awbiQAAAAAoCxIPkoys6SkqyW94yBNT5f0GzP7tLt/vEyxzJZ0s6SXjNKsTtIbJF1gZpe6+43liGUy2tM9oP+6JzuA9V1nLlFrQ20FIwIOXUNtUkfMnqIjZgfrlg4Nu7Z29gQjH3d1afvePu3cf+BU7Z37+/S/T27T/z65TZLU3lSrkxdP08ql0/XihW1aMbeVdSMBAAAAACVB8jHwJeUmHrslfV/SowoSfSsVJPtqJSUk/ZOZ7Xb3L5UyCDNrlHSjchOPWyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxnLZPWde9ZpX19q1GNDjd515pIKRwSUTjJhWjCtSQumNenM5TP0l6cs0MZd3XpkY4ce2bBbj2zsyNlsKa2je0C3rdqu21ZtlyTVJk0r5rbqRQva9KKFbTpi9hQtndGixjoSkgAAAACA4lR98tHMLpT0wVDVKkmvcPfnIu1OkHSrstOgv2Bmt7v7EyUM518UJDrTfirpEncPD136rJldLOm7CpKhDZJ+ZGZHuHtvCWOZdDq7+/Vf927IlN995lJNbWTUIyavHzyU/TaW3lG7u29QG3d3a/3OLq3f2aWtnT0HrBs5MOR6bPMePbZ5j/77/o2Z+vltjVo2q0XLZjZrQXuT5kxt0OzWes2a0qBZrfWqryE5CQAAAADIVdXJx9SGLp8OVXVLenU08ShJ7v6YmV0k6W4Fox/T5766RLEcJukDoarHJV3s7gMjxHK9mS2U9JlU1QJJ75f0xVLEMlldffd67U+NemxtqNE7z1xc2YCACmiqr9HRc1t19NxWScG6kRt2dWn9ji5t2t2tLZ09Ghw+cN1ISdrS2aMtnT2669kdI74/tbFWbU21amusVWtjbaY8tbFWbY11mtpYq6mpcmtDraY01KilvkYtDTWqTSbK9jEDAAAAACqnqpOPkl6q3M1l/sPd1+Vr7O73mdlPJL05VfXnZrbc3deUIJb3KRjFmHbFSInHkC8oSFbOT5U/LJKPee3u6tc192bXenzP2UtZ6xFQsG7kUXNaddScIBk5NOzatrdXz+0ONrDZ2tmrnfv78iYkw/b0DGhPz4A2HrTlgeprEjnJyJb6GrXU16qlPqnGuho11ibVUJtQY21SjXVJNdQmU3VJNdYlcsuhNk11SRKbAAAAAFBB1Z58fF2kfHUB53xb2eSjJL1WQSKwlLFslHTbaI3dfdDMrpH0T6mqw8zsJe7+SAlimXS+ffc6dfUPSZLammr19tMXVzYgIKaSCdP8tkbNb2tUsMysNOyuzu4B7djXqx37+rRjf5/29Axob8+g9vYOqDv1f2ss+gaH1be/Xzv394/5WlEt9TVqa6rVtOY6tTXVaVpTbfBvc53aQ8dtTbVqb6pTe1Md61sCAAAAQIlUe/LxwtDxWndfW8A5d0vqVXaU4p9rjMlHM1si6ehQ1e3ufvBhRtKvlU0+pmMh+Rixa3+f/vu+DZnyX5+1VFMY9QgULGGmac1Bgu7IOQe+Pzg0rL29g+rqG1TPwJB6+ofUMzCk7v4h9abK3Zn6QfX0DwXJxsHhcYl/f9+g9vcNanNHT8Hn1Nck1N4USkg2B0nKtsagPKUhO0IzGLFZq+b6pKak/q1htCUAAAAASKri5KOZtUlaGKp6oJDz3L3fzH4v6YxU1fGjtS/QCZFyQbFIekjSoLJfx1LEMul86+51mZFZ7Yx6BEquJpnIJCeLMeyu/lQSsncglZAcGFJv6t++wWH1Dg5pcMjVPzSsgcFhDQwNa2DIU/8Gx/3p48FhDQy7BgaHD9hEp1h9g8PatrdX2/Ye2j5ejbXJ0PTx7HTyKfU1aq6PJi6Duik5U86D48bapMxsjB8NAAAAAFRO1SYflTvSUJKKWbdxrbLJx3Yzm+Pu28Y7FnfvNbOtyiZRV4whhklp5/4+XXtfdgW6y85Zppb6au72QHwkzNSQWqexlDvPu7uGhl0DQ66+wWAEZnf/kLr6B4PjvsFMuSdc3z+k/hKNxuwZCEZ/7tjXN6brmEkNNak1LGsSaqhLZsupNTDrU+tc1tUkVJdMqCZhqq1JqDaZUF3SVJsMjmtrEqpNWOY4/F5N0lSXbpdMqK4meC+ZMNUkgn/Tr5rIMclRAAAAAKOp5izM0kh5UxHnRtsulTSW5ONYY0knH6PXqXr/fd8G9QwEox6nN9fp0tMWVTgiAOVmZqpJmmqSUmNdUm1NhZ87ODSs7oF0wnJQ3X2paeOhBGX3wFB2ZGZ6xObgkAaGxjre8kDu2URmXCVMqkkklEho1ERlupxIJSvTSct06tIseAV1lq1TtkG4bfpcS9UnzGQ28r+JTNmUDMWZjq8maTmJ1pqc+BOh9w+sz9aFzs1Xn/r40x9rIhV7ULYRPx5JSiRC7RR8PEodj3id1OfNFfQhueRyuafrPDM6OL3ISxBbcH7ClBNrwkzJTNwkmwEAAFCcak4+tkbKu4s4tyNSnhKTWGrNrN7dxzbUZhJ537nBSMdv3rVOl52zVE111dzlARxMTTKh1mRCrYewLuzQcDCNvHcwO4U8naAM6odHTFr2DQxnjnsHgn8L2Fw8NoZd6h8aloYkaXzW8UTlRJORmeRuKlmZTvQmEqFjMyUSoeNQvYUSnenkZ/5EaKocOk6k2tsIx+lzw23zGW2lbR9lIYd0Qjd7jVSS17PnZpO+oevlSQRH63ISx56NJXzfzNcm8nUK6iynPNJ7kX9SbQ7848BIbdIJ85G+7tGvRbiPZL+GwR8E0sfpr1n485v5nB1Qp4LaKafdoV8n53KpymJiiCb8M21D10/H5zl12bY550ZiyNc/sudkzz9Y27xx5dwrG0e0v+Zcq5i4Ivca6WsY/qNUug+m6zXCH67Sf7TJHIf7f+QPWdnjkesVvVbe61vkGtk/FkWvW6himmc/99HvLz7C1+DAfhFtE/6elb5ubmy53y+in6fwe/m+N1n4ZEW/BnnOidw/G0/+9tn7HRjz6PfIvn/A98O83y8Pcq0ivqj5mmb/BxTavrzXH7ltnmsUHUuR1y/yOvlOGKm2mM/XzCn1umDF7Hx3rSrVnIlpiZSLWdgrumtB9FqVjqWg5KOZ/SnPW0etXbtWxxxzTBFhxNuwS1f9QPpycT/nx2xPz8D43hDApBBNQmR+gfARfjlIt0m/n7lG6JfFTL3nlLPHB/7CO+aFMwEAAIAq1liX0GHtRUzDirG1a9dK0oJDPb+ak48NkXJ/EedGk3uNkygWSRru6+vrWrVq1XMluFY1W5b6t5Bd1FFd6BvIh76BfOgbyIe+gXzoG8iHvoF86BslNCBp1fOVjqJkFkjqPtSTqzn5GB1dWMw2rfWRcnT0YSliKXT04yHH4u6TZ2hjDKVHlvJ5RhR9A/nQN5APfQP50DeQD30D+dA3kA99A+WSqHQAFbQ/Uo6OPhxNdHRh9FoTORYAAAAAAACgJKo5+bg3Um4v4ty2SHnf2EIpWSwDbDYDAAAAAACAuKjm5OP6SHlhEecuipTXxSSWscYBAAAAAAAAlEw1Jx9XRcrLizh3Wei4w923VSIWM2uQNG+U6wAAAAAAAAAVU7XJR3fvlLQpVHVaIeeZWZ2kk0JVT5QgnMci5YJikXSKcjcNKkUsAAAAAAAAQElU827XknSrpPemjpeZ2VJ3P9jU5bOUuyHMzWMNwt3Xm9nTko5KVZ1vZubufpBTL4iUxxwLSocdwpAPfQP50DeQD30D+dA3kA99A/nQN5APfQPlUrUjH1N+Hin/dQHnRNvcUJpQcmJZJOllozU2sxpJ7wxVbZH0SIliAQAAAAAAAMas2pOPt0t6MlT+oJktydfYzE6TdFGo6hZ3X52n7WIz89DrdweJ5euSwjtVf97Makdp//eS5ofKVxUwUhIAAAAAAAAYN1WdfHT3YUlXhqqaJd1kZguibc3seEk/UfZzNizpYyWM5TlJXw1VHS/p+2ZWP0Isfynpk6GqLZK+UqpYAAAAAAAAgFIwBstJZvZVSZeHqrokfV/So5JqJZ0q6Y2p47R/cPcvjHLNxZLWh6rudPdzDxJHk6Q7Jb0kVL1F0nWS1klql/QqSeeE3u+TdL673zPatQEAAAAAAIDxRvJRkpklJV0j6W0FNHdJn3X3K0drdCjJx9R5cyTdIunEAmLZJ+nt7h5duxIAAAAAAACouKqedp3m7kPufqmkNyt3DcioBxSMMhw18TjGWLYpGGn5CUnb8jTrV7BBzQkkHgEAAAAAABBXjHwcgZkdq2DNxXmShiRtlfSwu68b5ziSkk6XtFzSbAUjHTdLutvdd49nLAAAAAAAAECxSD4CIWZ2jHITz1skPeLu60c9sfRxJBQknpdJmitpTyqWu929YzxjQaDSfcPM6iQdLWmFpDmSmiTtlbQ9Fce4/nEEWZXuG4ivuPUNM2tV8LNlnqRZkvZLeiEV16Pu3lWJuKpRXPqGmS1TsNTPXElTJPVI2iXpcUlPuPvgeMaD+OBZFFE8iwIYC5KPgCQze6Okjyv4RWAk90n6mLv/rsxx1Ej6qIINkOaN0KRf0k2S/t7dN5QzFgQq2TfMbL6Cza5eJelMBQ95+ayR9DVJX3P3vlLHggPF5ftGPmb2fklfiVR/0t3/uQLhVJW49Q0zO0vBz5YLJNXlaTakYHmZj7n7neMRVzWKQ99Izax5n6T3SzpqlKY7Jf23pE8z46Z8Ukm+oxVsOJl+nSCpMdTsz8bx+wXPojERh77Bs2g8xaFvFIJnUYSRfERVSz2AXy3pHQU0H1bwAP7xMsUyW9LNyt3tPJ+9ki519xvLEQsq3zfM7GWSfinJijz1T5Le5O6rShULclW6bxTCzA6TtErBSKYwHvjKKG59w8yaFDz0v0OFfy/5B3f/QrliqlZx6RtmNkvBxoaFPGukvSDpDe5+T6njqXZm9j+SXi6p+SBNxyWJwLNofMShb/AsGk9x6BuF4FkUUTWVDgCosC8p9xeBbknfl/SogtEhKyW9QVKtgg2a/snMdrv7l0oZhJk1SrpRuQ97WyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxkLMirdN5qU+7A3LOkxSXdL2iipQ1K7gg2q/kLZ0UzHSLrDzM509zUligW5Kt03CvF1Hfiwh/KLTd8ws2YFSaZzQtU9kn6jYITjdklJBVPnXiTpPAU/W1AeFe8bqSmTv1buqMs+Sb9Q0Cd2S2qRdJyCkU7TUm1mSfpfM1tJMqHkTtLBEwjjgmfR2IlD3+BZNJ7i0DcKwbMocrk7L15V+ZJ0oSQPvf4kacEI7U5Q8PCVbjck6bgSx/LvkVh+Iql+hHYXK5jukm63SVJDpT+Xk+0Vh74h6bWpa65TMP1p7ihtFyqYqheO+a5Kfx4n4ysOfaOAGN8Suu+qSLz/XOnP4WR9xa1vSLo1Es+1kmaN0r5W0uskvaLSn8vJ9opL35B0RSSORyUtydN2iqQfRdr/utKfy8n2krQh9PntlfSQgl/Yr4t87s8dh1h4Fo3RKw59QzyLxvIVh75RQIw8i/I64JUQUIVS62R8OlTVLenV7v5ctK27PybpIgV/7ZOCEQmfjrYbQyyHSfpAqOpxSRf7CGuluPv1kj4RqlqgYM0mlEiM+sYLki6TdKS7f87dn8/X0N03KZh+8Uyo+iwzOzvPKTgEMeobo8U4XdKXU8VeSR8q9z0Rv75hZn+lYIRS2ufd/VJ3fyHfOe4+4O4/d/dfljKWahezvvH20HFPKo71IzV0932S3qrgmSTtpWY20hqAOHTXSnqPgpFMU9z9FHd/n4IRyuOGZ9FYikPf4Fk0nuLQN/LiWRT5kHxEtXqpcqcd/YePskObu9+n4C/AaX9uZstLFMv7JDWEyle4+8Ao7b+gYGRE2odLFAcCsegb7n6fu3/rIH0h3H6fpE9Gqv98rHEgRyz6xkF8ScEUSUn6lILF31F+sekbZjZFwc+JtAck/Z9SXBuHJBZ9w8waFOxQm3bzSAnQSCyDkr4dvozyb5SDQ+Dun3D3b7v7Hwr9eV8mPIvGTBz6Bs+i8RSHvnEQPItiRCQfUa1eFylfXcA5346UX1uaUHJi2SjpttEap34ZuCZUdZiZFbNwPEYXp75RrNsj5WUViWLyinXfSC0M/7ZUcZWkz5frXjhAnPrGJZLaQuUr3H04T1uUX1z6xvRIudBfBldHytNGbIWJjmdRlArPolWMZ1GMhuQjqtWFoeO17r62gHPuVjB0PG3Mf8kzsyWSjg5V3e4eLJRxEL+OlPmrYunEom8cov2R8kRYjHoiiW3fSG0u8s1U0SVdFtO/hk9Wceob7wkdP+Pud5foujg0cekbnQq+N6QV+vOhJVLOO3UfExPPoigxnkWrFM+iOBiSj6g6ZtamYFHktAcKOc/d+yX9PlRViqlHJ0TKBcWiYGHhwRLHUvVi1jcOxZJIeVtFopiEJkDf+JSkxanjq939njLdBxFx6htmNkPBztVpt471mjh0ceob7t6lYJfatPMKPPWloeP0xgaYXHgWRSnxLFq9eBbFqEg+ohodHSkXsw5FeMRCu5nNqUQs7t4raWuoakW+tihKnPrGoXh9pHx/BWKYrGLbN8zsFGUX896uYEdKjJ849Y1TIuX7pWDxdzP7iJndY2bPm1lf6t/7zOxTZnb4GO+LkcWpb0jSf4aOjzWzUTcJMbOTJb0rVPUtd99bgjgQLzyLopR4Fq1CPIuiECQfUY2WRsqbijg32jZ6rUrFMtY4EIhT3yiKmbVIujxU1S/pxvGMYZKLZd8ws1pJ31H25/lH3L2jVNdHQeLUN14cKT9tZm+Q9LSk/yfpDElzJNWl/j1N0sckPWVmXzOz+jHeH7ni1DekYI2+8M+F/0x93Y8KNzKzOWZ2haTfSkr3iYckXVmCGBA/PIuiJHgWrU48i6JQJB9RjVoj5d1FnBv9RjolJrHU8ktjScSpbxTri5LmhsrfcHemupROXPvGP0o6NnV8m7v/oITXRmHi1DdmRsrnKtg5eUaq7JJ2SHpe0lCoXVLBbre/MbPGMcaArDj1DaXW8XuTpKsUTJc1BV/3p8xsj5mtN7N0//icgrXaBiR9XdJLU1O3MfnwLIpS4Vm0OvEsioKQfEQ1ii6e3jtiq5H1HORaEzkWTNCvh5ldqtxNJjZJ+vh43b9KxK5vmNnRCkatpe/xvlJcF0WLU99oi5S/qCDB1CfpnyXNd/dZ7j5Pwe7Hlys30XCGgkQTSiNOfUNSsJ6ku39EwS+Kd4bealWwVteMUN0mSa9198vdPbqJBCaP2PVTTDw8i1YnnkVRDJKPqEYNkXJ/Eef2RcpjHSESp1gwAb8eZnaOpG+HqgYkvYV1uUouVn3DzEzB1z09yuRf3H3dWK+LQxKnvhH9xb9WwfeEV7n7J939+fQb7r7H3b8u6UxJu0LnvD211h/GLk59Q5JkZgkz+4ikuySdc5DmCyXdYma/NjOm1E5eseunmFh4Fq1OPIuiWCQfUY2if9GtK+Lc6HSS6F98J3IsmGBfDzM7SdIvlI3TJb3T3Vncu/Ti1jcuVzBKTZKeUDDCDZURp74x0oilL7r7HflOcPenJP1tpPrDY4wDgTj1DZlZg6SbFaz/OStVfbuk1yqYKlknqV1BUvLbyk7NP1/SI2Z24lhjQCzFqp9iYuFZtKrxLIqikHxENYpOHYr+xXc00b/ojnUaUpxiwQT6epjZcZJ+pdy1mi539++X875VLDZ9w8wWSPpMquiSLnP3gbFcE2MSm74haV+k7JL+o4DzrlewO2Xa+WOMA4E49Q1J+rKkV4bKV7r7Be5+o7tvc/cBd+9097vc/T2SXqZsYqpd0s9SG0pgcolbP8UEwbNo9eJZFIeC5COqUXQKQHsR57ZFytFf9IpVqlgG3D069QXFi1PfyCu1M+ntCtZsS/uwu3+jXPdErPrG15XdfOIbjC6ouDj1jWgsT4enWufj7oOS7glVzTKzw8YYC2LUN1Lrcv11qOoX7v6ZfO0lKTVi9mOhqkWSLhtLHIglnkVRNJ5Fqx7PoigayUdUo/WR8sIizl0UKY91XYtSxcL6GqURp74xIjM7XNIdyk6Zk6R/dPcvl+N+yIhF3zCz10i6MFXcJun/HOq1UDKx6BspayPlTUWcuzFSju6cjeLFqW+8RcHmQ2lfKfC8byp3DcDXjzEOxA/PoigKz6LVjWdRHKqaSgcAVMCqSHl5EecuCx13uPu2MsRy50gNw1LrNs0b5To4NHHqGwdILfh/h4K1udI+4e6fK/W9cIC49I3wpg9Nkn4frPedV/Tn/IfM7JJQ+VPu/t0xxIP49A1J+lOkXMyutdG2xUy9xMji1DeOj5QfKeQkd+8ys6dD5x8zxjgQPzyLomA8i0I8i+IQkXxE1XH3TjPbpOxfdk8r5Dwzq5N0UqjqiRKE81ikfJqk7xRw3inK/f9biliqXsz6RvQeiyT9VlJ4KuSn3P1fS30vHCimfaNVuessFaJduVPq2koWTZWKWd94UsEmIclUeVoR50bb7hqxFQoWs77RHCkXszZfV+iY3YwnH55FURCeRTECnkVRMKZdo1rdGjpelvor3sGcpdyRIDePNQh3Xy/p6VDV+XaQPx2lXBApjzkWZMSib4Sl1l67Q7lToT7n7h8v5X1wULHrG4iNWPQNd9+j3BFLx5tZoc96Lw4dD0jaPNZ4ICkmfUNSR6Q8p4hzwyOcSEpPMjyLohA8iwIYK5KPqFY/j5T/esRWo7e5oTSh5MSySMHuknmZWY2kd4aqtqjA6VMoSJz6hsxsroKHvfAvrP/P3f+xVPdAwSreN9z9Kne3Ql+SlkQu8clIm6vGEg8yKt43Qn4aOp6qg/xMkSQzWyLp5FDVA+7eXaJ4ql1c+saaSDmaOBpRam23xaGqZ0sQC+KHZ1HkxbMowngWxaEi+YhqdbuC6WlpH0z98jUiMztN0kWhqlvcfXWetovNzEOv3x0klq9LCu8O+Hkzqx2l/d9Lmh8qX+XufpB7oHCx6RtmNjMVz+Gh6v9w97872AeBsohN30DsxKlvXCdpe6j82dQ03tF8UbnPhP99kPYoXFz6xi8j5SvNbMqILXNF13H7VQHnoMJ4FkU+PIsiH55FUW4kH1GV3H1Y0pWhqmZJN5nZgmhbMzte0k+U/f8yLOljJYzlOUlfDVUdL+n7ZlY/Qix/KemToaotKnzHShQgLn3DzNol/VrSilD119z9b0pxfRQvLn0D8ROnvuHu+yX931DVCZJ+lvqeEo2l3sy+Kul1oepnJV1bqniqXVz6hrvfLenhUNUySbemplIewMyazOxq5faNvZK+XYp4EC88i2IkPIsCKCU2nEHVcvebzOxrki5PVR0j6Skz+76kRyXVSjpV0htTx2kfdffo4txj9XFJZ0t6Sap8kaTTzew6SesULMr7KknnhM7pk/QWdy9mN1MUICZ94wMKkgZhrzCz6NS50Wx293NLFA8Um76BGIpZ3/iWgp8Xf5kqXyhpjZn9WNLjkgYVjGJ5k4Iplmn7Jb3B3QdKHE9Vi1HfuEzSXZJaUuUzFfSLX0h6UMF6js0KEk9vkDQ9cv7fuPvOEsZT9czs9ZI+P8Jb0VGp3zeznhHaXeHuPytRODyLxkhM+gbPojEUk74BFI3kI6rdhxR8o35bqtws6T152rqkz7r7F0odhLt3m9mrJd0i6cRU9XxJ+dZS2Sfp7e5+T6ljQUal+0ZyhLpCNioI43t8eVS6byC+YtE33N3N7B0KRtC9OVU9TdJ7Rzlti6TXufuTo7TBoat433D3P5rZhZJ+qOwmMvUKkkwX5T1R6pX0EXf/binjgaRgl9hlBbSbN8r5JcGzaOzEoW/wLBpPcegbQNGYdo2q5u5D7n6pgl/ORvuF6wFJ57v7laO0GWss2xSMfPiEpG15mvUrWBT8BHePLmKPEopT30C80DeQT5z6hrv3u/tbFIxufHSUpnsUjKA4wd0fHqUdxiAufcPd75J0rKR/U/5njbRuSddIerG7f6Mc8SBeeBYFAJSLsTYwkGVmxyqYbjRP0pCkrZIedvd14xxHUtLpkpZLmq3gr8ubJd3t7rvHMxYE4tI3ED/0DeQTp75hZkdIenEqljoFU2xXSXrI3QfHO55qF4e+YWYm6WhJL5I0U8HIzB5JuxX0jUfdvS/vBTCp8SwKACglko8AAAAAAAAAyoJp1wAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACMwMwSZnaMmb3dzP7TzO43s24z89Dr3ErHmWZmGyKxHcrrd6WMqaaUFwMAAAAAAAAmAzP7H0kvl9Rc6VjGWWcpL0byEQAAAAAAADjQSZp4iccNkgaLPGeepMZQ+Qcli0YkHwEAAAAAAICD6ZP0uKTfS2qRdEllwxmZu59bTHszq5e0Rdnk4y5JN5QyJpKPAAAAAAAAwIGulfScgoTjE+4+IElm9g7FNPl4CF4raXqofJ2795XyBiQfAQAAAAAAgAh3/8R43cvMTNKJklZImiXJJG2X9Ad3/1MZb/3uSPk7pb4ByUcAAAAAAACgAsxsiqSPKkgCzs7TZrWk/+vuJV2L0cwWS3ppqOpBd3+ylPeQpESpLwgAAAAAAABgdGZ2qqTVkj6mPInHlMMlXW9mPzaz2hKG8C4FIyzTri7htTMY+QgAAAAAAACMIzP7M0k3S2oKVT+TqlurYMfqIyW9SdKC1PsXSXJJby7B/ROS3hGq6pL0o7FedyQkHwEAAAAAAIBxYmazJP1A2cRjr6T3S7rG3T3S9uOSviTpslTVm8zsZne/boxhvEzZpKYk/cjd943xmiNi2jUAAAAAAAAwfj6r7DTrYUmvc/f/iiYeJcnde9z9vZL+J1T9r6mRi2MR3WimLFOuJZKPAAAAAAAAwLgwszmS3hqqutrdf1nAqR+SNJA6XiTpVWOIYaak14SqVrn7/Yd6vYMh+QgAAAAAAACMjzdKqguVv1TISe6+VdLtoaoLxhDDpZLCG9d8ZwzXOiiSjwAAAAAAAMD4OCt0vM7dny7i3IdCxyvHEMO7Qsf9kq4dw7UOiuQjAAAAAAAAMD5OCB3/qchzt4eODzuUm5vZaZJWhKpudPedh3KtQrHbNQAAAAAAADA+poeOX21mB2wyU6D2Qzxv3DaaSWPkIwAAAAAAADA+2kp0naZiTzCzFklvClVtVO46kmXByEcAAAAAAABgfHRLak0dd0jaPY73foukllD5GncfLvdNST4CAAAAAAAA42OnssnHn7j7ZeN4778KHQ9LumY8bsq0awAAAAAAAGB8hHe3Pma8bmpmx0g6NVR1m7tvGo97k3wEAAAAAAAAxsdvQ8enmtmMcbrvX0XK3xmn+5J8BAAAAAAAAMbJTyUNpo6Tkv6h3Dc0szpJbwtV7ZB0Y7nvm0byEQAAAAAAABgH7r5B0g9CVX9rZi8r5hoWqCvilL+QFB5hea27DxRzz7Eg+QgAAAAAAACMnyskPZ86rpF0k5n9nZk1jHaSmc01sw8qWDfyxCLuV7Ep15Jk7j6e9wMAAAAAAABiz8xeL+nzI7w1RdKsUHmrpJ4R2l3h7j/Lc+3TJP1S2Z2vpWAn7F9JelTSbgXTstskHaEg2fhiSZZqe5q7P1DAx7BQ0nplByDe5+5nHOy8UqoZz5sBAAAAAAAAE0SrpGUFtJs3yvkjcvf7zexUSTcoSC5KwdTot6ZeBzNUQBtJeqdyZz5fXeB5JcO0awAAAAAAAGCcuftTko6V9F5Jqwo4ZZWkL0p6sbs/fLDGZmYKko9p+yT9+BBCHROmXQMAAAAAAAAVZmbzJZ0qabakdkn9kjokrZX0pLvvqGB4h4zkIwAAAAAAAICyYNo1AAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMri/wMZEmFf1HC9MAAAAABJRU5ErkJggg==\n",
217
+ "text/plain": [
218
+ "<Figure size 1500x750 with 1 Axes>"
219
+ ]
220
+ },
221
+ "metadata": {
222
+ "needs_background": "light"
223
+ },
224
+ "output_type": "display_data"
225
+ }
226
+ ],
227
+ "source": [
228
+ "gene_detection_counts = [i for i in gene_detection_counts_dict.values()]\n",
229
+ "import seaborn as sns\n",
230
+ "import matplotlib.pyplot as plt\n",
231
+ "plt.figure(figsize=(10,5), dpi=150)\n",
232
+ "plt.rcParams.update({'font.size': 18})\n",
233
+ "count_plot = sns.distplot(gene_detection_counts).set_title(f\"# Cells Expressing Each\\nProtein-Coding or miRNA Gene\")"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 47,
239
+ "id": "missing-bradley",
240
+ "metadata": {},
241
+ "outputs": [
242
+ {
243
+ "data": {
244
+ "text/plain": [
245
+ "27454"
246
+ ]
247
+ },
248
+ "execution_count": 47,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "len(gene_detection_counts)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 55,
260
+ "id": "perfect-signal",
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "25424"
267
+ ]
268
+ },
269
+ "execution_count": 55,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "len([i for i in gene_detection_counts if i > 0])"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 56,
281
+ "id": "faced-theory",
282
+ "metadata": {},
283
+ "outputs": [
284
+ {
285
+ "data": {
286
+ "text/plain": [
287
+ "22735"
288
+ ]
289
+ },
290
+ "execution_count": 56,
291
+ "metadata": {},
292
+ "output_type": "execute_result"
293
+ }
294
+ ],
295
+ "source": [
296
+ "len([i for i in gene_detection_counts if i > 100])"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 57,
302
+ "id": "tough-workplace",
303
+ "metadata": {},
304
+ "outputs": [
305
+ {
306
+ "data": {
307
+ "text/plain": [
308
+ "21167"
309
+ ]
310
+ },
311
+ "execution_count": 57,
312
+ "metadata": {},
313
+ "output_type": "execute_result"
314
+ }
315
+ ],
316
+ "source": [
317
+ "len([i for i in gene_detection_counts if i > 1000])"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 49,
323
+ "id": "cooperative-camcorder",
324
+ "metadata": {},
325
+ "outputs": [
326
+ {
327
+ "data": {
328
+ "text/plain": [
329
+ "173152.0299000284"
330
+ ]
331
+ },
332
+ "execution_count": 49,
333
+ "metadata": {},
334
+ "output_type": "execute_result"
335
+ }
336
+ ],
337
+ "source": [
338
+ "gene_detection_event_digest = crick.tdigest.TDigest()\n",
339
+ "gene_detection_event_digest.update(gene_detection_counts)\n",
340
+ "gene_detection_event_digest.quantile(0.5)"
341
+ ]
342
+ }
343
+ ],
344
+ "metadata": {
345
+ "kernelspec": {
346
+ "display_name": "Python 3 (ipykernel)",
347
+ "language": "python",
348
+ "name": "python3"
349
+ },
350
+ "language_info": {
351
+ "codemirror_mode": {
352
+ "name": "ipython",
353
+ "version": 3
354
+ },
355
+ "file_extension": ".py",
356
+ "mimetype": "text/x-python",
357
+ "name": "python",
358
+ "nbconvert_exporter": "python",
359
+ "pygments_lexer": "ipython3",
360
+ "version": "3.10.11"
361
+ }
362
+ },
363
+ "nbformat": 4,
364
+ "nbformat_minor": 5
365
+ }
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # run with:
5
+ # deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
6
+
7
+ import datetime
8
+
9
+ # imports
10
+ import os
11
+
12
+ os.environ["NCCL_DEBUG"] = "INFO"
13
+ os.environ["OMPI_MCA_opal_cuda_support"] = "true"
14
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
15
+
16
+ import pickle
17
+ import random
18
+ import subprocess
19
+
20
+ import numpy as np
21
+ import pytz
22
+ import torch
23
+ from datasets import load_from_disk
24
+ from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
+
26
+ from geneformer import GeneformerPretrainer
27
+
28
+ seed_num = 0
29
+ random.seed(seed_num)
30
+ np.random.seed(seed_num)
31
+ seed_val = 42
32
+ torch.manual_seed(seed_val)
33
+ torch.cuda.manual_seed_all(seed_val)
34
+
35
+ # set local time/directories
36
+ timezone = pytz.timezone("US/Eastern")
37
+ rootdir = "/parent_ouput_directory"
38
+
39
+ # set model parameters
40
+ # model type
41
+ model_type = "bert"
42
+ # max input size
43
+ max_input_size = 2**11 # 2048
44
+ # number of layers
45
+ num_layers = 6
46
+ # number of attention heads
47
+ num_attn_heads = 4
48
+ # number of embedding dimensions
49
+ num_embed_dim = 256
50
+ # intermediate size
51
+ intermed_size = num_embed_dim * 2
52
+ # activation function
53
+ activ_fn = "relu"
54
+ # initializer range, layer norm, dropout
55
+ initializer_range = 0.02
56
+ layer_norm_eps = 1e-12
57
+ attention_probs_dropout_prob = 0.02
58
+ hidden_dropout_prob = 0.02
59
+
60
+
61
+ # set training parameters
62
+ # total number of examples in Genecorpus-30M after QC filtering:
63
+ num_examples = 27_406_208
64
+ # number gpus
65
+ num_gpus = 12
66
+ # batch size for training and eval
67
+ geneformer_batch_size = 12
68
+ # max learning rate
69
+ max_lr = 1e-3
70
+ # learning schedule
71
+ lr_schedule_fn = "linear"
72
+ # warmup steps
73
+ warmup_steps = 10_000
74
+ # number of epochs
75
+ epochs = 3
76
+ # optimizer
77
+ optimizer = "adamw"
78
+ # weight_decay
79
+ weight_decay = 0.001
80
+
81
+
82
+ # output directories
83
+ current_date = datetime.datetime.now(tz=timezone)
84
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
85
+ run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
86
+ training_output_dir = f"{rootdir}/models/{run_name}/"
87
+ logging_dir = f"{rootdir}/runs/{run_name}/"
88
+ model_output_dir = os.path.join(training_output_dir, "models/")
89
+
90
+
91
+ # ensure not overwriting previously saved model
92
+ model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
93
+ if os.path.isfile(model_output_file) is True:
94
+ raise Exception("Model already saved to this directory.")
95
+
96
+
97
+ # make training and model output directories
98
+ subprocess.call(f"mkdir {training_output_dir}", shell=True)
99
+ subprocess.call(f"mkdir {model_output_dir}", shell=True)
100
+
101
+
102
+ # load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/token_dictionary.pkl)
103
+ with open("token_dictionary.pkl", "rb") as fp:
104
+ token_dictionary = pickle.load(fp)
105
+
106
+ # model configuration
107
+ config = {
108
+ "hidden_size": num_embed_dim,
109
+ "num_hidden_layers": num_layers,
110
+ "initializer_range": initializer_range,
111
+ "layer_norm_eps": layer_norm_eps,
112
+ "attention_probs_dropout_prob": attention_probs_dropout_prob,
113
+ "hidden_dropout_prob": hidden_dropout_prob,
114
+ "intermediate_size": intermed_size,
115
+ "hidden_act": activ_fn,
116
+ "max_position_embeddings": max_input_size,
117
+ "model_type": model_type,
118
+ "num_attention_heads": num_attn_heads,
119
+ "pad_token_id": token_dictionary.get("<pad>"),
120
+ "vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
121
+ }
122
+
123
+ config = BertConfig(**config)
124
+ model = BertForMaskedLM(config)
125
+ model = model.train()
126
+
127
+ # define the training arguments
128
+ training_args = {
129
+ "learning_rate": max_lr,
130
+ "do_train": True,
131
+ "do_eval": False,
132
+ "group_by_length": True,
133
+ "length_column_name": "length",
134
+ "disable_tqdm": False,
135
+ "lr_scheduler_type": lr_schedule_fn,
136
+ "warmup_steps": warmup_steps,
137
+ "weight_decay": weight_decay,
138
+ "per_device_train_batch_size": geneformer_batch_size,
139
+ "num_train_epochs": epochs,
140
+ "save_strategy": "steps",
141
+ "save_steps": np.floor(
142
+ num_examples / geneformer_batch_size / 8
143
+ ), # 8 saves per epoch
144
+ "logging_steps": 1000,
145
+ "output_dir": training_output_dir,
146
+ "logging_dir": logging_dir,
147
+ }
148
+ training_args = TrainingArguments(**training_args)
149
+
150
+ print("Starting training.")
151
+
152
+ # define the trainer
153
+ trainer = GeneformerPretrainer(
154
+ model=model,
155
+ args=training_args,
156
+ # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
157
+ train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
158
+ # file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
159
+ example_lengths_file="genecorpus_30M_2048_lengths.pkl",
160
+ token_dictionary=token_dictionary,
161
+ )
162
+
163
+ # train
164
+ trainer.train()
165
+
166
+ # save model
167
+ trainer.save_model(model_output_dir)
examples/tokenizing_scRNAseq_data.ipynb ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "source": [
10
+ "## Tokenizing .loom or .h5ad single cell RNA-seq data to rank value encoding .dataset format"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "1fe86f48-5578-47df-b373-58c21ec170ab",
16
+ "metadata": {},
17
+ "source": [
18
+ "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
19
+ "\n",
20
+ "#### The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.\n",
21
+ "\n",
22
+ "#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
23
+ "\n",
24
+ "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
25
+ "\n",
26
+ "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
+ "\n",
28
+ "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
34
+ "metadata": {},
35
+ "source": [
36
+ "**********************************************************************************************************\n",
37
+ "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
+ "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
+ "\n",
40
+ "#### ADDITIONALLY:\n",
41
+ "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
+ "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "080fdd9c-0c48-4d5d-a254-52b6c53cdf78",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from geneformer import TranscriptomeTokenizer"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "37205758-aa52-4443-a383-0638519ee8a9",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
63
+ "tk.tokenize_data(\"loom_data_directory\", \n",
64
+ " \"output_directory\", \n",
65
+ " \"output_prefix\", \n",
66
+ " file_format=\"loom\")"
67
+ ]
68
+ }
69
+ ],
70
+ "metadata": {
71
+ "kernelspec": {
72
+ "display_name": "Python 3 (ipykernel)",
73
+ "language": "python",
74
+ "name": "python3"
75
+ },
76
+ "language_info": {
77
+ "codemirror_mode": {
78
+ "name": "ipython",
79
+ "version": 3
80
+ },
81
+ "file_extension": ".py",
82
+ "mimetype": "text/x-python",
83
+ "name": "python",
84
+ "nbconvert_exporter": "python",
85
+ "pygments_lexer": "ipython3",
86
+ "version": "3.10.15"
87
+ }
88
+ },
89
+ "nbformat": 4,
90
+ "nbformat_minor": 5
91
+ }
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.02,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "relu",
9
+ "hidden_dropout_prob": 0.02,
10
+ "hidden_size": 256,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 512,
18
+ "label2id": {
19
+ "LABEL_0": 0,
20
+ "LABEL_1": 1,
21
+ "LABEL_2": 2
22
+ },
23
+ "layer_norm_eps": 1e-12,
24
+ "max_position_embeddings": 2048,
25
+ "model_type": "bert",
26
+ "num_attention_heads": 4,
27
+ "num_hidden_layers": 6,
28
+ "pad_token_id": 0,
29
+ "position_embedding_type": "absolute",
30
+ "problem_type": "single_label_classification",
31
+ "transformers_version": "4.6.0",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 25426
35
+ }
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.39658036828041077,
3
+ "best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
4
+ "epoch": 0.9,
5
+ "global_step": 7020,
6
+ "is_hyper_param_search": true,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.1,
12
+ "learning_rate": 0.00034606438343856935,
13
+ "loss": 0.911,
14
+ "step": 780
15
+ },
16
+ {
17
+ "epoch": 0.1,
18
+ "eval_accuracy": 0.4531576503366612,
19
+ "eval_loss": 1.4550466537475586,
20
+ "eval_runtime": 66.5164,
21
+ "eval_samples_per_second": 259.004,
22
+ "step": 780
23
+ },
24
+ {
25
+ "epoch": 0.2,
26
+ "learning_rate": 0.0006921287668771387,
27
+ "loss": 0.6273,
28
+ "step": 1560
29
+ },
30
+ {
31
+ "epoch": 0.2,
32
+ "eval_accuracy": 0.5953680055723242,
33
+ "eval_loss": 0.846651554107666,
34
+ "eval_runtime": 66.1267,
35
+ "eval_samples_per_second": 260.53,
36
+ "step": 1560
37
+ },
38
+ {
39
+ "epoch": 0.3,
40
+ "learning_rate": 0.0007330550166223805,
41
+ "loss": 0.5592,
42
+ "step": 2340
43
+ },
44
+ {
45
+ "epoch": 0.3,
46
+ "eval_accuracy": 0.5935105641978176,
47
+ "eval_loss": 1.0599186420440674,
48
+ "eval_runtime": 66.2608,
49
+ "eval_samples_per_second": 260.003,
50
+ "step": 2340
51
+ },
52
+ {
53
+ "epoch": 0.4,
54
+ "learning_rate": 0.0006283471571048975,
55
+ "loss": 0.3714,
56
+ "step": 3120
57
+ },
58
+ {
59
+ "epoch": 0.4,
60
+ "eval_accuracy": 0.686324587880195,
61
+ "eval_loss": 1.184874415397644,
62
+ "eval_runtime": 66.1411,
63
+ "eval_samples_per_second": 260.473,
64
+ "step": 3120
65
+ },
66
+ {
67
+ "epoch": 0.5,
68
+ "learning_rate": 0.0005236392975874146,
69
+ "loss": 0.2976,
70
+ "step": 3900
71
+ },
72
+ {
73
+ "epoch": 0.5,
74
+ "eval_accuracy": 0.7681100534014396,
75
+ "eval_loss": 0.6318939328193665,
76
+ "eval_runtime": 66.3309,
77
+ "eval_samples_per_second": 259.728,
78
+ "step": 3900
79
+ },
80
+ {
81
+ "epoch": 0.6,
82
+ "learning_rate": 0.0004189314380699318,
83
+ "loss": 0.2564,
84
+ "step": 4680
85
+ },
86
+ {
87
+ "epoch": 0.6,
88
+ "eval_accuracy": 0.7807058277223126,
89
+ "eval_loss": 0.7283642888069153,
90
+ "eval_runtime": 66.3416,
91
+ "eval_samples_per_second": 259.686,
92
+ "step": 4680
93
+ },
94
+ {
95
+ "epoch": 0.7,
96
+ "learning_rate": 0.0003142235785524487,
97
+ "loss": 0.2336,
98
+ "step": 5460
99
+ },
100
+ {
101
+ "epoch": 0.7,
102
+ "eval_accuracy": 0.8563965637334572,
103
+ "eval_loss": 0.5184123516082764,
104
+ "eval_runtime": 66.3416,
105
+ "eval_samples_per_second": 259.686,
106
+ "step": 5460
107
+ },
108
+ {
109
+ "epoch": 0.8,
110
+ "learning_rate": 0.0002095157190349659,
111
+ "loss": 0.1731,
112
+ "step": 6240
113
+ },
114
+ {
115
+ "epoch": 0.8,
116
+ "eval_accuracy": 0.8288832133735778,
117
+ "eval_loss": 0.5823884010314941,
118
+ "eval_runtime": 66.1535,
119
+ "eval_samples_per_second": 260.425,
120
+ "step": 6240
121
+ },
122
+ {
123
+ "epoch": 0.9,
124
+ "learning_rate": 0.00010480785951748295,
125
+ "loss": 0.1451,
126
+ "step": 7020
127
+ },
128
+ {
129
+ "epoch": 0.9,
130
+ "eval_accuracy": 0.886812166241003,
131
+ "eval_loss": 0.39658036828041077,
132
+ "eval_runtime": 66.3555,
133
+ "eval_samples_per_second": 259.632,
134
+ "step": 7020
135
+ }
136
+ ],
137
+ "max_steps": 7800,
138
+ "num_train_epochs": 1,
139
+ "total_flos": 0,
140
+ "trial_name": null,
141
+ "trial_params": {
142
+ "learning_rate": 0.0008039341830649843,
143
+ "lr_scheduler_type": "polynomial",
144
+ "num_train_epochs": 1,
145
+ "per_device_train_batch_size": 12,
146
+ "seed": 73.15243080311434,
147
+ "warmup_steps": 1812.6785581609881,
148
+ "weight_decay": 0.2588277764570262
149
+ }
150
+ }
geneformer/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: F401
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
+
7
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
11
+
12
+ from . import (
13
+ collator_for_classification,
14
+ emb_extractor,
15
+ in_silico_perturber,
16
+ in_silico_perturber_stats,
17
+ pretrainer,
18
+ tokenizer,
19
+ )
20
+ from .collator_for_classification import (
21
+ DataCollatorForCellClassification,
22
+ DataCollatorForGeneClassification,
23
+ )
24
+ from .emb_extractor import EmbExtractor, get_embs
25
+ from .in_silico_perturber import InSilicoPerturber
26
+ from .in_silico_perturber_stats import InSilicoPerturberStats
27
+ from .pretrainer import GeneformerPretrainer
28
+ from .tokenizer import TranscriptomeTokenizer
29
+
30
+ from . import classifier # noqa # isort:skip
31
+ from .classifier import Classifier # noqa # isort:skip
32
+
33
+ from . import mtl_classifier # noqa # isort:skip
34
+ from .mtl_classifier import MTLClassifier # noqa # isort:skip
geneformer/classifier.py ADDED
@@ -0,0 +1,1563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer classifier.
3
+
4
+ **Input data:**
5
+
6
+ | Cell state classifier:
7
+ | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
8
+
9
+ | Gene classifier:
10
+ | Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
11
+
12
+ **Usage:**
13
+
14
+ .. code-block :: python
15
+
16
+ >>> from geneformer import Classifier
17
+ >>> cc = Classifier(classifier="cell", # example of cell state classifier
18
+ ... cell_state_dict={"state_key": "disease", "states": "all"},
19
+ ... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
20
+ ... training_args=training_args,
21
+ ... freeze_layers = 2,
22
+ ... num_crossval_splits = 1,
23
+ ... forward_batch_size=200,
24
+ ... nproc=16)
25
+ >>> cc.prepare_data(input_data_file="path/to/input_data",
26
+ ... output_directory="path/to/output_directory",
27
+ ... output_prefix="output_prefix")
28
+ >>> all_metrics = cc.validate(model_directory="path/to/model",
29
+ ... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
30
+ ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
31
+ ... output_directory="path/to/output_directory",
32
+ ... output_prefix="output_prefix",
33
+ ... predict_eval=True)
34
+ >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
35
+ ... output_directory="path/to/output_directory",
36
+ ... output_prefix="output_prefix",
37
+ ... custom_class_order=["healthy","disease1","disease2"])
38
+ >>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
39
+ ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
40
+ ... title="disease",
41
+ ... output_directory="path/to/output_directory",
42
+ ... output_prefix="output_prefix",
43
+ ... custom_class_order=["healthy","disease1","disease2"])
44
+ """
45
+
46
+ import datetime
47
+ import logging
48
+ import os
49
+ import pickle
50
+ import subprocess
51
+ from pathlib import Path
52
+
53
+ import numpy as np
54
+ import pandas as pd
55
+ import seaborn as sns
56
+ from tqdm.auto import tqdm, trange
57
+ from transformers import Trainer
58
+ from transformers.training_args import TrainingArguments
59
+
60
+ from . import (
61
+ TOKEN_DICTIONARY_FILE,
62
+ DataCollatorForCellClassification,
63
+ DataCollatorForGeneClassification,
64
+ )
65
+ from . import classifier_utils as cu
66
+ from . import evaluation_utils as eu
67
+ from . import perturber_utils as pu
68
+
69
+ sns.set()
70
+
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+
75
+ class Classifier:
76
+ valid_option_dict = {
77
+ "classifier": {"cell", "gene"},
78
+ "quantize": {bool, dict},
79
+ "cell_state_dict": {None, dict},
80
+ "gene_class_dict": {None, dict},
81
+ "filter_data": {None, dict},
82
+ "rare_threshold": {int, float},
83
+ "max_ncells": {None, int},
84
+ "max_ncells_per_class": {None, int},
85
+ "training_args": {None, dict},
86
+ "freeze_layers": {int},
87
+ "num_crossval_splits": {0, 1, 5},
88
+ "split_sizes": {None, dict},
89
+ "no_eval": {bool},
90
+ "stratify_splits_col": {None, str},
91
+ "forward_batch_size": {int},
92
+ "token_dictionary_file": {None, str},
93
+ "nproc": {int},
94
+ "ngpu": {int},
95
+ }
96
+
97
+ def __init__(
98
+ self,
99
+ classifier=None,
100
+ quantize=False,
101
+ cell_state_dict=None,
102
+ gene_class_dict=None,
103
+ filter_data=None,
104
+ rare_threshold=0,
105
+ max_ncells=None,
106
+ max_ncells_per_class=None,
107
+ training_args=None,
108
+ ray_config=None,
109
+ freeze_layers=0,
110
+ num_crossval_splits=1,
111
+ split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
112
+ stratify_splits_col=None,
113
+ no_eval=False,
114
+ forward_batch_size=100,
115
+ token_dictionary_file=None,
116
+ nproc=4,
117
+ ngpu=1,
118
+ ):
119
+ """
120
+ Initialize Geneformer classifier.
121
+
122
+ **Parameters:**
123
+
124
+ classifier : {"cell", "gene"}
125
+ | Whether to fine-tune a cell state or gene classifier.
126
+ quantize : bool, dict
127
+ | Whether to fine-tune a quantized model.
128
+ | If True and no config provided, will use default.
129
+ | Will use custom config if provided.
130
+ | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
131
+ | For example: {"bnb_config": BitsAndBytesConfig(...),
132
+ | "peft_config": LoraConfig(...)}
133
+ cell_state_dict : None, dict
134
+ | Cell states to fine-tune model to distinguish.
135
+ | Two-item dictionary with keys: state_key and states
136
+ | state_key: key specifying name of column in .dataset that defines the states to model
137
+ | states: list of values in the state_key column that specifies the states to model
138
+ | Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
139
+ | Of note, if using "all", states will be defined after data is filtered.
140
+ | Must have at least 2 states to model.
141
+ | For example: {"state_key": "disease",
142
+ | "states": ["nf", "hcm", "dcm"]}
143
+ | or
144
+ | {"state_key": "disease",
145
+ | "states": "all"}
146
+ gene_class_dict : None, dict
147
+ | Gene classes to fine-tune model to distinguish.
148
+ | Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
149
+ | Gene_label_B: list(geneB1, geneB2, ...)}
150
+ | Gene values should be Ensembl IDs.
151
+ filter_data : None, dict
152
+ | Default is to fine-tune with all input data.
153
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
154
+ rare_threshold : float
155
+ | Threshold below which rare cell states should be removed.
156
+ | For example, setting to 0.05 will remove cell states representing
157
+ | < 5% of the total cells from the cell state classifier's possible classes.
158
+ max_ncells : None, int
159
+ | Maximum number of cells to use for fine-tuning.
160
+ | Default is to fine-tune with all input data.
161
+ max_ncells_per_class : None, int
162
+ | Maximum number of cells per cell class to use for fine-tuning.
163
+ | Of note, will be applied after max_ncells above.
164
+ | (Only valid for cell classification.)
165
+ training_args : None, dict
166
+ | Training arguments for fine-tuning.
167
+ | If None, defaults will be inferred for 6 layer Geneformer.
168
+ | Otherwise, will use the Hugging Face defaults:
169
+ | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
170
+ | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
171
+ ray_config : None, dict
172
+ | Training argument ranges for tuning hyperparameters with Ray.
173
+ freeze_layers : int
174
+ | Number of layers to freeze from fine-tuning.
175
+ | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
176
+ num_crossval_splits : {0, 1, 5}
177
+ | 0: train on all data without splitting
178
+ | 1: split data into train and eval sets by designated split_sizes["valid"]
179
+ | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
180
+ split_sizes : None, dict
181
+ | Dictionary of proportion of data to hold out for train, validation, and test sets
182
+ | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
183
+ stratify_splits_col : None, str
184
+ | Name of column in .dataset to be used for stratified splitting.
185
+ | Proportion of each class in this column will be the same in the splits as in the original dataset.
186
+ no_eval : bool
187
+ | If True, will skip eval step and use all data for training.
188
+ | Otherwise, will perform eval during training.
189
+ forward_batch_size : int
190
+ | Batch size for forward pass (for evaluation, not training).
191
+ token_dictionary_file : None, str
192
+ | Default is to use token dictionary file from Geneformer
193
+ | Otherwise, will load custom gene token dictionary.
194
+ nproc : int
195
+ | Number of CPU processes to use.
196
+ ngpu : int
197
+ | Number of GPUs available.
198
+
199
+ """
200
+
201
+ self.classifier = classifier
202
+ if self.classifier == "cell":
203
+ self.model_type = "CellClassifier"
204
+ elif self.classifier == "gene":
205
+ self.model_type = "GeneClassifier"
206
+ self.quantize = quantize
207
+ self.cell_state_dict = cell_state_dict
208
+ self.gene_class_dict = gene_class_dict
209
+ self.filter_data = filter_data
210
+ self.rare_threshold = rare_threshold
211
+ self.max_ncells = max_ncells
212
+ self.max_ncells_per_class = max_ncells_per_class
213
+ self.training_args = training_args
214
+ self.ray_config = ray_config
215
+ self.freeze_layers = freeze_layers
216
+ self.num_crossval_splits = num_crossval_splits
217
+ self.split_sizes = split_sizes
218
+ self.train_size = self.split_sizes["train"]
219
+ self.valid_size = self.split_sizes["valid"]
220
+ self.oos_test_size = self.split_sizes["test"]
221
+ self.eval_size = self.valid_size / (self.train_size + self.valid_size)
222
+ self.stratify_splits_col = stratify_splits_col
223
+ self.no_eval = no_eval
224
+ self.forward_batch_size = forward_batch_size
225
+ self.token_dictionary_file = token_dictionary_file
226
+ self.nproc = nproc
227
+ self.ngpu = ngpu
228
+
229
+ if self.training_args is None:
230
+ logger.warning(
231
+ "Hyperparameter tuning is highly recommended for optimal results. "
232
+ "No training_args provided; using default hyperparameters."
233
+ )
234
+
235
+ self.validate_options()
236
+
237
+ if self.filter_data is None:
238
+ self.filter_data = dict()
239
+
240
+ if self.classifier == "cell":
241
+ if self.cell_state_dict["states"] != "all":
242
+ self.filter_data[
243
+ self.cell_state_dict["state_key"]
244
+ ] = self.cell_state_dict["states"]
245
+
246
+ # load token dictionary (Ensembl IDs:token)
247
+ if self.token_dictionary_file is None:
248
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
+ with open(self.token_dictionary_file, "rb") as f:
250
+ self.gene_token_dict = pickle.load(f)
251
+
252
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
253
+
254
+ # filter genes for gene classification for those in token dictionary
255
+ if self.classifier == "gene":
256
+ all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
257
+ missing_genes = [
258
+ gene
259
+ for gene in all_gene_class_values
260
+ if gene not in self.gene_token_dict.keys()
261
+ ]
262
+ if len(missing_genes) == len(all_gene_class_values):
263
+ logger.error(
264
+ "None of the provided genes to classify are in token dictionary."
265
+ )
266
+ raise
267
+ elif len(missing_genes) > 0:
268
+ logger.warning(
269
+ f"Genes to classify {missing_genes} are not in token dictionary."
270
+ )
271
+ self.gene_class_dict = {
272
+ k: list(set([self.gene_token_dict.get(gene) for gene in v]))
273
+ for k, v in self.gene_class_dict.items()
274
+ }
275
+ empty_classes = []
276
+ for k, v in self.gene_class_dict.items():
277
+ if len(v) == 0:
278
+ empty_classes += [k]
279
+ if len(empty_classes) > 0:
280
+ logger.error(
281
+ f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
282
+ )
283
+ raise
284
+
285
+ def validate_options(self):
286
+ # confirm arguments are within valid options and compatible with each other
287
+ for attr_name, valid_options in self.valid_option_dict.items():
288
+ attr_value = self.__dict__[attr_name]
289
+ if not isinstance(attr_value, (list, dict)):
290
+ if attr_value in valid_options:
291
+ continue
292
+ valid_type = False
293
+ for option in valid_options:
294
+ if (option in [int, float, list, dict, bool, str]) and isinstance(
295
+ attr_value, option
296
+ ):
297
+ valid_type = True
298
+ break
299
+ if valid_type:
300
+ continue
301
+ logger.error(
302
+ f"Invalid option for {attr_name}. "
303
+ f"Valid options for {attr_name}: {valid_options}"
304
+ )
305
+ raise
306
+
307
+ if self.filter_data is not None:
308
+ for key, value in self.filter_data.items():
309
+ if not isinstance(value, list):
310
+ self.filter_data[key] = [value]
311
+ logger.warning(
312
+ "Values in filter_data dict must be lists. "
313
+ f"Changing {key} value to list ([{value}])."
314
+ )
315
+
316
+ if self.classifier == "cell":
317
+ if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
318
+ logger.error(
319
+ "Invalid keys for cell_state_dict. "
320
+ "The cell_state_dict should have only 2 keys: state_key and states"
321
+ )
322
+ raise
323
+
324
+ if self.cell_state_dict["states"] != "all":
325
+ if not isinstance(self.cell_state_dict["states"], list):
326
+ logger.error(
327
+ "States in cell_state_dict should be list of states to model."
328
+ )
329
+ raise
330
+ if len(self.cell_state_dict["states"]) < 2:
331
+ logger.error(
332
+ "States in cell_state_dict should contain at least 2 states to classify."
333
+ )
334
+ raise
335
+
336
+ if self.classifier == "gene":
337
+ if len(self.gene_class_dict.keys()) < 2:
338
+ logger.error(
339
+ "Gene_class_dict should contain at least 2 gene classes to classify."
340
+ )
341
+ raise
342
+ if sum(self.split_sizes.values()) != 1:
343
+ logger.error("Train, validation, and test proportions should sum to 1.")
344
+ raise
345
+
346
+ def prepare_data(
347
+ self,
348
+ input_data_file,
349
+ output_directory,
350
+ output_prefix,
351
+ split_id_dict=None,
352
+ test_size=None,
353
+ attr_to_split=None,
354
+ attr_to_balance=None,
355
+ max_trials=100,
356
+ pval_threshold=0.1,
357
+ ):
358
+ """
359
+ Prepare data for cell state or gene classification.
360
+
361
+ **Parameters**
362
+
363
+ input_data_file : Path
364
+ | Path to directory containing .dataset input
365
+ output_directory : Path
366
+ | Path to directory where prepared data will be saved
367
+ output_prefix : str
368
+ | Prefix for output file
369
+ split_id_dict : None, dict
370
+ | Dictionary of IDs for train and test splits
371
+ | Three-item dictionary with keys: attr_key, train, test
372
+ | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
373
+ | train: list of IDs in the attr_key column to include in the train split
374
+ | test: list of IDs in the attr_key column to include in the test split
375
+ | For example: {"attr_key": "individual",
376
+ | "train": ["patient1", "patient2", "patient3", "patient4"],
377
+ | "test": ["patient5", "patient6"]}
378
+ test_size : None, float
379
+ | Proportion of data to be saved separately and held out for test set
380
+ | (e.g. 0.2 if intending hold out 20%)
381
+ | If None, will inherit from split_sizes["test"] from Classifier
382
+ | The training set will be further split to train / validation in self.validate
383
+ | Note: only available for CellClassifiers
384
+ attr_to_split : None, str
385
+ | Key for attribute on which to split data while balancing potential confounders
386
+ | e.g. "patient_id" for splitting by patient while balancing other characteristics
387
+ | Note: only available for CellClassifiers
388
+ attr_to_balance : None, list
389
+ | List of attribute keys on which to balance data while splitting on attr_to_split
390
+ | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
391
+ | Note: only available for CellClassifiers
392
+ max_trials : None, int
393
+ | Maximum number of trials of random splitting to try to achieve balanced other attributes
394
+ | If no split is found without significant (p<0.05) differences in other attributes, will select best
395
+ | Note: only available for CellClassifiers
396
+ pval_threshold : None, float
397
+ | P-value threshold to use for attribute balancing across splits
398
+ | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
399
+ """
400
+
401
+ if test_size is None:
402
+ test_size = self.oos_test_size
403
+
404
+ # prepare data and labels for classification
405
+ data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
406
+
407
+ if self.classifier == "cell":
408
+ if "label" in data.features:
409
+ logger.error(
410
+ "Column name 'label' must be reserved for class IDs. Please rename column."
411
+ )
412
+ raise
413
+ elif self.classifier == "gene":
414
+ if "labels" in data.features:
415
+ logger.error(
416
+ "Column name 'labels' must be reserved for class IDs. Please rename column."
417
+ )
418
+ raise
419
+
420
+ if (attr_to_split is not None) and (attr_to_balance is None):
421
+ logger.error(
422
+ "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
423
+ )
424
+ raise
425
+
426
+ if not isinstance(attr_to_balance, list):
427
+ attr_to_balance = [attr_to_balance]
428
+
429
+ if self.classifier == "cell":
430
+ # remove cell states representing < rare_threshold of cells
431
+ data = cu.remove_rare(
432
+ data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
433
+ )
434
+ # downsample max cells and max per class
435
+ data = cu.downsample_and_shuffle(
436
+ data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
437
+ )
438
+ # rename cell state column to "label"
439
+ data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
+
441
+ # convert classes to numerical labels and save as id_class_dict
442
+ # of note, will label all genes in gene_class_dict
443
+ # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
+ # at the time of training with Classifier.validate
445
+ data, id_class_dict = cu.label_classes(
446
+ self.classifier, data, self.gene_class_dict, self.nproc
447
+ )
448
+
449
+ # save id_class_dict for future reference
450
+ id_class_output_path = (
451
+ Path(output_directory) / f"{output_prefix}_id_class_dict"
452
+ ).with_suffix(".pkl")
453
+ with open(id_class_output_path, "wb") as f:
454
+ pickle.dump(id_class_dict, f)
455
+
456
+ if split_id_dict is not None:
457
+ data_dict = dict()
458
+ data_dict["train"] = pu.filter_by_dict(
459
+ data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
460
+ )
461
+ data_dict["test"] = pu.filter_by_dict(
462
+ data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
463
+ )
464
+ train_data_output_path = (
465
+ Path(output_directory) / f"{output_prefix}_labeled_train"
466
+ ).with_suffix(".dataset")
467
+ test_data_output_path = (
468
+ Path(output_directory) / f"{output_prefix}_labeled_test"
469
+ ).with_suffix(".dataset")
470
+ data_dict["train"].save_to_disk(str(train_data_output_path))
471
+ data_dict["test"].save_to_disk(str(test_data_output_path))
472
+ elif (test_size is not None) and (self.classifier == "cell"):
473
+ if 1 > test_size > 0:
474
+ if attr_to_split is None:
475
+ data_dict = data.train_test_split(
476
+ test_size=test_size,
477
+ stratify_by_column=self.stratify_splits_col,
478
+ seed=42,
479
+ )
480
+ train_data_output_path = (
481
+ Path(output_directory) / f"{output_prefix}_labeled_train"
482
+ ).with_suffix(".dataset")
483
+ test_data_output_path = (
484
+ Path(output_directory) / f"{output_prefix}_labeled_test"
485
+ ).with_suffix(".dataset")
486
+ data_dict["train"].save_to_disk(str(train_data_output_path))
487
+ data_dict["test"].save_to_disk(str(test_data_output_path))
488
+ else:
489
+ data_dict, balance_df = cu.balance_attr_splits(
490
+ data,
491
+ attr_to_split,
492
+ attr_to_balance,
493
+ test_size,
494
+ max_trials,
495
+ pval_threshold,
496
+ self.cell_state_dict["state_key"],
497
+ self.nproc,
498
+ )
499
+ balance_df.to_csv(
500
+ f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
501
+ )
502
+ train_data_output_path = (
503
+ Path(output_directory) / f"{output_prefix}_labeled_train"
504
+ ).with_suffix(".dataset")
505
+ test_data_output_path = (
506
+ Path(output_directory) / f"{output_prefix}_labeled_test"
507
+ ).with_suffix(".dataset")
508
+ data_dict["train"].save_to_disk(str(train_data_output_path))
509
+ data_dict["test"].save_to_disk(str(test_data_output_path))
510
+ else:
511
+ data_output_path = (
512
+ Path(output_directory) / f"{output_prefix}_labeled"
513
+ ).with_suffix(".dataset")
514
+ data.save_to_disk(str(data_output_path))
515
+ print(data_output_path)
516
+ else:
517
+ data_output_path = (
518
+ Path(output_directory) / f"{output_prefix}_labeled"
519
+ ).with_suffix(".dataset")
520
+ data.save_to_disk(str(data_output_path))
521
+
522
+ def train_all_data(
523
+ self,
524
+ model_directory,
525
+ prepared_input_data_file,
526
+ id_class_dict_file,
527
+ output_directory,
528
+ output_prefix,
529
+ save_eval_output=True,
530
+ gene_balance=False,
531
+ ):
532
+ """
533
+ Train cell state or gene classifier using all data.
534
+
535
+ **Parameters**
536
+
537
+ model_directory : Path
538
+ | Path to directory containing model
539
+ prepared_input_data_file : Path
540
+ | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
541
+ id_class_dict_file : Path
542
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
543
+ | (dictionary of format: numerical IDs: class_labels)
544
+ output_directory : Path
545
+ | Path to directory where model and eval data will be saved
546
+ output_prefix : str
547
+ | Prefix for output files
548
+ save_eval_output : bool
549
+ | Whether to save cross-fold eval output
550
+ | Saves as pickle file of dictionary of eval metrics
551
+ gene_balance : None, bool
552
+ | Whether to automatically balance genes in training set.
553
+ | Only available for binary gene classifications.
554
+
555
+ **Output**
556
+
557
+ Returns trainer after fine-tuning with all data.
558
+
559
+ """
560
+
561
+ if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
562
+ logger.error(
563
+ "Automatically balancing gene sets for training is only available for binary gene classifications."
564
+ )
565
+ raise
566
+
567
+ ##### Load data and prepare output directory #####
568
+ # load numerical id to class dictionary (id:class)
569
+ with open(id_class_dict_file, "rb") as f:
570
+ id_class_dict = pickle.load(f)
571
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
572
+
573
+ # load previously filtered and prepared data
574
+ data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
575
+ data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
576
+
577
+ # define output directory path
578
+ current_date = datetime.datetime.now()
579
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
580
+ if output_directory[-1:] != "/": # add slash for dir if not present
581
+ output_directory = output_directory + "/"
582
+ output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
583
+ subprocess.call(f"mkdir {output_dir}", shell=True)
584
+
585
+ # get number of classes for classifier
586
+ num_classes = cu.get_num_classes(id_class_dict)
587
+
588
+ if self.classifier == "gene":
589
+ targets = pu.flatten_list(self.gene_class_dict.values())
590
+ labels = pu.flatten_list(
591
+ [
592
+ [class_id_dict[label]] * len(targets)
593
+ for label, targets in self.gene_class_dict.items()
594
+ ]
595
+ )
596
+ assert len(targets) == len(labels)
597
+ data = cu.prep_gene_classifier_all_data(
598
+ data, targets, labels, self.max_ncells, self.nproc, gene_balance
599
+ )
600
+
601
+ trainer = self.train_classifier(
602
+ model_directory, num_classes, data, None, output_dir
603
+ )
604
+
605
+ return trainer
606
+
607
+ def validate(
608
+ self,
609
+ model_directory,
610
+ prepared_input_data_file,
611
+ id_class_dict_file,
612
+ output_directory,
613
+ output_prefix,
614
+ split_id_dict=None,
615
+ attr_to_split=None,
616
+ attr_to_balance=None,
617
+ gene_balance=False,
618
+ max_trials=100,
619
+ pval_threshold=0.1,
620
+ save_eval_output=True,
621
+ predict_eval=True,
622
+ predict_trainer=False,
623
+ n_hyperopt_trials=0,
624
+ save_gene_split_datasets=True,
625
+ debug_gene_split_datasets=False,
626
+ ):
627
+ """
628
+ (Cross-)validate cell state or gene classifier.
629
+
630
+ **Parameters**
631
+
632
+ model_directory : Path
633
+ | Path to directory containing model
634
+ prepared_input_data_file : Path
635
+ | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
636
+ id_class_dict_file : Path
637
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
638
+ | (dictionary of format: numerical IDs: class_labels)
639
+ output_directory : Path
640
+ | Path to directory where model and eval data will be saved
641
+ output_prefix : str
642
+ | Prefix for output files
643
+ split_id_dict : None, dict
644
+ | Dictionary of IDs for train and eval splits
645
+ | Three-item dictionary with keys: attr_key, train, eval
646
+ | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
647
+ | train: list of IDs in the attr_key column to include in the train split
648
+ | eval: list of IDs in the attr_key column to include in the eval split
649
+ | For example: {"attr_key": "individual",
650
+ | "train": ["patient1", "patient2", "patient3", "patient4"],
651
+ | "eval": ["patient5", "patient6"]}
652
+ | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
653
+ attr_to_split : None, str
654
+ | Key for attribute on which to split data while balancing potential confounders
655
+ | e.g. "patient_id" for splitting by patient while balancing other characteristics
656
+ | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
657
+ attr_to_balance : None, list
658
+ | List of attribute keys on which to balance data while splitting on attr_to_split
659
+ | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
660
+ gene_balance : None, bool
661
+ | Whether to automatically balance genes in training set.
662
+ | Only available for binary gene classifications.
663
+ max_trials : None, int
664
+ | Maximum number of trials of random splitting to try to achieve balanced other attribute
665
+ | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
666
+ pval_threshold : None, float
667
+ | P-value threshold to use for attribute balancing across splits
668
+ | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
669
+ save_eval_output : bool
670
+ | Whether to save cross-fold eval output
671
+ | Saves as pickle file of dictionary of eval metrics
672
+ predict_eval : bool
673
+ | Whether or not to save eval predictions
674
+ | Saves as a pickle file of self.evaluate predictions
675
+ predict_trainer : bool
676
+ | Whether or not to save eval predictions from trainer
677
+ | Saves as a pickle file of trainer predictions
678
+ n_hyperopt_trials : int
679
+ | Number of trials to run for hyperparameter optimization
680
+ | If 0, will not optimize hyperparameters
681
+ save_gene_split_datasets : bool
682
+ | Whether or not to save train, valid, and test gene-labeled datasets
683
+ """
684
+ if self.num_crossval_splits == 0:
685
+ logger.error("num_crossval_splits must be 1 or 5 to validate.")
686
+ raise
687
+
688
+ if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
689
+ logger.error(
690
+ "Automatically balancing gene sets for training is only available for binary gene classifications."
691
+ )
692
+ raise
693
+
694
+ # ensure number of genes in each class is > 5 if validating model
695
+ if self.classifier == "gene":
696
+ insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
697
+ if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
698
+ logger.error(
699
+ f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
700
+ )
701
+ raise
702
+
703
+ ##### Load data and prepare output directory #####
704
+ # load numerical id to class dictionary (id:class)
705
+ with open(id_class_dict_file, "rb") as f:
706
+ id_class_dict = pickle.load(f)
707
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
708
+
709
+ # load previously filtered and prepared data
710
+ data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
711
+ data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
712
+
713
+ # define output directory path
714
+ current_date = datetime.datetime.now()
715
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
716
+ if output_directory[-1:] != "/": # add slash for dir if not present
717
+ output_directory = output_directory + "/"
718
+ output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
719
+ subprocess.call(f"mkdir {output_dir}", shell=True)
720
+
721
+ # get number of classes for classifier
722
+ num_classes = cu.get_num_classes(id_class_dict)
723
+
724
+ ##### (Cross-)validate the model #####
725
+ results = []
726
+ all_conf_mat = np.zeros((num_classes, num_classes))
727
+ iteration_num = 1
728
+ if self.classifier == "cell":
729
+ for i in trange(self.num_crossval_splits):
730
+ print(
731
+ f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
732
+ )
733
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
734
+ if self.num_crossval_splits == 1:
735
+ # single 1-eval_size:eval_size split
736
+ if split_id_dict is not None:
737
+ data_dict = dict()
738
+ data_dict["train"] = pu.filter_by_dict(
739
+ data,
740
+ {split_id_dict["attr_key"]: split_id_dict["train"]},
741
+ self.nproc,
742
+ )
743
+ data_dict["test"] = pu.filter_by_dict(
744
+ data,
745
+ {split_id_dict["attr_key"]: split_id_dict["eval"]},
746
+ self.nproc,
747
+ )
748
+ elif attr_to_split is not None:
749
+ data_dict, balance_df = cu.balance_attr_splits(
750
+ data,
751
+ attr_to_split,
752
+ attr_to_balance,
753
+ self.eval_size,
754
+ max_trials,
755
+ pval_threshold,
756
+ self.cell_state_dict["state_key"],
757
+ self.nproc,
758
+ )
759
+
760
+ balance_df.to_csv(
761
+ f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
762
+ )
763
+ else:
764
+ data_dict = data.train_test_split(
765
+ test_size=self.eval_size,
766
+ stratify_by_column=self.stratify_splits_col,
767
+ seed=42,
768
+ )
769
+ train_data = data_dict["train"]
770
+ eval_data = data_dict["test"]
771
+ else:
772
+ # 5-fold cross-validate
773
+ num_cells = len(data)
774
+ fifth_cells = int(np.floor(num_cells * 0.2))
775
+ num_eval = min((self.eval_size * num_cells), fifth_cells)
776
+ start = i * fifth_cells
777
+ end = start + num_eval
778
+ eval_indices = [j for j in range(start, end)]
779
+ train_indices = [
780
+ j for j in range(num_cells) if j not in eval_indices
781
+ ]
782
+ eval_data = data.select(eval_indices)
783
+ train_data = data.select(train_indices)
784
+ if n_hyperopt_trials == 0:
785
+ trainer = self.train_classifier(
786
+ model_directory,
787
+ num_classes,
788
+ train_data,
789
+ eval_data,
790
+ ksplit_output_dir,
791
+ predict_trainer,
792
+ )
793
+ else:
794
+ trainer = self.hyperopt_classifier(
795
+ model_directory,
796
+ num_classes,
797
+ train_data,
798
+ eval_data,
799
+ ksplit_output_dir,
800
+ n_trials=n_hyperopt_trials,
801
+ )
802
+ if iteration_num == self.num_crossval_splits:
803
+ return
804
+ else:
805
+ iteration_num = iteration_num + 1
806
+ continue
807
+
808
+ result = self.evaluate_model(
809
+ trainer.model,
810
+ num_classes,
811
+ id_class_dict,
812
+ eval_data,
813
+ predict_eval,
814
+ ksplit_output_dir,
815
+ output_prefix,
816
+ )
817
+ results += [result]
818
+ all_conf_mat = all_conf_mat + result["conf_mat"]
819
+ iteration_num = iteration_num + 1
820
+
821
+ elif self.classifier == "gene":
822
+ # set up (cross-)validation splits
823
+ targets = pu.flatten_list(self.gene_class_dict.values())
824
+ labels = pu.flatten_list(
825
+ [
826
+ [class_id_dict[label]] * len(targets)
827
+ for label, targets in self.gene_class_dict.items()
828
+ ]
829
+ )
830
+ assert len(targets) == len(labels)
831
+ n_splits = int(1 / (1 - self.train_size))
832
+ skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
833
+ # (Cross-)validate
834
+ test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
835
+ for train_index, eval_index, test_index in tqdm(
836
+ skf.split(targets, labels, test_ratio)
837
+ ):
838
+ print(
839
+ f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
840
+ )
841
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
842
+ # filter data for examples containing classes for this split
843
+ # subsample to max_ncells and relabel data in column "labels"
844
+ train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
845
+ data,
846
+ targets,
847
+ labels,
848
+ train_index,
849
+ eval_index,
850
+ self.max_ncells,
851
+ iteration_num,
852
+ self.nproc,
853
+ gene_balance,
854
+ )
855
+
856
+ if save_gene_split_datasets is True:
857
+ for split_name in ["train", "valid"]:
858
+ labeled_dataset_output_path = (
859
+ Path(output_dir)
860
+ / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
861
+ ).with_suffix(".dataset")
862
+ if split_name == "train":
863
+ train_data.save_to_disk(str(labeled_dataset_output_path))
864
+ elif split_name == "valid":
865
+ eval_data.save_to_disk(str(labeled_dataset_output_path))
866
+
867
+ if self.oos_test_size > 0:
868
+ test_data = cu.prep_gene_classifier_split(
869
+ data,
870
+ targets,
871
+ labels,
872
+ test_index,
873
+ "test",
874
+ self.max_ncells,
875
+ iteration_num,
876
+ self.nproc,
877
+ )
878
+ if save_gene_split_datasets is True:
879
+ test_labeled_dataset_output_path = (
880
+ Path(output_dir)
881
+ / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
882
+ ).with_suffix(".dataset")
883
+ test_data.save_to_disk(str(test_labeled_dataset_output_path))
884
+ if debug_gene_split_datasets is True:
885
+ logger.error(
886
+ "Exiting after saving gene split datasets given debug_gene_split_datasets = True."
887
+ )
888
+ raise
889
+ if n_hyperopt_trials == 0:
890
+ trainer = self.train_classifier(
891
+ model_directory,
892
+ num_classes,
893
+ train_data,
894
+ eval_data,
895
+ ksplit_output_dir,
896
+ predict_trainer,
897
+ )
898
+ result = self.evaluate_model(
899
+ trainer.model,
900
+ num_classes,
901
+ id_class_dict,
902
+ eval_data,
903
+ predict_eval,
904
+ ksplit_output_dir,
905
+ output_prefix,
906
+ )
907
+ else:
908
+ trainer = self.hyperopt_classifier(
909
+ model_directory,
910
+ num_classes,
911
+ train_data,
912
+ eval_data,
913
+ ksplit_output_dir,
914
+ n_trials=n_hyperopt_trials,
915
+ )
916
+
917
+ model = cu.load_best_model(
918
+ ksplit_output_dir, self.model_type, num_classes
919
+ )
920
+
921
+ if self.oos_test_size > 0:
922
+ result = self.evaluate_model(
923
+ model,
924
+ num_classes,
925
+ id_class_dict,
926
+ test_data,
927
+ predict_eval,
928
+ ksplit_output_dir,
929
+ output_prefix,
930
+ )
931
+ else:
932
+ if iteration_num == self.num_crossval_splits:
933
+ return
934
+ else:
935
+ iteration_num = iteration_num + 1
936
+ continue
937
+ results += [result]
938
+ all_conf_mat = all_conf_mat + result["conf_mat"]
939
+ # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
940
+ if iteration_num == self.num_crossval_splits:
941
+ break
942
+ iteration_num = iteration_num + 1
943
+
944
+ all_conf_mat_df = pd.DataFrame(
945
+ all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
946
+ )
947
+ all_metrics = {
948
+ "conf_matrix": all_conf_mat_df,
949
+ "macro_f1": [result["macro_f1"] for result in results],
950
+ "acc": [result["acc"] for result in results],
951
+ }
952
+ all_roc_metrics = None # roc metrics not reported for multiclass
953
+ if num_classes == 2:
954
+ mean_fpr = np.linspace(0, 1, 100)
955
+ all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
956
+ all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
957
+ all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
958
+ mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
959
+ all_tpr, all_roc_auc, all_tpr_wt
960
+ )
961
+ all_roc_metrics = {
962
+ "mean_tpr": mean_tpr,
963
+ "mean_fpr": mean_fpr,
964
+ "all_roc_auc": all_roc_auc,
965
+ "roc_auc": roc_auc,
966
+ "roc_auc_sd": roc_auc_sd,
967
+ }
968
+ all_metrics["all_roc_metrics"] = all_roc_metrics
969
+ if save_eval_output is True:
970
+ eval_metrics_output_path = (
971
+ Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
972
+ ).with_suffix(".pkl")
973
+ with open(eval_metrics_output_path, "wb") as f:
974
+ pickle.dump(all_metrics, f)
975
+
976
+ return all_metrics
977
+
978
+ def hyperopt_classifier(
979
+ self,
980
+ model_directory,
981
+ num_classes,
982
+ train_data,
983
+ eval_data,
984
+ output_directory,
985
+ n_trials=100,
986
+ ):
987
+ """
988
+ Fine-tune model for cell state or gene classification.
989
+
990
+ **Parameters**
991
+
992
+ model_directory : Path
993
+ | Path to directory containing model
994
+ num_classes : int
995
+ | Number of classes for classifier
996
+ train_data : Dataset
997
+ | Loaded training .dataset input
998
+ | For cell classifier, labels in column "label".
999
+ | For gene classifier, labels in column "labels".
1000
+ eval_data : None, Dataset
1001
+ | (Optional) Loaded evaluation .dataset input
1002
+ | For cell classifier, labels in column "label".
1003
+ | For gene classifier, labels in column "labels".
1004
+ output_directory : Path
1005
+ | Path to directory where fine-tuned model will be saved
1006
+ n_trials : int
1007
+ | Number of trials to run for hyperparameter optimization
1008
+ """
1009
+
1010
+ # initiate runtime environment for raytune
1011
+ import ray
1012
+ from ray import tune
1013
+ from ray.tune.search.hyperopt import HyperOptSearch
1014
+
1015
+ ray.shutdown() # engage new ray session
1016
+ ray.init()
1017
+
1018
+ ##### Validate and prepare data #####
1019
+ train_data, eval_data = cu.validate_and_clean_cols(
1020
+ train_data, eval_data, self.classifier
1021
+ )
1022
+
1023
+ if (self.no_eval is True) and (eval_data is not None):
1024
+ logger.warning(
1025
+ "no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
1026
+ )
1027
+
1028
+ # ensure not overwriting previously saved model
1029
+ saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1030
+ if os.path.isfile(saved_model_test) is True:
1031
+ logger.error("Model already saved to this designated output directory.")
1032
+ raise
1033
+ # make output directory
1034
+ subprocess.call(f"mkdir {output_directory}", shell=True)
1035
+
1036
+ ##### Load model and training args #####
1037
+ model = pu.load_model(
1038
+ self.model_type,
1039
+ num_classes,
1040
+ model_directory,
1041
+ "train",
1042
+ quantize=self.quantize,
1043
+ )
1044
+ def_training_args, def_freeze_layers = cu.get_default_train_args(
1045
+ model, self.classifier, train_data, output_directory
1046
+ )
1047
+ del model
1048
+
1049
+ if self.training_args is not None:
1050
+ def_training_args.update(self.training_args)
1051
+ logging_steps = round(
1052
+ len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1053
+ )
1054
+ def_training_args["logging_steps"] = logging_steps
1055
+ def_training_args["output_dir"] = output_directory
1056
+ if eval_data is None:
1057
+ def_training_args["evaluation_strategy"] = "no"
1058
+ def_training_args["load_best_model_at_end"] = False
1059
+ def_training_args.update(
1060
+ {"save_strategy": "epoch", "save_total_limit": 1}
1061
+ ) # only save last model for each run
1062
+ training_args_init = TrainingArguments(**def_training_args)
1063
+
1064
+ ##### Fine-tune the model #####
1065
+ # define the data collator
1066
+ if self.classifier == "cell":
1067
+ data_collator = DataCollatorForCellClassification(
1068
+ token_dictionary=self.gene_token_dict
1069
+ )
1070
+ elif self.classifier == "gene":
1071
+ data_collator = DataCollatorForGeneClassification(
1072
+ token_dictionary=self.gene_token_dict
1073
+ )
1074
+
1075
+ # define function to initiate model
1076
+ def model_init():
1077
+ model = pu.load_model(
1078
+ self.model_type,
1079
+ num_classes,
1080
+ model_directory,
1081
+ "train",
1082
+ quantize=self.quantize,
1083
+ )
1084
+
1085
+ if self.freeze_layers is not None:
1086
+ def_freeze_layers = self.freeze_layers
1087
+
1088
+ if def_freeze_layers > 0:
1089
+ modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1090
+ for module in modules_to_freeze:
1091
+ for param in module.parameters():
1092
+ param.requires_grad = False
1093
+
1094
+ if self.quantize is False:
1095
+ model = model.to("cuda:0")
1096
+ return model
1097
+
1098
+ # create the trainer
1099
+ trainer = Trainer(
1100
+ model_init=model_init,
1101
+ args=training_args_init,
1102
+ data_collator=data_collator,
1103
+ train_dataset=train_data,
1104
+ eval_dataset=eval_data,
1105
+ compute_metrics=cu.compute_metrics,
1106
+ )
1107
+
1108
+ # specify raytune hyperparameter search space
1109
+ if self.ray_config is None:
1110
+ logger.warning(
1111
+ "No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
1112
+ )
1113
+ def_ray_config = {
1114
+ "num_train_epochs": tune.choice([1]),
1115
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
1116
+ "weight_decay": tune.uniform(0.0, 0.3),
1117
+ "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
1118
+ "warmup_steps": tune.uniform(100, 2000),
1119
+ "seed": tune.uniform(0, 100),
1120
+ "per_device_train_batch_size": tune.choice(
1121
+ [def_training_args["per_device_train_batch_size"]]
1122
+ ),
1123
+ }
1124
+
1125
+ hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
1126
+
1127
+ # optimize hyperparameters
1128
+ trainer.hyperparameter_search(
1129
+ direction="maximize",
1130
+ backend="ray",
1131
+ resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
1132
+ hp_space=lambda _: def_ray_config
1133
+ if self.ray_config is None
1134
+ else self.ray_config,
1135
+ search_alg=hyperopt_search,
1136
+ n_trials=n_trials, # number of trials
1137
+ progress_reporter=tune.CLIReporter(
1138
+ max_report_frequency=600,
1139
+ sort_by_metric=True,
1140
+ max_progress_rows=n_trials,
1141
+ mode="max",
1142
+ metric="eval_macro_f1",
1143
+ metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1144
+ ),
1145
+ storage_path=output_directory,
1146
+ )
1147
+
1148
+ return trainer
1149
+
1150
+ def train_classifier(
1151
+ self,
1152
+ model_directory,
1153
+ num_classes,
1154
+ train_data,
1155
+ eval_data,
1156
+ output_directory,
1157
+ predict=False,
1158
+ ):
1159
+ """
1160
+ Fine-tune model for cell state or gene classification.
1161
+
1162
+ **Parameters**
1163
+
1164
+ model_directory : Path
1165
+ | Path to directory containing model
1166
+ num_classes : int
1167
+ | Number of classes for classifier
1168
+ train_data : Dataset
1169
+ | Loaded training .dataset input
1170
+ | For cell classifier, labels in column "label".
1171
+ | For gene classifier, labels in column "labels".
1172
+ eval_data : None, Dataset
1173
+ | (Optional) Loaded evaluation .dataset input
1174
+ | For cell classifier, labels in column "label".
1175
+ | For gene classifier, labels in column "labels".
1176
+ output_directory : Path
1177
+ | Path to directory where fine-tuned model will be saved
1178
+ predict : bool
1179
+ | Whether or not to save eval predictions from trainer
1180
+ """
1181
+
1182
+ ##### Validate and prepare data #####
1183
+ train_data, eval_data = cu.validate_and_clean_cols(
1184
+ train_data, eval_data, self.classifier
1185
+ )
1186
+
1187
+ if (self.no_eval is True) and (eval_data is not None):
1188
+ logger.warning(
1189
+ "no_eval set to True; model will be trained without evaluation."
1190
+ )
1191
+ eval_data = None
1192
+
1193
+ if (self.classifier == "gene") and (predict is True):
1194
+ logger.warning(
1195
+ "Predictions during training not currently available for gene classifiers; setting predict to False."
1196
+ )
1197
+ predict = False
1198
+
1199
+ # ensure not overwriting previously saved model
1200
+ saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1201
+ if os.path.isfile(saved_model_test) is True:
1202
+ logger.error("Model already saved to this designated output directory.")
1203
+ raise
1204
+ # make output directory
1205
+ subprocess.call(f"mkdir {output_directory}", shell=True)
1206
+
1207
+ ##### Load model and training args #####
1208
+ model = pu.load_model(
1209
+ self.model_type,
1210
+ num_classes,
1211
+ model_directory,
1212
+ "train",
1213
+ quantize=self.quantize,
1214
+ )
1215
+
1216
+ def_training_args, def_freeze_layers = cu.get_default_train_args(
1217
+ model, self.classifier, train_data, output_directory
1218
+ )
1219
+
1220
+ if self.training_args is not None:
1221
+ def_training_args.update(self.training_args)
1222
+ logging_steps = round(
1223
+ len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1224
+ )
1225
+ def_training_args["logging_steps"] = logging_steps
1226
+ def_training_args["output_dir"] = output_directory
1227
+ if eval_data is None:
1228
+ def_training_args["evaluation_strategy"] = "no"
1229
+ def_training_args["load_best_model_at_end"] = False
1230
+ training_args_init = TrainingArguments(**def_training_args)
1231
+
1232
+ if self.freeze_layers is not None:
1233
+ def_freeze_layers = self.freeze_layers
1234
+
1235
+ if def_freeze_layers > 0:
1236
+ modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1237
+ for module in modules_to_freeze:
1238
+ for param in module.parameters():
1239
+ param.requires_grad = False
1240
+
1241
+ ##### Fine-tune the model #####
1242
+ # define the data collator
1243
+ if self.classifier == "cell":
1244
+ data_collator = DataCollatorForCellClassification(
1245
+ token_dictionary=self.gene_token_dict
1246
+ )
1247
+ elif self.classifier == "gene":
1248
+ data_collator = DataCollatorForGeneClassification(
1249
+ token_dictionary=self.gene_token_dict
1250
+ )
1251
+
1252
+ # create the trainer
1253
+ trainer = Trainer(
1254
+ model=model,
1255
+ args=training_args_init,
1256
+ data_collator=data_collator,
1257
+ train_dataset=train_data,
1258
+ eval_dataset=eval_data,
1259
+ compute_metrics=cu.compute_metrics,
1260
+ )
1261
+
1262
+ # train the classifier
1263
+ trainer.train()
1264
+ trainer.save_model(output_directory)
1265
+ if predict is True:
1266
+ # make eval predictions and save predictions and metrics
1267
+ predictions = trainer.predict(eval_data)
1268
+ prediction_output_path = f"{output_directory}/predictions.pkl"
1269
+ with open(prediction_output_path, "wb") as f:
1270
+ pickle.dump(predictions, f)
1271
+ trainer.save_metrics("eval", predictions.metrics)
1272
+ return trainer
1273
+
1274
+ def evaluate_model(
1275
+ self,
1276
+ model,
1277
+ num_classes,
1278
+ id_class_dict,
1279
+ eval_data,
1280
+ predict=False,
1281
+ output_directory=None,
1282
+ output_prefix=None,
1283
+ ):
1284
+ """
1285
+ Evaluate the fine-tuned model.
1286
+
1287
+ **Parameters**
1288
+
1289
+ model : nn.Module
1290
+ | Loaded fine-tuned model (e.g. trainer.model)
1291
+ num_classes : int
1292
+ | Number of classes for classifier
1293
+ id_class_dict : dict
1294
+ | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
1295
+ | (dictionary of format: numerical IDs: class_labels)
1296
+ eval_data : Dataset
1297
+ | Loaded evaluation .dataset input
1298
+ predict : bool
1299
+ | Whether or not to save eval predictions
1300
+ output_directory : Path
1301
+ | Path to directory where eval data will be saved
1302
+ output_prefix : str
1303
+ | Prefix for output files
1304
+ """
1305
+
1306
+ ##### Evaluate the model #####
1307
+ labels = id_class_dict.keys()
1308
+ y_pred, y_true, logits_list = eu.classifier_predict(
1309
+ model, self.classifier, eval_data, self.forward_batch_size
1310
+ )
1311
+ conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1312
+ y_pred, y_true, logits_list, num_classes, labels
1313
+ )
1314
+ if predict is True:
1315
+ pred_dict = {
1316
+ "pred_ids": y_pred,
1317
+ "label_ids": y_true,
1318
+ "predictions": logits_list,
1319
+ }
1320
+ pred_dict_output_path = (
1321
+ Path(output_directory) / f"{output_prefix}_pred_dict"
1322
+ ).with_suffix(".pkl")
1323
+ with open(pred_dict_output_path, "wb") as f:
1324
+ pickle.dump(pred_dict, f)
1325
+ return {
1326
+ "conf_mat": conf_mat,
1327
+ "macro_f1": macro_f1,
1328
+ "acc": acc,
1329
+ "roc_metrics": roc_metrics,
1330
+ }
1331
+
1332
+ def evaluate_saved_model(
1333
+ self,
1334
+ model_directory,
1335
+ id_class_dict_file,
1336
+ test_data_file,
1337
+ output_directory,
1338
+ output_prefix,
1339
+ predict=True,
1340
+ ):
1341
+ """
1342
+ Evaluate the fine-tuned model.
1343
+
1344
+ **Parameters**
1345
+
1346
+ model_directory : Path
1347
+ | Path to directory containing model
1348
+ id_class_dict_file : Path
1349
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1350
+ | (dictionary of format: numerical IDs: class_labels)
1351
+ test_data_file : Path
1352
+ | Path to directory containing test .dataset
1353
+ output_directory : Path
1354
+ | Path to directory where eval data will be saved
1355
+ output_prefix : str
1356
+ | Prefix for output files
1357
+ predict : bool
1358
+ | Whether or not to save eval predictions
1359
+ """
1360
+
1361
+ # load numerical id to class dictionary (id:class)
1362
+ with open(id_class_dict_file, "rb") as f:
1363
+ id_class_dict = pickle.load(f)
1364
+
1365
+ # get number of classes for classifier
1366
+ num_classes = cu.get_num_classes(id_class_dict)
1367
+
1368
+ # load previously filtered and prepared data
1369
+ test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
+
1371
+ # load previously fine-tuned model
1372
+ model = pu.load_model(
1373
+ self.model_type,
1374
+ num_classes,
1375
+ model_directory,
1376
+ "eval",
1377
+ quantize=self.quantize,
1378
+ )
1379
+
1380
+ # evaluate the model
1381
+ result = self.evaluate_model(
1382
+ model,
1383
+ num_classes,
1384
+ id_class_dict,
1385
+ test_data,
1386
+ predict=predict,
1387
+ output_directory=output_directory,
1388
+ output_prefix=output_prefix,
1389
+ )
1390
+
1391
+ all_conf_mat_df = pd.DataFrame(
1392
+ result["conf_mat"],
1393
+ columns=id_class_dict.values(),
1394
+ index=id_class_dict.values(),
1395
+ )
1396
+ all_metrics = {
1397
+ "conf_matrix": all_conf_mat_df,
1398
+ "macro_f1": result["macro_f1"],
1399
+ "acc": result["acc"],
1400
+ }
1401
+ all_roc_metrics = None # roc metrics not reported for multiclass
1402
+
1403
+ if num_classes == 2:
1404
+ mean_fpr = np.linspace(0, 1, 100)
1405
+ mean_tpr = result["roc_metrics"]["interp_tpr"]
1406
+ all_roc_auc = result["roc_metrics"]["auc"]
1407
+ all_roc_metrics = {
1408
+ "mean_tpr": mean_tpr,
1409
+ "mean_fpr": mean_fpr,
1410
+ "all_roc_auc": all_roc_auc,
1411
+ }
1412
+ all_metrics["all_roc_metrics"] = all_roc_metrics
1413
+ test_metrics_output_path = (
1414
+ Path(output_directory) / f"{output_prefix}_test_metrics_dict"
1415
+ ).with_suffix(".pkl")
1416
+ with open(test_metrics_output_path, "wb") as f:
1417
+ pickle.dump(all_metrics, f)
1418
+
1419
+ return all_metrics
1420
+
1421
+ def plot_conf_mat(
1422
+ self,
1423
+ conf_mat_dict,
1424
+ output_directory,
1425
+ output_prefix,
1426
+ custom_class_order=None,
1427
+ ):
1428
+ """
1429
+ Plot confusion matrix results of evaluating the fine-tuned model.
1430
+
1431
+ **Parameters**
1432
+
1433
+ conf_mat_dict : dict
1434
+ | Dictionary of model_name : confusion_matrix_DataFrame
1435
+ | (all_metrics["conf_matrix"] from self.validate)
1436
+ output_directory : Path
1437
+ | Path to directory where plots will be saved
1438
+ output_prefix : str
1439
+ | Prefix for output file
1440
+ custom_class_order : None, list
1441
+ | List of classes in custom order for plots.
1442
+ | Same order will be used for all models.
1443
+ """
1444
+
1445
+ for model_name in conf_mat_dict.keys():
1446
+ eu.plot_confusion_matrix(
1447
+ conf_mat_dict[model_name],
1448
+ model_name,
1449
+ output_directory,
1450
+ output_prefix,
1451
+ custom_class_order,
1452
+ )
1453
+
1454
+ def plot_roc(
1455
+ self,
1456
+ roc_metric_dict,
1457
+ model_style_dict,
1458
+ title,
1459
+ output_directory,
1460
+ output_prefix,
1461
+ ):
1462
+ """
1463
+ Plot ROC curve results of evaluating the fine-tuned model.
1464
+
1465
+ **Parameters**
1466
+
1467
+ roc_metric_dict : dict
1468
+ | Dictionary of model_name : roc_metrics
1469
+ | (all_metrics["all_roc_metrics"] from self.validate)
1470
+ model_style_dict : dict[dict]
1471
+ | Dictionary of model_name : dictionary of style_attribute : style
1472
+ | where style includes color and linestyle
1473
+ | e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
1474
+ title : str
1475
+ | Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
1476
+ output_directory : Path
1477
+ | Path to directory where plots will be saved
1478
+ output_prefix : str
1479
+ | Prefix for output file
1480
+ """
1481
+
1482
+ eu.plot_ROC(
1483
+ roc_metric_dict, model_style_dict, title, output_directory, output_prefix
1484
+ )
1485
+
1486
+ def plot_predictions(
1487
+ self,
1488
+ predictions_file,
1489
+ id_class_dict_file,
1490
+ title,
1491
+ output_directory,
1492
+ output_prefix,
1493
+ custom_class_order=None,
1494
+ kwargs_dict=None,
1495
+ ):
1496
+ """
1497
+ Plot prediction results of evaluating the fine-tuned model.
1498
+
1499
+ **Parameters**
1500
+
1501
+ predictions_file : path
1502
+ | Path of model predictions output to plot
1503
+ | (saved output from self.validate if predict_eval=True)
1504
+ | (or saved output from self.evaluate_saved_model)
1505
+ id_class_dict_file : Path
1506
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1507
+ | (dictionary of format: numerical IDs: class_labels)
1508
+ title : str
1509
+ | Title for legend containing class labels.
1510
+ output_directory : Path
1511
+ | Path to directory where plots will be saved
1512
+ output_prefix : str
1513
+ | Prefix for output file
1514
+ custom_class_order : None, list
1515
+ | List of classes in custom order for plots.
1516
+ | Same order will be used for all models.
1517
+ kwargs_dict : None, dict
1518
+ | Dictionary of kwargs to pass to plotting function.
1519
+ """
1520
+ # load predictions
1521
+ with open(predictions_file, "rb") as f:
1522
+ predictions = pickle.load(f)
1523
+
1524
+ # load numerical id to class dictionary (id:class)
1525
+ with open(id_class_dict_file, "rb") as f:
1526
+ id_class_dict = pickle.load(f)
1527
+
1528
+ if isinstance(predictions, dict):
1529
+ if all(
1530
+ [
1531
+ key in predictions.keys()
1532
+ for key in ["pred_ids", "label_ids", "predictions"]
1533
+ ]
1534
+ ):
1535
+ # format is output from self.evaluate_saved_model
1536
+ predictions_logits = np.array(predictions["predictions"])
1537
+ true_ids = predictions["label_ids"]
1538
+ else:
1539
+ # format is output from self.validate if predict_eval=True
1540
+ predictions_logits = predictions.predictions
1541
+ true_ids = predictions.label_ids
1542
+
1543
+ num_classes = len(id_class_dict.keys())
1544
+ num_predict_classes = predictions_logits.shape[1]
1545
+ assert num_classes == num_predict_classes
1546
+ classes = id_class_dict.values()
1547
+ true_labels = [id_class_dict[idx] for idx in true_ids]
1548
+ predictions_df = pd.DataFrame(predictions_logits, columns=classes)
1549
+ if custom_class_order is not None:
1550
+ predictions_df = predictions_df.reindex(columns=custom_class_order)
1551
+ predictions_df["true"] = true_labels
1552
+ custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
1553
+ if custom_class_order is not None:
1554
+ custom_dict = dict(
1555
+ zip(custom_class_order, [i for i in range(len(custom_class_order))])
1556
+ )
1557
+ predictions_df = predictions_df.sort_values(
1558
+ by=["true"], key=lambda x: x.map(custom_dict)
1559
+ )
1560
+
1561
+ eu.plot_predictions(
1562
+ predictions_df, title, output_directory, output_prefix, kwargs_dict
1563
+ )
geneformer/classifier_utils.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ from collections import Counter, defaultdict
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from scipy.stats import chisquare, ranksums
10
+ from sklearn.metrics import accuracy_score, f1_score
11
+ from sklearn.model_selection import StratifiedKFold, train_test_split
12
+
13
+ from . import perturber_utils as pu
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
19
+ data = data.shuffle(seed=42)
20
+ num_cells = len(data)
21
+ # if max number of cells is defined, then subsample to this max number
22
+ if max_ncells is not None:
23
+ if num_cells > max_ncells:
24
+ data = data.select([i for i in range(max_ncells)])
25
+ if max_ncells_per_class is not None:
26
+ class_labels = data[cell_state_dict["state_key"]]
27
+ random.seed(42)
28
+ subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
29
+ data = data.select(subsample_indices)
30
+ return data
31
+
32
+
33
+ # subsample labels to maximum number N per class and return indices
34
+ def subsample_by_class(labels, N):
35
+ label_indices = defaultdict(list)
36
+ # Gather indices for each label
37
+ for idx, label in enumerate(labels):
38
+ label_indices[label].append(idx)
39
+ selected_indices = []
40
+ # Select up to N indices for each label
41
+ for label, indices in label_indices.items():
42
+ if len(indices) > N:
43
+ selected_indices.extend(random.sample(indices, N))
44
+ else:
45
+ selected_indices.extend(indices)
46
+ return selected_indices
47
+
48
+
49
+ def rename_cols(data, state_key):
50
+ data = data.rename_column(state_key, "label")
51
+ return data
52
+
53
+
54
+ def validate_and_clean_cols(train_data, eval_data, classifier):
55
+ # validate that data has expected label column and remove others
56
+ if classifier == "cell":
57
+ label_col = "label"
58
+ elif classifier == "gene":
59
+ label_col = "labels"
60
+
61
+ cols_to_keep = [label_col] + ["input_ids", "length"]
62
+ if label_col not in train_data.column_names:
63
+ logger.error(f"train_data must contain column {label_col} with class labels.")
64
+ raise
65
+ else:
66
+ train_data = remove_cols(train_data, cols_to_keep)
67
+
68
+ if eval_data is not None:
69
+ if label_col not in eval_data.column_names:
70
+ logger.error(
71
+ f"eval_data must contain column {label_col} with class labels."
72
+ )
73
+ raise
74
+ else:
75
+ eval_data = remove_cols(eval_data, cols_to_keep)
76
+ return train_data, eval_data
77
+
78
+
79
+ def remove_cols(data, cols_to_keep):
80
+ other_cols = list(data.features.keys())
81
+ other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
82
+ data = data.remove_columns(other_cols)
83
+ return data
84
+
85
+
86
+ def remove_rare(data, rare_threshold, label, nproc):
87
+ if rare_threshold > 0:
88
+ total_cells = len(data)
89
+ label_counter = Counter(data[label])
90
+ nonrare_label_dict = {
91
+ label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
92
+ }
93
+ data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
94
+ return data
95
+
96
+
97
+ def label_classes(classifier, data, gene_class_dict, nproc):
98
+ if classifier == "cell":
99
+ label_set = set(data["label"])
100
+ elif classifier == "gene":
101
+ # remove cells without any of the target genes
102
+ def if_contains_label(example):
103
+ a = pu.flatten_list(gene_class_dict.values())
104
+ b = example["input_ids"]
105
+ return not set(a).isdisjoint(b)
106
+
107
+ data = data.filter(if_contains_label, num_proc=nproc)
108
+ label_set = gene_class_dict.keys()
109
+
110
+ if len(data) == 0:
111
+ logger.error(
112
+ "No cells remain after filtering for target genes. Check target gene list."
113
+ )
114
+ raise
115
+
116
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
118
+
119
+ def classes_to_ids(example):
120
+ if classifier == "cell":
121
+ example["label"] = class_id_dict[example["label"]]
122
+ elif classifier == "gene":
123
+ example["labels"] = label_gene_classes(
124
+ example, class_id_dict, gene_class_dict
125
+ )
126
+ return example
127
+
128
+ data = data.map(classes_to_ids, num_proc=nproc)
129
+ return data, id_class_dict
130
+
131
+
132
+ def label_gene_classes(example, class_id_dict, gene_class_dict):
133
+ return [
134
+ class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
135
+ for token_id in example["input_ids"]
136
+ ]
137
+
138
+
139
+ def prep_gene_classifier_train_eval_split(
140
+ data,
141
+ targets,
142
+ labels,
143
+ train_index,
144
+ eval_index,
145
+ max_ncells,
146
+ iteration_num,
147
+ num_proc,
148
+ balance=False,
149
+ ):
150
+ # generate cross-validation splits
151
+ train_data = prep_gene_classifier_split(
152
+ data,
153
+ targets,
154
+ labels,
155
+ train_index,
156
+ "train",
157
+ max_ncells,
158
+ iteration_num,
159
+ num_proc,
160
+ balance,
161
+ )
162
+ eval_data = prep_gene_classifier_split(
163
+ data,
164
+ targets,
165
+ labels,
166
+ eval_index,
167
+ "eval",
168
+ max_ncells,
169
+ iteration_num,
170
+ num_proc,
171
+ balance,
172
+ )
173
+ return train_data, eval_data
174
+
175
+
176
+ def prep_gene_classifier_split(
177
+ data,
178
+ targets,
179
+ labels,
180
+ index,
181
+ subset_name,
182
+ max_ncells,
183
+ iteration_num,
184
+ num_proc,
185
+ balance=False,
186
+ ):
187
+ # generate cross-validation splits
188
+ targets = np.array(targets)
189
+ labels = np.array(labels)
190
+ targets_subset = targets[index]
191
+ labels_subset = labels[index]
192
+ label_dict_subset = dict(zip(targets_subset, labels_subset))
193
+
194
+ # function to filter by whether contains train or eval labels
195
+ def if_contains_subset_label(example):
196
+ a = targets_subset
197
+ b = example["input_ids"]
198
+ return not set(a).isdisjoint(b)
199
+
200
+ # filter dataset for examples containing classes for this split
201
+ logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
202
+ subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
203
+ logger.info(
204
+ f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
205
+ )
206
+
207
+ # balance gene subsets if train
208
+ if (subset_name == "train") and (balance is True):
209
+ subset_data, label_dict_subset = balance_gene_split(
210
+ subset_data, label_dict_subset, num_proc
211
+ )
212
+
213
+ # subsample to max_ncells
214
+ subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
215
+
216
+ # relabel genes for this split
217
+ def subset_classes_to_ids(example):
218
+ example["labels"] = [
219
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
220
+ ]
221
+ return example
222
+
223
+ subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
224
+
225
+ return subset_data
226
+
227
+
228
+ def prep_gene_classifier_all_data(
229
+ data, targets, labels, max_ncells, num_proc, balance=False
230
+ ):
231
+ targets = np.array(targets)
232
+ labels = np.array(labels)
233
+ label_dict_train = dict(zip(targets, labels))
234
+
235
+ # function to filter by whether contains train labels
236
+ def if_contains_train_label(example):
237
+ a = targets
238
+ b = example["input_ids"]
239
+ return not set(a).isdisjoint(b)
240
+
241
+ # filter dataset for examples containing classes for this split
242
+ logger.info("Filtering training data for genes to classify.")
243
+ train_data = data.filter(if_contains_train_label, num_proc=num_proc)
244
+ logger.info(
245
+ f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
246
+ )
247
+
248
+ if balance is True:
249
+ train_data, label_dict_train = balance_gene_split(
250
+ train_data, label_dict_train, num_proc
251
+ )
252
+
253
+ # subsample to max_ncells
254
+ train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
255
+
256
+ # relabel genes for this split
257
+ def train_classes_to_ids(example):
258
+ example["labels"] = [
259
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
260
+ ]
261
+ return example
262
+
263
+ train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
264
+
265
+ return train_data
266
+
267
+
268
+ def balance_gene_split(subset_data, label_dict_subset, num_proc):
269
+ # count occurrence of genes in each label category
270
+ label0_counts, label1_counts = count_genes_for_balancing(
271
+ subset_data, label_dict_subset, num_proc
272
+ )
273
+ label_ratio_0to1 = label0_counts / label1_counts
274
+
275
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
276
+ # gene sets already balanced
277
+ logger.info(
278
+ "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
279
+ )
280
+ return subset_data, label_dict_subset
281
+ else:
282
+ label_ratio_0to1_orig = label_ratio_0to1 + 0
283
+ label_dict_subset_orig = label_dict_subset.copy()
284
+ # balance gene sets
285
+ max_ntrials = 25
286
+ boost = 1
287
+ if label_ratio_0to1 > 10 / 8:
288
+ # downsample label 0
289
+ for i in range(max_ntrials):
290
+ label0 = 0
291
+ label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
292
+ label0_ngenes = len(label0_genes)
293
+ label0_nremove = max(
294
+ 1,
295
+ int(
296
+ np.floor(
297
+ label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
298
+ )
299
+ ),
300
+ )
301
+ random.seed(i)
302
+ label0_remove_genes = random.sample(label0_genes, label0_nremove)
303
+ label_dict_subset_new = {
304
+ k: v
305
+ for k, v in label_dict_subset.items()
306
+ if k not in label0_remove_genes
307
+ }
308
+ label0_counts, label1_counts = count_genes_for_balancing(
309
+ subset_data, label_dict_subset_new, num_proc
310
+ )
311
+ label_ratio_0to1 = label0_counts / label1_counts
312
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
313
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
314
+ return filter_data_balanced_genes(
315
+ subset_data, label_dict_subset_new, num_proc
316
+ )
317
+ elif label_ratio_0to1 > 10 / 8:
318
+ boost = boost * 1.1
319
+ elif label_ratio_0to1 < 8 / 10:
320
+ boost = boost * 0.9
321
+ else:
322
+ # downsample label 1
323
+ for i in range(max_ntrials):
324
+ label1 = 1
325
+ label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
326
+ label1_ngenes = len(label1_genes)
327
+ label1_nremove = max(
328
+ 1,
329
+ int(
330
+ np.floor(
331
+ label1_ngenes
332
+ - label1_ngenes / ((1 / label_ratio_0to1) * boost)
333
+ )
334
+ ),
335
+ )
336
+ random.seed(i)
337
+ label1_remove_genes = random.sample(label1_genes, label1_nremove)
338
+ label_dict_subset_new = {
339
+ k: v
340
+ for k, v in label_dict_subset.items()
341
+ if k not in label1_remove_genes
342
+ }
343
+ label0_counts, label1_counts = count_genes_for_balancing(
344
+ subset_data, label_dict_subset_new, num_proc
345
+ )
346
+ label_ratio_0to1 = label0_counts / label1_counts
347
+ if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
348
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
349
+ return filter_data_balanced_genes(
350
+ subset_data, label_dict_subset_new, num_proc
351
+ )
352
+ elif label_ratio_0to1 < 8 / 10:
353
+ boost = boost * 1.1
354
+ elif label_ratio_0to1 > 10 / 8:
355
+ boost = boost * 0.9
356
+
357
+ assert i + 1 == max_ntrials
358
+ if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
359
+ 10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
360
+ ):
361
+ label_ratio_0to1 = label_ratio_0to1_orig
362
+ label_dict_subset_new = label_dict_subset_orig
363
+ logger.warning(
364
+ f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
365
+ )
366
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
367
+
368
+
369
+ def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
370
+ def count_targets(example):
371
+ labels = [
372
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
373
+ ]
374
+ counter_labels = Counter(labels)
375
+ # get count of labels 0 or 1, or if absent, return 0
376
+ example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
377
+ return example
378
+
379
+ subset_data = subset_data.map(count_targets, num_proc=num_proc)
380
+
381
+ label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
382
+ label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
383
+
384
+ subset_data = subset_data.remove_columns("labels_counts")
385
+
386
+ return label0_counts, label1_counts
387
+
388
+
389
+ def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
390
+ # function to filter by whether contains labels
391
+ def if_contains_subset_label(example):
392
+ a = list(label_dict_subset.keys())
393
+ b = example["input_ids"]
394
+ return not set(a).isdisjoint(b)
395
+
396
+ # filter dataset for examples containing classes for this split
397
+ logger.info("Filtering data for balanced genes")
398
+ subset_data_len_orig = len(subset_data)
399
+ subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
400
+ logger.info(
401
+ f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
402
+ )
403
+
404
+ return subset_data, label_dict_subset
405
+
406
+
407
+ def balance_attr_splits(
408
+ data,
409
+ attr_to_split,
410
+ attr_to_balance,
411
+ eval_size,
412
+ max_trials,
413
+ pval_threshold,
414
+ state_key,
415
+ nproc,
416
+ ):
417
+ metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
418
+ for attr in attr_to_balance:
419
+ if attr == state_key:
420
+ metadata_df[attr] = data["label"]
421
+ else:
422
+ metadata_df[attr] = data[attr]
423
+ metadata_df = metadata_df.drop_duplicates()
424
+
425
+ split_attr_ids = list(metadata_df["split_attr_ids"])
426
+ assert len(split_attr_ids) == len(set(split_attr_ids))
427
+ eval_num = round(len(split_attr_ids) * eval_size)
428
+ colnames = (
429
+ ["trial_num", "train_ids", "eval_ids"]
430
+ + pu.flatten_list(
431
+ [
432
+ [
433
+ f"{attr}_train_mean_or_counts",
434
+ f"{attr}_eval_mean_or_counts",
435
+ f"{attr}_pval",
436
+ ]
437
+ for attr in attr_to_balance
438
+ ]
439
+ )
440
+ + ["mean_pval"]
441
+ )
442
+ balance_df = pd.DataFrame(columns=colnames)
443
+ data_dict = dict()
444
+ trial_num = 1
445
+ for i in range(max_trials):
446
+ if not all(
447
+ count > 1 for count in list(Counter(metadata_df[state_key]).values())
448
+ ):
449
+ logger.error(
450
+ f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
451
+ )
452
+ raise
453
+ eval_base = []
454
+ for state in set(metadata_df[state_key]):
455
+ eval_base += list(
456
+ metadata_df.loc[
457
+ metadata_df[state_key][metadata_df[state_key].eq(state)]
458
+ .sample(1, random_state=i)
459
+ .index
460
+ ]["split_attr_ids"]
461
+ )
462
+ non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
463
+ random.seed(i)
464
+ eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
465
+ train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
466
+ df_vals = [trial_num, train_ids, eval_ids]
467
+ pvals = []
468
+ for attr in attr_to_balance:
469
+ train_attr = list(
470
+ metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
471
+ )
472
+ eval_attr = list(
473
+ metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
474
+ )
475
+ if attr == state_key:
476
+ # ensure IDs are interpreted as categorical
477
+ train_attr = [str(item) for item in train_attr]
478
+ eval_attr = [str(item) for item in eval_attr]
479
+ if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
480
+ train_attr_mean = np.nanmean(train_attr)
481
+ eval_attr_mean = np.nanmean(eval_attr)
482
+ pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
483
+ df_vals += [train_attr_mean, eval_attr_mean, pval]
484
+ elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
485
+ obs_counts = Counter(train_attr)
486
+ exp_counts = Counter(eval_attr)
487
+ all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
488
+ obs = [obs_counts[cat] for cat in all_categ]
489
+ exp = [
490
+ exp_counts[cat] * sum(obs) / sum(exp_counts.values())
491
+ for cat in all_categ
492
+ ]
493
+ pval = chisquare(f_obs=obs, f_exp=exp).pvalue
494
+ train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
495
+ eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
496
+ df_vals += [train_attr_counts, eval_attr_counts, pval]
497
+ else:
498
+ logger.error(
499
+ f"Inconsistent data types in attribute {attr}. "
500
+ "Cannot infer if continuous or categorical. "
501
+ "Must be all numeric (continuous) or all strings (categorical) to balance."
502
+ )
503
+ raise
504
+ pvals += [pval]
505
+
506
+ df_vals += [np.nanmean(pvals)]
507
+ balance_df_i = pd.DataFrame(df_vals, index=colnames).T
508
+ balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
509
+ valid_pvals = [
510
+ pval_i
511
+ for pval_i in pvals
512
+ if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
513
+ ]
514
+ if all(i >= pval_threshold for i in valid_pvals):
515
+ data_dict["train"] = pu.filter_by_dict(
516
+ data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
517
+ )
518
+ data_dict["test"] = pu.filter_by_dict(
519
+ data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
520
+ )
521
+ return data_dict, balance_df
522
+ trial_num = trial_num + 1
523
+ balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
524
+ data_dict["train"] = pu.filter_by_dict(
525
+ data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
526
+ )
527
+ data_dict["test"] = pu.filter_by_dict(
528
+ data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
529
+ )
530
+ logger.warning(
531
+ f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
532
+ f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
533
+ )
534
+ return data_dict, balance_df
535
+
536
+
537
+ def get_num_classes(id_class_dict):
538
+ return len(set(id_class_dict.values()))
539
+
540
+
541
+ def compute_metrics(pred):
542
+ labels = pred.label_ids
543
+ preds = pred.predictions.argmax(-1)
544
+
545
+ # calculate accuracy and macro f1 using sklearn's function
546
+ if len(labels.shape) == 1:
547
+ acc = accuracy_score(labels, preds)
548
+ macro_f1 = f1_score(labels, preds, average="macro")
549
+ else:
550
+ flat_labels = labels.flatten().tolist()
551
+ flat_preds = preds.flatten().tolist()
552
+ logit_label_paired = [
553
+ item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
554
+ ]
555
+ y_pred = [item[0] for item in logit_label_paired]
556
+ y_true = [item[1] for item in logit_label_paired]
557
+
558
+ acc = accuracy_score(y_true, y_pred)
559
+ macro_f1 = f1_score(y_true, y_pred, average="macro")
560
+
561
+ return {"accuracy": acc, "macro_f1": macro_f1}
562
+
563
+
564
+ def get_default_train_args(model, classifier, data, output_dir):
565
+ num_layers = pu.quant_layers(model)
566
+ freeze_layers = 0
567
+ batch_size = 12
568
+ if classifier == "cell":
569
+ epochs = 10
570
+ evaluation_strategy = "epoch"
571
+ load_best_model_at_end = True
572
+ else:
573
+ epochs = 1
574
+ evaluation_strategy = "no"
575
+ load_best_model_at_end = False
576
+
577
+ if num_layers == 6:
578
+ default_training_args = {
579
+ "learning_rate": 5e-5,
580
+ "lr_scheduler_type": "linear",
581
+ "warmup_steps": 500,
582
+ "per_device_train_batch_size": batch_size,
583
+ "per_device_eval_batch_size": batch_size,
584
+ }
585
+ else:
586
+ default_training_args = {
587
+ "per_device_train_batch_size": batch_size,
588
+ "per_device_eval_batch_size": batch_size,
589
+ }
590
+
591
+ training_args = {
592
+ "num_train_epochs": epochs,
593
+ "do_train": True,
594
+ "do_eval": True,
595
+ "evaluation_strategy": evaluation_strategy,
596
+ "logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
597
+ "save_strategy": "epoch",
598
+ "group_by_length": False,
599
+ "length_column_name": "length",
600
+ "disable_tqdm": False,
601
+ "weight_decay": 0.001,
602
+ "load_best_model_at_end": load_best_model_at_end,
603
+ }
604
+ training_args.update(default_training_args)
605
+
606
+ return training_args, freeze_layers
607
+
608
+
609
+ def load_best_model(directory, model_type, num_classes, mode="eval"):
610
+ file_dict = dict()
611
+ for subdir, dirs, files in os.walk(directory):
612
+ for file in files:
613
+ if file.endswith("result.json"):
614
+ with open(f"{subdir}/{file}", "rb") as fp:
615
+ result_json = json.load(fp)
616
+ file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
617
+ file_df = pd.DataFrame(
618
+ {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
619
+ )
620
+ model_superdir = (
621
+ "run-"
622
+ + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
623
+ .split("_objective_")[2]
624
+ .split("_")[0]
625
+ )
626
+
627
+ for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
628
+ for file in files:
629
+ if file.endswith("model.safetensors"):
630
+ model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
631
+ return model
632
+
633
+
634
+ class StratifiedKFold3(StratifiedKFold):
635
+ def split(self, targets, labels, test_ratio=0.5, groups=None):
636
+ s = super().split(targets, labels, groups)
637
+ for train_indxs, test_indxs in s:
638
+ if test_ratio == 0:
639
+ yield train_indxs, test_indxs, None
640
+ else:
641
+ labels_test = np.array(labels)[test_indxs]
642
+ valid_indxs, test_indxs = train_test_split(
643
+ test_indxs,
644
+ stratify=labels_test,
645
+ test_size=test_ratio,
646
+ random_state=0,
647
+ )
648
+ yield train_indxs, valid_indxs, test_indxs
geneformer/collator_for_classification.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer collator for gene and cell classification.
3
+ Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
4
+ """
5
+
6
+ import warnings
7
+ from enum import Enum
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from transformers import (
13
+ BatchEncoding,
14
+ DataCollatorForTokenClassification,
15
+ SpecialTokensMixin,
16
+ )
17
+ from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
+ from transformers.utils.generic import _is_tensorflow, _is_torch
19
+
20
+ EncodedInput = List[int]
21
+ logger = logging.get_logger(__name__)
22
+ VERY_LARGE_INTEGER = int(
23
+ 1e30
24
+ ) # This is used to set the max input length for a model with infinite size input
25
+ LARGE_INTEGER = int(
26
+ 1e20
27
+ ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
+
29
+ # precollator functions
30
+
31
+
32
+ class ExplicitEnum(Enum):
33
+ """
34
+ Enum with more explicit error message for missing values.
35
+ """
36
+
37
+ @classmethod
38
+ def _missing_(cls, value):
39
+ raise ValueError(
40
+ "%r is not a valid %s, please select one of %s"
41
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
42
+ )
43
+
44
+
45
+ class TruncationStrategy(ExplicitEnum):
46
+ """
47
+ Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
48
+ tab-completion in an IDE.
49
+ """
50
+
51
+ ONLY_FIRST = "only_first"
52
+ ONLY_SECOND = "only_second"
53
+ LONGEST_FIRST = "longest_first"
54
+ DO_NOT_TRUNCATE = "do_not_truncate"
55
+
56
+
57
+ class PaddingStrategy(ExplicitEnum):
58
+ """
59
+ Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
60
+ in an IDE.
61
+ """
62
+
63
+ LONGEST = "longest"
64
+ MAX_LENGTH = "max_length"
65
+ DO_NOT_PAD = "do_not_pad"
66
+
67
+
68
+ class TensorType(ExplicitEnum):
69
+ """
70
+ Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
71
+ tab-completion in an IDE.
72
+ """
73
+
74
+ PYTORCH = "pt"
75
+ TENSORFLOW = "tf"
76
+ NUMPY = "np"
77
+ JAX = "jax"
78
+
79
+
80
+ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
81
+ def __init__(self, *args, **kwargs) -> None:
82
+ super().__init__(mask_token="<mask>", pad_token="<pad>")
83
+
84
+ self.token_dictionary = kwargs.get("token_dictionary")
85
+ self.padding_side = "right"
86
+ self.model_input_names = ["input_ids"]
87
+ self._mask_token_id = self.token_dictionary.get("<mask>")
88
+ self._pad_token_id = self.token_dictionary.get("<pad>")
89
+ self._all_special_ids = [
90
+ self.token_dictionary.get("<mask>"),
91
+ self.token_dictionary.get("<pad>"),
92
+ ]
93
+
94
+ @property
95
+ def all_special_ids(self):
96
+ return self._all_special_ids
97
+
98
+ @property
99
+ def mask_token_id(self):
100
+ return self._mask_token_id
101
+
102
+ @property
103
+ def pad_token_id(self):
104
+ return self._pad_token_id
105
+
106
+ def _get_padding_truncation_strategies(
107
+ self,
108
+ padding=True,
109
+ truncation=False,
110
+ max_length=None,
111
+ pad_to_multiple_of=None,
112
+ verbose=True,
113
+ **kwargs,
114
+ ):
115
+ """
116
+ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
117
+ and pad_to_max_length) and behaviors.
118
+ """
119
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
120
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
121
+
122
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
123
+ # If you only set max_length, it activates truncation for max_length
124
+ if max_length is not None and padding is False and truncation is False:
125
+ if verbose:
126
+ if not self.deprecation_warnings.get(
127
+ "Truncation-not-explicitly-activated", False
128
+ ):
129
+ logger.warning(
130
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, "
131
+ "please use `truncation=True` to explicitly truncate examples to max length. "
132
+ "Defaulting to 'longest_first' truncation strategy. "
133
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
134
+ "more precisely by providing a specific strategy to `truncation`."
135
+ )
136
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
137
+ truncation = "longest_first"
138
+
139
+ # Get padding strategy
140
+ if padding is False and old_pad_to_max_length:
141
+ if verbose:
142
+ warnings.warn(
143
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
144
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
145
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
146
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
147
+ "maximal input size of the model (e.g. 512 for Bert).",
148
+ FutureWarning,
149
+ )
150
+ if max_length is None:
151
+ padding_strategy = PaddingStrategy.LONGEST
152
+ else:
153
+ padding_strategy = PaddingStrategy.MAX_LENGTH
154
+ elif padding is not False:
155
+ if padding is True:
156
+ padding_strategy = (
157
+ PaddingStrategy.LONGEST
158
+ ) # Default to pad to the longest sequence in the batch
159
+ elif not isinstance(padding, PaddingStrategy):
160
+ padding_strategy = PaddingStrategy(padding)
161
+ elif isinstance(padding, PaddingStrategy):
162
+ padding_strategy = padding
163
+ else:
164
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
165
+
166
+ # Get truncation strategy
167
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
168
+ if verbose:
169
+ warnings.warn(
170
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
171
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
172
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
173
+ "maximal input size of the model (e.g. 512 for Bert). "
174
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
175
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
176
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
177
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
178
+ FutureWarning,
179
+ )
180
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
181
+ elif truncation is not False:
182
+ if truncation is True:
183
+ truncation_strategy = (
184
+ TruncationStrategy.LONGEST_FIRST
185
+ ) # Default to truncate the longest sequences in pairs of inputs
186
+ elif not isinstance(truncation, TruncationStrategy):
187
+ truncation_strategy = TruncationStrategy(truncation)
188
+ elif isinstance(truncation, TruncationStrategy):
189
+ truncation_strategy = truncation
190
+ else:
191
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
192
+
193
+ # Set max length if needed
194
+ if max_length is None:
195
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
196
+ if self.model_max_length > LARGE_INTEGER:
197
+ if verbose:
198
+ if not self.deprecation_warnings.get(
199
+ "Asking-to-pad-to-max_length", False
200
+ ):
201
+ logger.warning(
202
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
203
+ "Default to no padding."
204
+ )
205
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
206
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
207
+ else:
208
+ max_length = self.model_max_length
209
+
210
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
211
+ if self.model_max_length > LARGE_INTEGER:
212
+ if verbose:
213
+ if not self.deprecation_warnings.get(
214
+ "Asking-to-truncate-to-max_length", False
215
+ ):
216
+ logger.warning(
217
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
218
+ "Default to no truncation."
219
+ )
220
+ self.deprecation_warnings[
221
+ "Asking-to-truncate-to-max_length"
222
+ ] = True
223
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
224
+ else:
225
+ max_length = self.model_max_length
226
+
227
+ # Test if we have a padding token
228
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
229
+ not self.pad_token or self.pad_token_id < 0
230
+ ):
231
+ raise ValueError(
232
+ "Asking to pad but the tokenizer does not have a padding token. "
233
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
234
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
235
+ )
236
+
237
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
238
+ if (
239
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
240
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
241
+ and pad_to_multiple_of is not None
242
+ and max_length is not None
243
+ and (max_length % pad_to_multiple_of != 0)
244
+ ):
245
+ raise ValueError(
246
+ f"Truncation and padding are both activated but "
247
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
248
+ )
249
+
250
+ return padding_strategy, truncation_strategy, max_length, kwargs
251
+
252
+ def pad(
253
+ self,
254
+ encoded_inputs: Union[
255
+ BatchEncoding,
256
+ List[BatchEncoding],
257
+ Dict[str, EncodedInput],
258
+ Dict[str, List[EncodedInput]],
259
+ List[Dict[str, EncodedInput]],
260
+ ],
261
+ class_type, # options: "gene" or "cell"
262
+ padding: Union[bool, str, PaddingStrategy] = True,
263
+ max_length: Optional[int] = None,
264
+ pad_to_multiple_of: Optional[int] = None,
265
+ return_attention_mask: Optional[bool] = True,
266
+ return_tensors: Optional[Union[str, TensorType]] = None,
267
+ verbose: bool = True,
268
+ ) -> BatchEncoding:
269
+ """
270
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
271
+ in the batch.
272
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
273
+ ``self.pad_token_id`` and ``self.pad_token_type_id``)
274
+ .. note::
275
+ If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
276
+ result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
277
+ case of PyTorch tensors, you will lose the specific device of your tensors however.
278
+ Args:
279
+ encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
280
+ Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
281
+ List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
282
+ List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
283
+ well as in a PyTorch Dataloader collate function.
284
+ Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
285
+ see the note above for the return type.
286
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
287
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
288
+ index) among:
289
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
290
+ single sequence if provided).
291
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
292
+ maximum acceptable input length for the model if that argument is not provided.
293
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
294
+ different lengths).
295
+ max_length (:obj:`int`, `optional`):
296
+ Maximum length of the returned list and optionally padding length (see above).
297
+ pad_to_multiple_of (:obj:`int`, `optional`):
298
+ If set will pad the sequence to a multiple of the provided value.
299
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
300
+ >= 7.5 (Volta).
301
+ return_attention_mask (:obj:`bool`, `optional`):
302
+ Whether to return the attention mask. If left to the default, will return the attention mask according
303
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
304
+ `What are attention masks? <../glossary.html#attention-mask>`__
305
+ return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
306
+ If set, will return tensors instead of list of python integers. Acceptable values are:
307
+ * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
308
+ * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
309
+ * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
310
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
311
+ Whether or not to print more information and warnings.
312
+ """
313
+ # If we have a list of dicts, let's convert it in a dict of lists
314
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
315
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
316
+ encoded_inputs[0], (dict, BatchEncoding)
317
+ ):
318
+ encoded_inputs = {
319
+ key: [example[key] for example in encoded_inputs]
320
+ for key in encoded_inputs[0].keys()
321
+ }
322
+
323
+ # The model's main input name, usually `input_ids`, has be passed for padding
324
+ if self.model_input_names[0] not in encoded_inputs:
325
+ raise ValueError(
326
+ "You should supply an encoding or a list of encodings to this method"
327
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
328
+ )
329
+
330
+ required_input = encoded_inputs[self.model_input_names[0]]
331
+
332
+ if not required_input:
333
+ if return_attention_mask:
334
+ encoded_inputs["attention_mask"] = []
335
+ return encoded_inputs
336
+
337
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
338
+ # and rebuild them afterwards if no return_tensors is specified
339
+ # Note that we lose the specific device the tensor may be on for PyTorch
340
+
341
+ first_element = required_input[0]
342
+ if isinstance(first_element, (list, tuple)):
343
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
344
+ index = 0
345
+ while len(required_input[index]) == 0:
346
+ index += 1
347
+ if index < len(required_input):
348
+ first_element = required_input[index][0]
349
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
350
+ if not isinstance(first_element, (int, list, tuple)):
351
+ if is_tf_available() and _is_tensorflow(first_element):
352
+ return_tensors = "tf" if return_tensors is None else return_tensors
353
+ elif is_torch_available() and _is_torch(first_element):
354
+ return_tensors = "pt" if return_tensors is None else return_tensors
355
+ elif isinstance(first_element, np.ndarray):
356
+ return_tensors = "np" if return_tensors is None else return_tensors
357
+ else:
358
+ raise ValueError(
359
+ f"type of {first_element} unknown: {type(first_element)}. "
360
+ f"Should be one of a python, numpy, pytorch or tensorflow object."
361
+ )
362
+
363
+ for key, value in encoded_inputs.items():
364
+ encoded_inputs[key] = to_py_obj(value)
365
+
366
+ # Convert padding_strategy in PaddingStrategy
367
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
368
+ padding=padding, max_length=max_length, verbose=verbose
369
+ )
370
+
371
+ required_input = encoded_inputs[self.model_input_names[0]]
372
+ if required_input and not isinstance(required_input[0], (list, tuple)):
373
+ encoded_inputs = self._pad(
374
+ encoded_inputs,
375
+ class_type=class_type,
376
+ max_length=max_length,
377
+ padding_strategy=padding_strategy,
378
+ pad_to_multiple_of=pad_to_multiple_of,
379
+ return_attention_mask=return_attention_mask,
380
+ )
381
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
382
+
383
+ batch_size = len(required_input)
384
+ assert all(
385
+ len(v) == batch_size for v in encoded_inputs.values()
386
+ ), "Some items in the output dictionary have a different batch size than others."
387
+
388
+ if padding_strategy == PaddingStrategy.LONGEST:
389
+ max_length = max(len(inputs) for inputs in required_input)
390
+ padding_strategy = PaddingStrategy.MAX_LENGTH
391
+
392
+ batch_outputs = {}
393
+ for i in range(batch_size):
394
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
395
+ outputs = self._pad(
396
+ inputs,
397
+ class_type=class_type,
398
+ max_length=max_length,
399
+ padding_strategy=padding_strategy,
400
+ pad_to_multiple_of=pad_to_multiple_of,
401
+ return_attention_mask=return_attention_mask,
402
+ )
403
+
404
+ for key, value in outputs.items():
405
+ if key not in batch_outputs:
406
+ batch_outputs[key] = []
407
+ batch_outputs[key].append(value)
408
+ if class_type == "cell":
409
+ del batch_outputs["label"]
410
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
411
+
412
+ def _pad(
413
+ self,
414
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
415
+ class_type, # options: "gene" or "cell"
416
+ max_length: Optional[int] = None,
417
+ padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
418
+ pad_to_multiple_of: Optional[int] = None,
419
+ return_attention_mask: Optional[bool] = True,
420
+ ) -> dict:
421
+ """
422
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
423
+ Args:
424
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
425
+ max_length: maximum length of the returned list and optionally padding length (see below).
426
+ Will truncate by taking into account the special tokens.
427
+ padding_strategy: PaddingStrategy to use for padding.
428
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
429
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
430
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
431
+ The tokenizer padding sides are defined in self.padding_side:
432
+ - 'left': pads on the left of the sequences
433
+ - 'right': pads on the right of the sequences
434
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
435
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
436
+ >= 7.5 (Volta).
437
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
438
+ """
439
+ # Load from model defaults
440
+ if return_attention_mask is None:
441
+ return_attention_mask = "attention_mask" in self.model_input_names
442
+
443
+ required_input = encoded_inputs[self.model_input_names[0]]
444
+
445
+ if padding_strategy == PaddingStrategy.LONGEST:
446
+ max_length = len(required_input)
447
+
448
+ if (
449
+ max_length is not None
450
+ and pad_to_multiple_of is not None
451
+ and (max_length % pad_to_multiple_of != 0)
452
+ ):
453
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
454
+
455
+ needs_to_be_padded = (
456
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
457
+ and len(required_input) != max_length
458
+ )
459
+
460
+ if needs_to_be_padded:
461
+ difference = max_length - len(required_input)
462
+ if self.padding_side == "right":
463
+ if return_attention_mask:
464
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [
465
+ 0
466
+ ] * difference
467
+ if "token_type_ids" in encoded_inputs:
468
+ encoded_inputs["token_type_ids"] = (
469
+ encoded_inputs["token_type_ids"]
470
+ + [self.pad_token_type_id] * difference
471
+ )
472
+ if "special_tokens_mask" in encoded_inputs:
473
+ encoded_inputs["special_tokens_mask"] = (
474
+ encoded_inputs["special_tokens_mask"] + [1] * difference
475
+ )
476
+ encoded_inputs[self.model_input_names[0]] = (
477
+ required_input + [self.pad_token_id] * difference
478
+ )
479
+ if class_type == "gene":
480
+ encoded_inputs["labels"] = (
481
+ encoded_inputs["labels"] + [-100] * difference
482
+ )
483
+ elif self.padding_side == "left":
484
+ if return_attention_mask:
485
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
486
+ required_input
487
+ )
488
+ if "token_type_ids" in encoded_inputs:
489
+ encoded_inputs["token_type_ids"] = [
490
+ self.pad_token_type_id
491
+ ] * difference + encoded_inputs["token_type_ids"]
492
+ if "special_tokens_mask" in encoded_inputs:
493
+ encoded_inputs["special_tokens_mask"] = [
494
+ 1
495
+ ] * difference + encoded_inputs["special_tokens_mask"]
496
+ encoded_inputs[self.model_input_names[0]] = [
497
+ self.pad_token_id
498
+ ] * difference + required_input
499
+ if class_type == "gene":
500
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
501
+ "labels"
502
+ ]
503
+ else:
504
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
505
+ elif return_attention_mask and "attention_mask" not in encoded_inputs:
506
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
507
+
508
+ return encoded_inputs
509
+
510
+ def get_special_tokens_mask(
511
+ self,
512
+ token_ids_0: List[int],
513
+ token_ids_1: Optional[List[int]] = None,
514
+ already_has_special_tokens: bool = False,
515
+ ) -> List[int]:
516
+ """
517
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
518
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
519
+ Args:
520
+ token_ids_0 (:obj:`List[int]`):
521
+ List of ids of the first sequence.
522
+ token_ids_1 (:obj:`List[int]`, `optional`):
523
+ List of ids of the second sequence.
524
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
525
+ Whether or not the token list is already formatted with special tokens for the model.
526
+ Returns:
527
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
528
+ """
529
+ assert already_has_special_tokens and token_ids_1 is None, (
530
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
531
+ "Please use a slow (full python) tokenizer to activate this argument."
532
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
533
+ "to get the special tokens mask in any tokenizer. "
534
+ )
535
+
536
+ all_special_ids = self.all_special_ids # cache the property
537
+
538
+ special_tokens_mask = [
539
+ 1 if token in all_special_ids else 0 for token in token_ids_0
540
+ ]
541
+
542
+ return special_tokens_mask
543
+
544
+ def convert_tokens_to_ids(
545
+ self, tokens: Union[str, List[str]]
546
+ ) -> Union[int, List[int]]:
547
+ """
548
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
549
+ vocabulary.
550
+ Args:
551
+ tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
552
+ Returns:
553
+ :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
554
+ """
555
+ if tokens is None:
556
+ return None
557
+
558
+ if isinstance(tokens, str):
559
+ return self._convert_token_to_id_with_added_voc(tokens)
560
+
561
+ ids = []
562
+ for token in tokens:
563
+ ids.append(self._convert_token_to_id_with_added_voc(token))
564
+ return ids
565
+
566
+ def _convert_token_to_id_with_added_voc(self, token):
567
+ if token is None:
568
+ return None
569
+
570
+ return self.token_dictionary.get(token)
571
+
572
+ def __len__(self):
573
+ return len(self.token_dictionary)
574
+
575
+
576
+ # collator functions
577
+
578
+
579
+ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
580
+ """
581
+ Data collator that will dynamically pad the inputs received, as well as the labels.
582
+ Args:
583
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
584
+ The tokenizer used for encoding the data.
585
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
586
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
587
+ among:
588
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
589
+ sequence if provided).
590
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
591
+ maximum acceptable input length for the model if that argument is not provided.
592
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
593
+ different lengths).
594
+ max_length (:obj:`int`, `optional`):
595
+ Maximum length of the returned list and optionally padding length (see above).
596
+ pad_to_multiple_of (:obj:`int`, `optional`):
597
+ If set will pad the sequence to a multiple of the provided value.
598
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
599
+ 7.5 (Volta).
600
+ label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
601
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
602
+ """
603
+
604
+ class_type = "gene"
605
+ padding: Union[bool, str, PaddingStrategy] = True
606
+ max_length: Optional[int] = None
607
+ pad_to_multiple_of: Optional[int] = None
608
+ label_pad_token_id: int = -100
609
+
610
+ def __init__(self, *args, **kwargs) -> None:
611
+ self.token_dictionary = kwargs.pop("token_dictionary")
612
+ super().__init__(
613
+ tokenizer=PrecollatorForGeneAndCellClassification(
614
+ token_dictionary=self.token_dictionary
615
+ ),
616
+ padding=self.padding,
617
+ max_length=self.max_length,
618
+ pad_to_multiple_of=self.pad_to_multiple_of,
619
+ label_pad_token_id=self.label_pad_token_id,
620
+ *args,
621
+ **kwargs,
622
+ )
623
+
624
+ def _prepare_batch(self, features):
625
+ label_name = "label" if "label" in features[0].keys() else "labels"
626
+ labels = (
627
+ [feature[label_name] for feature in features]
628
+ if label_name in features[0].keys()
629
+ else None
630
+ )
631
+ batch = self.tokenizer.pad(
632
+ features,
633
+ class_type=self.class_type,
634
+ padding=self.padding,
635
+ max_length=self.max_length,
636
+ pad_to_multiple_of=self.pad_to_multiple_of,
637
+ return_tensors="pt",
638
+ )
639
+ return batch
640
+
641
+ def __call__(self, features):
642
+ batch = self._prepare_batch(features)
643
+
644
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
645
+ return batch
646
+
647
+
648
+ class DataCollatorForCellClassification(DataCollatorForGeneClassification):
649
+ class_type = "cell"
650
+
651
+ def _prepare_batch(self, features):
652
+ batch = super()._prepare_batch(features)
653
+
654
+ # Special handling for labels.
655
+ # Ensure that tensor is created with the correct type
656
+ # (it should be automatically the case, but let's make sure of it.)
657
+ first = features[0]
658
+ if "label" in first and first["label"] is not None:
659
+ label = (
660
+ first["label"].item()
661
+ if isinstance(first["label"], torch.Tensor)
662
+ else first["label"]
663
+ )
664
+ dtype = torch.long if isinstance(label, int) else torch.float
665
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
666
+
667
+ return batch
geneformer/emb_extractor.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer embedding extractor.
3
+
4
+ **Description:**
5
+
6
+ | Extracts gene or cell embeddings.
7
+ | Plots cell embeddings as heatmaps or UMAPs.
8
+ | Generates cell state embedding dictionary for use with InSilicoPerturber.
9
+
10
+ """
11
+
12
+ # imports
13
+ import logging
14
+ import pickle
15
+ from collections import Counter
16
+ from pathlib import Path
17
+
18
+ import anndata
19
+ import matplotlib.pyplot as plt
20
+ import pandas as pd
21
+ import scanpy as sc
22
+ import seaborn as sns
23
+ import torch
24
+ from tdigest import TDigest
25
+ from tqdm.auto import trange
26
+
27
+ from . import TOKEN_DICTIONARY_FILE
28
+ from . import perturber_utils as pu
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # extract embeddings
34
+ def get_embs(
35
+ model,
36
+ filtered_input_data,
37
+ emb_mode,
38
+ layer_to_quant,
39
+ pad_token_id,
40
+ forward_batch_size,
41
+ token_gene_dict,
42
+ special_token=False,
43
+ summary_stat=None,
44
+ silent=False,
45
+ ):
46
+ model_input_size = pu.get_model_input_size(model)
47
+ total_batch_length = len(filtered_input_data)
48
+
49
+ if summary_stat is None:
50
+ embs_list = []
51
+ elif summary_stat is not None:
52
+ # get # of emb dims
53
+ emb_dims = pu.get_model_emb_dims(model)
54
+ if emb_mode == "cell":
55
+ # initiate tdigests for # of emb dims
56
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
57
+ if emb_mode == "gene":
58
+ gene_set = list(
59
+ {
60
+ element
61
+ for sublist in filtered_input_data["input_ids"]
62
+ for element in sublist
63
+ }
64
+ )
65
+ # initiate dict with genes as keys and tdigests for # of emb dims as values
66
+ embs_tdigests_dict = {
67
+ k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
+ }
69
+
70
+ # Check if CLS and EOS token is present in the token dictionary
71
+ cls_present = any("<cls>" in value for value in token_gene_dict.values())
72
+ eos_present = any("<eos>" in value for value in token_gene_dict.values())
73
+ if emb_mode == "cls":
74
+ assert cls_present, "<cls> token missing in token dictionary"
75
+ # Check to make sure that the first token of the filtered input data is cls token
76
+ gene_token_dict = {v: k for k, v in token_gene_dict.items()}
77
+ cls_token_id = gene_token_dict["<cls>"]
78
+ assert (
79
+ filtered_input_data["input_ids"][0][0] == cls_token_id
80
+ ), "First token is not <cls> token value"
81
+ elif emb_mode == "cell":
82
+ if cls_present:
83
+ logger.warning(
84
+ "CLS token present in token dictionary, excluding from average."
85
+ )
86
+ if eos_present:
87
+ logger.warning(
88
+ "EOS token present in token dictionary, excluding from average."
89
+ )
90
+
91
+ overall_max_len = 0
92
+
93
+ for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
94
+ max_range = min(i + forward_batch_size, total_batch_length)
95
+
96
+ minibatch = filtered_input_data.select([i for i in range(i, max_range)])
97
+
98
+ max_len = int(max(minibatch["length"]))
99
+ original_lens = torch.tensor(minibatch["length"], device="cuda")
100
+ minibatch.set_format(type="torch")
101
+
102
+ input_data_minibatch = minibatch["input_ids"]
103
+ input_data_minibatch = pu.pad_tensor_list(
104
+ input_data_minibatch, max_len, pad_token_id, model_input_size
105
+ )
106
+
107
+ with torch.no_grad():
108
+ outputs = model(
109
+ input_ids=input_data_minibatch.to("cuda"),
110
+ attention_mask=pu.gen_attention_mask(minibatch),
111
+ )
112
+
113
+ embs_i = outputs.hidden_states[layer_to_quant]
114
+
115
+ if emb_mode == "cell":
116
+ if cls_present:
117
+ non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
118
+ if eos_present:
119
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
120
+ else:
121
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
122
+ else:
123
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
124
+ if summary_stat is None:
125
+ embs_list.append(mean_embs)
126
+ elif summary_stat is not None:
127
+ # update tdigests with current batch for each emb dim
128
+ accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
129
+ del mean_embs
130
+ elif emb_mode == "gene":
131
+ if summary_stat is None:
132
+ embs_list.append(embs_i)
133
+ elif summary_stat is not None:
134
+ for h in trange(len(minibatch)):
135
+ length_h = minibatch[h]["length"]
136
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
137
+
138
+ # double check dimensions before unsqueezing
139
+ embs_i_dim = embs_i.dim()
140
+ if embs_i_dim != 3:
141
+ logger.error(
142
+ f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
143
+ )
144
+ raise
145
+
146
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
147
+ dict_h = dict(zip(input_ids_h, embs_h))
148
+ for k in dict_h.keys():
149
+ accumulate_tdigests(
150
+ embs_tdigests_dict[int(k)], dict_h[k], emb_dims
151
+ )
152
+ del embs_h
153
+ del dict_h
154
+ elif emb_mode == "cls":
155
+ cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
156
+ embs_list.append(cls_embs)
157
+ del cls_embs
158
+
159
+ overall_max_len = max(overall_max_len, max_len)
160
+ del outputs
161
+ del minibatch
162
+ del input_data_minibatch
163
+ del embs_i
164
+
165
+ torch.cuda.empty_cache()
166
+
167
+ if summary_stat is None:
168
+ if (emb_mode == "cell") or (emb_mode == "cls"):
169
+ embs_stack = torch.cat(embs_list, dim=0)
170
+ elif emb_mode == "gene":
171
+ embs_stack = pu.pad_tensor_list(
172
+ embs_list,
173
+ overall_max_len,
174
+ pad_token_id,
175
+ model_input_size,
176
+ 1,
177
+ pu.pad_3d_tensor,
178
+ )
179
+
180
+ # calculate summary stat embs from approximated tdigests
181
+ elif summary_stat is not None:
182
+ if emb_mode == "cell":
183
+ if summary_stat == "mean":
184
+ summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
+ elif summary_stat == "median":
186
+ summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
+ embs_stack = torch.tensor(summary_emb_list)
188
+ elif emb_mode == "gene":
189
+ if summary_stat == "mean":
190
+ [
191
+ update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
192
+ for gene in embs_tdigests_dict.keys()
193
+ ]
194
+ elif summary_stat == "median":
195
+ [
196
+ update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
197
+ for gene in embs_tdigests_dict.keys()
198
+ ]
199
+ return embs_tdigests_dict
200
+
201
+ return embs_stack
202
+
203
+
204
+ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
205
+ # note: tdigest batch update known to be slow so updating serially
206
+ [
207
+ embs_tdigests[j].update(mean_embs[i, j].item())
208
+ for i in range(mean_embs.size(0))
209
+ for j in range(emb_dims)
210
+ ]
211
+
212
+
213
+ def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
214
+ embs_tdigests_dict[gene] = accumulate_tdigests(
215
+ embs_tdigests_dict[gene], gene_embs, emb_dims
216
+ )
217
+
218
+
219
+ def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
220
+ embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
221
+
222
+
223
+ def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
224
+ embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
225
+
226
+
227
+ def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
228
+ length_h = minibatch[h]["length"]
229
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
230
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
231
+ dict_h = dict(zip(input_ids_h, embs_h))
232
+ [
233
+ update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
234
+ for k in dict_h.keys()
235
+ ]
236
+
237
+
238
+ def tdigest_mean(embs_tdigests, emb_dims):
239
+ return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
240
+
241
+
242
+ def tdigest_median(embs_tdigests, emb_dims):
243
+ return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
244
+
245
+
246
+ def label_cell_embs(embs, downsampled_data, emb_labels):
247
+ embs_df = pd.DataFrame(embs.cpu().numpy())
248
+ if emb_labels is not None:
249
+ for label in emb_labels:
250
+ emb_label = downsampled_data[label]
251
+ embs_df[label] = emb_label
252
+ return embs_df
253
+
254
+
255
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
+ gene_set = {
257
+ element for sublist in downsampled_data["input_ids"] for element in sublist
258
+ }
259
+ gene_emb_dict = {k: [] for k in gene_set}
260
+ for i in range(embs.size()[0]):
261
+ length = downsampled_data[i]["length"]
262
+ dict_i = dict(
263
+ zip(
264
+ downsampled_data[i]["input_ids"][0:length],
265
+ embs[i, :, :].unsqueeze(dim=1),
266
+ )
267
+ )
268
+ for k in dict_i.keys():
269
+ gene_emb_dict[k].append(dict_i[k])
270
+ for k in gene_emb_dict.keys():
271
+ gene_emb_dict[k] = (
272
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
+ .cpu()
274
+ .numpy()
275
+ )
276
+ embs_df = pd.DataFrame(gene_emb_dict).T
277
+ embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
+ return embs_df
279
+
280
+
281
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
+ only_embs_df = embs_df.iloc[:, :emb_dims]
283
+ only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
+ str
286
+ )
287
+ vars_dict = {"embs": only_embs_df.columns}
288
+ obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
289
+ adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
+ sc.tl.pca(adata, svd_solver="arpack")
291
+ sc.pp.neighbors(adata, random_state=seed)
292
+ sc.tl.umap(adata, random_state=seed)
293
+ sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
294
+ sns.set_style("white")
295
+ default_kwargs_dict = {"size": 200}
296
+ if kwargs_dict is not None:
297
+ default_kwargs_dict.update(kwargs_dict)
298
+
299
+ cats = set(embs_df[label])
300
+
301
+ with plt.rc_context():
302
+ ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
+ ax.legend(
304
+ markerscale=2,
305
+ frameon=False,
306
+ loc="center left",
307
+ bbox_to_anchor=(1, 0.5),
308
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
+ )
310
+ plt.show()
311
+ plt.savefig(output_file, bbox_inches="tight")
312
+
313
+
314
+ def gen_heatmap_class_colors(labels, df):
315
+ pal = sns.cubehelix_palette(
316
+ len(Counter(labels).keys()),
317
+ light=0.9,
318
+ dark=0.1,
319
+ hue=1,
320
+ reverse=True,
321
+ start=1,
322
+ rot=-2,
323
+ )
324
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
325
+ colors = pd.Series(labels, index=df.index).map(lut)
326
+ return colors
327
+
328
+
329
+ def gen_heatmap_class_dict(classes, label_colors_series):
330
+ class_color_dict_df = pd.DataFrame(
331
+ {"classes": classes, "color": label_colors_series}
332
+ )
333
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
334
+ return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
335
+
336
+
337
+ def make_colorbar(embs_df, label):
338
+ labels = list(embs_df[label])
339
+
340
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
341
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
342
+
343
+ # create dictionary for colors and classes
344
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
345
+ return label_colors, label_color_dict
346
+
347
+
348
+ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
349
+ sns.set_style("white")
350
+ sns.set(font_scale=2)
351
+ plt.figure(figsize=(15, 15), dpi=150)
352
+ label_colors, label_color_dict = make_colorbar(embs_df, label)
353
+
354
+ default_kwargs_dict = {
355
+ "row_cluster": True,
356
+ "col_cluster": True,
357
+ "row_colors": label_colors,
358
+ "standard_scale": 1,
359
+ "linewidths": 0,
360
+ "xticklabels": False,
361
+ "yticklabels": False,
362
+ "figsize": (15, 15),
363
+ "center": 0,
364
+ "cmap": "magma",
365
+ }
366
+
367
+ if kwargs_dict is not None:
368
+ default_kwargs_dict.update(kwargs_dict)
369
+ g = sns.clustermap(
370
+ embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
371
+ )
372
+
373
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
374
+
375
+ for label_color in list(label_color_dict.keys()):
376
+ g.ax_col_dendrogram.bar(
377
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
378
+ )
379
+
380
+ g.ax_col_dendrogram.legend(
381
+ title=f"{label}",
382
+ loc="lower center",
383
+ ncol=4,
384
+ bbox_to_anchor=(0.5, 1),
385
+ facecolor="white",
386
+ )
387
+ plt.show()
388
+ logger.info(f"Output file: {output_file}")
389
+ plt.savefig(output_file, bbox_inches="tight")
390
+
391
+
392
+ class EmbExtractor:
393
+ valid_option_dict = {
394
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
395
+ "num_classes": {int},
396
+ "emb_mode": {"cls", "cell", "gene"},
397
+ "cell_emb_style": {"mean_pool"},
398
+ "gene_emb_style": {"mean_pool"},
399
+ "filter_data": {None, dict},
400
+ "max_ncells": {None, int},
401
+ "emb_layer": {-1, 0},
402
+ "emb_label": {None, list},
403
+ "labels_to_plot": {None, list},
404
+ "forward_batch_size": {int},
405
+ "token_dictionary_file": {None, str},
406
+ "nproc": {int},
407
+ "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
408
+ }
409
+
410
+ def __init__(
411
+ self,
412
+ model_type="Pretrained",
413
+ num_classes=0,
414
+ emb_mode="cls",
415
+ cell_emb_style="mean_pool",
416
+ gene_emb_style="mean_pool",
417
+ filter_data=None,
418
+ max_ncells=1000,
419
+ emb_layer=-1,
420
+ emb_label=None,
421
+ labels_to_plot=None,
422
+ forward_batch_size=100,
423
+ nproc=4,
424
+ summary_stat=None,
425
+ token_dictionary_file=None,
426
+ ):
427
+ """
428
+ Initialize embedding extractor.
429
+
430
+ **Parameters:**
431
+
432
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
433
+ | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
434
+ num_classes : int
435
+ | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
+ | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
437
+ emb_mode : {"cls", "cell", "gene"}
438
+ | Whether to output CLS, cell, or gene embeddings.
439
+ | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
440
+ cell_emb_style : {"mean_pool"}
441
+ | Method for summarizing cell embeddings if not using CLS token.
442
+ | Currently only option is mean pooling of gene embeddings for given cell.
443
+ gene_emb_style : "mean_pool"
444
+ | Method for summarizing gene embeddings.
445
+ | Currently only option is mean pooling of contextual gene embeddings for given gene.
446
+ filter_data : None, dict
447
+ | Default is to extract embeddings from all input data.
448
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
449
+ max_ncells : None, int
450
+ | Maximum number of cells to extract embeddings from.
451
+ | Default is 1000 cells randomly sampled from input data.
452
+ | If None, will extract embeddings from all cells.
453
+ emb_layer : {-1, 0}
454
+ | Embedding layer to extract.
455
+ | The last layer is most specifically weighted to optimize the given learning objective.
456
+ | Generally, it is best to extract the 2nd to last layer to get a more general representation.
457
+ | -1: 2nd to last layer
458
+ | 0: last layer
459
+ emb_label : None, list
460
+ | List of column name(s) in .dataset to add as labels to embedding output.
461
+ labels_to_plot : None, list
462
+ | Cell labels to plot.
463
+ | Shown as color bar in heatmap.
464
+ | Shown as cell color in umap.
465
+ | Plotting umap requires labels to plot.
466
+ forward_batch_size : int
467
+ | Batch size for forward pass.
468
+ nproc : int
469
+ | Number of CPU processes to use.
470
+ summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
471
+ | If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
472
+ | If mean or median, outputs only approximated mean or median embedding of input data.
473
+ | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
+ | Non-exact is slower but more memory-efficient.
475
+ token_dictionary_file : Path
476
+ | Default is the Geneformer token dictionary
477
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
478
+
479
+ **Examples:**
480
+
481
+ .. code-block :: python
482
+
483
+ >>> from geneformer import EmbExtractor
484
+ >>> embex = EmbExtractor(model_type="CellClassifier",
485
+ ... num_classes=3,
486
+ ... emb_mode="cell",
487
+ ... filter_data={"cell_type":["cardiomyocyte"]},
488
+ ... max_ncells=1000,
489
+ ... emb_layer=-1,
490
+ ... emb_label=["disease", "cell_type"],
491
+ ... labels_to_plot=["disease", "cell_type"])
492
+
493
+ """
494
+
495
+ self.model_type = model_type
496
+ self.num_classes = num_classes
497
+ self.emb_mode = emb_mode
498
+ self.cell_emb_style = cell_emb_style
499
+ self.gene_emb_style = gene_emb_style
500
+ self.filter_data = filter_data
501
+ self.max_ncells = max_ncells
502
+ self.emb_layer = emb_layer
503
+ self.emb_label = emb_label
504
+ self.labels_to_plot = labels_to_plot
505
+ self.token_dictionary_file = token_dictionary_file
506
+ self.forward_batch_size = forward_batch_size
507
+ self.nproc = nproc
508
+ if (summary_stat is not None) and ("exact" in summary_stat):
509
+ self.summary_stat = None
510
+ self.exact_summary_stat = summary_stat
511
+ else:
512
+ self.summary_stat = summary_stat
513
+ self.exact_summary_stat = None
514
+
515
+ self.validate_options()
516
+
517
+ # load token dictionary (Ensembl IDs:token)
518
+ if self.token_dictionary_file is None:
519
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
520
+ with open(token_dictionary_file, "rb") as f:
521
+ self.gene_token_dict = pickle.load(f)
522
+
523
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
524
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
525
+
526
+ def validate_options(self):
527
+ # confirm arguments are within valid options and compatible with each other
528
+ for attr_name, valid_options in self.valid_option_dict.items():
529
+ attr_value = self.__dict__[attr_name]
530
+ if not isinstance(attr_value, (list, dict)):
531
+ if attr_value in valid_options:
532
+ continue
533
+ valid_type = False
534
+ for option in valid_options:
535
+ if (option in [int, list, dict, bool, str]) and isinstance(
536
+ attr_value, option
537
+ ):
538
+ valid_type = True
539
+ break
540
+ if valid_type:
541
+ continue
542
+ logger.error(
543
+ f"Invalid option for {attr_name}. "
544
+ f"Valid options for {attr_name}: {valid_options}"
545
+ )
546
+ raise
547
+
548
+ if self.filter_data is not None:
549
+ for key, value in self.filter_data.items():
550
+ if not isinstance(value, list):
551
+ self.filter_data[key] = [value]
552
+ logger.warning(
553
+ "Values in filter_data dict must be lists. "
554
+ f"Changing {key} value to list ([{value}])."
555
+ )
556
+
557
+ def extract_embs(
558
+ self,
559
+ model_directory,
560
+ input_data_file,
561
+ output_directory,
562
+ output_prefix,
563
+ output_torch_embs=False,
564
+ cell_state=None,
565
+ ):
566
+ """
567
+ Extract embeddings from input data and save as results in output_directory.
568
+
569
+ **Parameters:**
570
+
571
+ model_directory : Path
572
+ | Path to directory containing model
573
+ input_data_file : Path
574
+ | Path to directory containing .dataset inputs
575
+ output_directory : Path
576
+ | Path to directory where embedding data will be saved as csv
577
+ output_prefix : str
578
+ | Prefix for output file
579
+ output_torch_embs : bool
580
+ | Whether or not to also output the embeddings as a tensor.
581
+ | Note, if true, will output embeddings as both dataframe and tensor.
582
+ cell_state : dict
583
+ | Cell state key and value for state embedding extraction.
584
+
585
+ **Examples:**
586
+
587
+ .. code-block :: python
588
+
589
+ >>> embs = embex.extract_embs("path/to/model",
590
+ ... "path/to/input_data",
591
+ ... "path/to/output_directory",
592
+ ... "output_prefix")
593
+
594
+ """
595
+
596
+ filtered_input_data = pu.load_and_filter(
597
+ self.filter_data, self.nproc, input_data_file
598
+ )
599
+
600
+ # Check to make sure that all the labels exist in the tokenized data:
601
+ if self.emb_label is not None:
602
+ for label in self.emb_label:
603
+ assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
604
+
605
+ if cell_state is not None:
606
+ filtered_input_data = pu.filter_by_dict(
607
+ filtered_input_data, cell_state, self.nproc
608
+ )
609
+ downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
610
+ model = pu.load_model(
611
+ self.model_type, self.num_classes, model_directory, mode="eval"
612
+ )
613
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
614
+ embs = get_embs(
615
+ model=model,
616
+ filtered_input_data=downsampled_data,
617
+ emb_mode=self.emb_mode,
618
+ layer_to_quant=layer_to_quant,
619
+ pad_token_id=self.pad_token_id,
620
+ forward_batch_size=self.forward_batch_size,
621
+ token_gene_dict=self.token_gene_dict,
622
+ summary_stat=self.summary_stat,
623
+ )
624
+
625
+ if self.emb_mode == "cell":
626
+ if self.summary_stat is None:
627
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
628
+ elif self.summary_stat is not None:
629
+ embs_df = pd.DataFrame(embs.cpu().numpy()).T
630
+ elif self.emb_mode == "gene":
631
+ if self.summary_stat is None:
632
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
633
+ elif self.summary_stat is not None:
634
+ embs_df = pd.DataFrame(embs).T
635
+ embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
636
+ elif self.emb_mode == "cls":
637
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
638
+
639
+ # save embeddings to output_path
640
+ if cell_state is None:
641
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
642
+ embs_df.to_csv(output_path)
643
+
644
+ if self.exact_summary_stat == "exact_mean":
645
+ embs = embs.mean(dim=0)
646
+ emb_dims = pu.get_model_emb_dims(model)
647
+ embs_df = pd.DataFrame(
648
+ embs_df[0 : emb_dims - 1].mean(axis="rows"),
649
+ columns=[self.exact_summary_stat],
650
+ ).T
651
+ elif self.exact_summary_stat == "exact_median":
652
+ embs = torch.median(embs, dim=0)[0]
653
+ emb_dims = pu.get_model_emb_dims(model)
654
+ embs_df = pd.DataFrame(
655
+ embs_df[0 : emb_dims - 1].median(axis="rows"),
656
+ columns=[self.exact_summary_stat],
657
+ ).T
658
+
659
+ if cell_state is not None:
660
+ return embs
661
+ else:
662
+ if output_torch_embs:
663
+ return embs_df, embs
664
+ else:
665
+ return embs_df
666
+
667
+ def get_state_embs(
668
+ self,
669
+ cell_states_to_model,
670
+ model_directory,
671
+ input_data_file,
672
+ output_directory,
673
+ output_prefix,
674
+ output_torch_embs=True,
675
+ ):
676
+ """
677
+ Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
678
+
679
+ **Parameters:**
680
+
681
+ cell_states_to_model : None, dict
682
+ | Cell states to model if testing perturbations that achieve goal state change.
683
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
684
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
685
+ | start_state: value in the state_key column that specifies the start state
686
+ | goal_state: value in the state_key column taht specifies the goal end state
687
+ | alt_states: list of values in the state_key column that specify the alternate end states
688
+ | For example:
689
+ | {"state_key": "disease",
690
+ | "start_state": "dcm",
691
+ | "goal_state": "nf",
692
+ | "alt_states": ["hcm", "other1", "other2"]}
693
+ model_directory : Path
694
+ | Path to directory containing model
695
+ input_data_file : Path
696
+ | Path to directory containing .dataset inputs
697
+ output_directory : Path
698
+ | Path to directory where embedding data will be saved as csv
699
+ output_prefix : str
700
+ | Prefix for output file
701
+ output_torch_embs : bool
702
+ | Whether or not to also output the embeddings as a tensor.
703
+ | Note, if true, will output embeddings as both dataframe and tensor.
704
+
705
+ **Outputs**
706
+
707
+ | Outputs state_embs_dict for use with in silico perturber.
708
+ | Format is dictionary of embedding positions of each cell state to model shifts from/towards.
709
+ | Keys specify each possible cell state to model.
710
+ | Values are target embedding positions as torch.tensor.
711
+ | For example:
712
+ | {"nf": emb_nf,
713
+ | "hcm": emb_hcm,
714
+ | "dcm": emb_dcm,
715
+ | "other1": emb_other1,
716
+ | "other2": emb_other2}
717
+ """
718
+
719
+ pu.validate_cell_states_to_model(cell_states_to_model)
720
+ valid_summary_stats = ["exact_mean", "exact_median"]
721
+ if self.exact_summary_stat not in valid_summary_stats:
722
+ logger.error(
723
+ "For extracting state embs, summary_stat in EmbExtractor "
724
+ f"must be set to option in {valid_summary_stats}"
725
+ )
726
+ raise
727
+
728
+ if self.emb_label is not None:
729
+ logger.error(
730
+ "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
731
+ )
732
+ raise
733
+
734
+ state_embs_dict = dict()
735
+ state_key = cell_states_to_model["state_key"]
736
+ for k, v in cell_states_to_model.items():
737
+ if k == "state_key":
738
+ continue
739
+ elif (k == "start_state") or (k == "goal_state"):
740
+ state_embs_dict[v] = self.extract_embs(
741
+ model_directory,
742
+ input_data_file,
743
+ output_directory,
744
+ output_prefix,
745
+ output_torch_embs,
746
+ cell_state={state_key: v},
747
+ )
748
+ else: # k == "alt_states"
749
+ for alt_state in v:
750
+ state_embs_dict[alt_state] = self.extract_embs(
751
+ model_directory,
752
+ input_data_file,
753
+ output_directory,
754
+ output_prefix,
755
+ output_torch_embs,
756
+ cell_state={state_key: alt_state},
757
+ )
758
+
759
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
760
+ with open(output_path, "wb") as fp:
761
+ pickle.dump(state_embs_dict, fp)
762
+
763
+ return state_embs_dict
764
+
765
+ def plot_embs(
766
+ self,
767
+ embs,
768
+ plot_style,
769
+ output_directory,
770
+ output_prefix,
771
+ max_ncells_to_plot=1000,
772
+ kwargs_dict=None,
773
+ ):
774
+ """
775
+ Plot embeddings, coloring by provided labels.
776
+
777
+ **Parameters:**
778
+
779
+ embs : pandas.core.frame.DataFrame
780
+ | Pandas dataframe containing embeddings output from extract_embs
781
+ plot_style : str
782
+ | Style of plot: "heatmap" or "umap"
783
+ output_directory : Path
784
+ | Path to directory where plots will be saved as pdf
785
+ output_prefix : str
786
+ | Prefix for output file
787
+ max_ncells_to_plot : None, int
788
+ | Maximum number of cells to plot.
789
+ | Default is 1000 cells randomly sampled from embeddings.
790
+ | If None, will plot embeddings from all cells.
791
+ kwargs_dict : dict
792
+ | Dictionary of kwargs to pass to plotting function.
793
+
794
+ **Examples:**
795
+
796
+ .. code-block :: python
797
+
798
+ >>> embex.plot_embs(embs=embs,
799
+ ... plot_style="heatmap",
800
+ ... output_directory="path/to/output_directory",
801
+ ... output_prefix="output_prefix")
802
+
803
+ """
804
+
805
+ if plot_style not in ["heatmap", "umap"]:
806
+ logger.error(
807
+ "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
808
+ )
809
+ raise
810
+
811
+ if (plot_style == "umap") and (self.labels_to_plot is None):
812
+ logger.error("Plotting UMAP requires 'labels_to_plot'. ")
813
+ raise
814
+
815
+ if max_ncells_to_plot is not None:
816
+ if max_ncells_to_plot > self.max_ncells:
817
+ max_ncells_to_plot = self.max_ncells
818
+ logger.warning(
819
+ "max_ncells_to_plot must be <= max_ncells. "
820
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
821
+ )
822
+ elif max_ncells_to_plot < self.max_ncells:
823
+ embs = embs.sample(max_ncells_to_plot, axis=0)
824
+
825
+ if self.emb_label is None:
826
+ label_len = 0
827
+ else:
828
+ label_len = len(self.emb_label)
829
+
830
+ emb_dims = embs.shape[1] - label_len
831
+
832
+ if self.emb_label is None:
833
+ emb_labels = None
834
+ else:
835
+ emb_labels = embs.columns[emb_dims:]
836
+
837
+ if plot_style == "umap":
838
+ for label in self.labels_to_plot:
839
+ if label not in emb_labels:
840
+ logger.warning(
841
+ f"Label {label} from labels_to_plot "
842
+ f"not present in provided embeddings dataframe."
843
+ )
844
+ continue
845
+ output_prefix_label = output_prefix + f"_umap_{label}"
846
+ output_file = (
847
+ Path(output_directory) / output_prefix_label
848
+ ).with_suffix(".pdf")
849
+ plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
850
+
851
+ if plot_style == "heatmap":
852
+ for label in self.labels_to_plot:
853
+ if label not in emb_labels:
854
+ logger.warning(
855
+ f"Label {label} from labels_to_plot "
856
+ f"not present in provided embeddings dataframe."
857
+ )
858
+ continue
859
+ output_prefix_label = output_prefix + f"_heatmap_{label}"
860
+ output_file = (
861
+ Path(output_directory) / output_prefix_label
862
+ ).with_suffix(".pdf")
863
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/evaluation_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import pickle
4
+ from pathlib import Path
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ import torch
11
+ from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
+ from sklearn import preprocessing
13
+ from sklearn.metrics import (
14
+ ConfusionMatrixDisplay,
15
+ accuracy_score,
16
+ auc,
17
+ confusion_matrix,
18
+ f1_score,
19
+ roc_curve,
20
+ )
21
+ from tqdm.auto import trange
22
+
23
+ from . import TOKEN_DICTIONARY_FILE
24
+ from .emb_extractor import make_colorbar
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
+ if max_len is None:
31
+ max_len = max([len(i) for i in cell_batch["input_ids"]])
32
+
33
+ # load token dictionary (Ensembl IDs:token)
34
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
+ gene_token_dict = pickle.load(f)
36
+
37
+ def pad_label_example(example):
38
+ example[label_name] = np.pad(
39
+ example[label_name],
40
+ (0, max_len - len(example["input_ids"])),
41
+ mode="constant",
42
+ constant_values=-100,
43
+ )
44
+ example["input_ids"] = np.pad(
45
+ example["input_ids"],
46
+ (0, max_len - len(example["input_ids"])),
47
+ mode="constant",
48
+ constant_values=gene_token_dict.get("<pad>"),
49
+ )
50
+ example["attention_mask"] = (
51
+ example["input_ids"] != gene_token_dict.get("<pad>")
52
+ ).astype(int)
53
+ return example
54
+
55
+ padded_batch = cell_batch.map(pad_label_example)
56
+ return padded_batch
57
+
58
+
59
+ # Function to find the largest number smaller
60
+ # than or equal to N that is divisible by k
61
+ def find_largest_div(N, K):
62
+ rem = N % K
63
+ if rem == 0:
64
+ return N
65
+ else:
66
+ return N - rem
67
+
68
+
69
+ def vote(logit_list):
70
+ m = max(logit_list)
71
+ logit_list.index(m)
72
+ indices = [i for i, x in enumerate(logit_list) if x == m]
73
+ if len(indices) > 1:
74
+ return "tie"
75
+ else:
76
+ return indices[0]
77
+
78
+
79
+ def py_softmax(vector):
80
+ e = np.exp(vector)
81
+ return e / e.sum()
82
+
83
+
84
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
+ if classifier_type == "gene":
86
+ label_name = "labels"
87
+ elif classifier_type == "cell":
88
+ label_name = "label"
89
+
90
+ predict_logits = []
91
+ predict_labels = []
92
+ model.eval()
93
+
94
+ # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
95
+ evalset_len = len(evalset)
96
+ max_divisible = find_largest_div(evalset_len, forward_batch_size)
97
+ if len(evalset) - max_divisible == 1:
98
+ evalset_len = max_divisible
99
+
100
+ max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
101
+
102
+ disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
103
+ for i in trange(0, evalset_len, forward_batch_size):
104
+ max_range = min(i + forward_batch_size, evalset_len)
105
+ batch_evalset = evalset.select([i for i in range(i, max_range)])
106
+ padded_batch = preprocess_classifier_batch(
107
+ batch_evalset, max_evalset_len, label_name
108
+ )
109
+ padded_batch.set_format(type="torch")
110
+
111
+ input_data_batch = padded_batch["input_ids"]
112
+ attn_msk_batch = padded_batch["attention_mask"]
113
+ label_batch = padded_batch[label_name]
114
+ with torch.no_grad():
115
+ outputs = model(
116
+ input_ids=input_data_batch.to("cuda"),
117
+ attention_mask=attn_msk_batch.to("cuda"),
118
+ labels=label_batch.to("cuda"),
119
+ )
120
+ predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
121
+ predict_labels += [torch.squeeze(label_batch.to("cpu"))]
122
+
123
+ enable_progress_bar()
124
+ logits_by_cell = torch.cat(predict_logits)
125
+ last_dim = len(logits_by_cell.shape) - 1
126
+ all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
127
+ labels_by_cell = torch.cat(predict_labels)
128
+ all_labels = torch.flatten(labels_by_cell)
129
+ logit_label_paired = [
130
+ item
131
+ for item in list(zip(all_logits.tolist(), all_labels.tolist()))
132
+ if item[1] != -100
133
+ ]
134
+ y_pred = [vote(item[0]) for item in logit_label_paired]
135
+ y_true = [item[1] for item in logit_label_paired]
136
+ logits_list = [item[0] for item in logit_label_paired]
137
+ return y_pred, y_true, logits_list
138
+
139
+
140
+ def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
141
+ conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
142
+ macro_f1 = f1_score(y_true, y_pred, average="macro")
143
+ acc = accuracy_score(y_true, y_pred)
144
+ roc_metrics = None # roc metrics not reported for multiclass
145
+ if num_classes == 2:
146
+ y_score = [py_softmax(item)[1] for item in logits_list]
147
+ fpr, tpr, _ = roc_curve(y_true, y_score)
148
+ mean_fpr = np.linspace(0, 1, 100)
149
+ interp_tpr = np.interp(mean_fpr, fpr, tpr)
150
+ interp_tpr[0] = 0.0
151
+ tpr_wt = len(tpr)
152
+ roc_auc = auc(fpr, tpr)
153
+ roc_metrics = {
154
+ "fpr": fpr,
155
+ "tpr": tpr,
156
+ "interp_tpr": interp_tpr,
157
+ "auc": roc_auc,
158
+ "tpr_wt": tpr_wt,
159
+ }
160
+ return conf_mat, macro_f1, acc, roc_metrics
161
+
162
+
163
+ # get cross-validated mean and sd metrics
164
+ def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
165
+ wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
166
+ all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
167
+ mean_tpr = np.sum(all_weighted_tpr, axis=0)
168
+ mean_tpr[-1] = 1.0
169
+ all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
170
+ roc_auc = np.sum(all_weighted_roc_auc)
171
+ roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
172
+ return mean_tpr, roc_auc, roc_auc_sd
173
+
174
+
175
+ # plot ROC curve
176
+ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
177
+ fig = plt.figure()
178
+ fig.set_size_inches(10, 8)
179
+ sns.set(font_scale=2)
180
+ sns.set_style("white")
181
+ lw = 3
182
+ for model_name in roc_metric_dict.keys():
183
+ mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
+ mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
+ roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
+ roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
+ color = model_style_dict[model_name]["color"]
188
+ linestyle = model_style_dict[model_name]["linestyle"]
189
+ if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
+ label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
191
+ else:
192
+ label = f"{model_name} (AUC {roc_auc:0.2f})"
193
+ plt.plot(
194
+ mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
+ )
196
+
197
+ plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
198
+ plt.xlim([0.0, 1.0])
199
+ plt.ylim([0.0, 1.05])
200
+ plt.xlabel("False Positive Rate")
201
+ plt.ylabel("True Positive Rate")
202
+ plt.title(title)
203
+ plt.legend(loc="lower right")
204
+
205
+ output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
206
+ plt.savefig(output_file, bbox_inches="tight")
207
+ plt.show()
208
+
209
+
210
+ # plot confusion matrix
211
+ def plot_confusion_matrix(
212
+ conf_mat_df, title, output_dir, output_prefix, custom_class_order
213
+ ):
214
+ fig = plt.figure()
215
+ fig.set_size_inches(10, 10)
216
+ sns.set(font_scale=1)
217
+ sns.set_style("whitegrid", {"axes.grid": False})
218
+ if custom_class_order is not None:
219
+ conf_mat_df = conf_mat_df.reindex(
220
+ index=custom_class_order, columns=custom_class_order
221
+ )
222
+ display_labels = generate_display_labels(conf_mat_df)
223
+ conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
224
+ display = ConfusionMatrixDisplay(
225
+ confusion_matrix=conf_mat, display_labels=display_labels
226
+ )
227
+ display.plot(cmap="Blues", values_format=".2g")
228
+ plt.title(title)
229
+ plt.show()
230
+
231
+ output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
232
+ display.figure_.savefig(output_file, bbox_inches="tight")
233
+
234
+
235
+ def generate_display_labels(conf_mat_df):
236
+ display_labels = []
237
+ i = 0
238
+ for label in conf_mat_df.index:
239
+ display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
240
+ i = i + 1
241
+ return display_labels
242
+
243
+
244
+ def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
245
+ sns.set(font_scale=2)
246
+ plt.figure(figsize=(10, 10), dpi=150)
247
+ label_colors, label_color_dict = make_colorbar(predictions_df, "true")
248
+ predictions_df = predictions_df.drop(columns=["true"])
249
+ predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
250
+ predict_label_list = [label for label in predictions_df.columns]
251
+ predict_colors = pd.DataFrame(
252
+ pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
253
+ )
254
+
255
+ default_kwargs_dict = {
256
+ "row_cluster": False,
257
+ "col_cluster": False,
258
+ "row_colors": label_colors,
259
+ "col_colors": predict_colors,
260
+ "linewidths": 0,
261
+ "xticklabels": False,
262
+ "yticklabels": False,
263
+ "center": 0,
264
+ "cmap": "vlag",
265
+ }
266
+
267
+ if kwargs_dict is not None:
268
+ default_kwargs_dict.update(kwargs_dict)
269
+ g = sns.clustermap(predictions_df, **default_kwargs_dict)
270
+
271
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
272
+
273
+ for label_color in list(label_color_dict.keys()):
274
+ g.ax_col_dendrogram.bar(
275
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
276
+ )
277
+
278
+ g.ax_col_dendrogram.legend(
279
+ title=f"{title}",
280
+ loc="lower center",
281
+ ncol=4,
282
+ bbox_to_anchor=(0.5, 1),
283
+ facecolor="white",
284
+ )
285
+
286
+ output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
287
+ plt.savefig(output_file, bbox_inches="tight")
geneformer/in_silico_perturber.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber.
3
+
4
+ **Usage:**
5
+
6
+ .. code-block :: python
7
+
8
+ >>> from geneformer import InSilicoPerturber
9
+ >>> isp = InSilicoPerturber(perturb_type="delete",
10
+ ... perturb_rank_shift=None,
11
+ ... genes_to_perturb="all",
12
+ ... model_type="CellClassifier",
13
+ ... num_classes=0,
14
+ ... emb_mode="cell",
15
+ ... filter_data={"cell_type":["cardiomyocyte"]},
16
+ ... cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
+ ... state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
18
+ ... max_ncells=None,
19
+ ... emb_layer=0,
20
+ ... forward_batch_size=100,
21
+ ... nproc=16)
22
+ >>> isp.perturb_data("path/to/model",
23
+ ... "path/to/input_data",
24
+ ... "path/to/output_directory",
25
+ ... "output_prefix")
26
+
27
+ **Description:**
28
+
29
+ | Performs in silico perturbation (e.g. deletion or overexpression) of defined set of genes or all genes in sample of cells.
30
+ | Outputs impact of perturbation on cell or gene embeddings.
31
+ | Output files are analyzed with ``in_silico_perturber_stats``.
32
+
33
+ """
34
+
35
+ import logging
36
+
37
+ # imports
38
+ import os
39
+ import pickle
40
+ from collections import defaultdict
41
+
42
+ import torch
43
+ from datasets import Dataset
44
+ from multiprocess import set_start_method
45
+ from tqdm.auto import trange
46
+
47
+ from . import TOKEN_DICTIONARY_FILE
48
+ from . import perturber_utils as pu
49
+ from .emb_extractor import get_embs
50
+
51
+ import datasets
52
+ datasets.logging.disable_progress_bar()
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ class InSilicoPerturber:
59
+ valid_option_dict = {
60
+ "perturb_type": {"delete", "overexpress", "inhibit", "activate"},
61
+ "perturb_rank_shift": {None, 1, 2, 3},
62
+ "genes_to_perturb": {"all", list},
63
+ "combos": {0, 1},
64
+ "anchor_gene": {None, str},
65
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
66
+ "num_classes": {int},
67
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
68
+ "cell_emb_style": {"mean_pool"},
69
+ "filter_data": {None, dict},
70
+ "cell_states_to_model": {None, dict},
71
+ "state_embs_dict": {None, dict},
72
+ "max_ncells": {None, int},
73
+ "cell_inds_to_perturb": {"all", dict},
74
+ "emb_layer": {-1, 0},
75
+ "token_dictionary_file": {None, str},
76
+ "forward_batch_size": {int},
77
+ "nproc": {int},
78
+ }
79
+
80
+ def __init__(
81
+ self,
82
+ perturb_type="delete",
83
+ perturb_rank_shift=None,
84
+ genes_to_perturb="all",
85
+ combos=0,
86
+ anchor_gene=None,
87
+ model_type="Pretrained",
88
+ num_classes=0,
89
+ emb_mode="cls",
90
+ cell_emb_style="mean_pool",
91
+ filter_data=None,
92
+ cell_states_to_model=None,
93
+ state_embs_dict=None,
94
+ max_ncells=None,
95
+ cell_inds_to_perturb="all",
96
+ emb_layer=-1,
97
+ forward_batch_size=100,
98
+ nproc=4,
99
+ token_dictionary_file=None,
100
+ clear_mem_ncells=1000,
101
+ ):
102
+ """
103
+ Initialize in silico perturber.
104
+
105
+ **Parameters:**
106
+
107
+ perturb_type : {"delete", "overexpress", "inhibit", "activate"}
108
+ | Type of perturbation.
109
+ | "delete": delete gene from rank value encoding
110
+ | "overexpress": move gene to front of rank value encoding
111
+ | *(TBA)* "inhibit": move gene to lower quartile of rank value encoding
112
+ | *(TBA)* "activate": move gene to higher quartile of rank value encoding
113
+ *(TBA)* perturb_rank_shift : None, {1,2,3}
114
+ | Number of quartiles by which to shift rank of gene.
115
+ | For example, if perturb_type="activate" and perturb_rank_shift=1:
116
+ | genes in 4th quartile will move to middle of 3rd quartile.
117
+ | genes in 3rd quartile will move to middle of 2nd quartile.
118
+ | genes in 2nd quartile will move to middle of 1st quartile.
119
+ | genes in 1st quartile will move to front of rank value encoding.
120
+ | For example, if perturb_type="inhibit" and perturb_rank_shift=2:
121
+ | genes in 1st quartile will move to middle of 3rd quartile.
122
+ | genes in 2nd quartile will move to middle of 4th quartile.
123
+ | genes in 3rd or 4th quartile will move to bottom of rank value encoding.
124
+ genes_to_perturb : "all", list
125
+ | Default is perturbing each gene detected in each cell in the dataset.
126
+ | Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
127
+ | If gene list is provided, then perturber will only test perturbing them all together
128
+ | (rather than testing each possible combination of the provided genes).
129
+ combos : {0,1}
130
+ | Whether to perturb genes individually (0) or in pairs (1).
131
+ anchor_gene : None, str
132
+ | ENSEMBL ID of gene to use as anchor in combination perturbations.
133
+ | For example, if combos=1 and anchor_gene="ENSG00000148400":
134
+ | anchor gene will be perturbed in combination with each other gene.
135
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
136
+ | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
137
+ num_classes : int
138
+ | If model is a gene or cell classifier, specify number of classes it was trained to classify.
139
+ | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
140
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
141
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
142
+ | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
143
+ cell_emb_style : "mean_pool"
144
+ | Method for summarizing cell embeddings if not using CLS token.
145
+ | Currently only option is mean pooling of gene embeddings for given cell.
146
+ filter_data : None, dict
147
+ | Default is to use all input data for in silico perturbation study.
148
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
149
+ cell_states_to_model : None, dict
150
+ | Cell states to model if testing perturbations that achieve goal state change.
151
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
152
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
153
+ | start_state: value in the state_key column that specifies the start state
154
+ | goal_state: value in the state_key column taht specifies the goal end state
155
+ | alt_states: list of values in the state_key column that specify the alternate end states
156
+ | For example: {"state_key": "disease",
157
+ | "start_state": "dcm",
158
+ | "goal_state": "nf",
159
+ | "alt_states": ["hcm", "other1", "other2"]}
160
+ state_embs_dict : None, dict
161
+ | Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
162
+ | Dictionary with keys specifying each possible cell state to model.
163
+ | Values are target embedding positions as torch.tensor.
164
+ | For example: {"nf": emb_nf,
165
+ | "hcm": emb_hcm,
166
+ | "dcm": emb_dcm,
167
+ | "other1": emb_other1,
168
+ | "other2": emb_other2}
169
+ max_ncells : None, int
170
+ | Maximum number of cells to test.
171
+ | If None, will test all cells.
172
+ cell_inds_to_perturb : "all", list
173
+ | Default is perturbing each cell in the dataset.
174
+ | Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
175
+ | start_ind: the first index to perturb.
176
+ | end_ind: the last index to perturb (exclusive).
177
+ | Indices will be selected *after* the filter_data criteria and sorting.
178
+ | Useful for splitting extremely large datasets across separate GPUs.
179
+ emb_layer : {-1, 0}
180
+ | Embedding layer to use for quantification.
181
+ | 0: last layer (recommended for questions closely tied to model's training objective)
182
+ | -1: 2nd to last layer (recommended for questions requiring more general representations)
183
+ forward_batch_size : int
184
+ | Batch size for forward pass.
185
+ nproc : int
186
+ | Number of CPU processes to use.
187
+ token_dictionary_file : Path
188
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
189
+ clear_mem_ncells : int
190
+ | Clear memory every n cells.
191
+ """
192
+ try:
193
+ set_start_method("spawn")
194
+ except RuntimeError:
195
+ pass
196
+
197
+ self.perturb_type = perturb_type
198
+ self.perturb_rank_shift = perturb_rank_shift
199
+ self.genes_to_perturb = genes_to_perturb
200
+ self.combos = combos
201
+ self.anchor_gene = anchor_gene
202
+ if self.genes_to_perturb == "all":
203
+ self.perturb_group = False
204
+ else:
205
+ self.perturb_group = True
206
+ if (self.anchor_gene is not None) or (self.combos != 0):
207
+ self.anchor_gene = None
208
+ self.combos = 0
209
+ logger.warning(
210
+ "anchor_gene set to None and combos set to 0. "
211
+ "If providing list of genes to perturb, "
212
+ "list of genes_to_perturb will be perturbed together, "
213
+ "without anchor gene or combinations."
214
+ )
215
+ self.model_type = model_type
216
+ self.num_classes = num_classes
217
+ self.emb_mode = emb_mode
218
+ self.cell_emb_style = cell_emb_style
219
+ self.filter_data = filter_data
220
+ self.cell_states_to_model = cell_states_to_model
221
+ self.state_embs_dict = state_embs_dict
222
+ self.max_ncells = max_ncells
223
+ self.cell_inds_to_perturb = cell_inds_to_perturb
224
+ self.emb_layer = emb_layer
225
+ self.forward_batch_size = forward_batch_size
226
+ self.nproc = nproc
227
+ self.token_dictionary_file = token_dictionary_file
228
+ self.clear_mem_ncells = clear_mem_ncells
229
+
230
+ self.validate_options()
231
+
232
+ # load token dictionary (Ensembl IDs:token)
233
+ if self.token_dictionary_file is None:
234
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
235
+ with open(token_dictionary_file, "rb") as f:
236
+ self.gene_token_dict = pickle.load(f)
237
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
238
+
239
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
240
+ self.cls_token_id = self.gene_token_dict.get("<cls>")
241
+ self.eos_token_id = self.gene_token_dict.get("<eos>")
242
+
243
+ # Identify if special token is present in the token dictionary
244
+ if (self.cls_token_id is not None) and (self.eos_token_id is not None):
245
+ self.special_token = True
246
+ else:
247
+ if "cls" in self.emb_mode:
248
+ logger.error(
249
+ f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
250
+ )
251
+ raise
252
+ self.special_token = False
253
+
254
+ if self.anchor_gene is None:
255
+ self.anchor_token = None
256
+ else:
257
+ try:
258
+ self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
259
+ except KeyError:
260
+ logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
261
+ raise
262
+
263
+ if self.genes_to_perturb == "all":
264
+ self.tokens_to_perturb = "all"
265
+ else:
266
+ missing_genes = [
267
+ gene
268
+ for gene in self.genes_to_perturb
269
+ if gene not in self.gene_token_dict.keys()
270
+ ]
271
+ if len(missing_genes) == len(self.genes_to_perturb):
272
+ logger.error(
273
+ "None of the provided genes to perturb are in token dictionary."
274
+ )
275
+ raise
276
+ elif len(missing_genes) > 0:
277
+ logger.warning(
278
+ f"Genes to perturb {missing_genes} are not in token dictionary."
279
+ )
280
+ self.tokens_to_perturb = [
281
+ self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
282
+ ]
283
+
284
+ def validate_options(self):
285
+ # first disallow options under development
286
+ if self.perturb_type in ["inhibit", "activate"]:
287
+ logger.error(
288
+ "In silico inhibition and activation currently under development. "
289
+ "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
290
+ )
291
+ raise
292
+ if (self.combos > 0) and (self.anchor_gene is None):
293
+ logger.error(
294
+ "Combination perturbation without anchor gene is currently under development. "
295
+ "Currently, must provide anchor gene for combination perturbation."
296
+ )
297
+ raise
298
+
299
+ # confirm arguments are within valid options and compatible with each other
300
+ for attr_name, valid_options in self.valid_option_dict.items():
301
+ attr_value = self.__dict__[attr_name]
302
+ if type(attr_value) not in {list, dict}:
303
+ if attr_value in valid_options:
304
+ continue
305
+ if attr_name in ["anchor_gene"]:
306
+ if type(attr_name) in {str}:
307
+ continue
308
+ valid_type = False
309
+ for option in valid_options:
310
+ if (option in [bool, int, list, dict, str]) and isinstance(
311
+ attr_value, option
312
+ ):
313
+ valid_type = True
314
+ break
315
+ if valid_type:
316
+ continue
317
+ logger.error(
318
+ f"Invalid option for {attr_name}. "
319
+ f"Valid options for {attr_name}: {valid_options}"
320
+ )
321
+ raise
322
+
323
+ if self.perturb_type in ["delete", "overexpress"]:
324
+ if self.perturb_rank_shift is not None:
325
+ if self.perturb_type == "delete":
326
+ logger.warning(
327
+ "perturb_rank_shift set to None. "
328
+ "If perturb type is delete then gene is deleted entirely "
329
+ "rather than shifted by quartile"
330
+ )
331
+ elif self.perturb_type == "overexpress":
332
+ logger.warning(
333
+ "perturb_rank_shift set to None. "
334
+ "If perturb type is overexpress then gene is moved to front "
335
+ "of rank value encoding rather than shifted by quartile"
336
+ )
337
+ self.perturb_rank_shift = None
338
+
339
+ if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
340
+ self.emb_mode = "cell"
341
+ logger.warning(
342
+ "emb_mode set to 'cell'. "
343
+ "Currently, analysis with anchor gene "
344
+ "only outputs effect on cell embeddings."
345
+ )
346
+
347
+ if self.cell_states_to_model is not None:
348
+ pu.validate_cell_states_to_model(self.cell_states_to_model)
349
+
350
+ if self.anchor_gene is not None:
351
+ self.anchor_gene = None
352
+ logger.warning(
353
+ "anchor_gene set to None. "
354
+ "Currently, anchor gene not available "
355
+ "when modeling multiple cell states."
356
+ )
357
+
358
+ if self.state_embs_dict is None:
359
+ logger.error(
360
+ "state_embs_dict must be provided for mode with cell_states_to_model. "
361
+ "Format is dictionary with keys specifying each possible cell state to model. "
362
+ "Values are target embedding positions as torch.tensor."
363
+ )
364
+ raise
365
+
366
+ for state_emb in self.state_embs_dict.values():
367
+ if not torch.is_tensor(state_emb):
368
+ logger.error(
369
+ "state_embs_dict must be dictionary with values being torch.tensor."
370
+ )
371
+ raise
372
+
373
+ keys_absent = []
374
+ for k, v in self.cell_states_to_model.items():
375
+ if (k == "start_state") or (k == "goal_state"):
376
+ if v not in self.state_embs_dict.keys():
377
+ keys_absent.append(v)
378
+ if k == "alt_states":
379
+ for state in v:
380
+ if state not in self.state_embs_dict.keys():
381
+ keys_absent.append(state)
382
+ if len(keys_absent) > 0:
383
+ logger.error(
384
+ "Each start_state, goal_state, and alt_states in cell_states_to_model "
385
+ "must be a key in state_embs_dict with the value being "
386
+ "the state's embedding position as torch.tensor. "
387
+ f"Missing keys: {keys_absent}"
388
+ )
389
+ raise
390
+
391
+ if self.perturb_type in ["inhibit", "activate"]:
392
+ if self.perturb_rank_shift is None:
393
+ logger.error(
394
+ "If perturb_type is inhibit or activate then "
395
+ "quartile to shift by must be specified."
396
+ )
397
+ raise
398
+
399
+ if self.filter_data is not None:
400
+ for key, value in self.filter_data.items():
401
+ if not isinstance(value, list):
402
+ self.filter_data[key] = [value]
403
+ logger.warning(
404
+ "Values in filter_data dict must be lists. "
405
+ f"Changing {key} value to list ([{value}])."
406
+ )
407
+
408
+ if self.cell_inds_to_perturb != "all":
409
+ if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
410
+ logger.error(
411
+ "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
412
+ )
413
+ raise
414
+ if (
415
+ self.cell_inds_to_perturb["start"] < 0
416
+ or self.cell_inds_to_perturb["end"] < 0
417
+ ):
418
+ logger.error("cell_inds_to_perturb must be positive.")
419
+ raise
420
+
421
+ def perturb_data(
422
+ self, model_directory, input_data_file, output_directory, output_prefix
423
+ ):
424
+ """
425
+ Perturb genes in input data and save as results in output_directory.
426
+
427
+ **Parameters:**
428
+
429
+ model_directory : Path
430
+ | Path to directory containing model
431
+ input_data_file : Path
432
+ | Path to directory containing .dataset inputs
433
+ output_directory : Path
434
+ | Path to directory where perturbation data will be saved as batched pickle files
435
+ output_prefix : str
436
+ | Prefix for output files
437
+ """
438
+
439
+ ### format output path ###
440
+ output_path_prefix = os.path.join(
441
+ output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
442
+ )
443
+
444
+ ### load model and define parameters ###
445
+ model = pu.load_model(
446
+ self.model_type, self.num_classes, model_directory, mode="eval"
447
+ )
448
+ self.max_len = pu.get_model_input_size(model)
449
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
450
+
451
+ ### filter input data ###
452
+ # general filtering of input data based on filter_data argument
453
+ filtered_input_data = pu.load_and_filter(
454
+ self.filter_data, self.nproc, input_data_file
455
+ )
456
+
457
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
458
+ if self.special_token:
459
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
460
+ "cls" not in self.emb_mode
461
+ ):
462
+ logger.error(
463
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
464
+ )
465
+ raise
466
+ if "cls" in self.emb_mode:
467
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
468
+ filtered_input_data["input_ids"][0][-1] != self.eos_token_id
469
+ ):
470
+ logger.error(
471
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
472
+ )
473
+ raise
474
+
475
+ filtered_input_data = self.apply_additional_filters(filtered_input_data)
476
+
477
+ if self.perturb_group is True:
478
+ if (self.special_token) and ("cls" in self.emb_mode):
479
+ self.isp_perturb_set_special(
480
+ model, filtered_input_data, layer_to_quant, output_path_prefix
481
+ )
482
+ else:
483
+ self.isp_perturb_set(
484
+ model, filtered_input_data, layer_to_quant, output_path_prefix
485
+ )
486
+ else:
487
+ if (self.special_token) and ("cls" in self.emb_mode):
488
+ self.isp_perturb_all_special(
489
+ model, filtered_input_data, layer_to_quant, output_path_prefix
490
+ )
491
+ else:
492
+ self.isp_perturb_all(
493
+ model, filtered_input_data, layer_to_quant, output_path_prefix
494
+ )
495
+
496
+ def apply_additional_filters(self, filtered_input_data):
497
+ # additional filtering of input data dependent on isp mode
498
+ if self.cell_states_to_model is not None:
499
+ # filter for cells with start_state and log result
500
+ filtered_input_data = pu.filter_data_by_start_state(
501
+ filtered_input_data, self.cell_states_to_model, self.nproc
502
+ )
503
+
504
+ if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
505
+ # filter for cells with tokens_to_perturb and log result
506
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
507
+ filtered_input_data,
508
+ self.tokens_to_perturb,
509
+ self.nproc,
510
+ "genes_to_perturb",
511
+ )
512
+
513
+ if self.anchor_token is not None:
514
+ # filter for cells with anchor gene and log result
515
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
516
+ filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
517
+ )
518
+
519
+ # downsample and sort largest to smallest to encounter memory constraints earlier
520
+ filtered_input_data = pu.downsample_and_sort(
521
+ filtered_input_data, self.max_ncells
522
+ )
523
+
524
+ # slice dataset if cells_inds_to_perturb is not "all"
525
+ if self.cell_inds_to_perturb != "all":
526
+ filtered_input_data = pu.slice_by_inds_to_perturb(
527
+ filtered_input_data, self.cell_inds_to_perturb
528
+ )
529
+
530
+ return filtered_input_data
531
+
532
+ def isp_perturb_set(
533
+ self,
534
+ model,
535
+ filtered_input_data: Dataset,
536
+ layer_to_quant: int,
537
+ output_path_prefix: str,
538
+ ):
539
+ def make_group_perturbation_batch(example):
540
+ example_input_ids = example["input_ids"]
541
+ example["tokens_to_perturb"] = self.tokens_to_perturb
542
+ indices_to_perturb = [
543
+ example_input_ids.index(token) if token in example_input_ids else None
544
+ for token in self.tokens_to_perturb
545
+ ]
546
+ indices_to_perturb = [
547
+ item for item in indices_to_perturb if item is not None
548
+ ]
549
+ if len(indices_to_perturb) > 0:
550
+ example["perturb_index"] = indices_to_perturb
551
+ else:
552
+ # -100 indicates tokens to overexpress are not present in rank value encoding
553
+ example["perturb_index"] = [-100]
554
+ if self.perturb_type == "delete":
555
+ example = pu.delete_indices(example)
556
+ elif self.perturb_type == "overexpress":
557
+ example = pu.overexpress_tokens(
558
+ example, self.max_len, self.special_token
559
+ )
560
+ example["n_overflow"] = pu.calc_n_overflow(
561
+ self.max_len,
562
+ example["length"],
563
+ self.tokens_to_perturb,
564
+ indices_to_perturb,
565
+ )
566
+ return example
567
+
568
+ total_batch_length = len(filtered_input_data)
569
+ if self.cell_states_to_model is None:
570
+ cos_sims_dict = defaultdict(list)
571
+ else:
572
+ cos_sims_dict = {
573
+ state: defaultdict(list)
574
+ for state in pu.get_possible_states(self.cell_states_to_model)
575
+ }
576
+
577
+ perturbed_data = filtered_input_data.map(
578
+ make_group_perturbation_batch, num_proc=self.nproc
579
+ )
580
+
581
+ if self.perturb_type == "overexpress":
582
+ filtered_input_data = filtered_input_data.add_column(
583
+ "n_overflow", perturbed_data["n_overflow"]
584
+ )
585
+ # remove overflow genes from original data so that embeddings are comparable
586
+ # i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
587
+ # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
588
+ # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
589
+ # rather than only adding 2048)
590
+ filtered_input_data = filtered_input_data.map(
591
+ pu.truncate_by_n_overflow, num_proc=self.nproc
592
+ )
593
+
594
+ if self.emb_mode == "cell_and_gene":
595
+ stored_gene_embs_dict = defaultdict(list)
596
+
597
+ # iterate through batches
598
+ for i in trange(0, total_batch_length, self.forward_batch_size):
599
+ max_range = min(i + self.forward_batch_size, total_batch_length)
600
+ inds_select = [i for i in range(i, max_range)]
601
+
602
+ minibatch = filtered_input_data.select(inds_select)
603
+ perturbation_batch = perturbed_data.select(inds_select)
604
+
605
+ if self.cell_emb_style == "mean_pool":
606
+ full_original_emb = get_embs(
607
+ model,
608
+ minibatch,
609
+ "gene",
610
+ layer_to_quant,
611
+ self.pad_token_id,
612
+ self.forward_batch_size,
613
+ token_gene_dict=self.token_gene_dict,
614
+ summary_stat=None,
615
+ silent=True,
616
+ )
617
+ indices_to_perturb = perturbation_batch["perturb_index"]
618
+ # remove indices that were perturbed
619
+ original_emb = pu.remove_perturbed_indices_set(
620
+ full_original_emb,
621
+ self.perturb_type,
622
+ indices_to_perturb,
623
+ self.tokens_to_perturb,
624
+ minibatch["length"],
625
+ )
626
+ full_perturbation_emb = get_embs(
627
+ model,
628
+ perturbation_batch,
629
+ "gene",
630
+ layer_to_quant,
631
+ self.pad_token_id,
632
+ self.forward_batch_size,
633
+ token_gene_dict=self.token_gene_dict,
634
+ summary_stat=None,
635
+ silent=True,
636
+ )
637
+
638
+ # remove overexpressed genes
639
+ if self.perturb_type == "overexpress":
640
+ perturbation_emb = full_perturbation_emb[
641
+ :, len(self.tokens_to_perturb) :, :
642
+ ]
643
+
644
+ elif self.perturb_type == "delete":
645
+ perturbation_emb = full_perturbation_emb[
646
+ :, : max(perturbation_batch["length"]), :
647
+ ]
648
+
649
+ n_perturbation_genes = perturbation_emb.size()[1]
650
+
651
+ # if no goal states, the cosine similarties are the mean of gene cosine similarities
652
+ if (
653
+ self.cell_states_to_model is None
654
+ or self.emb_mode == "cell_and_gene"
655
+ ):
656
+ gene_cos_sims = pu.quant_cos_sims(
657
+ perturbation_emb,
658
+ original_emb,
659
+ self.cell_states_to_model,
660
+ self.state_embs_dict,
661
+ emb_mode="gene",
662
+ )
663
+
664
+ # if there are goal states, the cosine similarities are the cell cosine similarities
665
+ if self.cell_states_to_model is not None:
666
+ original_cell_emb = pu.mean_nonpadding_embs(
667
+ full_original_emb,
668
+ torch.tensor(minibatch["length"], device="cuda"),
669
+ dim=1,
670
+ )
671
+ perturbation_cell_emb = pu.mean_nonpadding_embs(
672
+ full_perturbation_emb,
673
+ torch.tensor(perturbation_batch["length"], device="cuda"),
674
+ dim=1,
675
+ )
676
+ cell_cos_sims = pu.quant_cos_sims(
677
+ perturbation_cell_emb,
678
+ original_cell_emb,
679
+ self.cell_states_to_model,
680
+ self.state_embs_dict,
681
+ emb_mode="cell",
682
+ )
683
+
684
+ # get cosine similarities in gene embeddings
685
+ # if getting gene embeddings, need gene names
686
+ if self.emb_mode == "cell_and_gene":
687
+ gene_list = minibatch["input_ids"]
688
+ # need to truncate gene_list
689
+ gene_list = [
690
+ [g for g in genes if g not in self.tokens_to_perturb][
691
+ :n_perturbation_genes
692
+ ]
693
+ for genes in gene_list
694
+ ]
695
+
696
+ for cell_i, genes in enumerate(gene_list):
697
+ for gene_j, affected_gene in enumerate(genes):
698
+ if len(self.genes_to_perturb) > 1:
699
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
700
+ else:
701
+ tokens_to_perturb = self.tokens_to_perturb[0]
702
+
703
+ # fill in the gene cosine similarities
704
+ try:
705
+ stored_gene_embs_dict[
706
+ (tokens_to_perturb, affected_gene)
707
+ ].append(gene_cos_sims[cell_i, gene_j].item())
708
+ except KeyError:
709
+ stored_gene_embs_dict[
710
+ (tokens_to_perturb, affected_gene)
711
+ ] = gene_cos_sims[cell_i, gene_j].item()
712
+ else:
713
+ gene_list = None
714
+
715
+ if self.cell_states_to_model is None:
716
+ # calculate the mean of the gene cosine similarities for cell shift
717
+ # tensor of nonpadding lengths for each cell
718
+ if self.perturb_type == "overexpress":
719
+ # subtract number of genes that were overexpressed
720
+ # since they are removed before getting cos sims
721
+ n_overexpressed = len(self.tokens_to_perturb)
722
+ nonpadding_lens = [
723
+ x - n_overexpressed for x in perturbation_batch["length"]
724
+ ]
725
+ else:
726
+ nonpadding_lens = perturbation_batch["length"]
727
+ cos_sims_data = pu.mean_nonpadding_embs(
728
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
729
+ )
730
+ cos_sims_dict = self.update_perturbation_dictionary(
731
+ cos_sims_dict,
732
+ cos_sims_data,
733
+ gene_list,
734
+ )
735
+ else:
736
+ cos_sims_data = cell_cos_sims
737
+ for state in cos_sims_dict.keys():
738
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
739
+ cos_sims_dict[state],
740
+ cos_sims_data[state],
741
+ gene_list,
742
+ )
743
+ del minibatch
744
+ del perturbation_batch
745
+ del original_emb
746
+ del perturbation_emb
747
+ del cos_sims_data
748
+
749
+ torch.cuda.empty_cache()
750
+
751
+ pu.write_perturbation_dictionary(
752
+ cos_sims_dict,
753
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
754
+ )
755
+
756
+ if self.emb_mode == "cell_and_gene":
757
+ pu.write_perturbation_dictionary(
758
+ stored_gene_embs_dict,
759
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
760
+ )
761
+
762
+ def isp_perturb_set_special(
763
+ self,
764
+ model,
765
+ filtered_input_data: Dataset,
766
+ layer_to_quant: int,
767
+ output_path_prefix: str,
768
+ ):
769
+ def make_group_perturbation_batch(example):
770
+ example_input_ids = example["input_ids"]
771
+ example["tokens_to_perturb"] = self.tokens_to_perturb
772
+ indices_to_perturb = [
773
+ example_input_ids.index(token) if token in example_input_ids else None
774
+ for token in self.tokens_to_perturb
775
+ ]
776
+ indices_to_perturb = [
777
+ item for item in indices_to_perturb if item is not None
778
+ ]
779
+ if len(indices_to_perturb) > 0:
780
+ example["perturb_index"] = indices_to_perturb
781
+ else:
782
+ # -100 indicates tokens to overexpress are not present in rank value encoding
783
+ example["perturb_index"] = [-100]
784
+ if self.perturb_type == "delete":
785
+ example = pu.delete_indices(example)
786
+ elif self.perturb_type == "overexpress":
787
+ example = pu.overexpress_tokens(
788
+ example, self.max_len, self.special_token
789
+ )
790
+ example["n_overflow"] = pu.calc_n_overflow(
791
+ self.max_len,
792
+ example["length"],
793
+ self.tokens_to_perturb,
794
+ indices_to_perturb,
795
+ )
796
+ return example
797
+
798
+ total_batch_length = len(filtered_input_data)
799
+
800
+
801
+ if self.cell_states_to_model is None:
802
+ cos_sims_dict = defaultdict(list)
803
+ else:
804
+ cos_sims_dict = {
805
+ state: defaultdict(list)
806
+ for state in pu.get_possible_states(self.cell_states_to_model)
807
+ }
808
+
809
+ perturbed_data = filtered_input_data.map(
810
+ make_group_perturbation_batch, num_proc=self.nproc
811
+ )
812
+
813
+ if self.perturb_type == "overexpress":
814
+ filtered_input_data = filtered_input_data.add_column(
815
+ "n_overflow", perturbed_data["n_overflow"]
816
+ )
817
+ filtered_input_data = filtered_input_data.map(
818
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
819
+ )
820
+
821
+ if self.emb_mode == "cls_and_gene":
822
+ stored_gene_embs_dict = defaultdict(list)
823
+
824
+ # iterate through batches
825
+ for i in trange(0, total_batch_length, self.forward_batch_size):
826
+ max_range = min(i + self.forward_batch_size, total_batch_length)
827
+ inds_select = [i for i in range(i, max_range)]
828
+
829
+ minibatch = filtered_input_data.select(inds_select)
830
+ perturbation_batch = perturbed_data.select(inds_select)
831
+
832
+ ##### CLS Embedding Mode #####
833
+ if self.emb_mode == "cls":
834
+ indices_to_perturb = perturbation_batch["perturb_index"]
835
+
836
+ original_cls_emb = get_embs(
837
+ model,
838
+ minibatch,
839
+ "cls",
840
+ layer_to_quant,
841
+ self.pad_token_id,
842
+ self.forward_batch_size,
843
+ token_gene_dict=self.token_gene_dict,
844
+ summary_stat=None,
845
+ silent=True,
846
+ )
847
+
848
+ perturbation_cls_emb = get_embs(
849
+ model,
850
+ perturbation_batch,
851
+ "cls",
852
+ layer_to_quant,
853
+ self.pad_token_id,
854
+ self.forward_batch_size,
855
+ token_gene_dict=self.token_gene_dict,
856
+ summary_stat=None,
857
+ silent=True,
858
+ )
859
+
860
+ # Calculate the cosine similarities
861
+ cls_cos_sims = pu.quant_cos_sims(
862
+ perturbation_cls_emb,
863
+ original_cls_emb,
864
+ self.cell_states_to_model,
865
+ self.state_embs_dict,
866
+ emb_mode="cell",
867
+ )
868
+
869
+ # Update perturbation dictionary
870
+ if self.cell_states_to_model is None:
871
+ cos_sims_dict = self.update_perturbation_dictionary(
872
+ cos_sims_dict,
873
+ cls_cos_sims,
874
+ gene_list=None,
875
+ )
876
+ else:
877
+ for state in cos_sims_dict.keys():
878
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
879
+ cos_sims_dict[state],
880
+ cls_cos_sims[state],
881
+ gene_list=None,
882
+ )
883
+
884
+ ##### CLS and Gene Embedding Mode #####
885
+ elif self.emb_mode == "cls_and_gene":
886
+ full_original_emb = get_embs(
887
+ model,
888
+ minibatch,
889
+ "gene",
890
+ layer_to_quant,
891
+ self.pad_token_id,
892
+ self.forward_batch_size,
893
+ self.token_gene_dict,
894
+ summary_stat=None,
895
+ silent=True,
896
+ )
897
+ indices_to_perturb = perturbation_batch["perturb_index"]
898
+
899
+ # remove indices that were perturbed
900
+ original_emb = pu.remove_perturbed_indices_set(
901
+ full_original_emb,
902
+ self.perturb_type,
903
+ indices_to_perturb,
904
+ self.tokens_to_perturb,
905
+ minibatch["length"],
906
+ )
907
+
908
+ full_perturbation_emb = get_embs(
909
+ model,
910
+ perturbation_batch,
911
+ "gene",
912
+ layer_to_quant,
913
+ self.pad_token_id,
914
+ self.forward_batch_size,
915
+ self.token_gene_dict,
916
+ summary_stat=None,
917
+ silent=True,
918
+ )
919
+
920
+ # remove special tokens and padding
921
+ original_emb = original_emb[:, 1:-1, :]
922
+ if self.perturb_type == "overexpress":
923
+ perturbation_emb = full_perturbation_emb[
924
+ :, 1 + len(self.tokens_to_perturb) : -1, :
925
+ ]
926
+ elif self.perturb_type == "delete":
927
+ perturbation_emb = full_perturbation_emb[
928
+ :, 1 : max(perturbation_batch["length"]) - 1, :
929
+ ]
930
+
931
+ n_perturbation_genes = perturbation_emb.size()[1]
932
+
933
+ # truncate the original embedding as necessary
934
+ if self.perturb_type == "overexpress":
935
+ def calc_perturbation_length(ids):
936
+ if ids == [-100]:
937
+ return 0
938
+ else:
939
+ return len(ids)
940
+
941
+ max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
942
+
943
+ max_n_overflow = max(minibatch["n_overflow"])
944
+ if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
945
+ original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
946
+ elif perturbation_emb.size()[1] < original_emb.size()[1]:
947
+ original_emb = original_emb[:, 0:max_tensor_size, :]
948
+
949
+ gene_cos_sims = pu.quant_cos_sims(
950
+ perturbation_emb,
951
+ original_emb,
952
+ self.cell_states_to_model,
953
+ self.state_embs_dict,
954
+ emb_mode="gene",
955
+ )
956
+
957
+ # get cls emb
958
+ original_cls_emb = full_original_emb[:, 0, :]
959
+ perturbation_cls_emb = full_perturbation_emb[:, 0, :]
960
+
961
+ cls_cos_sims = pu.quant_cos_sims(
962
+ perturbation_cls_emb,
963
+ original_cls_emb,
964
+ self.cell_states_to_model,
965
+ self.state_embs_dict,
966
+ emb_mode="cell",
967
+ )
968
+
969
+ # get cosine similarities in gene embeddings
970
+ # since getting gene embeddings, need gene names
971
+
972
+ gene_list = minibatch["input_ids"]
973
+ # need to truncate gene_list
974
+ genes_to_exclude = self.tokens_to_perturb + [
975
+ self.cls_token_id,
976
+ self.eos_token_id,
977
+ ]
978
+ gene_list = [
979
+ [g for g in genes if g not in genes_to_exclude][
980
+ :n_perturbation_genes
981
+ ]
982
+ for genes in gene_list
983
+ ]
984
+
985
+ for cell_i, genes in enumerate(gene_list):
986
+ for gene_j, affected_gene in enumerate(genes):
987
+ if len(self.genes_to_perturb) > 1:
988
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
989
+ else:
990
+ tokens_to_perturb = self.tokens_to_perturb[0]
991
+
992
+ # fill in the gene cosine similarities
993
+ try:
994
+ stored_gene_embs_dict[
995
+ (tokens_to_perturb, affected_gene)
996
+ ].append(gene_cos_sims[cell_i, gene_j].item())
997
+ except KeyError:
998
+ stored_gene_embs_dict[
999
+ (tokens_to_perturb, affected_gene)
1000
+ ] = gene_cos_sims[cell_i, gene_j].item()
1001
+
1002
+ if self.cell_states_to_model is None:
1003
+ cos_sims_dict = self.update_perturbation_dictionary(
1004
+ cos_sims_dict,
1005
+ cls_cos_sims,
1006
+ gene_list=None,
1007
+ )
1008
+ else:
1009
+ for state in cos_sims_dict.keys():
1010
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1011
+ cos_sims_dict[state],
1012
+ cls_cos_sims[state],
1013
+ gene_list=None,
1014
+ )
1015
+ del full_original_emb
1016
+ del original_emb
1017
+ del full_perturbation_emb
1018
+ del perturbation_emb
1019
+ del gene_cos_sims
1020
+
1021
+ del original_cls_emb
1022
+ del perturbation_cls_emb
1023
+ del cls_cos_sims
1024
+ del minibatch
1025
+ del perturbation_batch
1026
+
1027
+ torch.cuda.empty_cache()
1028
+
1029
+ pu.write_perturbation_dictionary(
1030
+ cos_sims_dict,
1031
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
1032
+ )
1033
+
1034
+ if self.emb_mode == "cls_and_gene":
1035
+ pu.write_perturbation_dictionary(
1036
+ stored_gene_embs_dict,
1037
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1038
+ )
1039
+
1040
+ def isp_perturb_all(
1041
+ self,
1042
+ model,
1043
+ filtered_input_data: Dataset,
1044
+ layer_to_quant: int,
1045
+ output_path_prefix: str,
1046
+ ):
1047
+ pickle_batch = -1
1048
+ if self.cell_states_to_model is None:
1049
+ cos_sims_dict = defaultdict(list)
1050
+ else:
1051
+ cos_sims_dict = {
1052
+ state: defaultdict(list)
1053
+ for state in pu.get_possible_states(self.cell_states_to_model)
1054
+ }
1055
+
1056
+ if self.emb_mode == "cell_and_gene":
1057
+ stored_gene_embs_dict = defaultdict(list)
1058
+
1059
+ num_inds_perturbed = 1 + self.combos
1060
+ for h in trange(len(filtered_input_data)):
1061
+ example_cell = filtered_input_data.select([h])
1062
+ full_original_emb = get_embs(
1063
+ model,
1064
+ example_cell,
1065
+ "gene",
1066
+ layer_to_quant,
1067
+ self.pad_token_id,
1068
+ self.forward_batch_size,
1069
+ self.token_gene_dict,
1070
+ summary_stat=None,
1071
+ silent=True,
1072
+ )
1073
+
1074
+ if self.cell_states_to_model is not None:
1075
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
1076
+ full_original_emb, "mean_pool"
1077
+ )
1078
+
1079
+ # gene_list is used to assign cos sims back to genes
1080
+ gene_list = example_cell["input_ids"][0][:]
1081
+ # need to remove the anchor gene
1082
+ if self.anchor_token is not None:
1083
+ for token in self.anchor_token:
1084
+ gene_list.remove(token)
1085
+ # index 0 is not overexpressed so remove
1086
+ if self.perturb_type == "overexpress":
1087
+ gene_list = gene_list[num_inds_perturbed:]
1088
+ # remove perturbed index for gene list dict
1089
+ perturbed_gene_dict = {
1090
+ gene: gene_list[:i] + gene_list[i + 1 :]
1091
+ for i, gene in enumerate(gene_list)
1092
+ }
1093
+
1094
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1095
+ example_cell,
1096
+ self.perturb_type,
1097
+ self.tokens_to_perturb,
1098
+ self.anchor_token,
1099
+ self.combos,
1100
+ self.nproc,
1101
+ )
1102
+
1103
+ ispall_total_batch_length = len(perturbation_batch)
1104
+ for i in trange(
1105
+ 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1106
+ ):
1107
+ ispall_max_range = min(
1108
+ i + self.forward_batch_size, ispall_total_batch_length
1109
+ )
1110
+ perturbation_minibatch = perturbation_batch.select(
1111
+ [i for i in range(i, ispall_max_range)]
1112
+ )
1113
+ indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1114
+ gene_list_mini = gene_list[
1115
+ i:ispall_max_range
1116
+ ] # only perturbed genes from this minibatch
1117
+
1118
+ full_perturbation_emb = get_embs(
1119
+ model,
1120
+ perturbation_minibatch,
1121
+ "gene",
1122
+ layer_to_quant,
1123
+ self.pad_token_id,
1124
+ self.forward_batch_size,
1125
+ self.token_gene_dict,
1126
+ summary_stat=None,
1127
+ silent=True,
1128
+ )
1129
+
1130
+ del perturbation_minibatch
1131
+
1132
+ # need to remove overexpressed gene to quantify cosine shifts
1133
+ if self.perturb_type == "overexpress":
1134
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1135
+
1136
+ elif self.perturb_type == "delete":
1137
+ perturbation_emb = full_perturbation_emb
1138
+
1139
+ if (
1140
+ self.cell_states_to_model is None
1141
+ or self.emb_mode == "cell_and_gene"
1142
+ ):
1143
+ original_emb_minibatch = pu.make_comparison_batch(
1144
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1145
+ )
1146
+ gene_cos_sims = pu.quant_cos_sims(
1147
+ perturbation_emb,
1148
+ original_emb_minibatch,
1149
+ self.cell_states_to_model,
1150
+ self.state_embs_dict,
1151
+ emb_mode="gene",
1152
+ )
1153
+ del original_emb_minibatch
1154
+
1155
+ if self.cell_states_to_model is not None:
1156
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1157
+ full_perturbation_emb, "mean_pool"
1158
+ )
1159
+
1160
+ cell_cos_sims = pu.quant_cos_sims(
1161
+ perturbation_cell_emb,
1162
+ original_cell_emb,
1163
+ self.cell_states_to_model,
1164
+ self.state_embs_dict,
1165
+ emb_mode="cell",
1166
+ )
1167
+ del perturbation_cell_emb
1168
+
1169
+ if self.emb_mode == "cell_and_gene":
1170
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1171
+ for gene_j, affected_gene in enumerate(
1172
+ perturbed_gene_dict[perturbed_gene]
1173
+ ):
1174
+ try:
1175
+ stored_gene_embs_dict[
1176
+ (perturbed_gene, affected_gene)
1177
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1178
+ except KeyError:
1179
+ stored_gene_embs_dict[
1180
+ (perturbed_gene, affected_gene)
1181
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1182
+
1183
+ del full_perturbation_emb
1184
+
1185
+ if self.cell_states_to_model is None:
1186
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1187
+ cos_sims_dict = self.update_perturbation_dictionary(
1188
+ cos_sims_dict,
1189
+ cos_sims_data,
1190
+ gene_list_mini,
1191
+ )
1192
+ else:
1193
+ cos_sims_data = cell_cos_sims
1194
+ for state in cos_sims_dict.keys():
1195
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1196
+ cos_sims_dict[state],
1197
+ cos_sims_data[state],
1198
+ gene_list_mini,
1199
+ )
1200
+
1201
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1202
+ if i % self.clear_mem_ncells / 10 == 0:
1203
+ pu.write_perturbation_dictionary(
1204
+ cos_sims_dict,
1205
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1206
+ )
1207
+ if self.emb_mode == "cell_and_gene":
1208
+ pu.write_perturbation_dictionary(
1209
+ stored_gene_embs_dict,
1210
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1211
+ )
1212
+
1213
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1214
+ if i % self.clear_mem_ncells == 0:
1215
+ pickle_batch += 1
1216
+ if self.cell_states_to_model is None:
1217
+ cos_sims_dict = defaultdict(list)
1218
+ else:
1219
+ cos_sims_dict = {
1220
+ state: defaultdict(list)
1221
+ for state in pu.get_possible_states(
1222
+ self.cell_states_to_model
1223
+ )
1224
+ }
1225
+
1226
+ if self.emb_mode == "cell_and_gene":
1227
+ stored_gene_embs_dict = defaultdict(list)
1228
+
1229
+ torch.cuda.empty_cache()
1230
+
1231
+ pu.write_perturbation_dictionary(
1232
+ cos_sims_dict,
1233
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1234
+ )
1235
+
1236
+ if self.emb_mode == "cell_and_gene":
1237
+ pu.write_perturbation_dictionary(
1238
+ stored_gene_embs_dict,
1239
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1240
+ )
1241
+
1242
+ pickle_batch = -1
1243
+ if self.cell_states_to_model is None:
1244
+ cos_sims_dict = defaultdict(list)
1245
+ else:
1246
+ cos_sims_dict = {
1247
+ state: defaultdict(list)
1248
+ for state in pu.get_possible_states(self.cell_states_to_model)
1249
+ }
1250
+
1251
+ if self.emb_mode == "cell_and_gene":
1252
+ stored_gene_embs_dict = defaultdict(list)
1253
+
1254
+ # clear memory between cells
1255
+ del perturbation_batch
1256
+ del full_original_emb
1257
+ if self.cell_states_to_model is not None:
1258
+ del original_cell_emb
1259
+ torch.cuda.empty_cache()
1260
+
1261
+ def isp_perturb_all_special(
1262
+ self,
1263
+ model,
1264
+ filtered_input_data: Dataset,
1265
+ layer_to_quant: int,
1266
+ output_path_prefix: str,
1267
+ ):
1268
+ pickle_batch = -1
1269
+ if self.cell_states_to_model is None:
1270
+ cos_sims_dict = defaultdict(list)
1271
+ else:
1272
+ cos_sims_dict = {
1273
+ state: defaultdict(list)
1274
+ for state in pu.get_possible_states(self.cell_states_to_model)
1275
+ }
1276
+
1277
+ if self.emb_mode == "cls_and_gene":
1278
+ stored_gene_embs_dict = defaultdict(list)
1279
+
1280
+ num_inds_perturbed = 1 + self.combos
1281
+ for h in trange(len(filtered_input_data)):
1282
+ example_cell = filtered_input_data.select([h])
1283
+
1284
+ # get original example cell cls and/or gene embs for comparison
1285
+ if self.emb_mode == "cls":
1286
+ original_cls_emb = get_embs(
1287
+ model,
1288
+ example_cell,
1289
+ "cls",
1290
+ layer_to_quant,
1291
+ self.pad_token_id,
1292
+ self.forward_batch_size,
1293
+ self.token_gene_dict,
1294
+ summary_stat=None,
1295
+ silent=True,
1296
+ )
1297
+ elif self.emb_mode == "cls_and_gene":
1298
+ full_original_emb = get_embs(
1299
+ model,
1300
+ example_cell,
1301
+ "gene",
1302
+ layer_to_quant,
1303
+ self.pad_token_id,
1304
+ self.forward_batch_size,
1305
+ self.token_gene_dict,
1306
+ summary_stat=None,
1307
+ silent=True,
1308
+ )
1309
+ original_cls_emb = full_original_emb[:, 0, :].clone().detach()
1310
+
1311
+ # gene_list is used to assign cos sims back to genes
1312
+ gene_list = example_cell["input_ids"][0][:]
1313
+
1314
+ # need to remove special tokens
1315
+ for token in [self.cls_token_id, self.eos_token_id]:
1316
+ gene_list.remove(token)
1317
+ # need to remove the anchor gene
1318
+ if self.anchor_token is not None:
1319
+ for token in self.anchor_token:
1320
+ gene_list.remove(token)
1321
+ # index 0 is not overexpressed so remove
1322
+ if self.perturb_type == "overexpress":
1323
+ gene_list = gene_list[num_inds_perturbed:]
1324
+ # remove perturbed index for gene list dict
1325
+ perturbed_gene_dict = {
1326
+ gene: gene_list[:i] + gene_list[i + 1 :]
1327
+ for i, gene in enumerate(gene_list)
1328
+ }
1329
+
1330
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1331
+ example_cell,
1332
+ self.perturb_type,
1333
+ self.tokens_to_perturb,
1334
+ self.anchor_token,
1335
+ self.combos,
1336
+ self.nproc,
1337
+ )
1338
+
1339
+ ispall_total_batch_length = len(perturbation_batch)
1340
+ for i in trange(
1341
+ 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1342
+ ):
1343
+ ispall_max_range = min(
1344
+ i + self.forward_batch_size, ispall_total_batch_length
1345
+ )
1346
+ perturbation_minibatch = perturbation_batch.select(
1347
+ [i for i in range(i, ispall_max_range)]
1348
+ )
1349
+ indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1350
+ gene_list_mini = gene_list[
1351
+ i:ispall_max_range
1352
+ ] # only perturbed genes from this minibatch
1353
+
1354
+ ##### CLS Embedding Mode #####
1355
+ if self.emb_mode == "cls":
1356
+ # Extract cls embeddings from perturbed cells
1357
+ perturbation_cls_emb = get_embs(
1358
+ model,
1359
+ perturbation_minibatch,
1360
+ "cls",
1361
+ layer_to_quant,
1362
+ self.pad_token_id,
1363
+ self.forward_batch_size,
1364
+ self.token_gene_dict,
1365
+ summary_stat=None,
1366
+ silent=True,
1367
+ )
1368
+
1369
+ # Calculate cosine similarities
1370
+ cls_cos_sims = pu.quant_cos_sims(
1371
+ perturbation_cls_emb,
1372
+ original_cls_emb,
1373
+ self.cell_states_to_model,
1374
+ self.state_embs_dict,
1375
+ emb_mode="cell",
1376
+ )
1377
+
1378
+ if self.cell_states_to_model is None:
1379
+ cos_sims_dict = self.update_perturbation_dictionary(
1380
+ cos_sims_dict,
1381
+ cls_cos_sims,
1382
+ gene_list_mini,
1383
+ )
1384
+ else:
1385
+ for state in cos_sims_dict.keys():
1386
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1387
+ cos_sims_dict[state],
1388
+ cls_cos_sims[state],
1389
+ gene_list_mini,
1390
+ )
1391
+
1392
+ del perturbation_minibatch
1393
+ del perturbation_cls_emb
1394
+ del cls_cos_sims
1395
+
1396
+ ##### CLS and Gene Embedding Mode #####
1397
+ elif self.emb_mode == "cls_and_gene":
1398
+ full_perturbation_emb = get_embs(
1399
+ model,
1400
+ perturbation_minibatch,
1401
+ "gene",
1402
+ layer_to_quant,
1403
+ self.pad_token_id,
1404
+ self.forward_batch_size,
1405
+ self.token_gene_dict,
1406
+ summary_stat=None,
1407
+ silent=True,
1408
+ )
1409
+
1410
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1411
+ if self.perturb_type == "overexpress":
1412
+ perturbation_emb = (
1413
+ full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
1414
+ .clone()
1415
+ .detach()
1416
+ )
1417
+ elif self.perturb_type == "delete":
1418
+ perturbation_emb = (
1419
+ full_perturbation_emb[:, 1:-1, :].clone().detach()
1420
+ )
1421
+
1422
+ original_emb_minibatch = pu.make_comparison_batch(
1423
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1424
+ )
1425
+
1426
+ original_emb_minibatch = (
1427
+ original_emb_minibatch[:, 1:-1, :].clone().detach()
1428
+ )
1429
+ gene_cos_sims = pu.quant_cos_sims(
1430
+ perturbation_emb,
1431
+ original_emb_minibatch,
1432
+ self.cell_states_to_model,
1433
+ self.state_embs_dict,
1434
+ emb_mode="gene",
1435
+ )
1436
+
1437
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1438
+ for gene_j, affected_gene in enumerate(
1439
+ perturbed_gene_dict[perturbed_gene]
1440
+ ):
1441
+ try:
1442
+ stored_gene_embs_dict[
1443
+ (perturbed_gene, affected_gene)
1444
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1445
+ except KeyError:
1446
+ stored_gene_embs_dict[
1447
+ (perturbed_gene, affected_gene)
1448
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1449
+
1450
+ # get cls emb
1451
+ perturbation_cls_emb = (
1452
+ full_perturbation_emb[:, 0, :].clone().detach()
1453
+ )
1454
+
1455
+ cls_cos_sims = pu.quant_cos_sims(
1456
+ perturbation_cls_emb,
1457
+ original_cls_emb,
1458
+ self.cell_states_to_model,
1459
+ self.state_embs_dict,
1460
+ emb_mode="cell",
1461
+ )
1462
+
1463
+ if self.cell_states_to_model is None:
1464
+ cos_sims_dict = self.update_perturbation_dictionary(
1465
+ cos_sims_dict,
1466
+ cls_cos_sims,
1467
+ gene_list_mini,
1468
+ )
1469
+ else:
1470
+ for state in cos_sims_dict.keys():
1471
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1472
+ cos_sims_dict[state],
1473
+ cls_cos_sims[state],
1474
+ gene_list_mini,
1475
+ )
1476
+
1477
+ del perturbation_minibatch
1478
+ del original_emb_minibatch
1479
+ del full_perturbation_emb
1480
+ del perturbation_emb
1481
+ del perturbation_cls_emb
1482
+ del cls_cos_sims
1483
+ del gene_cos_sims
1484
+
1485
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1486
+ if i % max(1, self.clear_mem_ncells / 10) == 0:
1487
+ pu.write_perturbation_dictionary(
1488
+ cos_sims_dict,
1489
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1490
+ )
1491
+ if self.emb_mode == "cls_and_gene":
1492
+ pu.write_perturbation_dictionary(
1493
+ stored_gene_embs_dict,
1494
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1495
+ )
1496
+
1497
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1498
+ if i % self.clear_mem_ncells == 0:
1499
+ pickle_batch += 1
1500
+ if self.cell_states_to_model is None:
1501
+ cos_sims_dict = defaultdict(list)
1502
+ else:
1503
+ cos_sims_dict = {
1504
+ state: defaultdict(list)
1505
+ for state in pu.get_possible_states(
1506
+ self.cell_states_to_model
1507
+ )
1508
+ }
1509
+
1510
+ if self.emb_mode == "cls_and_gene":
1511
+ stored_gene_embs_dict = defaultdict(list)
1512
+
1513
+ torch.cuda.empty_cache()
1514
+
1515
+ pu.write_perturbation_dictionary(
1516
+ cos_sims_dict,
1517
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1518
+ )
1519
+
1520
+ if self.emb_mode == "cls_and_gene":
1521
+ pu.write_perturbation_dictionary(
1522
+ stored_gene_embs_dict,
1523
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1524
+ )
1525
+
1526
+ pickle_batch = -1
1527
+ if self.cell_states_to_model is None:
1528
+ cos_sims_dict = defaultdict(list)
1529
+ else:
1530
+ cos_sims_dict = {
1531
+ state: defaultdict(list)
1532
+ for state in pu.get_possible_states(self.cell_states_to_model)
1533
+ }
1534
+
1535
+ if self.emb_mode == "cls_and_gene":
1536
+ stored_gene_embs_dict = defaultdict(list)
1537
+
1538
+ # clear memory between cells
1539
+ del perturbation_batch
1540
+ del original_cls_emb
1541
+ if self.emb_mode == "cls_and_gene":
1542
+ del full_original_emb
1543
+ torch.cuda.empty_cache()
1544
+
1545
+ def update_perturbation_dictionary(
1546
+ self,
1547
+ cos_sims_dict: defaultdict,
1548
+ cos_sims_data: torch.Tensor,
1549
+ gene_list=None,
1550
+ ):
1551
+ if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1552
+ logger.error(
1553
+ f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1554
+ {cos_sims_data.shape[0]=}.\n \
1555
+ {len(gene_list)=}."
1556
+ )
1557
+ raise
1558
+
1559
+ if self.perturb_group is True:
1560
+ if len(self.tokens_to_perturb) > 1:
1561
+ perturbed_genes = tuple(self.tokens_to_perturb)
1562
+ else:
1563
+ perturbed_genes = self.tokens_to_perturb[0]
1564
+
1565
+ # if cell embeddings, can just append
1566
+ # shape will be (batch size, 1)
1567
+ cos_sims_data = torch.squeeze(cos_sims_data).tolist()
1568
+
1569
+ # handle case of single cell left
1570
+ if not isinstance(cos_sims_data, list):
1571
+ cos_sims_data = [cos_sims_data]
1572
+
1573
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
1574
+
1575
+ else:
1576
+ for i, cos in enumerate(cos_sims_data.tolist()):
1577
+ cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1578
+
1579
+ return cos_sims_dict
geneformer/in_silico_perturber_stats.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber stats generator.
3
+
4
+ **Usage:**
5
+
6
+ .. code-block :: python
7
+
8
+ >>> from geneformer import InSilicoPerturberStats
9
+ >>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
10
+ ... cell_states_to_model={"state_key": "disease",
11
+ ... "start_state": "dcm",
12
+ ... "goal_state": "nf",
13
+ ... "alt_states": ["hcm", "other1", "other2"]})
14
+ >>> ispstats.get_stats("path/to/input_data",
15
+ ... None,
16
+ ... "path/to/output_directory",
17
+ ... "output_prefix")
18
+
19
+ **Description:**
20
+
21
+ | Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
22
+ | Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
23
+
24
+ """
25
+
26
+
27
+ import logging
28
+ import os
29
+ import pickle
30
+ import random
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import pandas as pd
35
+ import statsmodels.stats.multitest as smt
36
+ from scipy.stats import ranksums
37
+ from sklearn.mixture import GaussianMixture
38
+ from tqdm.auto import tqdm, trange
39
+
40
+ from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
41
+ from .perturber_utils import flatten_list, validate_cell_states_to_model
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # invert dictionary keys/values
47
+ def invert_dict(dictionary):
48
+ return {v: k for k, v in dictionary.items()}
49
+
50
+
51
+ def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
52
+ if cell_or_gene_emb == "cell":
53
+ cell_emb_dict = {
54
+ k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
55
+ }
56
+ return [cell_emb_dict]
57
+ elif cell_or_gene_emb == "gene":
58
+ if anchor_token is None:
59
+ gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
60
+ else:
61
+ gene_emb_dict = {
62
+ k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
63
+ }
64
+ return [gene_emb_dict]
65
+
66
+
67
+ # read raw dictionary files
68
+ def read_dictionaries(
69
+ input_data_directory,
70
+ cell_or_gene_emb,
71
+ anchor_token,
72
+ cell_states_to_model,
73
+ pickle_suffix,
74
+ ):
75
+ file_found = False
76
+ file_path_list = []
77
+ if cell_states_to_model is None:
78
+ dict_list = []
79
+ else:
80
+ validate_cell_states_to_model(cell_states_to_model)
81
+ cell_states_to_model_valid = {
82
+ state: value
83
+ for state, value in cell_states_to_model.items()
84
+ if state != "state_key"
85
+ and cell_states_to_model[state] is not None
86
+ and cell_states_to_model[state] != []
87
+ }
88
+ cell_states_list = []
89
+ # flatten all state values into list
90
+ for state in cell_states_to_model_valid:
91
+ value = cell_states_to_model_valid[state]
92
+ if isinstance(value, list):
93
+ cell_states_list += value
94
+ else:
95
+ cell_states_list.append(value)
96
+ state_dict = {state_value: dict() for state_value in cell_states_list}
97
+ for file in os.listdir(input_data_directory):
98
+ # process only files with given suffix (e.g. "_raw.pickle")
99
+ if file.endswith(pickle_suffix):
100
+ file_found = True
101
+ file_path_list += [f"{input_data_directory}/{file}"]
102
+ for file_path in tqdm(file_path_list):
103
+ with open(file_path, "rb") as fp:
104
+ cos_sims_dict = pickle.load(fp)
105
+ if cell_states_to_model is None:
106
+ dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
107
+ else:
108
+ for state_value in cell_states_list:
109
+ new_dict = read_dict(
110
+ cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
111
+ )[0]
112
+ for key in new_dict:
113
+ try:
114
+ state_dict[state_value][key] += new_dict[key]
115
+ except KeyError:
116
+ state_dict[state_value][key] = new_dict[key]
117
+
118
+ if not file_found:
119
+ logger.error(
120
+ "No raw data for processing found within provided directory. "
121
+ "Please ensure data files end with '{pickle_suffix}'."
122
+ )
123
+ raise
124
+ if cell_states_to_model is None:
125
+ return dict_list
126
+ else:
127
+ return state_dict
128
+
129
+
130
+ # get complete gene list
131
+ def get_gene_list(dict_list, mode):
132
+ if mode == "cell":
133
+ position = 0
134
+ elif mode == "gene":
135
+ position = 1
136
+ gene_set = set()
137
+ if isinstance(dict_list, list):
138
+ for dict_i in dict_list:
139
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
140
+ elif isinstance(dict_list, dict):
141
+ for state, dict_i in dict_list.items():
142
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
143
+ else:
144
+ logger.error(
145
+ "dict_list should be a list, or if modeling shift to goal states, a dict. "
146
+ f"{type(dict_list)} is not the correct format."
147
+ )
148
+ raise
149
+ gene_list = list(gene_set)
150
+ if mode == "gene":
151
+ gene_list.remove("cell_emb")
152
+ gene_list.sort()
153
+ return gene_list
154
+
155
+
156
+ def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
157
+ try:
158
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
159
+ except TypeError:
160
+ return gene_token_id_dict.get(token_tuple, np.nan)
161
+
162
+
163
+ def n_detections(token, dict_list, mode, anchor_token):
164
+ cos_sim_megalist = []
165
+ for dict_i in dict_list:
166
+ if mode == "cell":
167
+ cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
168
+ elif mode == "gene":
169
+ cos_sim_megalist += dict_i.get((anchor_token, token), [])
170
+ return len(cos_sim_megalist)
171
+
172
+
173
+ def get_fdr(pvalues):
174
+ return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
175
+
176
+
177
+ def get_impact_component(test_value, gaussian_mixture_model):
178
+ impact_border = gaussian_mixture_model.means_[0][0]
179
+ nonimpact_border = gaussian_mixture_model.means_[1][0]
180
+ if test_value > nonimpact_border:
181
+ impact_component = 0
182
+ elif test_value < impact_border:
183
+ impact_component = 1
184
+ else:
185
+ impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
186
+ if impact_component_raw == 1:
187
+ impact_component = 0
188
+ elif impact_component_raw == 0:
189
+ impact_component = 1
190
+ return impact_component
191
+
192
+
193
+ # aggregate data for single perturbation in multiple cells
194
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
195
+ names = ["Cosine_sim", "Gene"]
196
+ cos_sims_full_dfs = []
197
+ if isinstance(genes_perturbed, list):
198
+ if len(genes_perturbed) > 1:
199
+ gene_ids_df = cos_sims_df.loc[
200
+ np.isin(
201
+ [set(idx) for idx in cos_sims_df["Ensembl_ID"]],
202
+ set(genes_perturbed),
203
+ ),
204
+ :,
205
+ ]
206
+ else:
207
+ gene_ids_df = cos_sims_df.loc[
208
+ np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
209
+ ]
210
+ else:
211
+ logger.error(
212
+ "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
213
+ )
214
+ raise
215
+
216
+ if gene_ids_df.empty:
217
+ logger.error("genes_to_perturb not found in data.")
218
+ raise
219
+
220
+ tokens = gene_ids_df["Gene"]
221
+ symbols = gene_ids_df["Gene_name"]
222
+
223
+ for token, symbol in zip(tokens, symbols):
224
+ cos_shift_data = []
225
+ for dict_i in dict_list:
226
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
227
+
228
+ df = pd.DataFrame(columns=names)
229
+ df["Cosine_sim"] = cos_shift_data
230
+ df["Gene"] = symbol
231
+ cos_sims_full_dfs.append(df)
232
+
233
+ return pd.concat(cos_sims_full_dfs)
234
+
235
+
236
+ def find(variable, x):
237
+ try:
238
+ if x in variable: # Test if variable is iterable and contains x
239
+ return True
240
+ elif x == variable:
241
+ return True
242
+ except (ValueError, TypeError):
243
+ return x == variable # Test if variable is x if non-iterable
244
+
245
+
246
+ def isp_aggregate_gene_shifts(
247
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
248
+ ):
249
+ cos_shift_data = dict()
250
+ for i in trange(cos_sims_df.shape[0]):
251
+ token = cos_sims_df["Gene"][i]
252
+ for dict_i in dict_list:
253
+ if token_dtype == "nontuple":
254
+ affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
255
+ else:
256
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
257
+ for key in affected_pairs:
258
+ if key in cos_shift_data.keys():
259
+ cos_shift_data[key] += dict_i.get(key, [])
260
+ else:
261
+ cos_shift_data[key] = dict_i.get(key, [])
262
+
263
+ cos_data_mean = {
264
+ k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
265
+ }
266
+ cos_sims_full_df = pd.DataFrame()
267
+ cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
268
+ cos_sims_full_df["Gene_name"] = [
269
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
270
+ for k, v in cos_data_mean.items()
271
+ ]
272
+ cos_sims_full_df["Ensembl_ID"] = [
273
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
274
+ for k, v in cos_data_mean.items()
275
+ ]
276
+
277
+ cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
278
+ cos_sims_full_df["Affected_gene_name"] = [
279
+ gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
280
+ for token in cos_sims_full_df["Affected"]
281
+ ]
282
+ cos_sims_full_df["Affected_Ensembl_ID"] = [
283
+ gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
284
+ ]
285
+ cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
286
+ cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
287
+ cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
288
+
289
+ specific_val = "cell_emb"
290
+ cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
291
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
292
+ cos_sims_full_df = cos_sims_full_df.sort_values(
293
+ by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
294
+ ).drop("temp", axis=1)
295
+
296
+ return cos_sims_full_df
297
+
298
+
299
+ # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
300
+ def isp_stats_to_goal_state(
301
+ cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
302
+ ):
303
+ if (
304
+ ("alt_states" not in cell_states_to_model.keys())
305
+ or (len(cell_states_to_model["alt_states"]) == 0)
306
+ or (cell_states_to_model["alt_states"] == [None])
307
+ ):
308
+ alt_end_state_exists = False
309
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (
310
+ cell_states_to_model["alt_states"] != [None]
311
+ ):
312
+ alt_end_state_exists = True
313
+
314
+ # for single perturbation in multiple cells, there are no random perturbations to compare to
315
+ if genes_perturbed != "all":
316
+ cos_sims_full_df = pd.DataFrame()
317
+
318
+ cos_shift_data_end = []
319
+ token = cos_sims_df["Gene"][0]
320
+ cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
321
+ (token, "cell_emb"), []
322
+ )
323
+ cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
324
+ if alt_end_state_exists is True:
325
+ for alt_state in cell_states_to_model["alt_states"]:
326
+ cos_shift_data_alt_state = []
327
+ cos_shift_data_alt_state += result_dict.get(alt_state).get(
328
+ (token, "cell_emb"), []
329
+ )
330
+ cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
331
+ np.mean(cos_shift_data_alt_state)
332
+ ]
333
+
334
+ # sort by shift to desired state
335
+ cos_sims_full_df = cos_sims_full_df.sort_values(
336
+ by=["Shift_to_goal_end"], ascending=[False]
337
+ )
338
+ return cos_sims_full_df
339
+
340
+ elif genes_perturbed == "all":
341
+ goal_end_random_megalist = []
342
+ if alt_end_state_exists is True:
343
+ alt_end_state_random_dict = {
344
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
345
+ }
346
+ for i in trange(cos_sims_df.shape[0]):
347
+ token = cos_sims_df["Gene"][i]
348
+ goal_end_random_megalist += result_dict[
349
+ cell_states_to_model["goal_state"]
350
+ ].get((token, "cell_emb"), [])
351
+ if alt_end_state_exists is True:
352
+ for alt_state in cell_states_to_model["alt_states"]:
353
+ alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
354
+ (token, "cell_emb"), []
355
+ )
356
+
357
+ # downsample to improve speed of ranksums
358
+ if len(goal_end_random_megalist) > 100_000:
359
+ random.seed(42)
360
+ goal_end_random_megalist = random.sample(
361
+ goal_end_random_megalist, k=100_000
362
+ )
363
+ if alt_end_state_exists is True:
364
+ for alt_state in cell_states_to_model["alt_states"]:
365
+ if len(alt_end_state_random_dict[alt_state]) > 100_000:
366
+ random.seed(42)
367
+ alt_end_state_random_dict[alt_state] = random.sample(
368
+ alt_end_state_random_dict[alt_state], k=100_000
369
+ )
370
+
371
+ names = [
372
+ "Gene",
373
+ "Gene_name",
374
+ "Ensembl_ID",
375
+ "Shift_to_goal_end",
376
+ "Goal_end_vs_random_pval",
377
+ ]
378
+ if alt_end_state_exists is True:
379
+ [
380
+ names.append(f"Shift_to_alt_end_{alt_state}")
381
+ for alt_state in cell_states_to_model["alt_states"]
382
+ ]
383
+ names.append(names.pop(names.index("Goal_end_vs_random_pval")))
384
+ [
385
+ names.append(f"Alt_end_vs_random_pval_{alt_state}")
386
+ for alt_state in cell_states_to_model["alt_states"]
387
+ ]
388
+ cos_sims_full_df = pd.DataFrame(columns=names)
389
+
390
+ n_detections_dict = dict()
391
+ for i in trange(cos_sims_df.shape[0]):
392
+ token = cos_sims_df["Gene"][i]
393
+ name = cos_sims_df["Gene_name"][i]
394
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
395
+ goal_end_cos_sim_megalist = result_dict[
396
+ cell_states_to_model["goal_state"]
397
+ ].get((token, "cell_emb"), [])
398
+ n_detections_dict[token] = len(goal_end_cos_sim_megalist)
399
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
400
+ pval_goal_end = ranksums(
401
+ goal_end_random_megalist, goal_end_cos_sim_megalist
402
+ ).pvalue
403
+
404
+ if alt_end_state_exists is True:
405
+ alt_end_state_dict = {
406
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
407
+ }
408
+ for alt_state in cell_states_to_model["alt_states"]:
409
+ alt_end_state_dict[alt_state] = result_dict[alt_state].get(
410
+ (token, "cell_emb"), []
411
+ )
412
+ alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
413
+ alt_end_state_dict[alt_state]
414
+ )
415
+ alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
416
+ alt_end_state_random_dict[alt_state],
417
+ alt_end_state_dict[alt_state],
418
+ ).pvalue
419
+
420
+ results_dict = dict()
421
+ results_dict["Gene"] = token
422
+ results_dict["Gene_name"] = name
423
+ results_dict["Ensembl_ID"] = ensembl_id
424
+ results_dict["Shift_to_goal_end"] = mean_goal_end
425
+ results_dict["Goal_end_vs_random_pval"] = pval_goal_end
426
+ if alt_end_state_exists is True:
427
+ for alt_state in cell_states_to_model["alt_states"]:
428
+ results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
429
+ f"{alt_state}_mean"
430
+ ]
431
+ results_dict[
432
+ f"Alt_end_vs_random_pval_{alt_state}"
433
+ ] = alt_end_state_dict[f"{alt_state}_pval"]
434
+
435
+ cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
436
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
437
+
438
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(
439
+ list(cos_sims_full_df["Goal_end_vs_random_pval"])
440
+ )
441
+ if alt_end_state_exists is True:
442
+ for alt_state in cell_states_to_model["alt_states"]:
443
+ cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
444
+ list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
445
+ )
446
+
447
+ # quantify number of detections of each gene
448
+ cos_sims_full_df["N_Detections"] = [
449
+ n_detections_dict[token] for token in cos_sims_full_df["Gene"]
450
+ ]
451
+
452
+ # sort by shift to desired state
453
+ cos_sims_full_df["Sig"] = [
454
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
455
+ ]
456
+ cos_sims_full_df = cos_sims_full_df.sort_values(
457
+ by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
458
+ ascending=[False, False, True],
459
+ )
460
+
461
+ return cos_sims_full_df
462
+
463
+
464
+ # stats comparing cos sim shifts of test perturbations vs null distribution
465
+ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
466
+ cos_sims_full_df = cos_sims_df.copy()
467
+
468
+ cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
469
+ cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
470
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
471
+ cos_sims_df.shape[0], dtype=float
472
+ )
473
+ cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
474
+ cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
475
+ cos_sims_full_df["N_Detections_test"] = np.zeros(
476
+ cos_sims_df.shape[0], dtype="uint32"
477
+ )
478
+ cos_sims_full_df["N_Detections_null"] = np.zeros(
479
+ cos_sims_df.shape[0], dtype="uint32"
480
+ )
481
+
482
+ for i in trange(cos_sims_df.shape[0]):
483
+ token = cos_sims_df["Gene"][i]
484
+ test_shifts = []
485
+ null_shifts = []
486
+
487
+ for dict_i in dict_list:
488
+ test_shifts += dict_i.get((token, "cell_emb"), [])
489
+
490
+ for dict_i in null_dict_list:
491
+ null_shifts += dict_i.get((token, "cell_emb"), [])
492
+
493
+ cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
494
+ cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
495
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
496
+ test_shifts
497
+ ) - np.mean(null_shifts)
498
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
499
+ test_shifts, null_shifts, nan_policy="omit"
500
+ ).pvalue
501
+ # remove nan values
502
+ cos_sims_full_df.Test_vs_null_pval = np.where(
503
+ np.isnan(cos_sims_full_df.Test_vs_null_pval),
504
+ 1,
505
+ cos_sims_full_df.Test_vs_null_pval,
506
+ )
507
+ cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
508
+ cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
509
+
510
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
511
+ cos_sims_full_df["Test_vs_null_pval"]
512
+ )
513
+
514
+ cos_sims_full_df["Sig"] = [
515
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
516
+ ]
517
+ cos_sims_full_df = cos_sims_full_df.sort_values(
518
+ by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
519
+ ascending=[False, False, True],
520
+ )
521
+ return cos_sims_full_df
522
+
523
+
524
+ # stats for identifying perturbations with largest effect within a given set of cells
525
+ # fits a mixture model to 2 components (impact vs. non-impact) and
526
+ # reports the most likely component for each test perturbation
527
+ # Note: because assumes given perturbation has a consistent effect in the cells tested,
528
+ # we recommend only using the mixture model strategy with uniform cell populations
529
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
530
+ names = ["Gene", "Gene_name", "Ensembl_ID"]
531
+
532
+ if combos == 0:
533
+ names += ["Test_avg_shift"]
534
+ elif combos == 1:
535
+ names += [
536
+ "Anchor_shift",
537
+ "Test_token_shift",
538
+ "Sum_of_indiv_shifts",
539
+ "Combo_shift",
540
+ "Combo_minus_sum_shift",
541
+ ]
542
+
543
+ names += ["Impact_component", "Impact_component_percent"]
544
+
545
+ cos_sims_full_df = pd.DataFrame(columns=names)
546
+ avg_values = []
547
+ gene_names = []
548
+
549
+ for i in trange(cos_sims_df.shape[0]):
550
+ token = cos_sims_df["Gene"][i]
551
+ name = cos_sims_df["Gene_name"][i]
552
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
553
+ cos_shift_data = []
554
+
555
+ for dict_i in dict_list:
556
+ if (combos == 0) and (anchor_token is not None):
557
+ cos_shift_data += dict_i.get((anchor_token, token), [])
558
+ else:
559
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
560
+
561
+ # Extract values for current gene
562
+ if combos == 0:
563
+ test_values = cos_shift_data
564
+ elif combos == 1:
565
+ test_values = []
566
+ for tup in cos_shift_data:
567
+ test_values.append(tup[2])
568
+
569
+ if len(test_values) > 0:
570
+ avg_value = np.mean(test_values)
571
+ avg_values.append(avg_value)
572
+ gene_names.append(name)
573
+
574
+ # fit Gaussian mixture model to dataset of mean for each gene
575
+ avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
576
+ gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
577
+
578
+ for i in trange(cos_sims_df.shape[0]):
579
+ token = cos_sims_df["Gene"][i]
580
+ name = cos_sims_df["Gene_name"][i]
581
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
582
+ cos_shift_data = []
583
+
584
+ for dict_i in dict_list:
585
+ if (combos == 0) and (anchor_token is not None):
586
+ cos_shift_data += dict_i.get((anchor_token, token), [])
587
+ else:
588
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
589
+
590
+ if combos == 0:
591
+ mean_test = np.mean(cos_shift_data)
592
+ impact_components = [
593
+ get_impact_component(value, gm) for value in cos_shift_data
594
+ ]
595
+ elif combos == 1:
596
+ anchor_cos_sim_megalist = [
597
+ anchor for anchor, token, combo in cos_shift_data
598
+ ]
599
+ token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
600
+ anchor_plus_token_cos_sim_megalist = [
601
+ 1 - ((1 - anchor) + (1 - token))
602
+ for anchor, token, combo in cos_shift_data
603
+ ]
604
+ combo_anchor_token_cos_sim_megalist = [
605
+ combo for anchor, token, combo in cos_shift_data
606
+ ]
607
+ combo_minus_sum_cos_sim_megalist = [
608
+ combo - (1 - ((1 - anchor) + (1 - token)))
609
+ for anchor, token, combo in cos_shift_data
610
+ ]
611
+
612
+ mean_anchor = np.mean(anchor_cos_sim_megalist)
613
+ mean_token = np.mean(token_cos_sim_megalist)
614
+ mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
615
+ mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
616
+ mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
617
+
618
+ impact_components = [
619
+ get_impact_component(value, gm)
620
+ for value in combo_anchor_token_cos_sim_megalist
621
+ ]
622
+
623
+ impact_component = get_impact_component(mean_test, gm)
624
+ impact_component_percent = np.mean(impact_components) * 100
625
+
626
+ data_i = [token, name, ensembl_id]
627
+ if combos == 0:
628
+ data_i += [mean_test]
629
+ elif combos == 1:
630
+ data_i += [
631
+ mean_anchor,
632
+ mean_token,
633
+ mean_sum,
634
+ mean_test,
635
+ mean_combo_minus_sum,
636
+ ]
637
+ data_i += [impact_component, impact_component_percent]
638
+
639
+ cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
640
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
+
642
+ # quantify number of detections of each gene
643
+ if anchor_token is None:
644
+ cos_sims_full_df["N_Detections"] = [
645
+ n_detections(i, dict_list, "cell", anchor_token)
646
+ for i in cos_sims_full_df["Gene"]
647
+ ]
648
+ else:
649
+ cos_sims_full_df["N_Detections"] = [
650
+ n_detections(i, dict_list, "gene", anchor_token)
651
+ for i in cos_sims_full_df["Gene"]
652
+ ]
653
+
654
+ if combos == 0:
655
+ cos_sims_full_df = cos_sims_full_df.sort_values(
656
+ by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
657
+ )
658
+ elif combos == 1:
659
+ cos_sims_full_df = cos_sims_full_df.sort_values(
660
+ by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
661
+ )
662
+ return cos_sims_full_df
663
+
664
+
665
+ class InSilicoPerturberStats:
666
+ valid_option_dict = {
667
+ "mode": {
668
+ "goal_state_shift",
669
+ "vs_null",
670
+ "mixture_model",
671
+ "aggregate_data",
672
+ "aggregate_gene_shifts",
673
+ },
674
+ "genes_perturbed": {"all", list},
675
+ "combos": {0, 1},
676
+ "anchor_gene": {None, str},
677
+ "cell_states_to_model": {None, dict},
678
+ "pickle_suffix": {None, str},
679
+ }
680
+
681
+ def __init__(
682
+ self,
683
+ mode="mixture_model",
684
+ genes_perturbed="all",
685
+ combos=0,
686
+ anchor_gene=None,
687
+ cell_states_to_model=None,
688
+ pickle_suffix="_raw.pickle",
689
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
690
+ gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
691
+ ):
692
+ """
693
+ Initialize in silico perturber stats generator.
694
+
695
+ **Parameters:**
696
+
697
+ mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
698
+ | Type of stats.
699
+ | "goal_state_shift": perturbation vs. random for desired cell state shift
700
+ | "vs_null": perturbation vs. null from provided null distribution dataset
701
+ | "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
702
+ | "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
703
+ | "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
704
+ genes_perturbed : "all", list
705
+ | Genes perturbed in isp experiment.
706
+ | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
707
+ | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
708
+ combos : {0,1,2}
709
+ | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
710
+ anchor_gene : None, str
711
+ | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
712
+ | For example, if combos=1 and anchor_gene="ENSG00000136574":
713
+ | analyzes data for anchor gene perturbed in combination with each other gene.
714
+ | However, if combos=0 and anchor_gene="ENSG00000136574":
715
+ | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
716
+ cell_states_to_model: None, dict
717
+ | Cell states to model if testing perturbations that achieve goal state change.
718
+ | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
719
+ | state_key: key specifying name of column in .dataset that defines the start/goal states
720
+ | start_state: value in the state_key column that specifies the start state
721
+ | goal_state: value in the state_key column taht specifies the goal end state
722
+ | alt_states: list of values in the state_key column that specify the alternate end states
723
+ | For example: {"state_key": "disease",
724
+ | "start_state": "dcm",
725
+ | "goal_state": "nf",
726
+ | "alt_states": ["hcm", "other1", "other2"]}
727
+ token_dictionary_file : Path
728
+ | Path to pickle file containing token dictionary (Ensembl ID:token).
729
+ gene_name_id_dictionary_file : Path
730
+ | Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
731
+ """
732
+
733
+ self.mode = mode
734
+ self.genes_perturbed = genes_perturbed
735
+ self.combos = combos
736
+ self.anchor_gene = anchor_gene
737
+ self.cell_states_to_model = cell_states_to_model
738
+ self.pickle_suffix = pickle_suffix
739
+
740
+ self.validate_options()
741
+
742
+ # load token dictionary (Ensembl IDs:token)
743
+ with open(token_dictionary_file, "rb") as f:
744
+ self.gene_token_dict = pickle.load(f)
745
+
746
+ # load gene name dictionary (gene name:Ensembl ID)
747
+ with open(gene_name_id_dictionary_file, "rb") as f:
748
+ self.gene_name_id_dict = pickle.load(f)
749
+
750
+ if anchor_gene is None:
751
+ self.anchor_token = None
752
+ else:
753
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
754
+
755
+ def validate_options(self):
756
+ for attr_name, valid_options in self.valid_option_dict.items():
757
+ attr_value = self.__dict__[attr_name]
758
+ if type(attr_value) not in {list, dict}:
759
+ if attr_name in {"anchor_gene"}:
760
+ continue
761
+ elif attr_value in valid_options:
762
+ continue
763
+ valid_type = False
764
+ for option in valid_options:
765
+ if (option in [str, int, list, dict]) and isinstance(
766
+ attr_value, option
767
+ ):
768
+ valid_type = True
769
+ break
770
+ if not valid_type:
771
+ logger.error(
772
+ f"Invalid option for {attr_name}. "
773
+ f"Valid options for {attr_name}: {valid_options}"
774
+ )
775
+ raise
776
+
777
+ if self.cell_states_to_model is not None:
778
+ if len(self.cell_states_to_model.items()) == 1:
779
+ logger.warning(
780
+ "The single value dictionary for cell_states_to_model will be "
781
+ "replaced with a dictionary with named keys for start, goal, and alternate states. "
782
+ "Please specify state_key, start_state, goal_state, and alt_states "
783
+ "in the cell_states_to_model dictionary for future use. "
784
+ "For example, cell_states_to_model={"
785
+ "'state_key': 'disease', "
786
+ "'start_state': 'dcm', "
787
+ "'goal_state': 'nf', "
788
+ "'alt_states': ['hcm', 'other1', 'other2']}"
789
+ )
790
+ for key, value in self.cell_states_to_model.items():
791
+ if (len(value) == 3) and isinstance(value, tuple):
792
+ if (
793
+ isinstance(value[0], list)
794
+ and isinstance(value[1], list)
795
+ and isinstance(value[2], list)
796
+ ):
797
+ if len(value[0]) == 1 and len(value[1]) == 1:
798
+ all_values = value[0] + value[1] + value[2]
799
+ if len(all_values) == len(set(all_values)):
800
+ continue
801
+ # reformat to the new named key format
802
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
803
+ self.cell_states_to_model = {
804
+ "state_key": list(self.cell_states_to_model.keys())[0],
805
+ "start_state": state_values[0][0],
806
+ "goal_state": state_values[1][0],
807
+ "alt_states": state_values[2:][0],
808
+ }
809
+ elif set(self.cell_states_to_model.keys()) == {
810
+ "state_key",
811
+ "start_state",
812
+ "goal_state",
813
+ "alt_states",
814
+ }:
815
+ if (
816
+ (self.cell_states_to_model["state_key"] is None)
817
+ or (self.cell_states_to_model["start_state"] is None)
818
+ or (self.cell_states_to_model["goal_state"] is None)
819
+ ):
820
+ logger.error(
821
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
822
+ )
823
+ raise
824
+
825
+ if (
826
+ self.cell_states_to_model["start_state"]
827
+ == self.cell_states_to_model["goal_state"]
828
+ ):
829
+ logger.error("All states must be unique.")
830
+ raise
831
+
832
+ if self.cell_states_to_model["alt_states"] is not None:
833
+ if not isinstance(self.cell_states_to_model["alt_states"], list):
834
+ logger.error(
835
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
836
+ )
837
+ raise
838
+ if len(self.cell_states_to_model["alt_states"]) != len(
839
+ set(self.cell_states_to_model["alt_states"])
840
+ ):
841
+ logger.error("All states must be unique.")
842
+ raise
843
+
844
+ elif set(self.cell_states_to_model.keys()) == {
845
+ "state_key",
846
+ "start_state",
847
+ "goal_state",
848
+ }:
849
+ self.cell_states_to_model["alt_states"] = []
850
+ else:
851
+ logger.error(
852
+ "cell_states_to_model must only have the following four keys: "
853
+ "'state_key', 'start_state', 'goal_state', 'alt_states'."
854
+ "For example, cell_states_to_model={"
855
+ "'state_key': 'disease', "
856
+ "'start_state': 'dcm', "
857
+ "'goal_state': 'nf', "
858
+ "'alt_states': ['hcm', 'other1', 'other2']}"
859
+ )
860
+ raise
861
+
862
+ if self.anchor_gene is not None:
863
+ self.anchor_gene = None
864
+ logger.warning(
865
+ "anchor_gene set to None. "
866
+ "Currently, anchor gene not available "
867
+ "when modeling multiple cell states."
868
+ )
869
+
870
+ if self.combos > 0:
871
+ if self.anchor_gene is None:
872
+ logger.error(
873
+ "Currently, stats are only supported for combination "
874
+ "in silico perturbation run with anchor gene. Please add "
875
+ "anchor gene when using with combos > 0. "
876
+ )
877
+ raise
878
+
879
+ if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
880
+ logger.error(
881
+ "Mixture model mode requires multiple gene perturbations to fit model "
882
+ "so is incompatible with a single grouped perturbation."
883
+ )
884
+ raise
885
+ if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
886
+ logger.error(
887
+ "Simple data aggregation mode is for single perturbation in multiple cells "
888
+ "so is incompatible with a genes_perturbed being 'all'."
889
+ )
890
+ raise
891
+
892
+ def get_stats(
893
+ self,
894
+ input_data_directory,
895
+ null_dist_data_directory,
896
+ output_directory,
897
+ output_prefix,
898
+ null_dict_list=None,
899
+ ):
900
+ """
901
+ Get stats for in silico perturbation data and save as results in output_directory.
902
+
903
+ **Parameters:**
904
+
905
+ input_data_directory : Path
906
+ | Path to directory containing cos_sim dictionary inputs
907
+ null_dist_data_directory : Path
908
+ | Path to directory containing null distribution cos_sim dictionary inputs
909
+ output_directory : Path
910
+ | Path to directory where perturbation data will be saved as .csv
911
+ output_prefix : str
912
+ | Prefix for output .csv
913
+ null_dict_list: list[dict]
914
+ | List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
915
+
916
+ **Outputs:**
917
+
918
+ Definition of possible columns in .csv output file.
919
+
920
+ | Of note, not all columns will be present in all output files.
921
+ | Some columns are specific to particular perturbation modes.
922
+
923
+ | "Gene": gene token
924
+ | "Gene_name": gene name
925
+ | "Ensembl_ID": gene Ensembl ID
926
+ | "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
927
+ | "Sig": 1 if FDR<0.05, otherwise 0
928
+
929
+ | "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
930
+ | "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
931
+ | "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
932
+ | pvalue compares shift caused by perturbing given gene compared to random genes
933
+ | "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
934
+ | pvalue compares shift caused by perturbing given gene compared to random genes
935
+ | "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
936
+ | "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
937
+
938
+ | "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
939
+ | "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
940
+ | "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
941
+ | (i.e. "Test_avg_shift" minus "Null_avg_shift")
942
+ | "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
943
+ | "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
944
+ | "N_Detections_test": "N_Detections" in cells from test distribution
945
+ | "N_Detections_null": "N_Detections" in cells from null distribution
946
+
947
+ | "Anchor_shift": cosine shift in response to given perturbation of anchor gene
948
+ | "Test_token_shift": cosine shift in response to given perturbation of test gene
949
+ | "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
950
+ | "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
951
+ | "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
952
+ | (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
953
+ | "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
954
+ | 1: within impact component; 0: not within impact component
955
+ | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
956
+
957
+ | In case of aggregating data / gene shifts:
958
+ | "Perturbed": ID(s) of gene(s) being perturbed
959
+ | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
960
+ | "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
961
+ | "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
962
+ """
963
+
964
+ if self.mode not in [
965
+ "goal_state_shift",
966
+ "vs_null",
967
+ "mixture_model",
968
+ "aggregate_data",
969
+ "aggregate_gene_shifts",
970
+ ]:
971
+ logger.error(
972
+ "Currently, only modes available are stats for goal_state_shift, "
973
+ "vs_null (comparing to null distribution), "
974
+ "mixture_model (fitting mixture model for perturbations with or without impact), "
975
+ "and aggregating data for single perturbations or for gene embedding shifts."
976
+ )
977
+ raise
978
+
979
+ self.gene_token_id_dict = invert_dict(self.gene_token_dict)
980
+ self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
981
+
982
+ # obtain total gene list
983
+ if (self.combos == 0) and (self.anchor_token is not None):
984
+ # cos sim data for effect of gene perturbation on the embedding of each other gene
985
+ dict_list = read_dictionaries(
986
+ input_data_directory,
987
+ "gene",
988
+ self.anchor_token,
989
+ self.cell_states_to_model,
990
+ self.pickle_suffix,
991
+ )
992
+ gene_list = get_gene_list(dict_list, "gene")
993
+ elif (
994
+ (self.combos == 0)
995
+ and (self.anchor_token is None)
996
+ and (self.mode == "aggregate_gene_shifts")
997
+ ):
998
+ dict_list = read_dictionaries(
999
+ input_data_directory,
1000
+ "gene",
1001
+ self.anchor_token,
1002
+ self.cell_states_to_model,
1003
+ self.pickle_suffix,
1004
+ )
1005
+ gene_list = get_gene_list(dict_list, "cell")
1006
+ else:
1007
+ # cos sim data for effect of gene perturbation on the embedding of each cell
1008
+ dict_list = read_dictionaries(
1009
+ input_data_directory,
1010
+ "cell",
1011
+ self.anchor_token,
1012
+ self.cell_states_to_model,
1013
+ self.pickle_suffix,
1014
+ )
1015
+ gene_list = get_gene_list(dict_list, "cell")
1016
+
1017
+ # initiate results dataframe
1018
+ cos_sims_df_initial = pd.DataFrame(
1019
+ {
1020
+ "Gene": gene_list,
1021
+ "Gene_name": [self.token_to_gene_name(item) for item in gene_list],
1022
+ "Ensembl_ID": [
1023
+ token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
1024
+ if self.genes_perturbed != "all"
1025
+ else self.gene_token_id_dict[genes[1]]
1026
+ if isinstance(genes, tuple)
1027
+ else self.gene_token_id_dict[genes]
1028
+ for genes in gene_list
1029
+ ],
1030
+ },
1031
+ index=[i for i in range(len(gene_list))],
1032
+ )
1033
+
1034
+ if self.mode == "goal_state_shift":
1035
+ cos_sims_df = isp_stats_to_goal_state(
1036
+ cos_sims_df_initial,
1037
+ dict_list,
1038
+ self.cell_states_to_model,
1039
+ self.genes_perturbed,
1040
+ )
1041
+
1042
+ elif self.mode == "vs_null":
1043
+ if null_dict_list is None:
1044
+ null_dict_list = read_dictionaries(
1045
+ null_dist_data_directory,
1046
+ "cell",
1047
+ self.anchor_token,
1048
+ self.cell_states_to_model,
1049
+ self.pickle_suffix,
1050
+ )
1051
+ cos_sims_df = isp_stats_vs_null(
1052
+ cos_sims_df_initial, dict_list, null_dict_list
1053
+ )
1054
+
1055
+ elif self.mode == "mixture_model":
1056
+ cos_sims_df = isp_stats_mixture_model(
1057
+ cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1058
+ )
1059
+
1060
+ elif self.mode == "aggregate_data":
1061
+ cos_sims_df = isp_aggregate_grouped_perturb(
1062
+ cos_sims_df_initial, dict_list, self.genes_perturbed
1063
+ )
1064
+
1065
+ elif self.mode == "aggregate_gene_shifts":
1066
+ if (self.genes_perturbed == "all") and (self.combos == 0):
1067
+ tuple_types = [
1068
+ True if isinstance(genes, tuple) else False for genes in gene_list
1069
+ ]
1070
+ if all(tuple_types):
1071
+ token_dtype = "tuple"
1072
+ elif not any(tuple_types):
1073
+ token_dtype = "nontuple"
1074
+ else:
1075
+ token_dtype = "mix"
1076
+ else:
1077
+ token_dtype = "mix"
1078
+
1079
+ cos_sims_df = isp_aggregate_gene_shifts(
1080
+ cos_sims_df_initial,
1081
+ dict_list,
1082
+ self.gene_token_id_dict,
1083
+ self.gene_id_name_dict,
1084
+ token_dtype,
1085
+ )
1086
+
1087
+ # save perturbation stats to output_path
1088
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
1089
+ cos_sims_df.to_csv(output_path)
1090
+
1091
+ def token_to_gene_name(self, item):
1092
+ if np.issubdtype(type(item), np.integer):
1093
+ return self.gene_id_name_dict.get(
1094
+ self.gene_token_id_dict.get(item, np.nan), np.nan
1095
+ )
1096
+ if isinstance(item, tuple):
1097
+ return tuple(
1098
+ [
1099
+ self.gene_id_name_dict.get(
1100
+ self.gene_token_id_dict.get(i, np.nan), np.nan
1101
+ )
1102
+ for i in item
1103
+ ]
1104
+ )
geneformer/mtl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ruff: noqa: F401
geneformer/mtl/collators.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ import torch
3
+ import pickle
4
+ from ..collator_for_classification import DataCollatorForGeneClassification
5
+ from .. import TOKEN_DICTIONARY_FILE
6
+
7
+ """Geneformer collator for multi-task cell classification."""
8
+
9
+ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
10
+ class_type = "cell"
11
+
12
+ @staticmethod
13
+ def load_token_dictionary():
14
+ with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
15
+ return pickle.load(f)
16
+
17
+ def __init__(self, *args, **kwargs) -> None:
18
+ # Load the token dictionary
19
+ token_dictionary = self.load_token_dictionary()
20
+ # Use the loaded token dictionary
21
+ super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
22
+
23
+ def _prepare_batch(self, features):
24
+ # Process inputs as usual
25
+ batch = self.tokenizer.pad(
26
+ features,
27
+ class_type=self.class_type,
28
+ padding=self.padding,
29
+ max_length=self.max_length,
30
+ pad_to_multiple_of=self.pad_to_multiple_of,
31
+ return_tensors="pt",
32
+ )
33
+
34
+ # Check if labels are present
35
+ if "label" in features[0]:
36
+ # Initialize labels dictionary for all tasks
37
+ labels = {task: [] for task in features[0]["label"].keys()}
38
+ # Populate labels for each task
39
+ for feature in features:
40
+ for task, label in feature["label"].items():
41
+ labels[task].append(label)
42
+
43
+ # Convert label lists to tensors, handling dictionaries appropriately
44
+ for task in labels:
45
+ if isinstance(labels[task][0], (list, torch.Tensor)):
46
+ dtype = torch.long
47
+ labels[task] = torch.tensor(labels[task], dtype=dtype)
48
+ elif isinstance(labels[task][0], dict):
49
+ # Handle dict specifically if needed
50
+ pass # Resolve nested data structure
51
+
52
+ # Update the batch to include task-specific labels
53
+ batch["labels"] = labels
54
+ else:
55
+ # If no labels are present, create empty labels for all tasks
56
+ batch["labels"] = {
57
+ task: torch.tensor([], dtype=torch.long)
58
+ for task in features[0]["input_ids"].keys()
59
+ }
60
+
61
+ return batch
62
+
63
+ def __call__(self, features):
64
+ batch = self._prepare_batch(features)
65
+ for k, v in batch.items():
66
+ if torch.is_tensor(v):
67
+ batch[k] = v.clone().detach()
68
+ elif isinstance(v, dict):
69
+ # Assuming nested structure needs conversion
70
+ batch[k] = {
71
+ task: torch.tensor(labels, dtype=torch.int64)
72
+ for task, labels in v.items()
73
+ }
74
+ else:
75
+ batch[k] = torch.tensor(v, dtype=torch.int64)
76
+ return batch
geneformer/mtl/data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .collators import DataCollatorForMultitaskCellClassification
3
+ from .imports import *
4
+
5
+ def validate_columns(dataset, required_columns, dataset_type):
6
+ """Ensures required columns are present in the dataset."""
7
+ missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
+ if missing_columns:
9
+ raise KeyError(
10
+ f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
+ f"Available columns: {dataset.column_names}"
12
+ )
13
+
14
+
15
+ def create_label_mappings(dataset, task_to_column):
16
+ """Creates label mappings for the dataset."""
17
+ task_label_mappings = {}
18
+ num_labels_list = []
19
+ for task, column in task_to_column.items():
20
+ unique_values = sorted(set(dataset[column]))
21
+ mapping = {label: idx for idx, label in enumerate(unique_values)}
22
+ task_label_mappings[task] = mapping
23
+ num_labels_list.append(len(unique_values))
24
+ return task_label_mappings, num_labels_list
25
+
26
+
27
+ def save_label_mappings(mappings, path):
28
+ """Saves label mappings to a pickle file."""
29
+ with open(path, "wb") as f:
30
+ pickle.dump(mappings, f)
31
+
32
+
33
+ def load_label_mappings(path):
34
+ """Loads label mappings from a pickle file."""
35
+ with open(path, "rb") as f:
36
+ return pickle.load(f)
37
+
38
+
39
+ def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
+ """Transforms the dataset to the required format."""
41
+ transformed_dataset = []
42
+ cell_id_mapping = {}
43
+
44
+ for idx, record in enumerate(dataset):
45
+ transformed_record = {
46
+ "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
+ "cell_id": idx, # Index-based cell ID
48
+ }
49
+
50
+ if not is_test:
51
+ label_dict = {
52
+ task: task_label_mappings[task][record[column]]
53
+ for task, column in task_to_column.items()
54
+ }
55
+ else:
56
+ label_dict = {task: -1 for task in config["task_names"]}
57
+
58
+ transformed_record["label"] = label_dict
59
+ transformed_dataset.append(transformed_record)
60
+ cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
+
62
+ return transformed_dataset, cell_id_mapping
63
+
64
+
65
+ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
+ """Main function to load and preprocess data."""
67
+ try:
68
+ dataset = load_from_disk(dataset_path)
69
+
70
+ # Setup task and column mappings
71
+ task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
+ task_to_column = dict(zip(task_names, config["task_columns"]))
73
+ config["task_names"] = task_names
74
+
75
+ label_mappings_path = os.path.join(
76
+ config["results_dir"],
77
+ f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
78
+ )
79
+
80
+ if not is_test:
81
+ validate_columns(dataset, task_to_column.values(), dataset_type)
82
+
83
+ # Create and save label mappings
84
+ task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
+ save_label_mappings(task_label_mappings, label_mappings_path)
86
+ else:
87
+ # Load existing mappings for test data
88
+ task_label_mappings = load_label_mappings(label_mappings_path)
89
+ num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
90
+
91
+ # Transform dataset
92
+ transformed_dataset, cell_id_mapping = transform_dataset(
93
+ dataset, task_to_column, task_label_mappings, config, is_test
94
+ )
95
+
96
+ return transformed_dataset, cell_id_mapping, num_labels_list
97
+
98
+ except KeyError as e:
99
+ raise ValueError(f"Configuration error or dataset key missing: {e}")
100
+ except Exception as e:
101
+ raise RuntimeError(f"Error during data loading or preprocessing: {e}")
102
+
103
+
104
+ def preload_and_process_data(config):
105
+ """Preloads and preprocesses train and validation datasets."""
106
+ # Process train data and save mappings
107
+ train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
108
+
109
+ # Process validation data and save mappings
110
+ val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
+
112
+ # Validate that the mappings match
113
+ validate_label_mappings(config)
114
+
115
+ return (*train_data, *val_data[:2]) # Return train and val data along with mappings
116
+
117
+
118
+ def validate_label_mappings(config):
119
+ """Ensures train and validation label mappings are consistent."""
120
+ train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
+ val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
+ train_mappings = load_label_mappings(train_mappings_path)
123
+ val_mappings = load_label_mappings(val_mappings_path)
124
+
125
+ for task_name in config["task_names"]:
126
+ if train_mappings[task_name] != val_mappings[task_name]:
127
+ raise ValueError(
128
+ f"Mismatch in label mappings for task '{task_name}'.\n"
129
+ f"Train Mapping: {train_mappings[task_name]}\n"
130
+ f"Validation Mapping: {val_mappings[task_name]}"
131
+ )
132
+
133
+
134
+ def get_data_loader(preprocessed_dataset, batch_size):
135
+ """Creates a DataLoader with optimal settings."""
136
+ return DataLoader(
137
+ preprocessed_dataset,
138
+ batch_size=batch_size,
139
+ shuffle=True,
140
+ collate_fn=DataCollatorForMultitaskCellClassification(),
141
+ num_workers=os.cpu_count(),
142
+ pin_memory=True,
143
+ )
144
+
145
+
146
+ def preload_data(config):
147
+ """Preprocesses train and validation data for trials."""
148
+ train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
+ val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
+ return train_loader, val_loader
151
+
152
+
153
+ def load_and_preprocess_test_data(config):
154
+ """Loads and preprocesses test data."""
155
+ return load_and_preprocess_data(config["test_path"], config, is_test=True)
156
+
157
+
158
+ def prepare_test_loader(config):
159
+ """Prepares DataLoader for test data."""
160
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
161
+ test_loader = get_data_loader(test_dataset, config["batch_size"])
162
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from .imports import * # noqa # isort:skip
4
+ from .data import prepare_test_loader # noqa # isort:skip
5
+ from .model import GeneformerMultiTask
6
+
7
+
8
+ def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
+ cell_ids = []
12
+
13
+ # # Load task label mappings from pickle file
14
+ # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
+ # task_label_mappings = pickle.load(f)
16
+
17
+ model.eval()
18
+ with torch.no_grad():
19
+ for batch in test_loader:
20
+ input_ids = batch["input_ids"].to(device)
21
+ attention_mask = batch["attention_mask"].to(device)
22
+ _, logits, _ = model(input_ids, attention_mask)
23
+ for sample_idx in range(len(batch["input_ids"])):
24
+ cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
25
+ cell_ids.append(cell_id)
26
+ for i, task_name in enumerate(config["task_names"]):
27
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
28
+ pred_prob = (
29
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
30
+ )
31
+ task_pred_labels[task_name].append(pred_label)
32
+ task_pred_probs[task_name].append(pred_prob)
33
+
34
+ # Save test predictions with cell IDs and probabilities to CSV
35
+ test_results_dir = config["results_dir"]
36
+ os.makedirs(test_results_dir, exist_ok=True)
37
+ test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
38
+
39
+ rows = []
40
+ for sample_idx in range(len(cell_ids)):
41
+ row = {"Cell ID": cell_ids[sample_idx]}
42
+ for task_name in config["task_names"]:
43
+ row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
44
+ row[f"{task_name} Probabilities"] = ",".join(
45
+ map(str, task_pred_probs[task_name][sample_idx])
46
+ )
47
+ rows.append(row)
48
+
49
+ df = pd.DataFrame(rows)
50
+ df.to_csv(test_preds_file, index=False)
51
+ print(f"Test predictions saved to {test_preds_file}")
52
+
53
+
54
+ def load_and_evaluate_test_model(config):
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
57
+ model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
58
+ hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
59
+
60
+ # Load the saved best hyperparameters
61
+ with open(hyperparams_path, "r") as f:
62
+ best_hyperparams = json.load(f)
63
+
64
+ # Extract the task weights if present, otherwise set to None
65
+ task_weights = best_hyperparams.get("task_weights", None)
66
+ normalized_task_weights = task_weights if task_weights else []
67
+
68
+ # Print the loaded hyperparameters
69
+ print("Loaded hyperparameters:")
70
+ for param, value in best_hyperparams.items():
71
+ if param == "task_weights":
72
+ print(f"normalized_task_weights: {value}")
73
+ else:
74
+ print(f"{param}: {value}")
75
+
76
+ best_model_path = os.path.join(model_directory, "pytorch_model.bin")
77
+ best_model = GeneformerMultiTask(
78
+ config["pretrained_path"],
79
+ num_labels_list,
80
+ dropout_rate=best_hyperparams["dropout_rate"],
81
+ use_task_weights=config["use_task_weights"],
82
+ task_weights=normalized_task_weights,
83
+ )
84
+ best_model.load_state_dict(torch.load(best_model_path))
85
+ best_model.to(device)
86
+
87
+ evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
+ print("Evaluation completed.")
geneformer/mtl/imports.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+ import json
4
+ import os
5
+ import pickle
6
+ import sys
7
+ import warnings
8
+ from enum import Enum
9
+ from itertools import chain
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ import numpy as np
13
+ import optuna
14
+ import pandas as pd
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from datasets import load_from_disk
20
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
+ from sklearn.model_selection import train_test_split
22
+ from sklearn.preprocessing import LabelEncoder
23
+ from torch.utils.data import DataLoader
24
+ from transformers import (
25
+ AdamW,
26
+ BatchEncoding,
27
+ BertConfig,
28
+ BertModel,
29
+ DataCollatorForTokenClassification,
30
+ SpecialTokensMixin,
31
+ get_cosine_schedule_with_warmup,
32
+ get_linear_schedule_with_warmup,
33
+ get_scheduler,
34
+ )
35
+ from transformers.utils import logging, to_py_obj
36
+
37
+ from .collators import DataCollatorForMultitaskCellClassification
38
+
39
+ # local modules
40
+ from .data import get_data_loader, preload_and_process_data
41
+ from .model import GeneformerMultiTask
42
+ from .optuna_utils import create_optuna_study
43
+ from .utils import save_model
geneformer/mtl/model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertConfig, BertModel
4
+
5
+
6
+ class AttentionPool(nn.Module):
7
+ """Attention-based pooling layer."""
8
+
9
+ def __init__(self, hidden_size):
10
+ super(AttentionPool, self).__init__()
11
+ self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
12
+ nn.init.xavier_uniform_(
13
+ self.attention_weights
14
+ ) # https://pytorch.org/docs/stable/nn.init.html
15
+
16
+ def forward(self, hidden_states):
17
+ attention_scores = torch.matmul(hidden_states, self.attention_weights)
18
+ attention_scores = torch.softmax(attention_scores, dim=1)
19
+ pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
20
+ return pooled_output
21
+
22
+
23
+ class GeneformerMultiTask(nn.Module):
24
+ def __init__(
25
+ self,
26
+ pretrained_path,
27
+ num_labels_list,
28
+ dropout_rate=0.1,
29
+ use_task_weights=False,
30
+ task_weights=None,
31
+ max_layers_to_freeze=0,
32
+ use_attention_pooling=False,
33
+ ):
34
+ super(GeneformerMultiTask, self).__init__()
35
+ self.config = BertConfig.from_pretrained(pretrained_path)
36
+ self.bert = BertModel(self.config)
37
+ self.num_labels_list = num_labels_list
38
+ self.use_task_weights = use_task_weights
39
+ self.dropout = nn.Dropout(dropout_rate)
40
+ self.use_attention_pooling = use_attention_pooling
41
+
42
+ if use_task_weights and (
43
+ task_weights is None or len(task_weights) != len(num_labels_list)
44
+ ):
45
+ raise ValueError(
46
+ "Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
47
+ )
48
+ self.task_weights = (
49
+ task_weights if use_task_weights else [1.0] * len(num_labels_list)
50
+ )
51
+
52
+ # Freeze the specified initial layers
53
+ for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
54
+ for param in layer.parameters():
55
+ param.requires_grad = False
56
+
57
+ self.attention_pool = (
58
+ AttentionPool(self.config.hidden_size) if use_attention_pooling else None
59
+ )
60
+
61
+ self.classification_heads = nn.ModuleList(
62
+ [
63
+ nn.Linear(self.config.hidden_size, num_labels)
64
+ for num_labels in num_labels_list
65
+ ]
66
+ )
67
+ # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
68
+ for head in self.classification_heads:
69
+ nn.init.xavier_uniform_(head.weight)
70
+ nn.init.zeros_(head.bias)
71
+
72
+ def forward(self, input_ids, attention_mask, labels=None):
73
+ try:
74
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
75
+ except Exception as e:
76
+ raise RuntimeError(f"Error during BERT forward pass: {e}")
77
+
78
+ sequence_output = outputs.last_hidden_state
79
+
80
+ try:
81
+ pooled_output = (
82
+ self.attention_pool(sequence_output)
83
+ if self.use_attention_pooling
84
+ else sequence_output[:, 0, :]
85
+ )
86
+ pooled_output = self.dropout(pooled_output)
87
+ except Exception as e:
88
+ raise RuntimeError(f"Error during pooling and dropout: {e}")
89
+
90
+ total_loss = 0
91
+ logits = []
92
+ losses = []
93
+
94
+ for task_id, (head, num_labels) in enumerate(
95
+ zip(self.classification_heads, self.num_labels_list)
96
+ ):
97
+ try:
98
+ task_logits = head(pooled_output)
99
+ except Exception as e:
100
+ raise RuntimeError(
101
+ f"Error during forward pass of classification head {task_id}: {e}"
102
+ )
103
+
104
+ logits.append(task_logits)
105
+
106
+ if labels is not None:
107
+ try:
108
+ loss_fct = nn.CrossEntropyLoss()
109
+ task_loss = loss_fct(
110
+ task_logits.view(-1, num_labels), labels[task_id].view(-1)
111
+ )
112
+ if self.use_task_weights:
113
+ task_loss *= self.task_weights[task_id]
114
+ total_loss += task_loss
115
+ losses.append(task_loss.item())
116
+ except Exception as e:
117
+ raise RuntimeError(
118
+ f"Error during loss computation for task {task_id}: {e}"
119
+ )
120
+
121
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from optuna.integration import TensorBoardCallback
3
+
4
+
5
+ def save_trial_callback(study, trial, trials_result_path):
6
+ with open(trials_result_path, "a") as f:
7
+ f.write(
8
+ f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
+ )
10
+
11
+
12
+ def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
+ study = optuna.create_study(direction="maximize")
14
+
15
+ # init TensorBoard callback
16
+ tensorboard_callback = TensorBoardCallback(
17
+ dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
+ )
19
+
20
+ # callback and TensorBoard callback
21
+ callbacks = [
22
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
+ tensorboard_callback,
24
+ ]
25
+
26
+ study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
+ return study
geneformer/mtl/train.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ from tqdm import tqdm
9
+
10
+ from .imports import *
11
+ from .model import GeneformerMultiTask
12
+ from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
+
14
+
15
+ def set_seed(seed):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+
23
+
24
+ def initialize_wandb(config):
25
+ if config.get("use_wandb", False):
26
+ import wandb
27
+
28
+ wandb.init(project=config["wandb_project"], config=config)
29
+ print("Weights & Biases (wandb) initialized and will be used for logging.")
30
+ else:
31
+ print(
32
+ "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
+ )
34
+
35
+
36
+ def create_model(config, num_labels_list, device):
37
+ model = GeneformerMultiTask(
38
+ config["pretrained_path"],
39
+ num_labels_list,
40
+ dropout_rate=config["dropout_rate"],
41
+ use_task_weights=config["use_task_weights"],
42
+ task_weights=config["task_weights"],
43
+ max_layers_to_freeze=config["max_layers_to_freeze"],
44
+ use_attention_pooling=config["use_attention_pooling"],
45
+ )
46
+ if config["use_data_parallel"]:
47
+ model = nn.DataParallel(model)
48
+ return model.to(device)
49
+
50
+
51
+ def setup_optimizer_and_scheduler(model, config, total_steps):
52
+ optimizer = AdamW(
53
+ model.parameters(),
54
+ lr=config["learning_rate"],
55
+ weight_decay=config["weight_decay"],
56
+ )
57
+ warmup_steps = int(config["warmup_ratio"] * total_steps)
58
+
59
+ if config["lr_scheduler_type"] == "linear":
60
+ scheduler = get_linear_schedule_with_warmup(
61
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
+ )
63
+ elif config["lr_scheduler_type"] == "cosine":
64
+ scheduler = get_cosine_schedule_with_warmup(
65
+ optimizer,
66
+ num_warmup_steps=warmup_steps,
67
+ num_training_steps=total_steps,
68
+ num_cycles=0.5,
69
+ )
70
+
71
+ return optimizer, scheduler
72
+
73
+
74
+ def train_epoch(
75
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
+ ):
77
+ model.train()
78
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
+ for batch_idx, batch in enumerate(progress_bar):
80
+ optimizer.zero_grad()
81
+ input_ids = batch["input_ids"].to(device)
82
+ attention_mask = batch["attention_mask"].to(device)
83
+ labels = [
84
+ batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
+ ]
86
+
87
+ loss, _, _ = model(input_ids, attention_mask, labels)
88
+ loss.backward()
89
+
90
+ if config["gradient_clipping"]:
91
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
+
93
+ optimizer.step()
94
+ scheduler.step()
95
+
96
+ writer.add_scalar(
97
+ "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
+ )
99
+ if config.get("use_wandb", False):
100
+ import wandb
101
+
102
+ wandb.log({"Training Loss": loss.item()})
103
+
104
+ # Update progress bar
105
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
+
107
+ return loss.item() # Return the last batch loss
108
+
109
+
110
+ def validate_model(model, val_loader, device, config):
111
+ model.eval()
112
+ val_loss = 0.0
113
+ task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
+
117
+ with torch.no_grad():
118
+ for batch in val_loader:
119
+ input_ids = batch["input_ids"].to(device)
120
+ attention_mask = batch["attention_mask"].to(device)
121
+ labels = [
122
+ batch["labels"][task_name].to(device)
123
+ for task_name in config["task_names"]
124
+ ]
125
+ loss, logits, _ = model(input_ids, attention_mask, labels)
126
+ val_loss += loss.item()
127
+
128
+ for sample_idx in range(len(batch["input_ids"])):
129
+ for i, task_name in enumerate(config["task_names"]):
130
+ true_label = batch["labels"][task_name][sample_idx].item()
131
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
+ pred_prob = (
133
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
+ )
135
+ task_true_labels[task_name].append(true_label)
136
+ task_pred_labels[task_name].append(pred_label)
137
+ task_pred_probs[task_name].append(pred_prob)
138
+
139
+ val_loss /= len(val_loader)
140
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
+
142
+
143
+ def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
+ for task_name, metrics in task_metrics.items():
145
+ print(
146
+ f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
+ )
148
+ if config.get("use_wandb", False):
149
+ import wandb
150
+
151
+ wandb.log(
152
+ {
153
+ f"{task_name} Validation F1 Macro": metrics["f1"],
154
+ f"{task_name} Validation Accuracy": metrics["accuracy"],
155
+ }
156
+ )
157
+
158
+ writer.add_scalar("Validation Loss", val_loss, epochs)
159
+ for task_name, metrics in task_metrics.items():
160
+ writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
+ writer.add_scalar(
162
+ f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
163
+ )
164
+
165
+
166
+ def save_validation_predictions(
167
+ val_cell_id_mapping,
168
+ task_true_labels,
169
+ task_pred_labels,
170
+ task_pred_probs,
171
+ config,
172
+ trial_number=None,
173
+ ):
174
+ if trial_number is not None:
175
+ trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
+ os.makedirs(trial_results_dir, exist_ok=True)
177
+ val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
+ else:
179
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
+
181
+ rows = []
182
+ for sample_idx in range(len(val_cell_id_mapping)):
183
+ row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
+ for task_name in config["task_names"]:
185
+ row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
+ row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
+ row[f"{task_name} Probabilities"] = ",".join(
188
+ map(str, task_pred_probs[task_name][sample_idx])
189
+ )
190
+ rows.append(row)
191
+
192
+ df = pd.DataFrame(rows)
193
+ df.to_csv(val_preds_file, index=False)
194
+ print(f"Validation predictions saved to {val_preds_file}")
195
+
196
+
197
+ def train_model(
198
+ config,
199
+ device,
200
+ train_loader,
201
+ val_loader,
202
+ train_cell_id_mapping,
203
+ val_cell_id_mapping,
204
+ num_labels_list,
205
+ ):
206
+ set_seed(config["seed"])
207
+ initialize_wandb(config)
208
+
209
+ model = create_model(config, num_labels_list, device)
210
+ total_steps = len(train_loader) * config["epochs"]
211
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
+
213
+ log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
+ writer = SummaryWriter(log_dir=log_dir)
215
+
216
+ epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
+ for epoch in epoch_progress:
218
+ last_loss = train_epoch(
219
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
220
+ )
221
+ epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
+
223
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
+ model, val_loader, device, config
225
+ )
226
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
+
228
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
+ writer.close()
230
+
231
+ save_validation_predictions(
232
+ val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
+ )
234
+
235
+ if config.get("use_wandb", False):
236
+ import wandb
237
+
238
+ wandb.finish()
239
+
240
+ print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
+ return val_loss, model # Return both the validation loss and the trained model
242
+
243
+
244
+ def objective(
245
+ trial,
246
+ train_loader,
247
+ val_loader,
248
+ train_cell_id_mapping,
249
+ val_cell_id_mapping,
250
+ num_labels_list,
251
+ config,
252
+ device,
253
+ ):
254
+ set_seed(config["seed"]) # Set the seed before each trial
255
+ initialize_wandb(config)
256
+
257
+ # Hyperparameters
258
+ config["learning_rate"] = trial.suggest_float(
259
+ "learning_rate",
260
+ config["hyperparameters"]["learning_rate"]["low"],
261
+ config["hyperparameters"]["learning_rate"]["high"],
262
+ log=config["hyperparameters"]["learning_rate"]["log"],
263
+ )
264
+ config["warmup_ratio"] = trial.suggest_float(
265
+ "warmup_ratio",
266
+ config["hyperparameters"]["warmup_ratio"]["low"],
267
+ config["hyperparameters"]["warmup_ratio"]["high"],
268
+ )
269
+ config["weight_decay"] = trial.suggest_float(
270
+ "weight_decay",
271
+ config["hyperparameters"]["weight_decay"]["low"],
272
+ config["hyperparameters"]["weight_decay"]["high"],
273
+ )
274
+ config["dropout_rate"] = trial.suggest_float(
275
+ "dropout_rate",
276
+ config["hyperparameters"]["dropout_rate"]["low"],
277
+ config["hyperparameters"]["dropout_rate"]["high"],
278
+ )
279
+ config["lr_scheduler_type"] = trial.suggest_categorical(
280
+ "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
+ )
282
+ config["use_attention_pooling"] = trial.suggest_categorical(
283
+ "use_attention_pooling", [False]
284
+ )
285
+
286
+ if config["use_task_weights"]:
287
+ config["task_weights"] = [
288
+ trial.suggest_float(
289
+ f"task_weight_{i}",
290
+ config["hyperparameters"]["task_weights"]["low"],
291
+ config["hyperparameters"]["task_weights"]["high"],
292
+ )
293
+ for i in range(len(num_labels_list))
294
+ ]
295
+ weight_sum = sum(config["task_weights"])
296
+ config["task_weights"] = [
297
+ weight / weight_sum for weight in config["task_weights"]
298
+ ]
299
+ else:
300
+ config["task_weights"] = None
301
+
302
+ # Dynamic range for max_layers_to_freeze
303
+ freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
+ config["max_layers_to_freeze"] = trial.suggest_int(
305
+ "max_layers_to_freeze",
306
+ freeze_range["min"],
307
+ freeze_range["max"]
308
+ )
309
+
310
+ model = create_model(config, num_labels_list, device)
311
+ total_steps = len(train_loader) * config["epochs"]
312
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
313
+
314
+ log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
+ writer = SummaryWriter(log_dir=log_dir)
316
+
317
+ for epoch in range(config["epochs"]):
318
+ train_epoch(
319
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
+ )
321
+
322
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
+ model, val_loader, device, config
324
+ )
325
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
+
327
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
+ writer.close()
329
+
330
+ save_validation_predictions(
331
+ val_cell_id_mapping,
332
+ task_true_labels,
333
+ task_pred_labels,
334
+ task_pred_probs,
335
+ config,
336
+ trial.number,
337
+ )
338
+
339
+ trial.set_user_attr("model_state_dict", model.state_dict())
340
+ trial.set_user_attr("task_weights", config["task_weights"])
341
+
342
+ trial.report(val_loss, config["epochs"])
343
+
344
+ if trial.should_prune():
345
+ raise optuna.TrialPruned()
346
+
347
+ if config.get("use_wandb", False):
348
+ import wandb
349
+
350
+ wandb.log(
351
+ {
352
+ "trial_number": trial.number,
353
+ "val_loss": val_loss,
354
+ **{
355
+ f"{task_name}_f1": metrics["f1"]
356
+ for task_name, metrics in task_metrics.items()
357
+ },
358
+ **{
359
+ f"{task_name}_accuracy": metrics["accuracy"]
360
+ for task_name, metrics in task_metrics.items()
361
+ },
362
+ **{
363
+ k: v
364
+ for k, v in config.items()
365
+ if k
366
+ in [
367
+ "learning_rate",
368
+ "warmup_ratio",
369
+ "weight_decay",
370
+ "dropout_rate",
371
+ "lr_scheduler_type",
372
+ "use_attention_pooling",
373
+ "max_layers_to_freeze",
374
+ ]
375
+ },
376
+ }
377
+ )
378
+ wandb.finish()
379
+
380
+ return val_loss