Alejandro Velez
commited on
Commit
·
47990ca
1
Parent(s):
5f1a697
tdc geneformer
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +25 -0
- .gitignore +160 -0
- .pre-commit-config.yaml +26 -0
- .readthedocs.yaml +19 -0
- MANIFEST.in +4 -0
- README.md +96 -0
- config.json +24 -0
- docs/Makefile +20 -0
- docs/make.bat +35 -0
- docs/requirements.txt +3 -0
- docs/source/_static/css/custom.css +40 -0
- docs/source/_static/gf_logo.png +0 -0
- docs/source/about.rst +49 -0
- docs/source/api.rst +51 -0
- docs/source/conf.py +80 -0
- docs/source/geneformer.classifier.rst +10 -0
- docs/source/geneformer.emb_extractor.rst +26 -0
- docs/source/geneformer.in_silico_perturber.rst +8 -0
- docs/source/geneformer.in_silico_perturber_stats.rst +25 -0
- docs/source/geneformer.mtl_classifier.rst +11 -0
- docs/source/geneformer.tokenizer.rst +15 -0
- docs/source/getstarted.rst +36 -0
- docs/source/index.rst +16 -0
- examples/cell_classification.ipynb +0 -0
- examples/extract_and_plot_cell_embeddings.ipynb +0 -0
- examples/gene_classification.ipynb +0 -0
- examples/in_silico_perturbation.ipynb +159 -0
- examples/multitask_cell_classification.ipynb +420 -0
- examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb +365 -0
- examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +167 -0
- examples/tokenizing_scRNAseq_data.ipynb +91 -0
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +24 -0
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json +35 -0
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json +150 -0
- geneformer/__init__.py +34 -0
- geneformer/classifier.py +1563 -0
- geneformer/classifier_utils.py +648 -0
- geneformer/collator_for_classification.py +667 -0
- geneformer/emb_extractor.py +863 -0
- geneformer/evaluation_utils.py +287 -0
- geneformer/in_silico_perturber.py +1579 -0
- geneformer/in_silico_perturber_stats.py +1104 -0
- geneformer/mtl/__init__.py +1 -0
- geneformer/mtl/collators.py +76 -0
- geneformer/mtl/data.py +162 -0
- geneformer/mtl/eval_utils.py +88 -0
- geneformer/mtl/imports.py +43 -0
- geneformer/mtl/model.py +121 -0
- geneformer/mtl/optuna_utils.py +27 -0
- geneformer/mtl/train.py +380 -0
.gitattributes
CHANGED
@@ -1,27 +1,41 @@
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
11 |
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +47,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
<<<<<<< HEAD
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
=======
|
8 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
10 |
+
>>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
|
11 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
12 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
13 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
14 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
15 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
16 |
+
<<<<<<< HEAD
|
17 |
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
18 |
*.model filter=lfs diff=lfs merge=lfs -text
|
19 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
20 |
*.npy filter=lfs diff=lfs merge=lfs -text
|
21 |
*.npz filter=lfs diff=lfs merge=lfs -text
|
22 |
+
=======
|
23 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
25 |
+
>>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
|
26 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
27 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
28 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
29 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
30 |
+
<<<<<<< HEAD
|
31 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
32 |
+
=======
|
33 |
+
>>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
|
34 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
35 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
36 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
37 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
38 |
+
<<<<<<< HEAD
|
39 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
40 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
41 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
47 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
48 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
49 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
50 |
+
=======
|
51 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
52 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
53 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
54 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
55 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
56 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
57 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
58 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
59 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
60 |
+
>>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pre-commit.com for more information
|
2 |
+
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
+
repos:
|
4 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
+
rev: v3.2.0
|
6 |
+
hooks:
|
7 |
+
- id: trailing-whitespace
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
- id: check-yaml
|
10 |
+
- id: check-added-large-files
|
11 |
+
- id: check-merge-conflict
|
12 |
+
- id: mixed-line-ending
|
13 |
+
- id: check-docstring-first
|
14 |
+
- repo: https://github.com/pycqa/isort
|
15 |
+
rev: 5.12.0
|
16 |
+
hooks:
|
17 |
+
- id: isort
|
18 |
+
args: ["--profile", "black"]
|
19 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
20 |
+
# Ruff version.
|
21 |
+
rev: v0.1.4
|
22 |
+
hooks:
|
23 |
+
# Run the Ruff linter.
|
24 |
+
- id: ruff
|
25 |
+
# Run the Ruff formatter.
|
26 |
+
- id: ruff-format
|
.readthedocs.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the Docs configuration file
|
2 |
+
|
3 |
+
# Required
|
4 |
+
version: 2
|
5 |
+
|
6 |
+
# Set the OS, Python version and other tools you might need
|
7 |
+
build:
|
8 |
+
os: ubuntu-22.04
|
9 |
+
tools:
|
10 |
+
python: "3.10"
|
11 |
+
|
12 |
+
# Build documentation in the "docs/" directory with Sphinx
|
13 |
+
sphinx:
|
14 |
+
configuration: docs/source/conf.py
|
15 |
+
|
16 |
+
# Python requirements required build your documentation
|
17 |
+
python:
|
18 |
+
install:
|
19 |
+
- requirements: docs/requirements.txt
|
MANIFEST.in
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include geneformer/gene_median_dictionary_gc95M.pkl
|
2 |
+
include geneformer/gene_name_id_dict_gc95M.pkl
|
3 |
+
include geneformer/ensembl_mapping_dict_gc95M.pkl
|
4 |
+
include geneformer/token_dictionary_gc95M.pkl
|
README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
---
|
|
|
2 |
license: apache-2.0
|
3 |
datasets:
|
4 |
- ctheodoris/Genecorpus-30M
|
@@ -20,3 +21,98 @@ model = AutoModelForMaskedLM.from_pretrained("ctheodoris/Geneformer")
|
|
20 |
```
|
21 |
|
22 |
For further details see: https://huggingface.co/ctheodoris/Geneformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
<<<<<<< HEAD
|
3 |
license: apache-2.0
|
4 |
datasets:
|
5 |
- ctheodoris/Genecorpus-30M
|
|
|
21 |
```
|
22 |
|
23 |
For further details see: https://huggingface.co/ctheodoris/Geneformer
|
24 |
+
=======
|
25 |
+
datasets: ctheodoris/Genecorpus-30M
|
26 |
+
license: apache-2.0
|
27 |
+
tags:
|
28 |
+
- single-cell
|
29 |
+
- genomics
|
30 |
+
---
|
31 |
+
# Geneformer
|
32 |
+
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
33 |
+
|
34 |
+
- See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
|
35 |
+
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
36 |
+
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
37 |
+
|
38 |
+
# Model Description
|
39 |
+
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
|
40 |
+
|
41 |
+
Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
42 |
+
|
43 |
+
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
44 |
+
|
45 |
+
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
46 |
+
|
47 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
48 |
+
|
49 |
+
The repository includes the following pretrained models:
|
50 |
+
|
51 |
+
L=layers\
|
52 |
+
M=millions of cells used for pretraining\
|
53 |
+
i=input size\
|
54 |
+
(pretraining date)
|
55 |
+
|
56 |
+
- GF-6L-30M-i2048 (June 2021)
|
57 |
+
- GF-12L-30M-i2048 (June 2021)
|
58 |
+
- GF-12L-95M-i4096 (April 2024)
|
59 |
+
- GF-20L-95M-i4096 (April 2024)
|
60 |
+
|
61 |
+
The current default model in the main directory of the repository is GF-12L-95M-i4096.
|
62 |
+
|
63 |
+
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
|
64 |
+
|
65 |
+
# Application
|
66 |
+
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
67 |
+
|
68 |
+
Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
|
69 |
+
|
70 |
+
*Fine-tuning*:
|
71 |
+
- transcription factor dosage sensitivity
|
72 |
+
- chromatin dynamics (bivalently marked promoters)
|
73 |
+
- transcription factor regulatory range
|
74 |
+
- gene network centrality
|
75 |
+
- transcription factor targets
|
76 |
+
- cell type annotation
|
77 |
+
- batch integration
|
78 |
+
- cell state classification across differentiation
|
79 |
+
- disease classification
|
80 |
+
- in silico perturbation to determine disease-driving genes
|
81 |
+
- in silico treatment to determine candidate therapeutic targets
|
82 |
+
|
83 |
+
*Zero-shot learning*:
|
84 |
+
- batch integration
|
85 |
+
- gene context specificity
|
86 |
+
- in silico reprogramming
|
87 |
+
- in silico differentiation
|
88 |
+
- in silico perturbation to determine impact on cell state
|
89 |
+
- in silico perturbation to determine transcription factor targets
|
90 |
+
- in silico perturbation to determine transcription factor cooperativity
|
91 |
+
|
92 |
+
# Installation
|
93 |
+
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install (~20s):
|
94 |
+
|
95 |
+
```bash
|
96 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
97 |
+
git lfs install
|
98 |
+
git clone https://huggingface.co/ctheodoris/Geneformer
|
99 |
+
cd Geneformer
|
100 |
+
pip install .
|
101 |
+
```
|
102 |
+
|
103 |
+
For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main/examples) for:
|
104 |
+
- tokenizing transcriptomes
|
105 |
+
- pretraining
|
106 |
+
- hyperparameter tuning
|
107 |
+
- fine-tuning
|
108 |
+
- extracting and plotting cell embeddings
|
109 |
+
- in silico perturbation
|
110 |
+
|
111 |
+
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
112 |
+
|
113 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
114 |
+
|
115 |
+
# Citations
|
116 |
+
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
117 |
+
- H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
|
118 |
+
>>>>>>> 09de19734bf3da83050abc74408517ba15b5b185
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"classifier_dropout": null,
|
7 |
+
"hidden_act": "relu",
|
8 |
+
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 512,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 1024,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 4096,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 8,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"torch_dtype": "float32",
|
20 |
+
"transformers_version": "4.37.1",
|
21 |
+
"type_vocab_size": 2,
|
22 |
+
"use_cache": true,
|
23 |
+
"vocab_size": 20275
|
24 |
+
}
|
docs/Makefile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimal makefile for Sphinx documentation
|
2 |
+
#
|
3 |
+
|
4 |
+
# You can set these variables from the command line, and also
|
5 |
+
# from the environment for the first two.
|
6 |
+
SPHINXOPTS ?=
|
7 |
+
SPHINXBUILD ?= sphinx-build
|
8 |
+
SOURCEDIR = source
|
9 |
+
BUILDDIR = build
|
10 |
+
|
11 |
+
# Put it first so that "make" without argument is like "make help".
|
12 |
+
help:
|
13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14 |
+
|
15 |
+
.PHONY: help Makefile
|
16 |
+
|
17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19 |
+
%: Makefile
|
20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/make.bat
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@ECHO OFF
|
2 |
+
|
3 |
+
pushd %~dp0
|
4 |
+
|
5 |
+
REM Command file for Sphinx documentation
|
6 |
+
|
7 |
+
if "%SPHINXBUILD%" == "" (
|
8 |
+
set SPHINXBUILD=sphinx-build
|
9 |
+
)
|
10 |
+
set SOURCEDIR=source
|
11 |
+
set BUILDDIR=build
|
12 |
+
|
13 |
+
%SPHINXBUILD% >NUL 2>NUL
|
14 |
+
if errorlevel 9009 (
|
15 |
+
echo.
|
16 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
17 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
18 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
19 |
+
echo.may add the Sphinx directory to PATH.
|
20 |
+
echo.
|
21 |
+
echo.If you don't have Sphinx installed, grab it from
|
22 |
+
echo.https://www.sphinx-doc.org/
|
23 |
+
exit /b 1
|
24 |
+
)
|
25 |
+
|
26 |
+
if "%1" == "" goto help
|
27 |
+
|
28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
29 |
+
goto end
|
30 |
+
|
31 |
+
:help
|
32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
33 |
+
|
34 |
+
:end
|
35 |
+
popd
|
docs/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.
|
2 |
+
sphinx_rtd_theme==2.0.0
|
3 |
+
nbsphinx==0.9.3
|
docs/source/_static/css/custom.css
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* top left logo */
|
2 |
+
.wy-side-nav-search, .wy-nav-top {
|
3 |
+
background: linear-gradient(15deg, #13547a 0%, #80d0c7 100%);
|
4 |
+
}
|
5 |
+
|
6 |
+
|
7 |
+
/* unvisited link */
|
8 |
+
.wy-nav-content a:link {
|
9 |
+
color: #067abd;
|
10 |
+
}
|
11 |
+
|
12 |
+
/* visited link */
|
13 |
+
.wy-nav-content a:visited {
|
14 |
+
color: #4b827c;
|
15 |
+
}
|
16 |
+
|
17 |
+
/* mouse over link */
|
18 |
+
.wy-nav-content a:hover {
|
19 |
+
color: #80d0c7;
|
20 |
+
}
|
21 |
+
|
22 |
+
/* selected link */
|
23 |
+
.wy-nav-content a:active {
|
24 |
+
color: #4b827c;
|
25 |
+
}
|
26 |
+
|
27 |
+
/* class object */
|
28 |
+
.sig.sig-object {
|
29 |
+
padding: 5px 5px 5px 5px;
|
30 |
+
background-color: #ececec;
|
31 |
+
border-style: solid;
|
32 |
+
border-color: black;
|
33 |
+
border-width: 1px 0;
|
34 |
+
}
|
35 |
+
|
36 |
+
/* parameter object */
|
37 |
+
dt {
|
38 |
+
padding: 5px 5px 5px 5px;
|
39 |
+
background-color: #ececec;
|
40 |
+
}
|
docs/source/_static/gf_logo.png
ADDED
docs/source/about.rst
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
About
|
2 |
+
=====
|
3 |
+
|
4 |
+
Model Description
|
5 |
+
-----------------
|
6 |
+
|
7 |
+
**Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus of human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
8 |
+
|
9 |
+
In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the original 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
10 |
+
|
11 |
+
Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-12L-30M-i2048/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
|
12 |
+
|
13 |
+
Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
14 |
+
|
15 |
+
Application
|
16 |
+
-----------
|
17 |
+
|
18 |
+
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
19 |
+
|
20 |
+
Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ include:
|
21 |
+
|
22 |
+
| *Fine-tuning*:
|
23 |
+
| - transcription factor dosage sensitivity
|
24 |
+
| - chromatin dynamics (bivalently marked promoters)
|
25 |
+
| - transcription factor regulatory range
|
26 |
+
| - gene network centrality
|
27 |
+
| - transcription factor targets
|
28 |
+
| - cell type annotation
|
29 |
+
| - batch integration
|
30 |
+
| - cell state classification across differentiation
|
31 |
+
| - disease classification
|
32 |
+
| - in silico perturbation to determine disease-driving genes
|
33 |
+
| - in silico treatment to determine candidate therapeutic targets
|
34 |
+
|
35 |
+
| *Zero-shot learning*:
|
36 |
+
| - batch integration
|
37 |
+
| - gene context specificity
|
38 |
+
| - in silico reprogramming
|
39 |
+
| - in silico differentiation
|
40 |
+
| - in silico perturbation to determine impact on cell state
|
41 |
+
| - in silico perturbation to determine transcription factor targets
|
42 |
+
| - in silico perturbation to determine transcription factor cooperativity
|
43 |
+
|
44 |
+
Citations
|
45 |
+
---------
|
46 |
+
|
47 |
+
| C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
|
48 |
+
|
49 |
+
| H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
|
docs/source/api.rst
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
API
|
2 |
+
===
|
3 |
+
|
4 |
+
Tokenizer
|
5 |
+
---------
|
6 |
+
|
7 |
+
.. toctree::
|
8 |
+
:maxdepth: 1
|
9 |
+
|
10 |
+
geneformer.tokenizer
|
11 |
+
|
12 |
+
Classifier
|
13 |
+
----------
|
14 |
+
|
15 |
+
.. toctree::
|
16 |
+
:maxdepth: 1
|
17 |
+
|
18 |
+
geneformer.classifier
|
19 |
+
|
20 |
+
Multitask Classifier
|
21 |
+
--------------------
|
22 |
+
|
23 |
+
.. toctree::
|
24 |
+
:maxdepth: 1
|
25 |
+
|
26 |
+
geneformer.mtl_classifier
|
27 |
+
|
28 |
+
Embedding Extractor
|
29 |
+
-------------------
|
30 |
+
|
31 |
+
.. toctree::
|
32 |
+
:maxdepth: 1
|
33 |
+
|
34 |
+
geneformer.emb_extractor
|
35 |
+
|
36 |
+
In Silico Perturber
|
37 |
+
-------------------
|
38 |
+
|
39 |
+
.. toctree::
|
40 |
+
:maxdepth: 1
|
41 |
+
|
42 |
+
geneformer.in_silico_perturber
|
43 |
+
|
44 |
+
|
45 |
+
In Silico Perturber Stats
|
46 |
+
-------------------------
|
47 |
+
|
48 |
+
.. toctree::
|
49 |
+
:maxdepth: 1
|
50 |
+
|
51 |
+
geneformer.in_silico_perturber_stats
|
docs/source/conf.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration file for the Sphinx documentation builder.
|
2 |
+
#
|
3 |
+
# For the full list of built-in configuration values, see the documentation:
|
4 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
5 |
+
|
6 |
+
import pathlib
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
|
10 |
+
from sphinx.ext import autodoc
|
11 |
+
|
12 |
+
sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix())
|
13 |
+
|
14 |
+
|
15 |
+
# -- Project information -----------------------------------------------------
|
16 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
17 |
+
|
18 |
+
project = "geneformer"
|
19 |
+
copyright = "2024, Christina Theodoris"
|
20 |
+
author = "Christina Theodoris"
|
21 |
+
release = "0.1.0"
|
22 |
+
repository_url = "https://huggingface.co/ctheodoris/Geneformer"
|
23 |
+
|
24 |
+
# -- General configuration ---------------------------------------------------
|
25 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
26 |
+
|
27 |
+
extensions = [
|
28 |
+
"sphinx.ext.autodoc",
|
29 |
+
"sphinx.ext.autosummary",
|
30 |
+
"nbsphinx",
|
31 |
+
"sphinx.ext.viewcode",
|
32 |
+
"sphinx.ext.doctest",
|
33 |
+
]
|
34 |
+
|
35 |
+
templates_path = ["_templates"]
|
36 |
+
exclude_patterns = [
|
37 |
+
"**.ipynb_checkpoints",
|
38 |
+
]
|
39 |
+
autoclass_content = "both"
|
40 |
+
|
41 |
+
|
42 |
+
class MockedClassDocumenter(autodoc.ClassDocumenter):
|
43 |
+
def add_line(self, line: str, source: str, *lineno: int) -> None:
|
44 |
+
if line == " Bases: :py:class:`object`":
|
45 |
+
return
|
46 |
+
super().add_line(line, source, *lineno)
|
47 |
+
|
48 |
+
|
49 |
+
autodoc.ClassDocumenter = MockedClassDocumenter
|
50 |
+
add_module_names = False
|
51 |
+
|
52 |
+
|
53 |
+
def process_signature(app, what, name, obj, options, signature, return_annotation):
|
54 |
+
# loop through each line in the docstring and replace path with
|
55 |
+
# the generic path text
|
56 |
+
signature = re.sub(r"PosixPath\(.*?\)", "FILEPATH", signature)
|
57 |
+
return (signature, None)
|
58 |
+
|
59 |
+
|
60 |
+
def setup(app):
|
61 |
+
app.connect("autodoc-process-signature", process_signature)
|
62 |
+
|
63 |
+
|
64 |
+
# -- Options for HTML output -------------------------------------------------
|
65 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
66 |
+
|
67 |
+
html_theme = "sphinx_rtd_theme"
|
68 |
+
html_show_sphinx = False
|
69 |
+
html_static_path = ["_static"]
|
70 |
+
html_logo = "_static/gf_logo.png"
|
71 |
+
html_theme_options = {
|
72 |
+
"collapse_navigation": False,
|
73 |
+
"sticky_navigation": True,
|
74 |
+
"navigation_depth": 3,
|
75 |
+
"logo_only": True,
|
76 |
+
}
|
77 |
+
html_css_files = [
|
78 |
+
"css/custom.css",
|
79 |
+
]
|
80 |
+
html_show_sourcelink = False
|
docs/source/geneformer.classifier.rst
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.classifier
|
2 |
+
=====================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.classifier
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
valid_option_dict,
|
10 |
+
validate_options
|
docs/source/geneformer.emb_extractor.rst
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.emb\_extractor
|
2 |
+
=========================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.emb_extractor
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
accumulate_tdigests,
|
10 |
+
gen_heatmap_class_colors,
|
11 |
+
gen_heatmap_class_dict,
|
12 |
+
get_embs,
|
13 |
+
label_cell_embs,
|
14 |
+
label_gene_embs,
|
15 |
+
make_colorbar,
|
16 |
+
plot_heatmap,
|
17 |
+
plot_umap,
|
18 |
+
summarize_gene_embs,
|
19 |
+
tdigest_mean,
|
20 |
+
tdigest_median,
|
21 |
+
test_emb,
|
22 |
+
update_tdigest_dict,
|
23 |
+
update_tdigest_dict_mean,
|
24 |
+
update_tdigest_dict_median,
|
25 |
+
valid_option_dict,
|
26 |
+
validate_options
|
docs/source/geneformer.in_silico_perturber.rst
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.in\_silico\_perturber
|
2 |
+
=======================================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.in_silico_perturber
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
|
docs/source/geneformer.in_silico_perturber_stats.rst
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.in\_silico\_perturber\_stats
|
2 |
+
==============================================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.in_silico_perturber_stats
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
find,
|
10 |
+
get_fdr,
|
11 |
+
get_gene_list,
|
12 |
+
get_impact_component,
|
13 |
+
invert_dict,
|
14 |
+
isp_aggregate_gene_shifts,
|
15 |
+
isp_aggregate_grouped_perturb,
|
16 |
+
isp_stats_mixture_model,
|
17 |
+
isp_stats_to_goal_state,
|
18 |
+
isp_stats_vs_null,
|
19 |
+
n_detections,
|
20 |
+
read_dict,
|
21 |
+
read_dictionaries,
|
22 |
+
token_to_gene_name,
|
23 |
+
token_tuple_to_ensembl_ids,
|
24 |
+
valid_option_dict,
|
25 |
+
validate_options
|
docs/source/geneformer.mtl_classifier.rst
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.mtl\_classifier
|
2 |
+
==========================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.mtl_classifier
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
valid_option_dict,
|
10 |
+
validate_options,
|
11 |
+
validate_additional_options
|
docs/source/geneformer.tokenizer.rst
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.tokenizer
|
2 |
+
====================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.tokenizer
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
create_dataset,
|
10 |
+
tokenize_anndata,
|
11 |
+
tokenize_files,
|
12 |
+
tokenize_loom,
|
13 |
+
rank_genes,
|
14 |
+
tokenize_cell,
|
15 |
+
sum_ensembl_ids
|
docs/source/getstarted.rst
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Getting Started
|
2 |
+
===============
|
3 |
+
|
4 |
+
Installation
|
5 |
+
------------
|
6 |
+
|
7 |
+
Geneformer installation instructions.
|
8 |
+
|
9 |
+
Make sure you have git-lfs installed (https://git-lfs.com).
|
10 |
+
|
11 |
+
.. code-block:: bash
|
12 |
+
|
13 |
+
git lfs install
|
14 |
+
git clone https://huggingface.co/ctheodoris/Geneformer
|
15 |
+
cd Geneformer
|
16 |
+
pip install .
|
17 |
+
|
18 |
+
|
19 |
+
Tutorials
|
20 |
+
---------
|
21 |
+
|
22 |
+
| See `examples <https://huggingface.co/ctheodoris/Geneformer/tree/main/examples>`_ for:
|
23 |
+
| - tokenizing transcriptomes
|
24 |
+
| - pretraining
|
25 |
+
| - hyperparameter tuning
|
26 |
+
| - fine-tuning
|
27 |
+
| - extracting and plotting cell embeddings
|
28 |
+
| - in silico perturbation
|
29 |
+
|
30 |
+
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the `example_input_files directory <https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files>`_ in the dataset repository, but these only represent a few example fine-tuning applications.
|
31 |
+
|
32 |
+
|
33 |
+
Tips
|
34 |
+
----
|
35 |
+
|
36 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
docs/source/index.rst
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Geneformer
|
2 |
+
==========
|
3 |
+
|
4 |
+
Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
|
5 |
+
|
6 |
+
See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
|
7 |
+
|
8 |
+
Table of Contents
|
9 |
+
-----------------
|
10 |
+
|
11 |
+
.. toctree::
|
12 |
+
:maxdepth: 2
|
13 |
+
|
14 |
+
about
|
15 |
+
getstarted
|
16 |
+
api
|
examples/cell_classification.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/extract_and_plot_cell_embeddings.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/gene_classification.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/in_silico_perturbation.ipynb
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "e10ac0c9-40ce-41fb-b6fa-3d62b76f2e57",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from geneformer import InSilicoPerturber\n",
|
11 |
+
"from geneformer import InSilicoPerturberStats\n",
|
12 |
+
"from geneformer import EmbExtractor"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"id": "cbd6851c-060e-4967-b816-e605ffe58b23",
|
18 |
+
"metadata": {
|
19 |
+
"tags": []
|
20 |
+
},
|
21 |
+
"source": [
|
22 |
+
"### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"id": "c53e98cd-c603-4878-82ba-db471181bb55",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# first obtain start, goal, and alt embedding positions\n",
|
33 |
+
"# this function was changed to be separate from perturb_data\n",
|
34 |
+
"# to avoid repeating calcuations when parallelizing perturb_data\n",
|
35 |
+
"cell_states_to_model={\"state_key\": \"disease\", \n",
|
36 |
+
" \"start_state\": \"dcm\", \n",
|
37 |
+
" \"goal_state\": \"nf\", \n",
|
38 |
+
" \"alt_states\": [\"hcm\"]}\n",
|
39 |
+
"\n",
|
40 |
+
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
+
"\n",
|
42 |
+
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
43 |
+
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
44 |
+
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
45 |
+
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
46 |
+
" num_classes=3,\n",
|
47 |
+
" filter_data=filter_data_dict,\n",
|
48 |
+
" max_ncells=1000,\n",
|
49 |
+
" emb_layer=0,\n",
|
50 |
+
" summary_stat=\"exact_mean\",\n",
|
51 |
+
" forward_batch_size=256,\n",
|
52 |
+
" nproc=16)\n",
|
53 |
+
"\n",
|
54 |
+
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
55 |
+
" \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
56 |
+
" \"path/to/input_data\",\n",
|
57 |
+
" \"path/to/output_directory\",\n",
|
58 |
+
" \"output_prefix\")"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": null,
|
64 |
+
"id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
|
65 |
+
"metadata": {
|
66 |
+
"tags": []
|
67 |
+
},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
71 |
+
"# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
|
72 |
+
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
73 |
+
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
74 |
+
" perturb_rank_shift=None,\n",
|
75 |
+
" genes_to_perturb=\"all\",\n",
|
76 |
+
" combos=0,\n",
|
77 |
+
" anchor_gene=None,\n",
|
78 |
+
" model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
79 |
+
" num_classes=3,\n",
|
80 |
+
" emb_mode=\"cell\",\n",
|
81 |
+
" cell_emb_style=\"mean_pool\",\n",
|
82 |
+
" filter_data=filter_data_dict,\n",
|
83 |
+
" cell_states_to_model=cell_states_to_model,\n",
|
84 |
+
" state_embs_dict=state_embs_dict,\n",
|
85 |
+
" max_ncells=2000,\n",
|
86 |
+
" emb_layer=0,\n",
|
87 |
+
" forward_batch_size=400,\n",
|
88 |
+
" nproc=16)"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"id": "0525a663-871a-4ce0-a135-cc203817ffa9",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"# outputs intermediate files from in silico perturbation\n",
|
99 |
+
"\n",
|
100 |
+
"isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
101 |
+
" \"path/to/input_data\",\n",
|
102 |
+
" \"path/to/isp_output_directory\",\n",
|
103 |
+
" \"output_prefix\")"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
114 |
+
"# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
|
115 |
+
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
116 |
+
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
117 |
+
" genes_perturbed=\"all\",\n",
|
118 |
+
" combos=0,\n",
|
119 |
+
" anchor_gene=None,\n",
|
120 |
+
" cell_states_to_model=cell_states_to_model)"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": null,
|
126 |
+
"id": "ffecfae6-e737-43e3-99e9-fa37ff46610b",
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
131 |
+
"ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
|
132 |
+
" None,\n",
|
133 |
+
" \"path/to/isp_stats_output_directory\",\n",
|
134 |
+
" \"output_prefix\")"
|
135 |
+
]
|
136 |
+
}
|
137 |
+
],
|
138 |
+
"metadata": {
|
139 |
+
"kernelspec": {
|
140 |
+
"display_name": "Python 3 (ipykernel)",
|
141 |
+
"language": "python",
|
142 |
+
"name": "python3"
|
143 |
+
},
|
144 |
+
"language_info": {
|
145 |
+
"codemirror_mode": {
|
146 |
+
"name": "ipython",
|
147 |
+
"version": 3
|
148 |
+
},
|
149 |
+
"file_extension": ".py",
|
150 |
+
"mimetype": "text/x-python",
|
151 |
+
"name": "python",
|
152 |
+
"nbconvert_exporter": "python",
|
153 |
+
"pygments_lexer": "ipython3",
|
154 |
+
"version": "3.10.15"
|
155 |
+
}
|
156 |
+
},
|
157 |
+
"nbformat": 4,
|
158 |
+
"nbformat_minor": 5
|
159 |
+
}
|
examples/multitask_cell_classification.ipynb
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "866f100c-e11a-4e7b-a37c-831775d845a7",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Geneformer Multi-Task Cell Classifier Tutorial\n",
|
9 |
+
"\n",
|
10 |
+
"This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"## 1. Installation and Imports\n",
|
19 |
+
"\n",
|
20 |
+
"First import the necessary modules."
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 3,
|
26 |
+
"id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"from geneformer import MTLClassifier"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
|
36 |
+
"metadata": {},
|
37 |
+
"source": [
|
38 |
+
"## 2. Set up Paths and Parameters\n",
|
39 |
+
"\n",
|
40 |
+
"Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"id": "04a04197-8e45-47f8-a86f-202209ea10ae",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"# Define paths\n",
|
51 |
+
"pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
|
52 |
+
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
53 |
+
"train_path = \"/path/to/train/data.dataset\"\n",
|
54 |
+
"val_path = \"/path/to/val/data.dataset\"\n",
|
55 |
+
"test_path = \"/path/to/test/data.dataset\"\n",
|
56 |
+
"results_dir = \"/path/to/results/directory\"\n",
|
57 |
+
"model_save_path = \"/path/to/model/save/path\"\n",
|
58 |
+
"tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
|
59 |
+
"\n",
|
60 |
+
"# Define tasks and hyperparameters\n",
|
61 |
+
"# task_columns should be a list of column names from your dataset\n",
|
62 |
+
"# Each column represents a specific classification task (e.g. cell type, disease state)\n",
|
63 |
+
"task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
|
64 |
+
"\n",
|
65 |
+
"hyperparameters = {\n",
|
66 |
+
" \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
|
67 |
+
" \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
|
68 |
+
" \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
|
69 |
+
" \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
|
70 |
+
" \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
|
71 |
+
" \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
|
72 |
+
"}"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "markdown",
|
77 |
+
"id": "31857690-a739-435a-aefd-f171fafc1b78",
|
78 |
+
"metadata": {},
|
79 |
+
"source": [
|
80 |
+
"In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
|
81 |
+
"1. Identifying the cell type\n",
|
82 |
+
"2. Determining the disease state\n",
|
83 |
+
"3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
|
84 |
+
"\n",
|
85 |
+
"These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
|
86 |
+
"\n",
|
87 |
+
"For example, your dataset might look something like this:\n",
|
88 |
+
"\n",
|
89 |
+
" | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
|
90 |
+
" |----------------|-----------|-----|-----------|---------------|\n",
|
91 |
+
" | cell1 | ... | ... | neuron | healthy |\n",
|
92 |
+
" | cell2 | ... | ... | astrocyte | diseased |\n",
|
93 |
+
" | ... | ... | ... | ... | ... |\n",
|
94 |
+
"The model will learn to predict classes within 'cell_type' and 'disease_state' "
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "markdown",
|
99 |
+
"id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
|
100 |
+
"metadata": {},
|
101 |
+
"source": [
|
102 |
+
"## 3. Initialize the MTLClassifier\n",
|
103 |
+
"\n",
|
104 |
+
"Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"id": "e27caac9-670c-409d-9313-50201c665cb9",
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"mc = MTLClassifier(\n",
|
115 |
+
" task_columns=task_columns, # Our defined classification tasks\n",
|
116 |
+
" study_name=\"MTLClassifier_example\",\n",
|
117 |
+
" pretrained_path=pretrained_path,\n",
|
118 |
+
" train_path=train_path,\n",
|
119 |
+
" val_path=val_path,\n",
|
120 |
+
" test_path=test_path,\n",
|
121 |
+
" model_save_path=model_save_path,\n",
|
122 |
+
" results_dir=results_dir,\n",
|
123 |
+
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
124 |
+
" hyperparameters=hyperparameters,\n",
|
125 |
+
" n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
|
126 |
+
" epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
|
127 |
+
" batch_size=8, # Adjust based on available GPU memory\n",
|
128 |
+
" seed=42\n",
|
129 |
+
")"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"id": "0d729444-e3ad-4584-9659-0c464ac97462",
|
135 |
+
"metadata": {},
|
136 |
+
"source": [
|
137 |
+
"## 4. Run Hyperparameter Optimization\n",
|
138 |
+
"\n",
|
139 |
+
"Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": null,
|
145 |
+
"id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"mc.run_optuna_study()"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "markdown",
|
154 |
+
"id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
|
155 |
+
"metadata": {},
|
156 |
+
"source": [
|
157 |
+
"## 5. Evaluate the Model on Test Data\n",
|
158 |
+
"\n",
|
159 |
+
"After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"mc.load_and_evaluate_test_model()"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "markdown",
|
174 |
+
"id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
|
175 |
+
"metadata": {},
|
176 |
+
"source": [
|
177 |
+
"## 6. (Optional) Manual Hyperparameter Tuning\n",
|
178 |
+
"\n",
|
179 |
+
"If you prefer to set hyperparameters manually, you can use the following approach:"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"manual_hyperparameters = {\n",
|
190 |
+
" \"learning_rate\": 0.001,\n",
|
191 |
+
" \"warmup_ratio\": 0.01,\n",
|
192 |
+
" \"weight_decay\": 0.1,\n",
|
193 |
+
" \"dropout_rate\": 0.1,\n",
|
194 |
+
" \"lr_scheduler_type\": \"cosine\",\n",
|
195 |
+
" \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
|
196 |
+
" \"max_layers_to_freeze\": 2\n",
|
197 |
+
"}\n",
|
198 |
+
"\n",
|
199 |
+
"mc_manual = MTLClassifier(\n",
|
200 |
+
" task_columns=task_columns,\n",
|
201 |
+
" study_name=\"mtl_manual\",\n",
|
202 |
+
" pretrained_path=pretrained_path,\n",
|
203 |
+
" train_path=train_path,\n",
|
204 |
+
" val_path=val_path,\n",
|
205 |
+
" test_path=test_path,\n",
|
206 |
+
" model_save_path=model_save_path,\n",
|
207 |
+
" results_dir=results_dir,\n",
|
208 |
+
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
209 |
+
" manual_hyperparameters=manual_hyperparameters,\n",
|
210 |
+
" use_manual_hyperparameters=True,\n",
|
211 |
+
" epochs=10,\n",
|
212 |
+
" batch_size=32,\n",
|
213 |
+
" seed=42\n",
|
214 |
+
")\n",
|
215 |
+
"\n",
|
216 |
+
"mc_manual.run_manual_tuning()"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "markdown",
|
221 |
+
"id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
|
222 |
+
"metadata": {},
|
223 |
+
"source": [
|
224 |
+
"# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
|
225 |
+
"This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": null,
|
241 |
+
"id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"# Define paths\n",
|
246 |
+
"model_directory = \"/path/to/model/save/path\"\n",
|
247 |
+
"input_data_file = \"/path/to/input/data.dataset\"\n",
|
248 |
+
"output_directory = \"/path/to/output/directory\"\n",
|
249 |
+
"output_prefix = \"mtl_quantized_perturbation\"\n",
|
250 |
+
"\n",
|
251 |
+
"# Define parameters\n",
|
252 |
+
"perturb_type = \"delete\" # or \"overexpress\"\n",
|
253 |
+
"\n",
|
254 |
+
"# Define cell states to model\n",
|
255 |
+
"cell_states_to_model = {\n",
|
256 |
+
" \"state_key\": \"disease_state\", \n",
|
257 |
+
" \"start_state\": \"disease\", \n",
|
258 |
+
" \"goal_state\": \"control\"\n",
|
259 |
+
"}\n",
|
260 |
+
"\n",
|
261 |
+
"# Define filter data\n",
|
262 |
+
"filter_data_dict = {\n",
|
263 |
+
" \"cell_type\": [\"Fibroblast\"]\n",
|
264 |
+
"}"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "markdown",
|
269 |
+
"id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
|
270 |
+
"metadata": {},
|
271 |
+
"source": [
|
272 |
+
"## 3. Extract State Embeddings\n",
|
273 |
+
"\n",
|
274 |
+
"Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "code",
|
279 |
+
"execution_count": null,
|
280 |
+
"id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [],
|
283 |
+
"source": [
|
284 |
+
"# Initialize EmbExtractor\n",
|
285 |
+
"embex = EmbExtractor(\n",
|
286 |
+
" filter_data_dict=filter_data_dict,\n",
|
287 |
+
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
288 |
+
" emb_layer=0, # Use the second to last layer\n",
|
289 |
+
" emb_mode = \"cls\",\n",
|
290 |
+
" summary_stat=\"exact_mean\",\n",
|
291 |
+
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
292 |
+
" nproc=4\n",
|
293 |
+
")\n",
|
294 |
+
"\n",
|
295 |
+
"# Extract state embeddings\n",
|
296 |
+
"state_embs_dict = embex.get_state_embs(\n",
|
297 |
+
" cell_states_to_model,\n",
|
298 |
+
" model_directory=model_directory,\n",
|
299 |
+
" input_data_file=input_data_file,\n",
|
300 |
+
" output_directory=output_directory,\n",
|
301 |
+
" output_prefix=output_prefix\n",
|
302 |
+
")"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "markdown",
|
307 |
+
"id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
|
308 |
+
"metadata": {},
|
309 |
+
"source": [
|
310 |
+
"## 4. Initialize the InSilicoPerturber\n",
|
311 |
+
"\n",
|
312 |
+
"Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": null,
|
318 |
+
"id": "09f985a1-91bc-4e8d-8001-a3663531b570",
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [],
|
321 |
+
"source": [
|
322 |
+
"# Initialize InSilicoPerturber\n",
|
323 |
+
"isp = InSilicoPerturber(\n",
|
324 |
+
" perturb_type=perturb_type,\n",
|
325 |
+
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
326 |
+
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
327 |
+
" emb_mode=\"cls\", # Use CLS token embedding\n",
|
328 |
+
" cell_states_to_model=cell_states_to_model,\n",
|
329 |
+
" state_embs_dict=state_embs_dict,\n",
|
330 |
+
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
331 |
+
" emb_layer=0, \n",
|
332 |
+
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
333 |
+
" nproc=1\n",
|
334 |
+
")"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"cell_type": "markdown",
|
339 |
+
"id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
|
340 |
+
"metadata": {},
|
341 |
+
"source": [
|
342 |
+
"## 5. Run In Silico Perturbation\n",
|
343 |
+
"\n",
|
344 |
+
"Run the in silico perturbation on the dataset."
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"execution_count": null,
|
350 |
+
"id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
|
351 |
+
"metadata": {},
|
352 |
+
"outputs": [],
|
353 |
+
"source": [
|
354 |
+
"# Run perturbation and output intermediate files\n",
|
355 |
+
"isp.perturb_data(\n",
|
356 |
+
" model_directory=model_directory,\n",
|
357 |
+
" input_data_file=input_data_file,\n",
|
358 |
+
" output_directory=output_directory,\n",
|
359 |
+
" output_prefix=output_prefix\n",
|
360 |
+
")"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "markdown",
|
365 |
+
"id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
|
366 |
+
"metadata": {},
|
367 |
+
"source": [
|
368 |
+
"## 6. Process Results with InSilicoPerturberStats\n",
|
369 |
+
"\n",
|
370 |
+
"After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": null,
|
376 |
+
"id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [],
|
379 |
+
"source": [
|
380 |
+
"# Initialize InSilicoPerturberStats\n",
|
381 |
+
"ispstats = InSilicoPerturberStats(\n",
|
382 |
+
" mode=\"goal_state_shift\",\n",
|
383 |
+
" genes_perturbed=\"all\",\n",
|
384 |
+
" combos=0,\n",
|
385 |
+
" anchor_gene=None,\n",
|
386 |
+
" cell_states_to_model=cell_states_to_model\n",
|
387 |
+
")\n",
|
388 |
+
"\n",
|
389 |
+
"# Process stats and output final .csv\n",
|
390 |
+
"ispstats.get_stats(\n",
|
391 |
+
" input_data_file,\n",
|
392 |
+
" None,\n",
|
393 |
+
" output_directory,\n",
|
394 |
+
" output_prefix\n",
|
395 |
+
")"
|
396 |
+
]
|
397 |
+
}
|
398 |
+
],
|
399 |
+
"metadata": {
|
400 |
+
"kernelspec": {
|
401 |
+
"display_name": "Python 3 (ipykernel)",
|
402 |
+
"language": "python",
|
403 |
+
"name": "python3"
|
404 |
+
},
|
405 |
+
"language_info": {
|
406 |
+
"codemirror_mode": {
|
407 |
+
"name": "ipython",
|
408 |
+
"version": 3
|
409 |
+
},
|
410 |
+
"file_extension": ".py",
|
411 |
+
"mimetype": "text/x-python",
|
412 |
+
"name": "python",
|
413 |
+
"nbconvert_exporter": "python",
|
414 |
+
"pygments_lexer": "ipython3",
|
415 |
+
"version": "3.11.5"
|
416 |
+
}
|
417 |
+
},
|
418 |
+
"nbformat": 4,
|
419 |
+
"nbformat_minor": 5
|
420 |
+
}
|
examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "charged-worcester",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Obtain non-zero median expression value of each gene across Genecorpus-30M"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "28e87f2a-a33e-4fe3-81af-ad4cd62fcc1b",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"#### Upon request, we are providing the code that we used for obtaining the non-zero median expression value of each gene across the broad range of cell types represented in Genecorpus-30M that we use as a normalization factor to prioritize genes that uniquely distinguish cell state.\n",
|
17 |
+
"\n",
|
18 |
+
"#### Please read the important information below before using this code.\n",
|
19 |
+
"\n",
|
20 |
+
"#### If using Geneformer, to ensure consistency of the normalization factor used for each gene for all future datasets, <ins>**users should use the Geneformer transcriptome tokenizer to tokenize their datasets and should not re-calculate this normalization factor for their individual dataset** </ins>. This code for re-calculating the normalization factor should only be used by users who are pretraining a new model from scratch with a new pretraining corpus other than Genecorpus-30M.\n",
|
21 |
+
"\n",
|
22 |
+
"#### It is critical that this calculation is performed on a large-scale pretraining corpus that has tens of millions of cells from a broad range of human tissues. <ins>**The richness of variable cell states in the pretraining corpus is what allows this normalization factor to accomplish the goal of prioritizing genes that uniquely distinguish cell states.** </ins> This normalization factor for each gene is calculated once from the large-scale pretraining corpus and is used for all future datasets presented to the model. \n",
|
23 |
+
"\n",
|
24 |
+
"#### Of note, as discussed in the Methods, we only included droplet-based sequencing platforms in the pretraining corpus to assure expression value unit comparability for the calculation of this normalization factor. Users wishing to pretrain a new model from scratch with a new pretraining corpus should choose either droplet-based or plate-based platforms for calculating this normalization factor, or they should exercise caution that including both platforms may cause unintended effects on the results. Once the normalization factor is calculated however, data from any platform can be used with the model because the expression value units will be consistent within each individual cell.\n",
|
25 |
+
"\n",
|
26 |
+
"#### Please see the Methods in the manuscript for a description of the procedure enacted by this code, an excerpt of which is below for convenience:\n",
|
27 |
+
"\n",
|
28 |
+
"#### \"To accomplish this, we first calculated the non-zero median value of expression of each detected gene across all cells passing quality filtering from the entire Genecorpus-30M. We aggregated the transcript count distribution for each gene in a memory-efficient manner by scanning through chunks of .loom data using loompy, normalizing the gene transcript counts in each cell by the total transcript count of that cell to account for varying sequencing depth and updating the normalized count distribution of the gene within the t-digest data structure developed for accurate online accumulation of rank-based statistics. We then normalized the genes in each single-cell transcriptome by the non-zero median value of expression of that gene across Genecorpus-30M and ordered the genes by the rank of their normalized expression in that specific cell. Of note, we opted to use the non-zero median value of expression rather than include zeros in the distribution so as not to weight the value by tissue representation within Genecorpus-30M, assuming that a representative range of transcript values would be observed within the cells in which each gene was detected. This normalization factor for each gene is calculated once from the pretraining corpus and is used for all future datasets presented to the model. The provided tokenizer code includes this normalization procedure and should be used for tokenizing new datasets presented to Geneformer to ensure consistency of the normalization factor used for each gene.\""
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 1,
|
34 |
+
"id": "textile-destruction",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"import os\n",
|
39 |
+
"import numpy as np\n",
|
40 |
+
"import loompy as lp\n",
|
41 |
+
"import pandas as pd\n",
|
42 |
+
"import crick\n",
|
43 |
+
"import pickle\n",
|
44 |
+
"import math\n",
|
45 |
+
"from tqdm.notebook import tqdm"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "4af8cfef-05f2-47e0-b8d2-71ca025059c7",
|
51 |
+
"metadata": {
|
52 |
+
"tags": []
|
53 |
+
},
|
54 |
+
"source": [
|
55 |
+
"### The following code is an example of how the nonzero median expression values are obtained for a single input file. This calculation should be run as a script to be parallelized for all dataset files."
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 30,
|
61 |
+
"id": "physical-intro",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"input_file = \"study1.loom\"\n",
|
66 |
+
"current_database = \"database1\"\n",
|
67 |
+
"\n",
|
68 |
+
"rootdir = f\"/path/to/{current_database}/data/\"\n",
|
69 |
+
"output_file = input_file.replace(\".loom\", \".gene_median_digest_dict.pickle\")\n",
|
70 |
+
"outdir = rootdir.replace(\"/data/\", \"/tdigest/\")\n",
|
71 |
+
"\n",
|
72 |
+
"with lp.connect(f\"{rootdir}{input_file}\") as data:\n",
|
73 |
+
" # define coordinates of protein-coding or miRNA genes\n",
|
74 |
+
" coding_miRNA_loc = np.where((data.ra.gene_type == \"protein_coding\") | (data.ra.gene_type == \"miRNA\"))[0]\n",
|
75 |
+
" coding_miRNA_genes = data.ra[\"ensembl_id\"][coding_miRNA_loc]\n",
|
76 |
+
" \n",
|
77 |
+
" # initiate tdigests\n",
|
78 |
+
" median_digests = [crick.tdigest.TDigest() for _ in range(len(coding_miRNA_loc))]\n",
|
79 |
+
" \n",
|
80 |
+
" # initiate progress meters\n",
|
81 |
+
" progress = tqdm(total=len(coding_miRNA_loc))\n",
|
82 |
+
" last_view_row = 0\n",
|
83 |
+
" progress.update(0)\n",
|
84 |
+
" \n",
|
85 |
+
" for (ix, selection, view) in data.scan(items=coding_miRNA_loc, axis=0):\n",
|
86 |
+
" # define coordinates of cells passing filter\n",
|
87 |
+
" filter_passed_loc = np.where(view.ca.filter_pass == 1)[0]\n",
|
88 |
+
" subview = view.view[:, filter_passed_loc]\n",
|
89 |
+
" # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision\n",
|
90 |
+
" subview_norm_array = subview[:,:]/subview.ca.n_counts*10_000\n",
|
91 |
+
" # if integer, convert to float to prevent error with filling with nan\n",
|
92 |
+
" if np.issubdtype(subview_norm_array.dtype, np.integer):\n",
|
93 |
+
" subview_norm_array = subview_norm_array.astype(np.float32)\n",
|
94 |
+
" # mask zeroes from distribution tdigest by filling with nan\n",
|
95 |
+
" nonzero_data = np.ma.masked_equal(subview_norm_array, 0.0).filled(np.nan)\n",
|
96 |
+
" # update tdigests\n",
|
97 |
+
" [median_digests[i+last_view_row].update(nonzero_data[i,:]) for i in range(nonzero_data.shape[0])]\n",
|
98 |
+
" # update progress meters\n",
|
99 |
+
" progress.update(view.shape[0])\n",
|
100 |
+
" last_view_row = last_view_row + view.shape[0]\n",
|
101 |
+
" \n",
|
102 |
+
"median_digest_dict = dict(zip(coding_miRNA_genes, median_digests))\n",
|
103 |
+
"with open(f\"{outdir}{output_file}\", \"wb\") as fp:\n",
|
104 |
+
" pickle.dump(median_digest_dict, fp)"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "markdown",
|
109 |
+
"id": "190a3754-aafa-4ccf-ba97-951c94ea3030",
|
110 |
+
"metadata": {
|
111 |
+
"tags": []
|
112 |
+
},
|
113 |
+
"source": [
|
114 |
+
"### After the above code is run as a script in parallel for all datasets to obtain the nonzero median tdigests for their contained genes, the following code can be run to merge the tdigests across all datasets."
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 2,
|
120 |
+
"id": "distributed-riding",
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"# merge new tdigests into total tdigest dict\n",
|
125 |
+
"def merge_digest(dict_key_ensembl_id, dict_value_tdigest, new_tdigest_dict):\n",
|
126 |
+
" new_gene_tdigest = new_tdigest_dict.get(dict_key_ensembl_id)\n",
|
127 |
+
" if new_gene_tdigest is not None:\n",
|
128 |
+
" dict_value_tdigest.merge(new_gene_tdigest)\n",
|
129 |
+
" return dict_value_tdigest\n",
|
130 |
+
" elif new_gene_tdigest is None:\n",
|
131 |
+
" return dict_value_tdigest"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"id": "distinct-library",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"# use tdigest1.merge(tdigest2) to merge tdigest1, tdigest2, ...tdigestn\n",
|
142 |
+
"# then, extract median by tdigest1.quantile(0.5)\n",
|
143 |
+
"\n",
|
144 |
+
"databases = [\"database1\", \"database2\", \"...databaseN\"]\n",
|
145 |
+
"\n",
|
146 |
+
"# obtain gene list\n",
|
147 |
+
"gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n",
|
148 |
+
"func_gene_list = [i for i in gene_info[(gene_info[\"gene_type\"] == \"protein_coding\") | (gene_info[\"gene_type\"] == \"miRNA\")][\"ensembl_id\"]]\n",
|
149 |
+
"\n",
|
150 |
+
"# initiate tdigests\n",
|
151 |
+
"median_digests = [crick.tdigest.TDigest() for _ in range(len(func_gene_list))]\n",
|
152 |
+
"total_tdigest_dict = dict(zip(func_gene_list, median_digests))\n",
|
153 |
+
"\n",
|
154 |
+
"# merge tdigests\n",
|
155 |
+
"for current_database in databases:\n",
|
156 |
+
" rootdir = f\"/path/to/{current_database}/tdigest/\"\n",
|
157 |
+
" \n",
|
158 |
+
" for subdir, dirs, files in os.walk(rootdir):\t\n",
|
159 |
+
" for file in files:\n",
|
160 |
+
" if file.endswith(\".gene_median_digest_dict.pickle\"):\n",
|
161 |
+
" with open(f\"{rootdir}{file}\", \"rb\") as fp:\n",
|
162 |
+
" tdigest_dict = pickle.load(fp)\n",
|
163 |
+
" total_tdigest_dict = {k: merge_digest(k,v,tdigest_dict) for k, v in total_tdigest_dict.items()}\n",
|
164 |
+
"\n",
|
165 |
+
"# save dict of merged tdigests\n",
|
166 |
+
"with open(f\"/path/to/total_gene_tdigest_dict.pickle\", \"wb\") as fp:\n",
|
167 |
+
" pickle.dump(total_tdigest_dict, fp)\n",
|
168 |
+
"\n",
|
169 |
+
"# extract medians and save dict\n",
|
170 |
+
"total_median_dict = {k: v.quantile(0.5) for k, v in total_tdigest_dict.items()}\n",
|
171 |
+
"with open(f\"/path/to/total_gene_median_dict.pickle\", \"wb\") as fp:\n",
|
172 |
+
" pickle.dump(total_median_dict, fp)\n",
|
173 |
+
"\n",
|
174 |
+
"# save dict of only detected genes' medians \n",
|
175 |
+
"detected_median_dict = {k: v for k, v in total_median_dict.items() if not math.isnan(v)}\n",
|
176 |
+
"with open(f\"/path/to/detected_gene_median_dict.pickle\", \"wb\") as fp:\n",
|
177 |
+
" pickle.dump(detected_median_dict, fp)"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"id": "e8e17ad6-79ac-4f34-aa0c-1eaa1bace2e5",
|
183 |
+
"metadata": {
|
184 |
+
"tags": []
|
185 |
+
},
|
186 |
+
"source": [
|
187 |
+
"### The below code displays some characteristics of the genes detected in the pretraining corpus."
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "code",
|
192 |
+
"execution_count": 38,
|
193 |
+
"id": "decent-switzerland",
|
194 |
+
"metadata": {},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"gene_detection_counts_dict = {k: v.size() for k, v in total_tdigest_dict.items()}"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 44,
|
203 |
+
"id": "polished-innocent",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"name": "stderr",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"/home1/ct68/miniconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
|
211 |
+
" warnings.warn(msg, FutureWarning)\n"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"data": {
|
216 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAMRCAYAAABlG8GWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABcSAAAXEgFnn9JSAAC/KUlEQVR4nOzdd5hjZ3X48e/Z7l2vK240G0wzBgOmmmp6NT/TQmjBlCS0ACGE3gklJARCLyGYGgi9hRqwgYBpxnRMMTZgsI1x2+Lt5/fHe8d7dUfSSBpdaWb2+3kePaN7dcs7M1ea0dF5z4nMRJIkSZIkSZLGbdm0ByBJkiRJkiRpaTL4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWrFi2gOQJEmTERFXAY4GDgf2A1YDG4BLgPOA0zPzT9Ma3yAi4mTgkbVVj8rMk/tsfwTwm9qqczLziDbGJqm3iDib8toz4xqZefZ0RqM9XUQcD3yltsq/DZLUIoOPkiQtYRFxLHAScF863/j32v4c4AvAe4GvZWa2OkANLCJOAe4wzmNmZozzeJImw9cDSdJi4rRrSZL6iIjlEXFpRGR1e/uA271m0mNtjOeGEfFF4HvA3zFA4LFyOPDXwKnAryLiERHh/wuSJEmSRuKbCUmS+rsJsE9t+Ss9trtZY7tT2hpQP1E8AzgduEufTRO4mDLtuld24zWBdwPfGOsgJUmSJO0xnHYtSVJ/zWltp/TY7k61+7uAr7Yymj6qDMX/AB7V5eE/AB8DPgt8F7gwM3dW+60GrgPcFrgf5XtZXtv3ei0OW6M7DXjXtAchaUHw9UCStGAZfJQkqb/ja/fPzMw/9NiuHnz8QWZe3N6QenoDswOPG4FXAK/JzMu77ZSZW4EfVbc3R8SRwHMptSKtAbZwnZmZb5n2IKRB2Myjdb4eSJIWLKddS5LUQ5VJeNvaqq5TriNiFXDrubZrU0Q8Gnh8Y/X5wO0y8+W9Ao/dZOavM/PRlO/prDEOU5IkSdIexuCjJEm93RjYr7bcK6h4K2BtbfmUdobTXURcGWg2uLkEuG1mnjHqcTPzNErNyy+MPDhJkiRJezSDj5Ik9bZY6j0+n85mNwBPycxfzffAmXkZ8JfzPY4kSZKkPZM1HyVJ6u342v2fZOYFPbarBx+/n5mXtjekThFxILPrPH41M989rnNk5q5R9ouI/SlZoYcAB1G6av8J+A1wWmZuG9cY2xQRhwI3Ao4A9gVWAZcDlwG/BX49jkCveouIFcAtgBsABwJbKE2UvtfWzz4i1lGu3+sA+1M+WPhjZg7U1CMiDgeOpVz7BwKbgQuAnwI/zMxeXeYHHV8ARwLHAIdRPoCI6jwXAedQ6gCeN+Lx96mOfR1KBvhaYCuwCTgXOBv4aWZun8/3MR8RsS9wG+DawN7ApZTr4quZeeGYzrEGuD1wdeBgys/gt8C3MvO34zjHYlS7/q5L+dnsQ0lsuRi4kPLc/E1L574ecDRwJeAAYAfld/8r4EeZ+acxnWcl5TXgBpTXgMspz+Fv+ZovScMx+ChJUhdVvcfb1Vad0mO7vYBbzrVdix4GrG6se9OEx3CFiFgOPBL4a+DmdHbNrtsYEZ8BXpyZP5vU+AZVBbseW91uOsD2FwFfAz4MvH/UgO1CEBF3BT5H5wyZl2TmC4c4xq2BU+n8X/M1mfm0Hts3A3HXyMyzI2It8EzgiZQAXrd9Twf+KTM/NsT4TgLeWVt1amYeXz12beBFwAOY/dyCPh2FI2I/4O8p2cLX6TOE8yLivcArMvOiQcddnWN/4OnAwylBn7m2/y3wReA9mXnqANvfHXgycDfmfq+wJSK+A3wMOLlfo62IOBs4vLbqGpl5dp/tXwTUr7l3ZeZJ1WPXBP4JeCCwssvuGRFfAp6VmafP8T30Ov9VgJdRroO9e2zzDcpr2Bd6jPnFmfmiUc6/EFXX3v2Be1MCsl2fk7XtzwXeAbx+vsHgqhHaPwInAFfus2lGxA+BjwLvyMxzRzjXeuA5wOPoLL1S3+ZnwAsy88PDHl+S9kROu5Yk7bEiInvdgJ2UTIcZT+yx3WZKJtyMf+hz3ONb+Dbu21i+iBIImLiIuC0lq+sdlGyRXoFHKG/mHwz8KCJeXmXRLAhVxtrpwJsZIPBYOQD4f8B7mD0FflHJzC8CL2msfl4VlJpTRFwJ+CCdgatvUoKIA4uIawDfA15A/yDHscBHI+JDVZbayKrGTT8CHkr3wGO/ff+W0qDpBfQPPAIcSgkgnhURDxziHHcDfkkJjMwZeKxcHXgM8O9zHHtNRHyQEni+F4MlKayhfEjzb8wuU9GKiHgI8GPgIXQPPELJAL0r8K2I+KsRz/FzygcpXQOPlVsDn4+If68+sFqyIuIY4DzgP4D7MUfgsXIVyvPhVxFxnxHPuyYi3kT5ffwt/QOPUH73NwJeDHx6hPPdiPIa8Cx6BB4rRwEfiog3L/XfvSSNgy+UkiQtUlV23m0aq78xjenMEfEI4H/pHnRJyhTljV0eWw48G/ivhfAGLiIOoGQw3rDHJpspUwo3TWxQ0/FS4PO15WXAeyPiqv12qn6H7wPq210I/MWQ03OvRLmertdYv5Ey9bGbBwIfj4ihgoYzIuKRlMB5c/9LgJ5jj4jlEfF64C10fmAxYydlKurWLo/tC/x3RPzdAOO7DfApugd9EthA+Vl3O88gPgz8RY/HtgJ/pjyPp5bVGxEPp1xfe9VW76L8fLu97q0ATo6IOw5xjkcC76V70HHmXDsb65/M7KZfS81aOj9oq9tOuT66vcZDuc4/GREPHuaEVTO1rwKPp3cw/DLKtd/1EEOe7waUxnKHNx66jN6v+Y+jBFglSX1M/Z98SZI0sqMomUd135n0IKqMlnfR+cb0YuBVwHHAmszcNzPXUzJJHgB8o3GYB1Ma50zby4Cr1ZYTeDdlCur+mbkuMw/KzL0p3+/RlAyskyn1LJeEatr4w4Hf1VZfiRIo65VtBuV3eLf6oYCHZ+bvhxzCG4BrVPd/TalreqXMXJ+ZaykZVU+m1F+ruzvwyiHPBSU7cKZcwS7KlOw7AKszc39KQPIISjZU0yuBJzXW/Yoy/fr6wMrMPCAz11TneQKlHuOMAF4bEXeiv7fR+Ry7DHg5pbzBuszcp7o211ACZzenBG0+Se+AbRlACQrdu7H6a5Tp41fJzDWZeaXM3JcSBDoCuA/wauDMOcY9LjegZN0Fpebnayh1QFdVP9/V1TZvoDNAGsB/VCUh+qqy3t5O53ukXcBbKdncqzPzAMrv4UaU17iZYO+TgXuM/N0tHpcDn6Fc87cB9svMVdX1sR5YR8kIfTWdwciZ38ORg5ykKmnyGcp1XHcxJTP7lpTfx76ZuQ/ld3ITyvPrS8wOEM9lL8qsgZkPED5G+X2uq86xN+V15x8oH0jUPSci5sp2lqQ9Wsyz1rUkSYtWRDyux0PLgNexe9rwl4EPddluNfDa2vIngc/2OeUnM/MPQw6zp4i4H6WuVd39MvPj4zrHAGM4HPg+nRlfXwQekZnnz7Hv8ygZdjN2ATfLzO/32edkylTIGY/KzJP7bH8EpcHNjHMy84ge266iBLP2ra3+y8z8YK/jd9n/AcDHM7NvsGcUEXEKnVNbr6iB15aIuBUl86gecPz3zHxql23vQsmWrAduXpqZc2YFdan5OONTlN/B5h77HVidsz49fhdw28z8Zp/znURnzccZG4D7ZuYpc425Os59gU80Vv8r8Nx+GcgRsTclg69eNuEPwJGZuaXL9rcAvlVbdQlwq8wcKPBXZfTeOTO7vY4REf8D3LO26s3AEwdtihMRtwf+1K9+6xhqPs74DXCvzPx5n30fQfnQoO7/ZeYn++yzjDLN/8a11RuBe2fmV/vsdzQl2HVol4dbq/k46deDiLguJUD9jkGbqlV1Mz9JKY0w4z8z8zED7Hsyna/1AB+nvOZfMsD+R1Cey6/r8fjxlCzHps2UD0x6li+pMiT/j84SGz1r2kqSgMz05s2bN2/evNVuwM0oGVszt4f32O4Oje3uOeFxPqFx/gRuN+Ex/Gfj/F+lZKMMuv8bG/t/YI7tT25sf9Ic2x/R2P7sPtter7Htt6d9LTbGd0qX3/d8bicOeN6ndNn3AY1trkIJ3Na3+RKwbMBzdBvfDylZs3PteyXg/Ma+n5ljn5N6nPM+Q/w+llHqL9b3f9UQ+6+hBLvq+z+ux7aPG/U8A47lvNqxtwH7tnD9nt34Ho6YY/sXdfn9XAZce8DzfbKx77vn2P5eo14PlIDl9i77v2jcP8faOafyejDCOA+mTMmeOc8WShZ5v32OoXyIUB/fhwd9PRlwXMf3+Dn85YD7P62x32/b+l178+bN21K4Oe1akqTZmtMfv9xju+Nr93cCX29lNL2t77JuoIyUcajq/z28tmoH8NjMHKbm3HMpAYUZD6yy2abhgMbyr6YyigUmM/+d8sa/7j+jdIWeqT36AeCg2uN/AB6a8+v6/eTskgXYZXwXUq6juntUDWuG8enMHKZBxQOBa9WWfwU8b9Cdq+/tHxure2Vjt31t1o9/YQ6Y2TYFr8zMXw647dsay83pu03Nn/0nBr0eMvMMdk/bV01mXkCppTpjNWVadj/PprNe4x8pf1varjf6xcz8wIDbvpPyN2/G1SLikBbGJElLgsFHSZJmqwcfz8zeU6WPr93/Xmb2Knrflm6NNSbZCOVBdE7H/Vxm/mKYA2SZPve52qrllO6503BJY/kmC6EJzgLxaKD+u90H+HBVl+2VwG1rj+2gZA816zEO46c54NTnynvpDGIvY3YNw7k0g1VzeVhj+S05ZLOnzPwyJetwxjHVFOmmSxrLg3ZhH1T9+IfM1VhoSnYx3O+oWVf2Or2ez1UA/c6N1cMGE9885PZ7ktMay7fqtWFVU/a+jdWvywGmWo/BwL/DzLwYaJYZaDbIkiRVenUNkyRpj1S98akHUrpmPVYddetvoE5pcVi9dMswXDfB89+hsfy5rlvN7Xt0dtk9jlLba9LOpDQzmKlfeT3grRHxtCkElgdxGqXRz6jOGHTDzNwQEQ+k1B2c6TR8DKUj9XGNzZ+bmV+bx7hgdh3Fuca3JSK+QMlGnHErSvORgQ4BnDro+aogVjNIPur1/31211sMSiONZu3YZvDmMRFxBvDWMWWDnQacUN1fRgks/2X2qck4BT+uslwHkpkXRcSl7K7huoySLd4tq/MYSjfnGVvonfHe63w/j4izgGsOs98YTez1oK7qSH09SjOx9ZRyAs0u081mLFejt1vS+buA8uHCJAz8GlA5C7hhbXm/8Q1FkpYWg4+SJHW6BZ0BvK/02O5WdHaaPqWtAfWxscu6fbusa0sze+V6fZr49HNMY/mwEcczL5m5MyLeSmdH48cCD4qIj1A6r34tMxdKV+szM/MtkzpZZv4oIh5Pqbs5oxl4/BTwL2M43ekj7lMPPt5oiH3PyczL5t7sCtehs8kSwJ0iYpSs3Ss1lmdd/5l5ekScxu7n3HJKZt4zIuK/KYHPb2WPxjwDeCO7g49QAkC/jIjPUgLBp2Tmr0c89ricPcI+G+h8TdyH7sHHZsbajzNzR5ft5vJ9phd8nNjrQUTcjZL5e29glDIZzedOXTOr95zM/P0I5xjWZZl50ZD7ND+U2qfrVpIkg4+SJDXUp1wnvYOPx9fuT6PeI5Q6WE3dpmyOXZX5dVBj9ZPGdPhp1XwEeAlwezprku1LmXb8aICI+AWl0+nXgC9n5jmTHuS0ZOa7qgBbt261ZwOPzMwcw6lG+Zk29xnmOvrzkOfq1tm4a1fdEfQa9yMoU4nrz7sjgGdUtx0R8X3KtflVSsDw4kFOmJmfj4jXAH9fW72CEpA8ASAizqvO/zXg1OzTlb4ll4ywz87G8vIe2zWDYb1Kbcyl22vykhER16JMfb/jPA/VrV7xjObflUnV3r1khH0Gvb4kaY9n8FGStEeJiJtRuln3Us+cuojSAKXbdvev3f8T8LAe2/0hMz857DgH1C0T6RiGnLI6ov1pr3Z0c8rdxGTm5RFxZ+AVlG7iq7psdp3q9iiAiPg28HbgXZm5fVJjnaLnU7pFN99oP2bQYNcAhslCnNHMaOuXXdXULYu4nzYD5F2v/8z8VUTclFKXrls9yxWUpio3B54KbK+mor82M78010kz82kR8XPgZczOxoQScL1/dSMizgbeTanHN2zwdhTjCGr3sl9jedQyC6Nct4tCRNyA0sF+HE1V+v3taD63LhnD+QbR5vUlSXs8g4+SpD3NfYAXDrjtgQxWgP7QPtudCrQVfPwppe5jvfHMXB1dx6VbUG5cukZxJ6XqQvz3EfFq4K+AE4Fj6Z3Vcovq9oyIeEhmfm8iA52CKBH2t9D9Z/F4hqyT18cogYBJXjdTuf4z83fAfSLiWMq1eS/g2j02X0kJUt47Ij5HyUrt2wQoM98WEe8HHkxpKHVbeteRPQJ4AfDUiHhiZk6qLl8bmvVzR/39tnldTE1VC/kDzA48/gD4KPAdSubxecDlwNZ6LdKIOJ7eswjmYlBQkpYAg4+SJC1Smbk9Ir5B5xS420TEqmG77o6gW6bTUZn585bPOzFVnbGXAy+PiPWUenvHAbehBGWaGWrXBr4cEbfNzB9NdLCT84/M7kQ744ER8ZTM/PcxnGeU2qXNemvjysLspnn9n5+Z3aZityIzT6fUuHxqRBxGKRNwa8p1eVNmB4fvAXwpIm6dmX2zPKvH3wG8owo63YRy3d+WUpLg4MYu+wDviYgVmXnyvL6x6WleK/uNeJxR91voHgYcXVveATxqiIDz3kOcq9lUaJgMZknSAtXWdClJkjQZzazKA4D7tX3SKrjZnGJ4rbbPOy2ZuSEzv5iZL8nMu1N+zicAn29sug+Dd1heVKpajy9rrD6rsfwvEXHLMZzu8DHs0+ZU4GbToUOqAPXEZeYfM/MjmfkPmXlLSnDwr4GfNTa9ISV4PMyxt2fmtzPz3zPzQZQs71tS6v41G7K8NiIWa6CoWavxqBGPM+p+C939G8uvHDLTtVnHsZ/mc2vJ/l2RpD2JwUdJ0h4lM1+UmdHtBvxPbdMz+mz31dp2p/Xarrod3/K39D6gmeX4hJbPOaPZcOL4CZ136jJza2Z+OjPvQWd3bIDbR8TVpzGutkTEwZRpl/VZM6dSaozWO1OvBP47Iubb+OjYMezzg3mOoZ+fAVsa6+7Q4vkGlpkXZeZ/ULp9f6rx8CPmeeysgpF/S8m4rgcg96WzY/Zi8t3G8lUj4irDHCAiVgM3HtuIFpYbN5bfPeT+txhi2+bv4vCIuOqQ55MkLTAGHyVJAiJiOXC72qpTemy3hpL5M2PUOlZjkZl/At7VWH37iPircZ2j6mzdzRcbyw+IiD2xpMurmJ05daNpDKQN1e//fcCVa6vPBx6SmZsoTZouqT12deDd0aMD04D+35BjXAPcrbH6tHmcv6+qLmizw/2D2zrfKKrmR89orL7GuDI0M/PrwEcaqxfldV/Vwmx2VX7YkIe5H73rYy52zan2A3ejr/623muIc30H2NRY9/Ah9pckLUAGHyVJKm4O1N+U9woqHkdng5epBh8rL2Z2d9Z/j4h5T1eLiH2A/+rx8EeAXbXlI4DHzPeci01mJrPfjC+lIMQLgbvUlncBD83MPwJk5m+oOn/X3Bt45jzOef2IGCaT8OF01nzcBXxmHucfxH83lh8SEddv+ZzD+k2XdeO8NpvHX8zX/fsay0+NiIFqj1Yfujx3/ENaMJrZ9fsNse9DKR9IDKQKmn+8sfrvBv1dSJIWJoOPkiQV9aYtu+icWt1ru+3A/7U2ogFl5rnAPzRW7wd8PSJGzkSKiFtRptTevcd5f06Zilv3rxFxk3mcc2qdrkfN2qyacjQDvefNf0TTFxF3A57XWP2izOzoap2ZHwde3djunyLi9vM4/eurqaxzjfFKzK5F+fkqKNqmk4Gza8vLgQ9FxH6jHrDX9T+PjOJmMHQnjZp688xWbh5/MV/3b6O8ps84DPiPKnNvLv8C3KCVUS0Mv28sDzS9vio/8doRzvdKOrtcX5nSAMn3rpK0SPkCLklSUQ8qnpGZl/TY7vja/W9l5ubWRjSEzHw75c1z3SHA1yLi2RGx16DHiohrRsQ7KIHVI+fY/PnApbXlvYH/jYgHDHq+6pyHRcQLmV2jbpKeEBGfjYh7DPkm9xXAlWrLGylTBxe1qs7a++j8f/HzzA70zXgWncH45cAHqnqRo7ghJZjX89qNiAOBz9E5LTT7jHFsqgytpzdWX58S9B8qEBURx0TE2ylZzN28OyLeHhHHDHHMdcwO/HwtM3c21t0gIn4YEY+p9hn0+P8PuE9j9ULIBB9JZv6BEvSqeyDwyYi4Wrd9IuLAiDgZeGq1qlkHdKn4cmP5ZRHR929DlQX8VUpzrqFk5o+BdzZWPwD48BDZqEdExJOHPbckqR17Yl0mSZI6RMQq4Da1Vaf02G4vFlC9xy6eAOxFZ1OJ9cDLgSdFxEeBzwLfAy6cCUJU2WXXpvwM7g/cmRI4mlNmnhURD6ZMcZ3ZZ3/Km8TTgP+gvAH9dWbuqs4XlGDRDYGbUrJojqMEub430nc+HsuAe1S3P0XEJyi/4zOAX1UdvgGIiEOA2wNPqr7Wvb2qhdi260bE4+Z5jK9k5pnNlVU23AfpDKr+Hnj4zO+xKTN3VNfC99nd3fYw4P0Rcbde+/XwLcpz7QTgRxHxT8AnM/OianyHUQJDz2N2Pbo3ZOZEMpIz8yPV2OrZoUcDZ0TExyglC76RmVdkBFaZdIdTmvUcR6lved3q4Wb26Iy1wEOAx0bEmcDHgG9Srs3f155by4BrUK7hpwHXbBznNT2Of0PKc/X1EfF5SpD5dOAn9Wu5qhd5c+CvKK8z9cD09+idMb5YvBS4J3Cz2rp7Ab+KiC9RmqFcSHmNuxElK3wmYPt7SimKp9T2rWfvta211wPgLcDj2f37PgT4bkS8DPjvzPwtXPG6cTPKVOu/BVZV25/C8A3JnkRpInXj2rr7AXeIiNcDnwZ+UH0IMJOBfn3Kc+r+wJ2AHwOvG/K8kqQWGHyUJKkEOdbWlnsFFW/N7jdT/babiszcGRGPBH4OvITOAOKVKW/mnjSzeURcXG2znv6zIZpdrZvn/XwVdHonnXUzb1XdAHZFxKXVeeY630JwEPDY6gZARGwGNlOCDb2y8b7L5Gq/1X++o3oU0C3Y8M+U633GDuDBmXlhv4Nl5rkR8TBKNuLM7/jOlLqRLxxiXE+i1FS8BiX79p0AEbGBcs2u7bHfl5jdZKVtL6CM6VnAzLTp5ZTg6AMBImIHJUN4DfOvi3hdOjusZ/Vz2U6pe7myx35vzMxPznHsvYATqxsAEbENuIxS67ZXs5o/UwLTkwy2jV1mbo+IuwNfoHwwMmMVJQjZq3HKRZRA+f0a6yeZCdna60Fm/jgiXkNneY/9KNPN/yUiNgFbKUHZZumAzwP/ypDBx8y8PCLuTcmGr3eyP4Da60n1dyUo1+bUynZIkvpb6P/4S5I0CfUp1zuBrw2w3VZK5tGCksXLKW+c+wVHg/Imbl96/z/wE+ABmXnHHo/Xz/sR4Bb07jC8jPLGtN/5dgE/mOtcLZorcLKWkgnYK/D438AdM/PysY5qwiLiRErWXN2zMvMbg+yfmV+kZJDVPa+qHzmoCylBy2YgZD29A48fA+5bdaKemOo59xxKBuNve2y2AjiQ/oHHzXQPBEP/azMoQccD6R543Ao8PzOf1OWxuY4NJfB2JXoHHs8AjqtqwC56VXbtHSkZc80p6t18i/L9n0Fn0yPo7AK/2D0DeEePx9ZR/p40g38fomQh7hjlhNVU+NtRPnzolTm9L+Xn3i3wOEy2tSSpRQYfJUmaXe/x0h7bHV+7f9qkgxzDyMwfZOadKFMk3wj8bsBdzwbeRHkzfYPM/OgQ5/x5Zh5HCRp9nM5akL1spmTGPB04PDOn2S37DcBtKTUcv0kJ2sxlM+UN9h0y88GZubHF8bUuIq5JaaRS94nM7DUduJeXULIQZywD3hsRVxn0AFXDmGOrY/25z6Y/AB6YmfefZuA3Mz9FaTz0KODrzO4Q3M0FlOntJwGHVrVbu3kYJbj5VuCnDDad93zKNX39zPynPuP+AXAU8I+U5+IlAxx7F2Uq7SOBm2bmLwfYZ9HIzA2Z+RTKNN4XUj5U+QMlu3QTJbv8XZRMyOMy8xfVrs0SABdPZsTty8xdmflY4EHM/SHRtykfXP3FfOsiZ+bmzHw0pTTAe+j/WgDl2vwW5Xoe5gMPSVKLYpHPjpAkSQOqGojcALg6ZcrcKkpzlIuBPwLfy8y53tgNc77llHpd16JkZe1PCchspLyRP5NSC3J7r2NMU1UL9HqUab9XpmR+LaeM/8+UINBPMnOQIKX6iIjmP6TXyMyza4+voGTW3pByLW2hXEPfW6iBr1qN2KtSxryeEqy+DDiHcv3/bpSpylXTjetRajoeTMk8S2AD5bn8I0qd0qEzv6qarNeqblejZJatrsZ+KfAL4Id9PqTZY0XELyk/txk3rJqnLDkRcS3K9X0oJRt8I+W6/nZmntvieZdR/q5ch5KRux9wOeXv2C+BH/VpGCdJmhKDj5IkSZqquYKP0kJXdTj/UW3VRmDfUQLAkiQtNU67liRJkqT5eX5j+csGHiVJKgw+SpIkSRIQEatH2OepwF80Vr9pLAOSJGkJMPgoSZIkScXLI+JjEXGPqu5rTxFxREScDLym8dC3gS+0NUBJkhabFdMegCRJkiQtEMuBE6vbhoj4FqWW4/mUOo57A4dRmq3cvNq+bgPw0FEaCUmStFQZfJQkSZKk2dYDd6lugzgPuH9m/rq9IUmStPg47VqSJEmSirOArUPusx14F3CzzPzm+IckSdLiFs4IkCRJ0jRFRPMf0mtk5tnTGIsUEeuBuwG3Bm4EHA4cBKwFErgE+DPwQ+BrwCcz83dTGawkSYuAwUdJkiRJkiRJrXDatSRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa1YMe0BSJKk3iLiFOAOtVV3zMxTpjOa6YqIVcAxwDWBw4B1wE7gEuBi4EzgJ5m5Y1pjnISIyPpyZsYc258MPLK26lGZefL4RyZNX0QcD3ylturUzDx+KoPRghAR1wKuC1wN2AdYBWxk99+OnwC/zszsdQxJ0vwYfJSkJaRLkKGXncBllH+8fwl8G/hcZv5fa4OTRhARa4GHAX8B3A5YPccul0fEd4H/Bj6YmX9qeYiSFrkh/3ZuAC4FzgFOB04FPpOZ20c894uAF3Z56NOZecKIx2wG0e6amV8a5Vi1Y34CuG9j9ecy857zOW4bImIZcB/gwcA9gAMG2G1DRJwOfBL4SGae0+IQJWmP47RrSdozLQf2B64B3A14HvD1iPhBRNx7qiNrSUTcOCJeVLudNO0xqbeIWBkRTwd+B7wNuAtzBx4B9qIEKV8P/CEi3hURh7c3Ukl7kOXAfsDhwO2BpwIfA34fEU+PiOVjPNd9IuLWYzzeyCLiYOBeXR66W0RcZdLj6SciHgL8CvgE8FAGCzwCrKfMMng1cHZEfDUi7t7OKCVpz2PwUZJUdwzw6Yh47bQH0oIbU7JLZm4nTXMw6i0ijqRk4/4L/d84bgH+DGzt8fgK4K+AMyPi/411kJK028GU16tTI2L9GI/7ijEeaz4eQfcZc8sor7FTFxEHRcRngfdTPljtJSmzPi6lZLL2cjvgcxHxubENUpL2YE67lqSl7ZfAv3VZv4IS1LkJJaNs78bjT4mIyzLzBS2PT+oQETcDPs/soGMCXwI+U339bWZuqO13KOV6vitlqt2Va/uupv+bUUmq6/W3cybz8SjgzsChjcdvA3w8Iu6ambvGMI7bR8Q9MnPaAbCT+jz2KKYcJK0+sPoi3V/nzwA+DXwZ+DFwUWburPZbARwB3ILyv9D9KL/fuhu0MWZJ2tMYfJSkpe0PmfmWfhtExAGUNw5/03joeRHxscz8fmuj05z2pEYJEXFd4AuUkgB1XwWenpnf6bVvZp4HfBb4bEQ8A/hL4KWUN5Z7tMw8CTN9tYeoGnL1bcI0gEH+dq4CHgf8M7Cm9tCdgIcD757nGGa8PCI+P61mKBFxczoDcDPjmPkZXzsibjOtmtHVB09fBq7eeOhHwHMz81O99q2ak/2qur0/Ip5AyfJ8Nn5gJUlj5bRrSdrDZeZFmfm3lCljdUH3IvjS2EXEXpQaXc3A45spHb57Bh6bMnNHZr6Xkp30Gna/WZakscjMbZn5OkpdwaZnz+PQ59P5mnUT4EHzON58Paqx/GU6u4l322YiqsYyH2Z24PEjwC36BR67ycwtmfl2yt+O5wHbxjJQSZLBR0nSFZ4D/Kax7h5VUEhq20uB6zbWvTEznzDq9MXqjeTTgAcCm+c7QElqysyPUbKu664XEc2A2KB+BXygse6lY25mM5CIWAM8pLH6XdWt7i8iYt1kRtXhKZSp7nX/DfxFZm4Z9aCZuTUzXwbcitn/F0mSRmDwUZIEXDH96B2N1auBBdFtU0tX9Sb97xqrfwY8fRzHz8yPAv8xjmNJUhf/3WXdLedxvBcAO2rL12E62YUn0lkDcSPwUUq24Yba+vXAAyY2KiAi9qVkJ9adCzxuTPU2qcrO3GMcx5KkPZ01HyVJdd/osu7wUQ4UEYdQ3nxdg/LGZAvwg8z84gD7XplSAP5g4EBgE/An4BfA6dOqfTWXqvbUzSnjPojShflPlOYF350pcj9tVTbrbYDrAftS3kSeD/xfZv5+CkN6CrCqse6J88lcaRrlzegkr8OIuAbl2rkysBa4CPgpcFpmLripfxGxkpIVdAPKVPnLgQuAb2Xmr8Z0jqD8/K8DHFatPh84IzN/MI5zjFPV5fjWlN/hQZROun8Cfgd8c5zXc+O8y4GbAkdTrtUVlC7wH8nMP7Vxzi5jWEv53q9LCVZdDpwFfC0z/zzA/uuB46r996F0Iz4H+Epmbmpn1GP1oy7rDh71YJn5q4h4B/C3tdUvjIj3tnUd9dAMeH545vcRER9uPP4oxlfnchCPZnZjsmdk5sXjPMl8r78qI/Q4dr8uLKe8Lvye8je3laz8iDiI8nf+msBelL8pvwO+mpmXjfE8AdyI8jp9EOV/iospfw++nZm/G9e5JC1ymenNmzdv3pbIDTiZUitq5nbKkPsf1dg/Kf/Mz3WeF9UeuytwCrCry7F6jgdYCTwJ+GGX/eq3C4C3AFcb4Ps5fo5jzXU7YoBz7AU8jdJRs9v3PHP7M/C2QcbdOP4pjeMcP8f2J/X6mQOHUGoobuozztOAO03wmp15I1Yfw0+n+Bwa+3U4x/nuDXy7z3kuBV4N7Ffbp2ObAc5xcmOfk4Z83pxde2w9pUHVxX3G/FPggfP4mewFvAj4Q59znAU8AVjWY8ynjHr+EcZ7L0oNvG19xrsZ+CRwyyGPfUSv3zflTf4rgQt7nPP4MX1//a6HQ6rnweYeY9hKyTq+Uo9jX4MSsLq8x/6XA68F9p3HeOe8Fro8R4a6foBrdxn7cwfc90WN/b5erb9yl5/r0wY8ZnMsdxnh935VSgC96zUF3KHx2C7gmhN83jVfoy8AVk3q/AOM736U+phb+7wubAE+M8LrQvN6fVHtsZtQuns3f3czt+2UTN1rz/P7uxZltsz5fb6/mb8HTwBWTvt34s2bt+nenHYtSarr1iE0B9oxYkVEvJnSrfgOPY7Va99bAmcCrwduOMfmB1GyQX4REc8a9BxtiIgTKfW5Xk355L/f93wA8NeUcTenGLcuIu5EeRPwOEpmXS+3BP43Ip47kYGVjJArNdY1p/9PxCSvw4hYHRHvo7xJvHmfTfehBLd/FBE3GvY841Sd/0fAs+icitl0FPChiHhz1RBimHMcA/yE0uzqsD6bXgN4I/CViGhmP01ERBwcEV+mBA+OpwSue9kLOAE4LSLeX2UKzufct6A8n59JycqduIi4QzWGv6V8f92sAh4DfCcirtXY/4GUANIj6OwWXbeGkhn9zSqzfKHap8u6eWXMZeYfgDc0Vj+7yhKdhJPoLNF1DnBqbfmrdNZDDOCR7Q8LIuJwZr9GvzcXQJZ4RBwVEd+lTE+/I7Oz+utWUz68OC0i/rPqoD6fc/8D8B3Kh1q9XntXUBoYfT8i7jbCOVZFxOsppVEezdwZvkdRXqt/EhHXH/Z8kpYOg4/SFETEdSLiARHxxIh4dkQ8NiJOiIjrD/tGTRqzbm/uLhxw37dRAlt1OykZUj2nvEbECZTsgGv02OQSOmtfzVgDvCIi/mMaz5sq4PRRSnZK0y7KuLtNp1oDvC4imt3FW1MFHv+H2VPULqFkXnTzTxFxUovDmnGHLutOmcB5O0zyOqzeYH6E7l1yoWR7bWysuyrwpYg4cpBzjFtE3ICS3Xd446HL6B1keRyldt2g57gx/X8Hl1KyiOpuT7m2ewWvWhER16ZkCd+xxyYbKb/Hbh5CCZo2g+6DnvuGwBeZ/dqziXkGvIYYwy2Y/ZqyizK1s9vz5AjgMzNB1yrw+AFg79o2/f5eHAV8bAH/j9Ttg4HfdFk3rFdSrvsZVwL+YQzHHcRJjeV3Z+YVH0ZW95vTrB9ZTcNtW7e/G1+dwHn7qoJ536SUQehmI52/z7pHAV8c9YOJ6gPDf6XMJpixk97PyXXAJyLiekOc40DgS5TZAd3Kt23rc75rA9+IiGaDIEl7iIX6B1wLUEQsi4ijI+KREfH6iPhmRGyOiKzdjp/2OBeqiFgZEX8fET+iZNZ8mPKJ9suBt1OmY/0E+HNE/JcdhjUl3ZrLnDPAfg9kd+2nDcCLKbXgVmXmAZTAwE1ovFGJiKMob0Cb/2x/ilLkfU1m7k/JHDiK0hG5+Yb+McCze4zrF8Djq1vzTdIva4/1unWtVRYRT6RMPa2/yfoD8Hzg2Or73j8z11GyAv4K+HHjME+PiElkiRxKmWK1mvJG5D8pb9xWV2PcCziS8rNtBiJfExH7tzy+YxvLWynZUBMzgeuw6RWUzJS631OCdYdm5trMXE/JaHsk5W8GlMDDewc8xzjtBXyMUtuR6v49gHWZuW9m7g1chRIUuaSx73Mi4jpznaCqi/ZRZmfxfZGSMbguM/fLzDXA1YEnA+dV29ySkik5EdXf548zO0j6Y0oW3wGZuT4z11KyNx9PqbVWdwvgfSMGav6L3Zl2pwL3p0xL3rv6XRxICRz9cYRjD2Iv4IOU58sOyrTrW1Je9w6kvN7fjhKsrrsO8Mwq2HEyJUiymRJgO4YyLXPm78U9gWZdz1tRMq0Wogc3lndSgtPzkpkXUQJKdU8bNXA9qIi4PeXvQl23eo7vonN2xOHAndoaV82Nu6z77gTO21NEHEv5X37f2urNlOfH8ZTXsPWZuR8l8HcPStZ03e0pWYLDuhvl7xKU4OaL2P2cOpDyt+vmlNeOujXAWwc5QVXn95OU53bdqcDDgKtk5ura+W5EeY9Tb0y0L/Dhqia4pD3NtOd9e1scN0qGxkb61/QYW32hpXajfAL64wF+fvVb1/pI3rz1uzGPulWUT7HPauy/BdhrgPPUa/tcdcDzLaPUSazvvxN49Bz7XZsSEG3WMLrZHPudNOrPpnGcmzG7htN7gfUD/Hzf0thvE3DYHPudMszrbJfvc+b2J+A2c+x7x+pnWd/vyS1fsz9onO+MCT9nJn0d3orZtbi+0O/6oQSO39fj95oDfI/N5+tJc2x/fI9zbQLuN8e+N6C8+a3v928DjPHfu5yvb307yhvZr/UY6yktXjOv73K+t9GnphmlVubnu+z31DnOdUSv3zvwDxN6jvS6Hi4GbttnvxWU4Ep9nwuBr1f3zwau12f/vYHvN/b//gjjnfNa6PIcGfj6oXz41vzZfHqI/V/U2PfrXX4Ozbp6fZ9TXcYzVM1H4J39xtTY9tTGtu+dwDX5iea1OInnQp/x7Mvs/5++Bxw5wL6PogTx6/ved8jrtX7Ouf6neH6X/Y4ZYJyvbuyzGXjoAPtdi/IBWn3fj07z9+XNm7fp3Mx81KBuSvmUTkOKiOMo08iOrq3+HSUI8Q+Ufzr+jvLG6zT6TE+VWvYyZmfyfDYze00dbLoUuGsO3i35RGZPVXtmZv5nv50y85fAXeic0rwCmFSNwlfRWcPpg8AjMnNDj+0ByMwdlAyoT9dWr6XUM2vbDkrQ6P/6bZSZX6E0pKl7YGujKpp1/ebsjDtmJzLZ6/D5dM48+TlwYr/rJzO3UjIgpz2t8DGZ+bF+G2TmjymZz3V9r6GqZuPfNFb/a2b+2xznupRSL+3sftuNU1V3sDnWTwF/m5nbe+1X/X5PZHZW77MiYvUIQ/nXzHz1CPuN00Mz8+u9Hqxe855A5/81B1I68G4FTsjMn/fZfyOzXx9vPK3SA01V7bunMDsbeRuDZ0HPqfo5vKyx+gkRcbVxnaMuIvZm9nP2XX12ObmxfP+I2Hesg5qtWXLg4pbPN5en0Pn/05mUgO+v59oxM9/J7OvlOSOM4ffA3TNzroznlzE7q3iu1+hr0vlcTOBBmfn+uQaVmb+ivE7XO2yfWM04kLQHMfioUWylFDN+C9OZ/rVoVMXVP8/u6VGXUQqzH5GZj8/Mf8vMkzPzDZn51Mw8jvIP1XMp/7xKrYuI/SPiLcAzGg8ls4MI/bw0M88dYvvmm8ofAK8ZZMcq8PNPjdX3jYheteLGoqpzVq/xdjHwxMzMQfavtnsanW/G/3oCdcze2S9I0PC2xvKxLY+v2TyhVz2stkzsOoyIIyhT7er+LjO71QZtnmsmeD2tD6i+mJkfGHDbd9JZ8+tqc0yzO4nOmo1/ZMBp1FVQ72kDjmscHk/nhw+XA08Y5DWg+iCnWRf3EOAvhxzDnxiilmZLPpmZn51ro8w8h5Lp2PTmzPzRAPt/FfhtY/XNBhvivFw5Ih7X5fbEiHhORLynGtdrKZnJM3ZSgvRzfm9DegudJVBW016pgQfRWYtzC6V0Ry8forPW6F7MnoY+bns3li9p+Xw9VTUam03kHp+ZwwREX0PJnJxxy6oG7jCemZlz1ujOzF2UDvR1/ZqeATydzlqS78nM5pTxfuf8NSXJYkZQ3g9J2oMYfNSg3k35pP+mlKlht8jMxwP/O91hLVxVHad3sPuN9Qbgbpn5tuoPf1eZeX5mvjwzL+u1jTSEfm+gnh8RH6W8ger2T+DLMvOMAc+zndnZDz1FxD7AbRurX5+ZOwc9BiVDr16jcBmzAzvj9rDG8vsyc6hMvSpgVa9NdQClNlObmtmMPVWZa/XXn3VAKxk2lWbW10QaZsBUrsP70Pm/15mZ+aVBT5SZP6Vk0k/DMNfQxZROqHX9mhrcvbF88iAB2ZpPAsN88DEf92osf2SIbG8y85vAt+Y45lzeM0RGelvePsS23+myrhn86KdZy2/gBhnzcG3KNd+8vYGSOfZwSuC47kxK9v/YP5jP0sX5RY3VJ0XEdcd9LnbXb57x8SrLuKsqM/Ojcxxj3Ob9dyMift+oW9/vdkqfQ92dUo93xo+rWQQDqz5c+nBj9fFDHOIi+geIm77RWO75nKo+fGx+QPK6Ic41o1lv8vgRjiFpETP4qIFk5gsy8+2ZeXq/aUXjEMVNI+IREfEPEfH06v7Rc++9oDycUjh6xjMzs/mGQ2pbvzdQLwHux+wMAoB/z8znD3GeHw4ZhLsVnX+DktlvXvrKzEsozSjq2u6ieIfG8udGPM73GsvHjXicQVzC7ClWc/lNY3m/sYyku2b34kmW+Jj0ddj8PX98mHNVhhrfGJ065PZnNZb367ZR9UHdLRqr/2eYE1XB4s8Ps88oqgynGzdWf2SEQzWDBMO+bg0V2GhB0j2bsZdm5uJFlPrAo+6/3xD7TkJSyikcPWzQaUjvoTOov5zdTUbGopqx02wo0m/Kda9tbtXytNrm7KBploZaCP8XfL0KYA5qoNfnyjHsbjYGcGFmNsc6p8z8GZ2N2m5YTfGXtIdYMe0BSDMiYj3wTOCxzP40eWabXwIvzMzmp2cL0ZNq93/FgN3kpCn7EfDsYabT1PYbxg0by78ecorSjO9SOuHOaC2DsAo8NMd98xHrbjWn5TbrHo7Tb/tlW/fQrD+4T9etxmMjnVNu264VVjfp67BZW3LoN3Aj7jNfl2XpujuMQa+hw+h845uUBkDD+v4I+wzrKGb/79wtq28uzUy+q0TEAUP8jMc9pXdYl1ZB90E1s9J+O2ipisrGxnKbr0ejCEqJkjXA89o6SWbujIjn0RnwfmBEHJuZp4/pNCc1lv/I7A9XuvkyJUh89caxnjmWUc3WvCYm+Xej6VaN5atERLO8wiCawdph/i84e8hzDfM3vvn9bRzx+4MSNN6rur8MOJjZv0tJS5TBRy0IEXErSgZIv5pQULK43h8R9wMe1nYW5qgi4hg6MzneMcKbf6lNOylTay8Bfgl8G/jcXA1J+hi2SciBjeVmpt2gmsXcm8cdp4OZPWNgXDW32hz3JSPs05x2vLzbRhFxR2CYaX+fzMw/NNb9kc4pawcMcbz5mvR12Fx/9gjnGmWf+bpkhH0GuobozKgB2FBN4xzWXE0WxqH5+9s+ZJ3bGd2aUBxIyQgcxKSbMjUNWxameS3Md/9e19I4nZqZx9dXVFm6ewNHAnej1IudaXyyDHhuRKzKzGb95LHJzI9GxHfYXaMvgJczhpIj1fTav2qsfu8gZSgyM6s6mPWGW4+IiOcMWcZiUH+glIKaMcrfjWfRfeYHlAZfzaBbL4c2lh9S3eZrmP8LLhnmwFUgu76q32zI5vd3BEOU4ZjDgczOwpS0RBl81NRVb14/Ten6OuPMat2vKUXrrwv8Bbvrjj2Ikh3RdkHrUd2tsTzUFDJpjGa9gWrJsMGCZsBh1CYjzf3aDFy1GSBcO/cmIxsmw2hYj6xug/o55U1j3a/pzEA8KiJWTujDpUlfh83zjVLbd9INeaDda2i/xnLfrvF9TKJOclvXCwzx2jVicHac5ns9tHk9tabK1txAycw9IyLeTCmDcJfaZv8YEd/NzGHq7w3rOXRmI949Iu6QmcOWRmi6C7Pr+w4y5XrGyXQGHw+jBEWHnUkxiFnThiPiKsN8GNCvNmeVFDFo8LGt/w2G+b+gzefUYv3fR9ICY81HTVVEHEwpQDzzx2cL8BjgqMx8ema+uao1+XRKALI+dfkvIuIRkx3xwOpZjxuAHwNExLER8YaI+ElEXBYRGyPiNxHxsYj4m4jYq/vhpCUnGsvj+se5zX/AV829yciaP489SXO64GpmT4duy7Svw0UZhBmzZs3PUZ9nbT4/Z7R1vYz7WJqAqtP6/Zldv/JNEXGlLruM67xfYnbjqVeM4dCP7rLux4M2ZaHMomhqq/HMGV3WzdWxuS1tvfYslP8L/N9H0lgYfNS0vZLdU613AffLzP/sVgsoMy/PzMfRWevmpdU0kYXmJrX7vwTWRMQbKXWenghcn9IFex1l+sKJlMDqWRFx/4mOVJqO5vTC/UY8TrPO0yj1+gbVnOqYwLrMjDHcTmpx3Atdt2ydO07o3JO+DpvrR6lTNs3aZm2Y9TOJxnzAAe03hrHMpXm9jPq76LZfm69dakkVgHwU5X/YGQcy5kYwXTy7sXxcRJzQdcsBRMR+wP+b14i6OyEi2sic+2qXdbfvsm4Smv8b3HNM/xccMY1vpovm9/fBMX1/kZmnTOMbkjQdCzFooz1ERBwKPKy26j8yc5AOcU8GZqbjHQ7ca9xjG4ODavcvAD4EPIHdn/BtA37P7BothwIfjointD1Aacqa/8weMeJxrjnHccfpT43l6HL+PUpmnjSGNxrfYPbv7THtjx66nPeIEY8z6HU4jvONss9Cdj6dWX+rGO151WZn3RnN39+qiLhy1y376/b9TbuOo0aUmd+mdKKue2zVObrNc368sfpl8/hA/qF0Nv4al1V0/q8/Fpl5FrMzTh8eEZPIgG5q/m/Q2u99Spb69ydpQgw+apoeSGcq/2sG2alqVvCl2qq7jnNQ81VlbKyvrbozuwOkZwL3A/bJzKtl5v7A9YB31g8B/FtE3HkS45WmpNmt9VpV5sWwbtZY/sFow5lb1QX5nMbq49s6354iM3cw+437UVU94LZN+jpsrr9p1636G2WfBauqX/jzxupBa63Nd59h/YxSh7qu+bsfRHOf34/QTVwLywspHyzPWAE8v+VzPo/OjMsbMnqjk+b06C8Ajx/x9ok5jj0u72wsH0SpCT9p328sHz+FMbSp+f3daMS/k5L2cDac0TTdrnb/rMxsvvno59vAPav7t+y1UURcdZSBDejSarpN0zo6A/srq6+nA3fKzI5C85l5JvDoiPgZ8Kpq9TLgtRFxTLcp6NIScBrlTdPMcyUogfnmm4meImJfZn/48I0+uzSDBqN0TP0i8Nja8oOBN4xwHHX6d0p2eP0DqTdGxLGZuWUcJ4iIZZm5q7F60tfhNykZRjNOpHRcHcZSLM3xf3RmLj4MeN+gO1fZh8ePeUyzZObmiDiDzuDhA4BPDnmov2gs93vd0iKQmedExLuAv66tflhEvDQzf9XSOX8SEe+ls0P1SyJiqGY3EXEDZgfE/zkzm3UlBz3ed+icwn3jiLhxZp4xyvH6eAelwc1+tXWviojPTjiY/0U6G6/dIyL2bf6/v4h9A9hEeX8DJX7wAMrPX5IGZuajpulGtfs/GXLf82v3+wUYf9fi7Yk9znl5l3W7gIf3+0ckM/+FzgLiN6Czg6K0ZGTmZcDXGqufNOSUsccB9SZNu4B+pRuaHxaMUq+t+abuthFx9xGOo5rMPBt4Y2P1UcC/jOP4EXE/OoPGM+ed9HX4aTozla4bEQO/zkfE9YE7DTG2xaIZaLx71W12UC9gtA8TRtHs3PvAYaZeR8QtmZ2l2UY3YE3ey9ldFgjKNdl29mMz4/KadHmtm0MzM/E84JRRB5SZ32N285mxZz9WsxFe1lh9ZeBtE64H/1mg3oF+HcN/qLRgZeY2ZmezPj8iVk9jPJIWL4OPmqZ6AeoTBu2mV3XUe1Nt3/0nPO6+MnMnpWt33Rcy82cD7P7axvKCmlIujdnrGsvHUmq6zikijmT2m7pPZOZv+uz2x8bytYatD5WZX2R2ltI7I+LqwxynbsTmGkvR85j9hvVJEfHGUd9IRsSaiHg1pVHZ2h6bTew6rIKszcDk6yKi19jq51oBvJkl+L9bVQu0/iHkMuDkiDh4rn2rwPLftDS0bt5CZ7BnLSVLd87ncUSsqfavOx/4wPiGp2mpnt/vbqx+WMu1H88G3tZYPXDAs3pdeXhj9Qe7ZIkP678ayw9rqR7jaygZ7HUPAD4SEXt12X7sqizL5t+Rf4yIe3bbfhAL8P+ClwI7a8uHM8/MxwX4PUpq2ZL7B1aLyn5jOs6cb9qm4LLG8lcG3O9UOgvvHzue4UgL0seZXQPvXyPiEf12qgI+X2L3FCAoU6qbGRBNP6Jz6vVeDJ8hAvAPdGa3HAb8X0QM1WkzIq5ZBcbePsIYlpzM3EyZqtfMEH8C8JWIGLi2XkSsiIiHURoSPI3dzb66+TiTvQ7/ic7sx6OAj0fE3n3OtQp4F9Pr5joJT6bz7991ga9GxK27bVz9jp8OfJDy+x3L9Py5ZOZ5zA72nEgJQPYsZ1T9fj8K3Ljx0CuqzCItDS+n8+/McsoHK236J8q02BmHDbHvvYFmkL8ZOBxF8xgHAiN34+6l+sD/AZQmjnUnAt+OiPsOe8yIuBFwkyF3+1fgt7Xl5cBHI+LxQ557v4j4e+BbQ56/VVVprDc1Vj8sIj4aEQcMepyIWBYR94iI/2F3+SxJewiDj5qmzbX7FwO/nsetqyG7sA57e2Wf7605pt923Wr2eC+rfhYzDuq1rbTYVZkVD6HztWA58O6I+FhE3GVmWk8U142IFwM/ZHbH3xdWU736ne9y4PON1W+MiC9ExIsj4kkR8bjGbX2X45xGKapfd1Xg1Ij4fEQ8NCIOr3+qX/3DfdWIuFdEvCgiTqe8TjyN8X0Qs+hVGeJ3p/N1EErQ7dvVz/fJEXH9ZrAuIg6u3tS8GjgbeC9wjQHOOenr8JvA6xur7wr8NCL+pp7tFxH7V0HQH7C7VmQzy2dJqOrLNX8u16UE9r8dEa+MiKdExLMi4j8o5U/+hVJXeQfwkuYhWxzuM5jdaffxwHer5/9+Mysj4pCI+BtKZmfzzfYXmJ0xpUWs6sL83sbqh7ec/Xg+pW7uKJrToX+TmfMOfFXBqjPmONdYVI0o78TshnA3AD4REWdExEsj4viIuFJEdJRoiIgDIuLWEfEPEXFqNe5jhhzDxcB96QwCrwHeFBE/joi/i4gbdjn3gRFxh4h4akR8AbgA+DfK9PmF5mnMTqa4H/CbiHhtRNw5IvapP1jNPjgmIh5WvW7/kTJN/Z4Yh5D2ODac0TRdCMz8kfpQZv7tNAczZj8BjqstD5ORUd92zXiGIy1MmfmziPhLSvZSfYrUidWNiLiEkl22ku7eAfT7MKDuZZTgVv3v313pXeLgc8yuFUlmvqOa0vVvjXHdrboB7IyIS6vH96Z/9p0qmfmtKPX+PkTnG8Cg8+dLRGyh1Nram/6vl5uAX/Q556Svw2cC16NcizOuBrwVeGtEbKZMcWsGvy+kTJFspYHFAvD3lN/loxvrb17dutlFmXZ9dmN9a5mQmXl5RJxI+TCjHuC+EVX9yojYQAli95qd8R3gYTaVW5JeBjyC3XVIZ7IfT2rxnK+iBMAHLkVUfdBxr8bqcWQ91o9149ryPSLisMxslkCZt8z8ZZR6qu9h9t/zG1W3mQzUrF7Pg/J6M9f74a9QZjzMNYYfRMS9KLWhD6k9dDS7P2TIiLiM8rq1D5OrVTtvmbkjIu4PvJ/OD1L2AZ5S3Yb5uyxpD+MnDpqmenfro6c2inac0VgeaEpClSlV/8fxz+MakLRQZeangDsCveo17kf3gM8W4NmZ+dhB61NVWWePYHZphKFl5huAOzA7A2rGcspzfz29A4/bGb7h1pKXmb+gBJueDVzSZ9M1wJXo/QZnK6VO4rUy83/mOOckr8OtlKBmr660a5kdePw9cNfM7Jntv9hVP7/HUhq6XTLALn8ATsjMd7L7w8wZg+w/ssz8JeVDxl5lVdbTO/D4X8DxmXlhG2PTdFXdrd/fWN129uOlwD8PudvDmf2aNs7g4wfozEBeTvn724oqA/TulO7fZ/fZdOZ/7f3oHXhM4OvA/8vMO2Xm9wccw1cpU7Z7/b0JSrO7/ekfeBzofJOWmZdQpuo/h84mO3Vz/V2G0tRo7EFoSQubwUdNU/0f9ltFxJWmNpLx+1Rj+cYD7nddOrNuzhrLaKQFrprmdT1K3bcfz7H5nyg1164zR/mDXuf6AHAk8CTgY5SMuEvorOM46LG+SZnadX/gi3RO3e3lUkrnyCcAV87MFw573j1BZm6rfr9Xo3SU/jKdjT562Uzp1Po44NDMfEJVp2+Qc07yOtySmQ+m1Ln8bp9NL6M0VbhhZp4x7HkWmyzeBFyL8hz5AiWQsIUSTD6H8vx5DHBkLajcrFvXnLrfxljPz8w7AfehXHP9XkMup3Q7Py4zH1rVONXS9U90NuiYRO3H11EC8oNqToP+SWbO9bo3sMz8LbMbtLUy9bp2zszM91BeP+5HCYAO+lqwkRJwfAHlteV2mfnJEcbwx8y8N3BTSib0nwbYbRvw1erc183MBdtwsvoZvwK4OmW8P2GwMhe/pDTcuidw1blKlEhaesLZHpqPiDgJeGdt1R2zdK0cZN8jKH+IZj51fFVmPnOc45umiPgWcItq8VzgiMzc0WcXIuKFwItqqx6Tmf/ZzgilhSsirkJ5/hxMKVS/ifIP/JnA6Qt1qmKUxiA3o3SCPJCSWbGFMnX7d5Tx/2bQDDl1qmovHkMJHh9KmQa9k/Lm8mJKRv1PsjQhGMf5JnYdRsQ1qnNdmfIh1MWUrNpvpg1J5hQRb6ezgdTfVdnJkxzDeuA2lN/hQZRr80+Uus/fzMyJNMWRtFs1q+jalA/4r0bJSl5JCTZeQnmt/SXw8zb+Nlfnv351O6C67aL8X3A+5QPQX1QZ8YtSlUBS/1u5ht0/319RfrbO5pL2cAYfNS/zCT5W+7+b3VMwdgD3zswvDLF/ACsX4huziHgQnVPqnpmZr+qz/dUo3Xj3rVZdRglYtp69IUnSYlUF/X9DCfrNuGVmfntKQ5IkSVKN0641bc9gd82PFcCnqm5zfQsUR8RhEfF3lCyXY1se40gy80PAN2urXh4RXZvqVHWAvsjuwCPAqw08SpI0p8fQGXi8kNm1lyVJkjQlZj5qIFV3s25Ze+vprLP0B0pdo6ZnZOZHexz7OEpH2Xqx+AspXSTPAC6i1MrZD7gOJdh4E3Y3cDguM08b8FuZqGpq+Tcp0wNnfJ9Sr+p3lE5wt6LUi1td2+Z/gbuPa+qgJEkLXUSsGnYmQ/U/xP/SWS/5lZn57LEOTpIkSSMz+KiBdJlePaxHZebJfY5/FPBxSnBxWLfIzO+MOK7WRcSNKN/bEQPu8lHgrzJzU1tjkiRpoYmIE4HnAm8APtkv+z8i9qF0xX4RsKr20KXA0Zl5bnsjlSRJ0jBWzL2J1L7M/FlE3AB4NKXL6PXn2OWnwGeB9y707p+Z+YOIuCHlDdIjgV5dvX8MvAz44EJtpiFJUstuBpwM7IiI7wI/pHS4vowyQ+BAygyI21IaDjX9jYFHSZKkhcXMRy1IVYfRWwGHAPsD2yjd6H4N/Dgz/zTF4Y0sIlZQOmFek/K9baV0uvtmZv5mmmOTJGmaqszHj424+3ZKh+u3jm9EkiRJGgeDj5Wqa/KRwA2Aq1HqD26m1Bv8AfCjSdffi4hlwK2rcR1GmUp0LvA1G5FIkqSlpKrf+ClKduMwvgI8Z6HWf5YkSdrT7dHBx4hYD5wA3Be4E3BQn80vptQ8/NfM/GOf7cYxrhXAM4En0Nm9ccY2yj/nT8/Ms9sciyRJ0qRU/wPdHrgdcFPgGpT/hdZRygVdSvlg+FfA14DPZebp0xmtJEmSBrHHBh+rwOMFwJohd70IeGxmjjotqK+IOAT4NKXm0VwuozQm+UQbY5EkSZIkSZLmY08OPu5HyWasOws4FTgTuJASmLwh8AA6m4TsBB407gBkROxFmTp0y9rqc4H3UmodHgjck5IRMGMLcKfM/OY4xyJJkiRJkiTNl8HHkj34TuA/M/OHPbZdC7wW+Ova6ouB62TmhWMc078AT6+t+jDw8Mzc2tjuoZROkCurVb+rxrJlTOM4D1hbHVeSJEmSJEl7rqsBmzPz0FF23pODj3sDzwX+JTMvGnCf9wEPra16YWa+ZEzjuSrwS3ZPA/8hcLPM3N5j+2cBr6itenpmvnpMY7ls9erV64888shxHE6SJEmSJEmL1K9//Wu2bt26ITP3GWX/PTb4OIqIuDLweyCqVd/JzFuM6dgvA55TW3WPzPx8n+1XAGcDV6lW/T4zrzamsfzk+te//vV/8pOfjONwkiRJkiRJWqSOPvpofvrTn/40M48eZf9l4x7QUpaZfwB+Vls1ztTA+9XunwN8YY6x7KBMF59x1YgYpEmNJEmSJEmSNBEGH4e3sXZ/3TgOGBHXAI6qrfpSDpaS+sXG8n3GMR5JkiRJkiRpHAw+Du+I2v3zxnTMGzWWTxtwv28DO2rLx4xnOJIkSZIkSdL8GXwcQkTcFji4tuqbYzr0UY3lXw2yU9Xd+g+1Vdcf03gkSZIkSZKkeTP4OJxnNJb/e0zHvWZj+bdD7FvftnkcSZIkSZIkaWoMPg4oIh4CnFBbdQbwiTEdvtmq/KIh9r24dn9lRKwew3gkSZIkSZKkeVsx7QEsBhFxNPC22qodwF9n5q4xnWLvxvKWIfa9vMuxtg6yY0T8pMdD4+ziLUmSJEmSpD2UmY9ziIjDgM/QGSB8VmZ+d4ynWdNY3jbEvs1A417zHIskSZIkSZI0FmY+9hERBwCfBw6vrX5bZr56zKdqZjqu6rKul+Y062YmZE+ZeXS39VVGpM1rJEmSJEmSNC9mPvYQEfsAnwNuWFv9PuDxLZxuY2O5mQnZTzPTsXksSZIkSZIkaSoMPnYREXsDnwVuXlv9YeCRY6zzWHdZY3n/Ifbdr3Z/e2YOVO9RkiRJkiRJapvBx4aIWEup8Xjr2upPAg/NzJ0tnfY3jeWrD7FvfUr4WWMYiyRJkiRJkjQWBh9rImIv4FPA7WurPws8KDO3t3jqnzaWrzXIThGxBrhyn+NIkiRJkiRJU2PwsRIRq4GPA3eqrf4ScP/MHKb79Ch+0Fg+bsD9bkFn06AfjWc4kiRJkiRJ0vwZfAQiYhXwEeButdVfAe6bmYN2nR5ZZv4G+Hlt1V0iIgbY9a6N5U+Pb1SSJEmSJEnS/OzxwceIWAF8ALh3bfXXgBMy8/IJDuVjtfuH0xkInaUa96Nqq84FvtvCuCRJkiRJkqSR7NHBx4hYDrwXuF9t9TeAe2Xmpnke+4iIyNrtlDl2eTNQ71T9qohY2Wf7pwNXqS2/NjNzxOFKkiRJkiRJY7fHBh+rac3vAB5cW30acI/M3Djp8WTm74A31lYdA7yvqkXZISIeAry4tupc4A3tjlCSJEmSJEkazoq5N1mybgs8srHu6sD3Byu3eIU7ZOa5YxrT8ymdtm9WLT8IuHVEvAc4C9gfuBdwh9o+W4G/nERtSkmSJEmSJGkYe3LwcXmXdVce4Tj9pkYPJTM3R8QJwGeAY6vVVwGe1WOXDcAjM/Pr4xqDJEmSJEmSNC577LTrhSozzwNuBbwAOK/HZtsoDWpulJkf67GNJEmSJEmSNFV7bOZjZp4CDDW/esjjnz3q8TNzO/DSiHg5cGvgWsAhlEzH3wNfy8yLxjRUSZIkSZIkqRV7bPBxMcjMncDXqpskSZIkSZK0qDjtWpIkSZIkSVIrDD5KkiRJkiRJaoXTrqUWvf9bvx3r8R56y6uP9XiSJEmSJEltMvNRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRV7fPAxIpZFxNER8ciIeH1EfDMiNkdE1m7HtzyG4xvnG+Z2szbHJkmSJEmSJI1qxbQHME0R8RHg7sC6aY9FkiRJkiRJWmr26OAjcFMWZuDxHGDHgNtuaXMgkiRJkiRJ0qj29OBj3Vbgh8D3gL2Bh09xLMdn5tlTPL8kSZIkSZI0b3t68PHdwO8oAccfZeZ2gIg4iekGHyVJkiRJkqRFb48OPmbmC6Y9BkmSJEmSJGmp2uO7XUuSJEmSJElqh8FHSZIkSZIkSa0w+ChJkiRJkiSpFXt0zccF7OURcX3gcGAdcAlwHvBN4PPAJzJz5/SGJ0mSJEmSJM3N4OPC9JDG8kHV7YbA3wBnRcTTMvMTEx+ZJEmSJEmSNCCDjwvXxcBllMzHA+icIn9N4OMR8fLMfO6oJ4iIn/R46MhRjylJkiRJkiTNsObjwvFn4PXAPYADM/OAzDwiMw+iBB/vD/xfY5/nRMRTJjxOSZIkSZIkaSBmPi4M3wOumplbuj2YmZcCH4uIjwPPBV5ae/ifI+Kjmfm7YU+amUd3W19lRF5/2ONJkiRJkiRJdWY+LgCZuaFX4LGxXWbmPwFvqa1eDTyjtcFJkiRJkiRJIzL4uDg9D7i8tnzCtAYiSZIkSZIk9WLwcRHKzD8Dp9ZWHR4Rh01rPJIkSZIkSVI3Bh8XrzMbywdPZRSSJEmSJElSDwYfF6/LG8trpzIKSZIkSZIkqQeDj4vXIY3lC6cyCkmSJEmSJKkHg4+L1+1q97cD505rIJIkSZIkSVI3Bh8XoYi4J3Ct2qr/y8zN0xqPJEmSJEmS1I3BxxZExBERkbXbKX223WvIYx8GvLWx+uThRylJkiRJkiS1y+Dj9D04Ik6NiPtGxKp+G0bEXYBvAVerrf4B8J42ByhJkiRJkiSNYsW0BzBNEXF/4FVdHlrfWH5fRDS7SwM8IzM/Ooah3L66XRIR/wf8EPgjsIHSxfoawF2BGzX2Ow84MTN3jWEMkiRJkiRJ0ljt0cFHYB/gyAG2u3Kf/cdpP+De1W0upwEPz8yzxzwGSZIkSZIkaSycdj193wXeCfwMyDm2TeAbwMOB22bmr1semyRJkiRJkjSyPTrzMTNPpoVmLVU2Ygy47Y+BRwNExH7ATYCrA1cC9gK2ApcAZwPfzsxLxz1eSZIkSZIkqQ17dPBxocnMS4CvTHsckiRJkiRJ0jg47VqSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkot+945F/GpH/yBizdvm/ZQJEmSJEmSJmrFtAcgLWV/uORyPnL6uQBcsnkbjzjuiOkOSJIkSZIkaYLMfJRadO4ll19x/6wLN7Erc4qjkSRJkiRJmiyDj1KLLtuy/Yr7W3fs4pLN2/tsLUmSJEmStLQYfJRatGHLjo7lP156eY8tJUmSJEmSlh6Dj1KLNlzemen4x0u3TGkkkiRJkiRJk2fwUWrRZbMyHw0+SpIkSZKkPYfBR6lFG7Z0Zj6e57RrSZIkSZK0BzH4KLVk566cVfPx4s3buXzbzimNSJIkSZIkabIMPkot+fOmrWSX9edd5tRrSZIkSZK0ZzD4KLXkgsu2dl1vx2tJkiRJkrSnMPgoteT8HhmO59l0RpIkSZIk7SEMPkotOb9n5qPBR0mSJEmStGcw+Ci1pJ75eNDeqzvW79zVrRqkJEmSJEnS0mLwUWrJBRt2Zz4eefDeLItyf8eu5MKN3bMiJUmSJEmSlhKDj1JLLqhlPh6wdiVXqmU/OvVakiRJkiTtCQw+Si05f8PuAOP6vVZy2L5rrlg+z47XkiRJkiRpD2DwUWpJveHMPmtWcti+e12xbOajJEmSJEnaExh8lFqwY+cu/ryxHnxc0ZH5aPBRkiRJkiTtCVZMewDSUvTnTduoN7Rev2Ylq1bsjvVv3LqDDVu2s37NyimMTpIkSZIkaTLMfJRacH6t2czqFctYtWIZ69esZO/Vu+P955n9KEmSJEmSljiDj1ILmvUeZzj1WpIkSZIk7UkMPkotqGc+rt9rd7ZjZ/DRjteSJEmSJGlpM/goteCCDd0zHw+147UkSZIkSdqDGHyUWnBBLfNxnzXdMx8v3LiV7Tt3TXRckiRJkiRJk2TwUWpBx7TrWubjlfZezYplAcCuhAtqtSElSZIkSZKWGoOPUgvqDWfW1zIfly8LDtnHuo+SJEmSJGnPYPBRasEFG+rTrld2PHZovenMZdZ9lCRJkiRJS5fBR2nMtu/cxZ83bbtieZ+9OoOPHR2vLzH4KEmSJEmSli6Dj9KYXbhxK5m7l+vTrgEOq3W8Pu+yy8n6xpIkSZIkSUuIwUdpzOr1HvdauZyVyzufZofWaj5u2b6Ly7bsmNjYJEmSJEmSJsngozRmnZ2uV8x6fK9Vy1lVC0hu3mbwUZIkSZIkLU0GH6Uxu+Cy3s1mZuy1avkV9y/ftrP1MUmSJEmSJE2DwUdpzC7YsHvadbfMRyjTsWdsNvgoSZIkSZKWKIOP0pjVp103O13PqGc+btlu8FGSJEmSJC1NBh+lMas3nDHzUZIkSZIk7ckMPkpj1tlwZoCaj2Y+SpIkSZKkJcrgozRm9ZqP+/TIfFy70oYzkiRJkiRp6TP4KI3Rth27uGjTtiuWB+p2beajJEmSJElaogw+SmP0p41bO5Z71nw0+ChJkiRJkvYABh+lMarXe9x/7UpWLO/+FNvLadeSJEmSJGkPYPBRGqMLasHHg9ev6bmdmY+SJEmSJGlPYPBRGqPzL9s97frgfVb33K6e+bh5245WxyRJkiRJkjQtBh+lMbpgw+7Mx0P26Z35uHbV7lqQW7fvYldmq+OSJEmSJEmahokEHyNi30mcR5q2eubjIQNmPiawxanXkiRJkiRpCZpU5uMfIuLdEXH7CZ1Pmop6w5l+mY+rVy4jass2nZEkSZIkSUvRpIKPewEPA74SEWdGxNMj4qAJnVuamAvqNR/7NJxZFsGalTadkSRJkiRJS9ukaz4GcG3gn4HfR8SHIuIeEx6D1JrzazUf+zWcgUbHazMfJUmSJEnSEjSp4OMrgT821q0E7g98JiLOjogXRMTVJjQeaey27tjJJZu3X7Hcb9o1dNZ9NPNRkiRJkiQtRRMJPmbmc4CrAycCnwRmIi1R3a4OvBA4KyL+JyJOjIjl3Y4lLVT1KdcAB+09eObjZjMfJUmSJEnSEjSxadeZuSszP5mZJ1KCjc8FftXYbDlwd+AjlGnZr4iIa09qjNJ8XFCbcn3gulWsWtH/6VXPfLTbtSRJkiRJWoomXfMRgMw8LzNfkZnXAe4IvB+YSRubyYY8BHgG8POI+EpEPDQi+qeSSVN0fr3ZzBxTrsHMR0mSJEmStPRNJfhYl5mnZubDgcOAJwNnNDYJ4PbAe4A/RMS/R8Qxkx2lNLfzL6s1m1k/d5x8rTUfJUmSJEnSEjf14OOMzLw0M9+QmccCNwPeClxWPTyTDbk/8CTg+xHxrYh4TESsm86IpU4XbNid+XjIHJ2uwW7XkiRJkiRp6Vswwce6zDw9Mx9PyYY8CTgfyOo2E4i8GfA24NyIeJ2dsjVtf6oFHw9eP8C0azMfJUmSJEnSErcgg48AEXEo8FTgecDBtYdyZpPq6z7AE4FfRMQ/RcTKiQ1Sqtm4ZccV9/fZa8Wc25v5KEmSJEmSlrq5IyQTFBHLgHsDjwXuSel+fcXD1dc/Au8DbgjctbZ+NfBs4OYRcc/M3DWRQUuVTdt2Bx/Xrhog+GjmoyRJkiRJWuIWRPAxIq4FPBp4JHDozOraJruAL1DqQH4qM3dW+10NeBzweGC/ap+7AE8A3jCJsUsz6h2r161e3mfLorPb9Y4+W0qSJEmSJC1OU5t2HRGrI+LhEfEV4EzgmZQajzM1HaFkOb4MODIz75mZH58JPAJk5u8y87nAtYEv1Q7/iIl8E1LNpq27A4jrhsx83L4z2bHLZF1JkiRJkrS0TDzzMSJuQplW/RBg35nVtU12AV+kkeXYT2b+OSJOAs6hTNU+apxjlgZRn3a9bvXcT63m1OzLt+1k/ZoFW4ZVkiRJkiRpaBMJPkbEvsDDKEHHG82sZnf3aihZju8E3p6Z5wx7jsz8Q0ScA1wTWDfvQUtD2rx1d5x87aq5p12vXB4sj2Bnlh5KJfhovyRJkiRJkrR0TCrz8Y+UhjCwO+hI9fULwNuATw6S5TiHDfPcXxrZxtq0670HyHyMCPZatfyK/Ww6I0mSJEmSlppJBR/X0JnleB7zyHLs44+UxjPSRO3YuYutO3bXbFw7QPARSt3HK4KP2ww+SpIkSZKkpWWSNR/HneU4+wSZ9xr3MaVBbG5kLa4bYNo1dHa8NvNRkiRJkiQtNZMKPr6c8Wc5SgtGvdM1zG4m00u94/VmMx8lSZIkSdISM5HgY2Y+bxLnkaZlU63ZzKrly1i1YrCu1WY+SpIkSZKkpWxS3a5vX929PDO/M4/jHAvsDZCZXx3H2KRx2Lxtd+bj2tWDTbmGRvDRzEdJkiRJkrTETGra9SmUmo+/Aq47j+O8AzimOtYk61VKfdU7Xa8bcMo1dE67NvNRkiRJkiQtNZMM4AW7u13P9zjSgrK5Nu163RCZj2vNfJQkSZIkSUvYYIXpxiMneC5pojbVp12b+ShJkiRJkgRMNvg4DjORmh19t5ImbNOImY/WfJQkSZIkSUvZYgs+HlZ93TjVUUgN9YYzo9Z83GzmoyRJkiRJWmIWTfAxIu4IHEiZvn3OlIcjdejMfBwi+FjLfNyybSeZVieQJEmSJElLx1gbzkTEMcCN+2yyPiL+aohDLgP2BW4APLi2/rThRye1p7Pm4xDTrmuZjzsz2bZzF6tXDL6/JEmSJEnSQjbubtf3A17Q47EADgbeOeKxZ7pcJ/CfIx5DasWmrbuDj3uPmPkIpe6jwUdJkiRJkrRUtDHtOubeZOjj1QOPL8zM7475HNK8bK41ixmm2/WKZctYtXz309CO15IkSZIkaSkZd+bjjF4ByGEDkzsozWXOpky1fmdmfmce45JasbGW+ThMt2so2Y/bLt8F2PFakiRJkiQtLWMNPmbmi4EXN9dHxC5K1uKvM/M64zyntBB0dLseYto1lLqPl16+HTDzUZIkSZIkLS2T7HY97unY0oJR73Y9TMMZ6Kz7aOajJEmSJElaStqadt00kw150YTOJ01UveHMuiFqPkJnx2szHyVJkiRJ0lIykeBjNR1bWrLqDWeGnnZdy3zcbOajJEmSJElaQiY57VpasjZtG73hzFozHyVJkiRJ0hJl8FEag/q067XDTru25qMkSZIkSVqiDD5K87Rtxy6278wrlveex7RrMx8lSZIkSdJSMraajxFRj5pkZq7o8dg4dBxfmqbNtSnXAGuHnHbd0XDGzEdJkiRJkrSEjDOAF0BWX4d5TFrUNm5tBB9XDhl8NPNRkiRJkiQtUeOedt0vuGjgUUtSvUP1mpXLWLF8uKdVPfOxmUUpSZIkSZK0mI0z8/FRIz4mLWr1ZjPrhmw2A53Bx63bd7Erk2VhrF6SJEmSJC1+Yws+Zua7RnlMWuw2bd2d+ThsvUfo7I6dlABkfSq2JEmSJEnSYmW3a2meNm2bX+bj6pXLOmoSOPVakiRJkiQtFQYfpXmqBwvXrR4++LgsgjUrbTojSZIkSZKWHoOP0jxtrE+7HnG6dEfH620GHyVJkiRJ0tIwzoYzYxURa4FrAyuBczLzT1MektTV5lrDmb1HyHyEzqYzZj5KkiRJkqSlYiKZjxGxd0Rcs7pdZY5trxwRHwQuBk4HvgWcFxFfi4hbTGK80jA2batnPo4YfKxlPm4281GSJEmSJC0Rk5p2/a/AL6vbP/baKCIOBU4DHkjJeIza7TbA1yPixLYHKw1j09Z6zccRp13XMh+3mPkoSZIkSZKWiEkFH+8LVzT0fWOf7d4EXLW6n43HkjJN/L0RcfXxDk8a3XwbzoCZj5IkSZIkaWlqPfgYEUcAh1KChz/LzF/22O5o4ER2Bx0vBP4euBfwHGBj9dhewIvaHLM0jE21hjPrRm04Y81HSZIkSZK0BE2i4cz1a/e/2We7R1RfA7gcOC4zz6rWfS4iTgO+XC0/KCKekJlbxjtUaXj1adej1nxca7drSZIkSZK0BE1i2nV9ivTP+mx3z+prAh+sBR7LysxTgFOqxbXAsWManzQvm7bZ7VqSJEmSJKmbSQQf96ndv7jbBhFxJeAGtVX/3eNYp9TuHzW/YUnjUa/RuHbUhjNmPkqSJEmSpCVoEtOuV9buN5vIzLgNuxvSbAdO7bHd72v395/nuBa0iFgG3Bo4EjgMuBQ4F/haZnYN4mo6Nta7XY847drMR0mSJEmStBRNIvi4oXb/wB7b3KH6msD3MvPyHtvVg5er5jswuCLIdxRws9rtRpTGNjPuWE37bl1ErACeCTwBuHKXTbZFxKeAp2fm2ZMYk/rbXG84M4Zu12Y+SpIkSZKkpWISwcdza/d71Wk8oXb/632OdUDt/saRR1SJiI8AdwfWzfdY4xARhwCfpgRAe1kFPAC4a0T8VWZ+YiKDU0/1mo9rx9DtetvOXezYtYsVyyZRFUGSJEmSJKk9kwg+nl59DeCEiDgwM/8882BE3IUytXjG//Y51nVq9/84hrHdlIUTeNwL+ASdgcdzgfcCv6Zkjd4TuH312D7AByLiTpnZr4u4WpSZHd2uR818bHbJvnzbTtavMfgoSZIkSZIWt9ajG5n5G+D7lCnT64BPRcQNImJ1RNwReCe7p1NfSP/g4y1r93855qFuBb4DvIUS8Ju0l9D5/X0YODIzn5WZb8/MV2bmHYCHUepiAqwBPhgRayY8VlW27tjFrloxgHUjNpxZuTxYHnHFslOvJUmSJEnSUjCp1KqXs7uhzC2BHwCbgS8BV6keS+C1mdk16hIRhwPHVItbgB+PYVzvBv6GkgG5PjNvkZmPp38AdOwi4qrAk2qrfgg8NDO3NrfNzPcDL6ituhrwxHZHqF7qWY8wesOZiOis+2jTGUmSJEmStARMJPiYmR+hZBTOBCCjdpvJG/s28Oo+h3nYzOGAb2fmjj7bDjquF1RZhadn5va592jN4ylZjDOeMcd4/pXOWppPbWNQmtumrZ1BwnrtxmHZ8VqSJEmSJC01Eysql5lPoATZzmk8tAV4M3CXzNzWbd+IWMXuzMAA/qetcU7J/Wr3zwG+0G/jKvD6ztqqq0ZEvyY1akm92cy6VctZtiz6bN2fHa8lSZIkSdJSM4mGM1fIzLcCb42IawCHUqZe/6xX0LFmf+DZteXPtTTEiat+FkfVVn0pM7PX9jVfBJ5XW74P8N1xjk1z21zvdD1is5kZZj5KkiRJkqSlZqLBxxlVE5rfDLH9+cC72hvRVN2osXzagPt9G9jB7t/hMX22VUs21qZdr1s1+pRr6Mx83GzmoyRJkiRJWgImNu1aPR3VWP7VIDtl5hbgD7VV1x/biDSwzbWGM+vmm/lowxlJkiRJkrTEGHycvms2ln87xL71bZvH0QRs2lbPfBzjtGszHyVJkiRJ0hJg8HH69mksXzTEvhfX7q+MiNVjGI+GsGlrvebj/KZdr7XhjCRJkiRJWmKmUvMxItZSah0eBewHrKN0sR5YZr5k/CObir0by1uG2PfyLsfaOujOEfGTHg8dOcQY9mgd3a5tOCNJkiRJktRhosHHiLgBpUPzfYH5ZuktleDjmsbyXJ2/65qBxr3mORYNaXNLDWfMfJQkSZIkSUvBxIKPEfF44LXVOWeyHJMhMx5r+y0VzUzHVV3W9dIM4DYzIfvKzKO7ra8yIm1gM4CN9WnXY6z5uNnMR0mSJEmStARMJPgYEScCb6wW64HDpNQ43DiJcSxQze99DYMHH5uZjnvyz3EqNtemXe89xm7XW7btJDOJGCU2L0mSJEmStDC0HnyMEj35t2pxJtPxv4C3At/OzGFqHC5FlzWW9wcuGXDf/Wr3t2fmwPUeNR71btfzbThTz3zcmcn2ncmqFQYfJUmSJEnS4jWJzMebA0ewO+PxMZn5zgmcd7H4TWP56l3W9XJ47f5Z4xmOhlHvdr1uvtOuGzUjN2/bwaoVq+Z1TEmSJEmSpGlaNoFz3Lh2/38NPM7y08bytQbZKSLWAFfucxxNQEfDmXlOu16xbBmrlu9+StrxWpIkSZIkLXaTCD4eULv/2Qmcb7H5QWP5uAH3uwWdmas/Gs9wNIxN2+qZj/Obdg12vJYkSZIkSUvLJIKPf67dv3gC51tUMvM3wM9rq+4Sg3UZuWtj+dPjG5UGVZ92vXaemY/QWffRzEdJkiRJkrTYTSL4eHbt/kETON9i9LHa/cOBu/XbOCJWAI+qrToX+G4L49Ic6g1n9p5nwxkw81GSJEmSJC0tkwg+nsLu7Mc7TuB8UxcRR0RE1m6nzLHLm4F6p+pXRcTKPts/HbhKbfm1mZm9NlZ7NtczH+fZcAbMfJQkSZIkSUtL68HHzNwOvBEI4G4RceO2z7nYZObvKD+jGccA74uI1c1tI+IhwItrq84F3tDuCNXNrl3Zkfk4327X0Jn5uNnMR0mSJEmStMjNP1oymJdSahQeB3wkIu6Ymb+d0Ll7ioj7A6/q8tD6xvL7IuLyLts9IzM/OqbhPB+4PXCzavlBwK0j4j3AWcD+wL2AO9T22Qr8ZWZuGdMYNIRmZuK6MUy7XmvmoyRJkiRJWkImEnzMzJ0RcU/gv4B7Aj+IiFcA78rM8ycxhh72AY4cYLsr99l/LDJzc0ScAHwGOLZafRXgWT122QA8MjO/Pq4xaDj1TtcA68bRcMaaj5IkSZIkaQmZSPAxIr5c3V0G7AL2BV4BvCIizgHOA4bJ3svMvPN4Rzl9mXleRNyKEnB8AnBol822UQKU/1B1ytaUbNq6Ozi4LGD1ivlXMegIPpr5KEmSJEmSFrlJTbs+Hqg3RElKDUiAIygdngcVjWONLDNPBk4ex7Eaxz2b3d/fsPtuB14aES8Hbg1cCziEkun4e+BrmXnRmIaqedhUazazbvUKIkb6lXfoaDhj5qMkSZIkSVrkJhV8hP7BuPlHbZaYzNwJfK26aQHaPOZmM2C3a0mSJEmStLRMKvj4rgmdR5qYeubj2jE0mwFrPkqSJEmSpKVlUg1nHjWJ80iTVG8400bm45btO9mVybIxTOeWJEmSJEmahvl3yJD2UJtrDWfWjSnzcW0tiJnA1u27xnJcSZIkSZKkaTD4KI1o49bxZz6uXrmsowDq5lp2pSRJkiRJ0mJj8FEaUT0wuHb1eIKPyyJYY9MZSZIkSZK0RBh8lEa0qdYQZu8xTbuGRtMZg4+SJEmSJGkRm1S361ki4l7AXYFbAlcF9gfWAr/KzOs2tl0J3KRa3JmZ35vkWKVuNte7XY9p2jV0Np2x47UkSZIkSVrMJh58jIiHAS8FDq+v7nEfgMzcHhHvBq5dHeMmmfnDVgcqzWFjveHMKjMfJUmSJEmSmiY27ToiVkTEfwHvpgQeo3aD0ty3nzfWtn14K4OUhlCv+bhuTDUfwcxHSZIkSZK0dEyy5uP7gAezO+C4Hfgs8GLgCdW6fgHID9Uev2d7w5QGU6/5OK6GM9DIfDT4KEmSJEmSFrGJTLuOiL8EHkQJHgbwEeDJmfnH2jZv6neMzDwvIr4PHAtcPyIOzMw/tzhsqa9NtZqPY512Xct83Oy0a0mSJEmStIhNKvPxxbX7b8nMB9UDj0M4vXb/BvMckzQvHcHHMWY+rjXzUZIkSZIkLRGtBx8j4vqURjEJ/A74+3kc7he1+0fOZ1zSfG3eVm8401LNRzMfJUmSJEnSIjaJzMdja/f/OzO3zuNYl9Tu7z+P40jzVs98XLu6pW7XZj5KkiRJkqRFbBLBx0Nq98+c57HqkZhV8zyWNC+bat2u926r27WZj5IkSZIkaRGbRPCx3sF6vhGaA2r3L57nsaSR7dyVbNm+64rlteNsOGPmoyRJkiRJWiImEXy8oHZ/vnUab1y7f/48jyWNrJ71CO3VfNy2cxc7du3qs7UkSZIkSdLCNYng409r9+8z6kEiYjVw99qq00YekTRPm7d2ZiSOt9t157HMfpQkSZIkSYtV68HHzDydkqUYwHUj4pEjHuqJwJUo07h/mpl/HNMQpaHVMx9XLg9WrRjfU2nl8mB5xBXL1n2UJEmSJEmL1SQyHwHeUX0N4E0Rcddhdq62f3lt1evGNTBpFB2drsc45RogIlhj3UdJkiRJkrQETCr4+EpK7ccE9gL+JyLeFBHX6bdTRBwQEa8APk3pbp3AL4D/bHm8Ul+batOux9npesZaO15LkiRJkqQlYPxRky4yc2NEnAh8iRJ8XA78LfC3EXEW8JPa5gdExJuBo4FbVdvOzEG9DDgxM43GaKo2b6tnPo6v0/UMO15LkiRJkqSlYFKZj2TmacAJdHa/DkoH7BMoWY0A+wN/A9yGzuDoH4F7ZuaZ7Y9W6m9jfdp1C5mPe5n5KEmSJEmSloCJBR8BMvMrwDHAycD22kPR2DRq63YC7wGOrQKY0tRt3lafdt1u5uNmMx8lSZIkSdIiNZFp13WZ+Sfg0RHxbOCBwO2AGwEHAvsBm4ELgZ8DXwE+nJnnTHqcUj9tNpyBxrRrMx8lSZIkSdIiNfHg44zMPB94Y3WTFpV6w5l1bdR8rE273mLmoyRJkiRJWqQmOu1aWirqDWfWtdHt2mnXkiRJkiRpCTD4KI1gU8vBRxvOSJIkSZKkpWCi064jIoBjq9uVgAOAfYBLgYsotR6/m5lnTHJc0rDq067XtjHtul7z0cxHSZIkSZK0SE0k+BgRtwCeAdyZEmyca/tLgC8C/5KZ32t3dNLw6g1n9jbzUZIkSZIkqatWp11HxCER8Wngm8D9gH2BqG5dd6lu+wMPAr4dER+PiIPaHKc0rHodxla6Xa/szHzMzLGfQ5IkSZIkqW2tBR8j4ijgNOCe7A421iMo0eVGY7sATgC+GRHXaWus0rA2bq3XfGx32vXOTLbvNPgoSZIkSZIWn1amXUfE1YCvAgdSAolJCSRuomRBfgM4G7gY2AisB/YDrgkcB9wKWMfuIOQ1gVMj4maZeW4bY5aG0dHtuo3Mx0Ydyc3bdrBqxaqxn0eSJEmSJKlNbdV8/E92Bx4D+BPwz8DbM3PDXDtHxD7A3wL/SGlMk8AhwH9QMimlqepoONNC5uOKZctYtXwZ23buAkrdx/3GfhZJkiRJkqR2jX3adUTcndJYZiZr8TvAsZn5b4MEHgEy87LM/BdKV+zvsntK9t0i4s7jHrM0rE0tZz5Co+O1TWckSZIkSdIi1EbNx6dUXwP4HXD3UadKZ+bvgXtUx5kJZj51vgOU5mtzLfNxXQvdrmF20xlJkiRJkqTFZqzBx4g4ALhLtZjAYzPzkvkcMzMvAh7L7qY0d4uI/eZzTGk+tu3YdcV0aGin4Qw0Mh8NPkqSJEmSpEVo3JmPd6TUkUzgR5n5pXEcNDO/CPyoWlwB3Gkcx5VGUW82A7C2rWnXK512LUmSJEmSFrdxBx9vVbv/7jEfu368W/XcSmrZpkYW4rpVZj5KkiRJkiR1M+7g4/Vq97815mOfVrt//TEfWxrY5q27Mx9Xr1jGiuVtlE6FtbXMx81mPkqSJEmSpEVo3FGTI2r3vzfmY59eu3/4mI8tDWxjLfjYVrMZMPNRkiRJkiQtfuMOPh5cfb08M7eM88CZeTmwmdJ05uA5Npdas3lbvdN1O1OuAdZY81GSJEmSJC1y4w4+7k1pNnPJmI87Y+a461s6vjSnTfXMx5aazQCsNfNRkiRJkiQtcuMOPq6uvm4e83FnXF59XdXS8aU5bap1u17bUrMZsNu1JEmSJEla/MYdfGyn88ZsMaHzSLNs2lqfdm3NR0mSJEmSpF4mFSyUlozN2yYz7bqe+bhl+052ZbZ2LkmSJEmSpDYYfJSGtLGW+bi2xYYza2uBzQS2bt/V2rkkSZIkSZLa0Fba1vqI+Ks2jtvCMaWhbK41nNm7xWnXq1cuIyiBR7DuoyRJkiRJWnzaipwcDLyzpWNLU7WpVn9xbYvTrpdFsGbl8iuCjvXp3pIkSZIkSYtBe5GTdprCWPROU7dpa73mY3vTrqE0nZkJPpr5KEmSJEmSFps2go9tdqK2y7WmrqPhTIvTrqGz6YwdryVJkiRJ0mIz7sjJo8Z8PGnB2VRrOLOuxYYzUDIfZ5j5KEmSJEmSFpuxBh8z813jPJ60EG2qZT62WfMRzHyUJEmSJEmL27JpD0BabDZNqNs1NDIfDT5KkiRJkqRFxuCjNKTNHd2uW552vdJp15IkSZIkafEy+CgNaePWyTWcqQc3N5v5KEmSJEmSFhmDj9IQMrMjCDjJbtf1LtuSJEmSJEmLgcFHaQhbd+xi5668Ynldy9Ou16/ZHdzcsMXgoyRJkiRJWlwMPkpDqDebAVjbcubj+jUrr7i/YcsOMrPP1pIkSZIkSQuLwUdpCM26i2tXtpv5uM9eu4OP23bu6qg3KUmSJEmStNAZfJSGsKlWd3HtquUsWxatnm/tquXUT3HBhq2tnk+SJEmSJGmcDD5KQ6hPu167qt0p1wDLIjqmXp9/2ZbWzylJkiRJkjQuBh+lIWzaunva9d6r251yPaPedOaCy8x8lCRJkiRJi4fBR2kIm7dNNvMROpvOXLDBzEdJkiRJkrR4GHyUhrCxlvm4bkKZj/vUMh/PN/NRkiRJkiQtIgYfpSFMJ/OxNu3ahjOSJEmSJGkRMfgoDaGz5uNkgo/72HBGkiRJkiQtUgYfpSF0drueRsMZg4+SJEmSJGnxMPgoDWFTbdr1ugllPnY2nNlKZk7kvJIkSZIkSfNl8FEawuZpNJzZa3fwcfO2nWysZV9KkiRJkiQtZAYfpSFsnELDmbWrlrMsdi/bdEaSJEmSJC0WBh+lIWyuZR2um1DNx2URHVOvbTojSZIkSZIWC4OP0hA2batPu55M5iM0m86Y+ShJkiRJkhYHg4/SEOrdricbfKw3nTHzUZIkSZIkLQ4GH6UhbK5lPq6d0LRrgH1qmY/nm/koSZIkSZIWCYOP0hDqmY97T2natTUfJUmSJEnSYmHwURpCPfg4qW7XAPt0TLs281GSJEmSJC0OBh+lAe3alWzeXm84M7lp1x01H818lCRJkiRJi4TBR2lAW3bsJHP38tS6XW/YStYHIkmSJEmStEAZfJQGtLE25Rpg3SSnXe+1O/Nx87ads8YiSZIkSZK0EBl8lAa0eevuKdfLAtasnNzTZ+2q5SyL3ct2vJYkSZIkSYuBwUdpQJu27c42XLdqBRHRZ+vxWhbRWfdxg3UfJUmSJEnSwmfwURrQplrm49oJNpuZ0VH30cxHSZIkSZK0CBh8lAbUzHycNDMfJUmSJEnSYmPwURpQvebjJDtdz9inlvlozUdJkiRJkrQYGHyUBrSp1mF67arpTrs+/zIzHyVJkiRJ0sJn8FEaUMe066lkPtanXZv5KEmSJEmSFj6Dj9KANm+b7rTrjpqPZj5KkiRJkqRFwOCjNKCNW+sNZ6bc7XrDVjJz4mOQJEmSJEkahsFHaUCbO2o+TmHa9V67Mx83b9vZEQyVJEmSJElaiAw+SgPaVJt2vffqyWc+rl21nBXL4oplO15LkiRJkqSFzuCjNKCObtdTqPm4LIKD1q++YvmCDdZ9lCRJkiRJC5vBR2lA9czHadR8BDh4nzVX3L/AzEdJkiRJkrTAGXyUBlSv+TiNbtcAB9cyH8+347UkSZIkSVrgDD5KA9o45YYzAIfsU592beajJEmSJEla2Aw+SgPaXJ92PYWGMwAHr9897drMR0mSJEmStNAZfJQGtHnb9Kddm/koSZIkSZIWE4OP0oDq067XTWnadWfDGTMfJUmSJEnSwmbwURrAzl3Jlu27rlheO61u1x0NZ7aSmVMZhyRJkiRJ0iAMPkoDqE+5Bth7atOud2c+Xr59Z0c2piRJkiRJ0kJj8FEawKatOzuW106p4cwBa1exYllcsXz+ZdZ9lCRJkiRJC5fBR2kAm2qZjyuWBauWT+eps2xZcND6etMZ6z5KkiRJkqSFy+CjNIDNtczHdatXEBF9tm5XZ9MZMx8lSZIkSdLCZfBRGkBnp+vpTLme0dl0xsxHSZIkSZK0cBl8lAZQbzizdkrNZmYcsk992rWZj5IkSZIkaeEy+CgNYNO2zmnX03TI+t3Trs18lCRJkiRJC5nBR2kAmxbStGszHyVJkiRJ0iJh8FEaQD34uHbVdDMfOxvOmPkoSZIkSZIWLoOP0gA216Zd7716ITWc2UpmTnE0kiRJkiRJvU03hWuBioijgWOAKwM7gXOB72bmb6Y6ME1NR+bjtGs+1jIfL9++k0s2b2f/daumOCJJkiRJkqTuDD7WRMQDgedTAo/dHv8G8NzMPKWFcx8PfGXE3W+emd8d32jUtGnbwqn5eOC6Vey710ouvXw7AGeev4FbXfPAqY5JkiRJkiSpG6ddAxGxPCLeCXyIHoHHyq2B/42Il05mZFooNm9dON2uI4LrHrr+iuUzz9swxdFIkiRJkiT1ZuZj8RrgpNryZuB9wBnAKuCWwAOAlZSA7fMi4qLMfE2LYzoH2DHnVoVdR1q2saPb9fSfNtc7dD3f/s1FAPzc4KMkSZIkSVqgph9FmbKIuDfwd7VVPwXukZm/a2x3I+B/KHUgAf41Ir6UmT9qaWjHZ+bZLR1bQ6o3nFk75YYzANc5ZHfm4y/ON/goSZIkSZIWpj162nVELANeXlu1GTihGXgEyMwfAA8CdlWrmvtqCavXfNx7ytOuoWQ+zvjFeRvseC1JkiRJkhakPTr4CNyZzhqPr8vMs3ptnJnfoNSFnHGfiLhWW4PTwtHR7XoBTLu+Ti34uGHrDs695PIpjkaSJEmSJKm7PT34eL/G8n8MsM/bG8snjmcoWsg21RvOTLnbNcA+a1Zylf32umLZpjOSJEmSJGkh2tODj/eu3f91Zv56gH2+RmeDl/uMd0haiDbXpl1Pu9v1jHrHa5vOSJIkSZKkhWiPDT5GxH7A1WurThtkv8zcBnyvtuqYXttq6ejIfFwADWegM/ho5qMkSf+/vfuOj+sq8z/+fWbUJcuSe4m705wGCYnTkw0JLQtLC7AhhLIsgVAWtoT9hYXfssvSFn6EXTphwyYQ6kJCykIIgfQKpOEUd8d27LhILurl+f1xp9y51sgz1ozmSvN5v17z8j1nzr33kXQsXT06BQAAAHFUtclHSUdHymuKODc8QrLdzOaUIJ6oT5vZo2bWYWb9ZvaCmT1uZt80s9ebWTwyYFWgf3BY/UPDmXIc1nyUIpvOsOM1AAAAAACIoWpOPi6NlDcVcW60bfRapfCXkk6Q1CapVtJMScdJeo+k/5H0rJn9RRnui4ie/qGcchynXa/dsV8DoQQpAAAAAABAHFRz8rE1Ut5dxLkdkfKUEVuNXYekjZJ2SopmlpZKusHM/q1M90bK/tB6j5LUFIMNZyRp6YwW1SRMkjQw5Fq3o6vCEQEAAAAAAOSq5uRjS6TcO2KrkfUc5FqHapek/5T0CknT3X2auy9295mSpkl6vaR7I+dcaWZ/cyg3M7M/jfSStGwsH8Rk092XTT7W1SRUm4zHf5u6moSWzmzOlJ/etreC0QAAAAAAABwoHlmUymiIlPuLOLcvUm4cYyxSsInNYe7+IXf/lbvnjMR09z3u/nNJZ0n6eOTcz5nZghLEgBF0haZdt8RkynXakXOyA3jZdAYAAAAAAMRNNScfoyMd64o4tz5Sjo6ELJq773P3g46+9MCnJH0jEs8Vh3DPY0Z6KXdDnarXFRr5GJcp12lHseM1AAAAAACIsWpOPu6PlKMjIUcTHekYvdZ4+CflJj1fXYEYqkI4+dgck52u046cnU0+Pk3yEQAAAAAAxEw1Jx+jC+S1F3FuW6Q87lkfd98l6c5Q1SIzmzvecVSD7tC06+b6eI18DO94vaWzR/t6ByoYDQAAAAAAQK5qTj6uj5QXFnHuokh53RhjOVTPRMqzKhLFJNcV2u26OWZrPh7W3pizDuWz2ysxCBcAAAAAAGBk1Zx8XBUpLy/i3PBu0B3uvq0E8RyK6FqTTRWJYpKL85qPZqYjZmc3W2fdRwAAAAAAECdVm3x0905Jm0JVpxVynpnVSTopVPVECcMq1uxIeWdFopjkuvrC067jNfJRyp16/cy26GoCAAAAAAAAlVO1yceUW0PHy8xsaQHnnKXczWluLm1IRTkrdDwgaUulApnMuvvju+GMxKYzAAAAAAAgvqo9+fjzSPmvCzgn2uaG0oRSHDN7pXKnit/r7t2ViGWy2x8a+dgUsw1nJOnIOa2Z42e275O7VzAaAAAAAACArGpPPt4u6clQ+YNmtiRfYzM7TdJFoapb3H11nraLzcxDr9+Nct3GYoJO7Wr9zUj1d4u5BgoX95GPR4WmXXd2D+iFfX0VjAYAAAAAACCrqpOP7j4s6cpQVbOkm8xsQbStmR0v6SfKfs6GJX2sRKG82czuNLPXpNaUzMvMzpf0oKRwjI9Juq5EsSAi7ms+tjfXadaU+kyZTWcAAAAAAEBcxC+TMs7c/SYz+5qky1NVx0h6ysy+L+lRSbWSTpX0xtRx2kfd/bEShnJ26tVpZvdKelzS85L2KdjFeomkCySdEDlvm6TXphKpKIPwbtfNMdvtOu3IOVMyIx6f2bZPZx8xs8IRAQAAAAAAkHxM+5CkKZLelio3S3pPnrYu6bPu/oUyxdIm6cLU62AekHSJu28oUyxQ7rTrphiOfJSCqdd3rw42O2fTGQAAAAAAEBdVPe06zd2H3P1SSW9W7hqQUQ9IOt/drxylzaF4RNI1kp5SkNwcjUu6T9Ilks5097UljgURXf3ZadctMdxwRpKOCO14/cz2vRWMBAAAAAAAICuew7gqxN1/LOnHZnaspOMlzZM0JGmrpIfdfV0R19ogyQps+6Skd0mSmbVJerGkhZJmSGqU1CepU9IGSQ+5+55C48DYhaddN8VwwxlJOiq04/Xq7fs1NOxKJgrqfgAAAAAAAGUTz0xKhaWSgaONgCznvTsl/bYS98bIctd8jOd/mcNntyhh0rBLfYPD2rCrS8tmtlQ6LAAAAAAAUOWYdg2Mwt3V3R/e7Tqe064bapNaPL05U35yC4NjAQAAAABA5ZF8BEbRNzisweHsMpzNMd1wRpJetLAtc/zbp1+oXCAAAAAAAAApJB+BUYRHPUpSU108Rz5K0gVHz84c3/H0CxoYGq5gNAAAAAAAACQfgVGF13uU4rvhjCSdfcRM1dUE/6X39g7q4fW7KxwRAAAAAACodiQfgVF09WeTj421yVjvIN1cX6Mzlk3PlG9btb2C0QAAAAAAAJB8BEbV1Rf/zWbCLlgxJ3P861Xb5e6jtAYAAAAAACgvko/AKPb2DmSOW2K82Uza+UfPyhxv6ezRU8/vq2A0AAAAAACg2pF8BEbR2d2fOW5rqqtgJIWZ1dqgFy1oy5R/zdRrAAAAAABQQSQfgVF0dGVHPrY31VYwksJdsCK76/Wvn9pWwUgAAAAAAEC1I/kIjCI88rG9Of4jHyXpZaHk45Nb9mprZ08FowEAAAAAANWM5CMwio7u8MjHiZF8XD6rRUtmNGfKtz/F1GsAAAAAAFAZJB+BUXSERz5OkGnXZpY79Zp1HwEAAAAAQIWQfARG0Rka+TgRNpxJCycfH1i3K2fXbgAAAAAAgPFC8hEYRe7Ix4mTfDxxYbump9aoHBhy/e6ZHRWOCAAAAAAAVCOSj8AoOrom3rRrSUomTOcdNStTZuo1AAAAAACohJpKBwDEWUfMpl1f/+CmgtvW1yQzx7f9aZuuvX+DahK5f2+4eOXCksUGAAAAAAAQxchHII/egSH1DAxlyu3NE2fkoxTsel2bNElS3+Cw1u/sqnBEAAAAAACg2pB8BPIIbzYjTaw1HyWpriah5TNbMuUnNu+pYDQAAAAAAKAakXwE8ghvNtNQm1BDbXKU1vF0zPypmeM/PtepvT3seg0AAAAAAMYPyUcgj4m603XY8fOnqrUhWNp1aNh192p2vQYAAAAAAOOH5COQR2fMNps5FDXJhM46fGam/NCG3drfN1jBiAAAAAAAQDUh+QjkER75OG2CbTYTdvLiaWquD0Y/Dgy57l2zs8IRAQAAAACAakHyEchjMox8lIKNZ85aPiNTvn/dLnX3M/oRAAAAAACUH8lHII+OrvCajxN35KMkrVwyTY2pDXP6B4d139pdFY4IAAAAAABUA5KPQB4doZGPE3XDmbT62qTOWD49U75v7U71DgxVMCIAAAAAAFANSD4CeXSG1nycyNOu005bOkP1NcF/+d6BYT2wjtGPAAAAAACgvEg+Anns7p48064lqbEuqdOWZUc/3rNmJ2s/AgAAAACAsiL5COTROYmmXaedsWyGapMmSeruH9L1D26qcEQAAAAAAGAyI/kI5NGRM+164o98lKTm+hqtXJId/fj1363N2VgHAAAAAACglEg+AiMYGnbt6Zl8Ix8l6azDs6Mfd3X1619vXlXhiAAAAAAAwGRF8hEYwd6eAblny5Mp+TiloVYXHD07U/7ZH7fojqe3VzAiAAAAAAAwWZF8BEYQnnKdMGlKQ00Foym905fP0IL2xkz5yp89qb29A6OcAQAAAAAAUDySj8AIOkKbzbQ11SmRsApGU3oJM73+xMNUlwy+BWzb26vP3Pp0haMCAAAAAACTDclHYASdoZGP7ZNks5mo2a0N+uB5yzPlHzy0Sfet2VnBiAAAAAAAwGRD8hEYQXjk42Ra7zHqvecu04q5rZnyR3/2uLr7BysYEQAAAAAAmExIPgIjCI98bJvEycfaZEKff+PxSqamlT+3u0f//qtnKhwVAAAAAACYLEg+AiPoqIJp12nHzp+q956zNFP+7n0bdC/TrwEAAAAAQAmQfARGkDPtunnyjnxM++B5h2v5rBZJkrv0Nz/8o7bv7a1wVAAAAAAAYKIj+QiMoKMrPO16co98lKSG2qS+/JYXqa4m+Jawc3+/Pnj9HzU4NFzhyAAAAAAAwERG8hEYQe6068k/8lGSjpk3VZ98zTGZ8kMbdusLtz1bwYgAAAAAAMBER/IRGEFnzm7Xk3/kY9pbTl6g1714fqb8jTvX6jdPba9gRAAAAAAAYCIj+QiMoKNKdruOMjP92+uO1eGp9R8l6W9//Jie291dwagAAAAAAMBERfIRiHD33A1nqij5KElNdTX6+iUnqrE2KUna0zOgD1z/B/UNDlU4MgAAAAAAMNGQfAQiegaG1D+Y3WilmqZdpy2fNUWfef1xmfJjm/fokzetqmBEAAAAAABgIqqpdABA3IRHPUqTe9r19Q9uGvX9UxZP00MbdmfadvUNauWS6XnbX7xyYUnjAwAAAAAAExsjH4GIjq7seo8t9TWqq6ne/yYXHj9XC9obM+WbHtuq9Tu7KhgRAAAAAACYSKo3qwLkEd7puq0Kp1yH1SYTeuvKRZrSEAySHnbp+gc3qjO0IQ8AAAAAAEA+JB+BiPBO19W22cxIWhtrdcnKRUomTJLU1T+k7z2wMWddTAAAAAAAgJGQfAQiwqP6qn3kY9qCaU167YvmZ8pb9/TqZ3/cLHevYFQAAAAAACDuSD4CEbu7stOuGfmYddKidp2+LLvZzOOb9+jOZ3dUMCIAAAAAABB3JB+BiNxp14x8DHvlsXO1dGZzpnzbqu16fHNn5QICAAAAAACxRvIRiMidds3Ix7BkwnTxyQs1rTn7efnp7zdrAztgAwAAAACAEZB8BCI6usPTrhn5GNVUX6N3nLZYjbVJSdLgsOu6BzZq576+CkcGAAAAAADihuQjEBEe+djezMjHkcyYUq9LTs3ugN0zMKTv3r9Bu/aTgAQAAAAAAFkkH4GI8MhHpl3nt2RGs9540mGZ8u6ufr372kfUOzBUwagAAAAAAECckHwEIthwpnAnHNaml6+YnSn/cVOnPvzDRzU07BWMCgAAAAAAxAXJRyBkcGhY+3oHM+V2Rj4e1NlHzNTJi6dlyr/80zZd8dPHNUwCEgAAAACAqkfyEQjp7BnIKbPm48GZmV5zwjwdMbslU/c/f9isf7rxSbmTgAQAAAAAoJqRfARCwpvN1CZNzXXJCkYzcSQTpotPWaSVS7IjIK9/cJP+5eZVJCABAAAAAKhiJB+BkOhmM2ZWwWgmlrqahL7zjpN14sK2TN01927QZ3/5NAlIAAAAAACqFMlHIKSji81mxqKlvkbffdcpOv6wqZm6b965TlfdvrqCUQEAAAAAgEoh+QiEhHe6bmOzmUPS2lCra991io6aMyVT9+XfrNYXfvUMIyABAAAAAKgyJB+BkPC0a0Y+Hrq2pjp9790rtXxWdhOar/x2jT52w5MaYhdsAAAAAACqBslHICQ88rGdkY9jMqOlXte/e6WOnJ0dAXn9g5v0wR/8QX2DQxWMDAAAAAAAjBeSj0BIZ1fuhjMYm1mtDfrRZafqpEXtmbpbn9imd333Ye3vG6xgZAAAAAAAYDyQfARCckc+Mu26FNqa6nTdX52ic4+cmam7d80uXfztB7Rrf18FIwMAAAAAAOVWU+kAgDjpzFnzkZGPxbr+wU1533vpUbPV0dWvxzbvkSQ9vnmPLvjSXbr0tEWaNaVhxHMuXrmwLHECAAAAAIDxwchHICR3t2tGPpZSMmG66CULdNqy6Zm63V39+sada7Xmhf0VjAwAAAAAAJQLyUcgJGe362ZGPpZawkx/ftxcvfyYOZm63oFhffe+9Xpw/a4KRgYAAAAAAMqB5COQ4u7qZLfrsjMznXPETF18ykLVJk2SNOzSjY9u1S2Pb9Wwe4UjBAAAAAAApULyEUjZ3zeoweFs4osNZ8rr2PlT9Z6zlqm1Ibv07L1rd+m6+zeqb2CogpEBAAAAAIBSIfkIpIQ3m5GkqY0kH8ttfnuj3nfucs2bmt1w5pnt+/TNu9blrL8JAAAAAAAmJpKPQEo42dXaUKOaJP89xsPUxlq95+xlWjG3NVO3bW+vvva7tfrDpo4KRgYAAAAAAMaK7AqQsrsrtN4jm82Mq7qahC5euVBnHz4zU9fVN6i3fOsB/eKxrRWMDAAAAAAAjAXJRyAlPO26jc1mxl3CTK84do7ecOJ8JS3YiKZ/cFgf+sEfddXtz8rZiAYAAAAAgAmH5COQ0pGz0zXrPVbKSYum6Z1nLlZjbTJTd9Xtq3XZdb/Xvt6BUc4EAAAAAABxQ/IRSOkIjXxsZ+RjRS2d0aL3nbtMS2c0Z+puW7Vdf/GVe7V6+74KRgYAAAAAAIpB8hFIeWFvb+Z4Gms+VtyMlnr9/PIzdM4R2XUg1+3s0l989V7d8vjzFYwMAAAAAAAUiuQjkLL6hf2Z46Uzm0dpifEytalW//WOk/Wh85Zn6rr7h/T+6/+gT9/6lAaHhisYHQAAAAAAOBiSj4Akd8+Zznv4rCkVjAZhyYTpb192pK6+9CWa0lCTqf/WXet00Tfv1/qdXRWMDgAAAAAAjIbkIyBpx74+7e0dzJQPn9VSwWgwkvNXzNZNHzhTR87OJob/uKlTr/ry3brugY3shg0AAAAAQAyRfASUO+V6Rkud2lnzMZYWz2jWz99/ut5w4mGZup6BIX38hif1jmse1vbQup0AAAAAAKDySD4CElOuJ5Cmuhp98U0n6BuXnKj2ptpM/Z3P7tDLr7pLN/xxC6MgAQAAAACICZKPgHJHPh4+mynXE8Erjp2rX33kbJ131KxMXWf3gD78o0d10Tfu1xOb91QwOgAAAAAAIEk1B28CTH45yUfWe4yN6x/cdNA2Lz1qlqY21OqWJ55Xf2r360c2dug1X7lHJy5q18tWzNaUhmCE5MUrF5Y1XgAAAAAAkIvkI6pedKfr5Uy7nlDMTCcvmaZls1p06xPPa9XzeyVJLun3Gzv05JY9OveImTp12fTKBgoAAAAAQBUi+Yiqt6urXx3dA5ky064npmnNdbrk1EVa88J+3fz4Vr2wr0+S1Dc4rF+t2q671+xUd/+QLj1tUWYkJAAAAAAAKC/WfETVW709O+V6WnOdZrTUVzAajNXyWS364HmH69UnzFNjbTJT390/pH//1TM683O/1VW3P6s9oYQzAAAAAAAoD5KPqHprXghPuWbU42SQTJhOWzpdf3fBETrr8BmqS2a/1e3pGdBVt6/WmZ+7Q5/536e0fW9vBSMFAAAAAGByI/mIqsdmM5NXU32NXnnsXP3Dy4/UuUfOVEt9dqWJfX2D+uad63Tm5+7Q3//kMT0bWvcTAAAAAACUBslHVL3wtGuSj5NTc32NXrZiju796Hn68PmHq7Uhm4QcGHL99Peb9bIv3aV3ffdhPbBul9y9gtECAAAAADB5sOEMql7OyMfZ7HQ9mU1tqtWHzz9C7z5rqX740CZ95571en5Pdtr1HU+/oDuefkEnHDZV7zl7mV5x7BwlE1bBiAEAAAAAmNhIPqKqdXT1a+f+vkyZkY+T2/UPbsocN9XV6PJzl+vxzZ26e/VObQut/fjY5j16//V/0LTmOp25fIZOXNiuupoDB4pfvHLhuMQNAAAAAMBERfIRVS086nFqY61mTmGn62qSTJhevLBdL1rQptUv7Nfdq3do7Y6uzPu7u/r1i8e26vantuvM5TO0csl0NdYlR7kiAAAAAAAII/mIqrY6tNP14bNaZMYU22pkZjpi9hQdMXuKtnT26O7VO/TE5j1Kr/zY3T+k21Zt153P7tCpS6frjOUzcjavAQAAAAAAI+O3Z1S1nM1mZjPlGtL8tka95eSFevmKft2zdqce2bBbA0NBGrJvcFh3PrtD967ZqZMXT9M5R87U/LbGCkcMAAAAAEB8sds1qtqa0LTr5bPYbAZZ7c11evXx8/QPLz9Kf3bkTDXUZr9dDg677l+3S+d8/re64qePad2O/aNcCQAAAACA6sXIR1S16LRrIKqlvkYXrJijsw6fqQfX79Y9a3aqq29QUpCE/PEjm/WT32/Wq46bq8vPXaZj5k2tcMQAAAAAAMQHyUdUrT09A9q+N7TTNdOuMYqG2qTOOWKmTl82XY9s7NDdz+5QZ8+AJMlduuXx53XL48/rjOXT9bZTF+v8o2epJsngcgAAAABAdSP5iKoVnnI9pb5Gc1obKhgNJoraZEKnLZ2uUxZPU2NdUl/73RqtC+2Qfe+aXbp3zS7Nm9qgt566SG8+eYFmtLCLOgAAAACgOjEsB1Vr9fbslOvls9npGsVJJkxvPOkw/foj5+jrbz1Rx85vzXl/655e/fuvntHpn7lDH7j+D/rNU9s1MDRcoWgBAAAAAKgMRj6iaq0OjXxkvUccqmTC9Mrj5uoVx87RIxs7dO39G/W/TzyvweFgh+z+oWHd/Pjzuvnx59XeVKsLj5+r175ovk5a1E7CGwAAAAAw6ZF8RNXKTT6y0zWKd/2Dmw6oO23pdB0zr1UPb9ith9bv1r7ewcx7Hd0D+t4Dm/S9BzapralWK+a2asXcVi2a3qxkwnTxyoXjGT4AAAAAAGVH8hFVa01k2jVQKq0NtXrpUbN17hGz9My2fXp0c6eefn5vZjSkJHV2D+i+tbt039pdaqxN6ui5UzStuU5nHj5DLfV8awYAAAAATA78houqtK93QFv39GbKTLtGOSQTphXzWrViXqt6B4b0p6179OhznVq3o0seatczMKQ/bOrUe7/3e9UkTCctatc5R87U2YfP1Iq5rUokmJ4NAAAAAJiYSD6iKq0N7U7cXJfU/LbGCkaDatBQm9RJi6bppEXTtLd3QE89v1dPPb9Xa3d0aSg0InJw2PXg+t16cP1uff6Xz2hGS73OWD5dpy4NXounN7FWJAAAAABgwiD5iKqUs9P1LHa6xvhqbajVyiXTtXLJdPUODOnZ7fu06vm92rirW3t6BnLa7tzfpxsf3aobH90qSZrT2qBTl07TKUum6+TF7Vo2s4WRkQAAAACA2CL5iKq0JrTZzHI2m0EFNdQmdfxhbTr+sDa9+eQFemxzp+56dofufHaHHnuuU8Oe237b3l7d8OhW3ZBKRrY11eqkhe16yeJpesnidh03f6oaapMV+EgAAAAAADgQyUdUpWdDIx8PZ7MZxMSPHn5OkjRrSoMuOmmBLjxurtbu6NL6nfu1bkeXXtjXd8A5nd0D+s3TL+g3T78gKVhn8rC2Ri2a3qRLT1uskxa1q725blw/DgAAAAAA0kg+jsDMjpF0vKR5koYkbZH0iLuvH+c4EpJOl7RM0lxJe1Kx3O3uHeMZy2Syp2dAj2zMfvrYbAZx1VRXo+PmT9Vx86dKkvb3DWr9zi6t27FfG3d1a/veXkUGRmpo2LVxd7c27u7WXat3SgqWFnjJomB05MmL27VwGutGAgAAAADGB8nHEDN7o6SPK0g8jvT+fZI+5u6/K3McNZI+KulyBQnQqH4zu0nS37v7hnLGMhn91z3rta93UJI0pb5GJy+ZVuGIgMK01OcmI3sHhrRpd7c27urShl3d2tzRrYGhaDoyWGZgzQv79cPUyMoZLfU6eXG7Tl06XSuXTtMRs6awbiQAAAAAoCxIPkoys6SkqyW94yBNT5f0GzP7tLt/vEyxzJZ0s6SXjNKsTtIbJF1gZpe6+43liGUy2tM9oP+6JzuA9V1nLlFrQ20FIwIOXUNtUkfMnqIjZgfrlg4Nu7Z29gQjH3d1afvePu3cf+BU7Z37+/S/T27T/z65TZLU3lSrkxdP08ql0/XihW1aMbeVdSMBAAAAACVB8jHwJeUmHrslfV/SowoSfSsVJPtqJSUk/ZOZ7Xb3L5UyCDNrlHSjchOPWyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxnLZPWde9ZpX19q1GNDjd515pIKRwSUTjJhWjCtSQumNenM5TP0l6cs0MZd3XpkY4ce2bBbj2zsyNlsKa2je0C3rdqu21ZtlyTVJk0r5rbqRQva9KKFbTpi9hQtndGixjoSkgAAAACA4lR98tHMLpT0wVDVKkmvcPfnIu1OkHSrstOgv2Bmt7v7EyUM518UJDrTfirpEncPD136rJldLOm7CpKhDZJ+ZGZHuHtvCWOZdDq7+/Vf927IlN995lJNbWTUIyavHzyU/TaW3lG7u29QG3d3a/3OLq3f2aWtnT0HrBs5MOR6bPMePbZ5j/77/o2Z+vltjVo2q0XLZjZrQXuT5kxt0OzWes2a0qBZrfWqryE5CQAAAADIVdXJx9SGLp8OVXVLenU08ShJ7v6YmV0k6W4Fox/T5766RLEcJukDoarHJV3s7gMjxHK9mS2U9JlU1QJJ75f0xVLEMlldffd67U+NemxtqNE7z1xc2YCACmiqr9HRc1t19NxWScG6kRt2dWn9ji5t2t2tLZ09Ghw+cN1ISdrS2aMtnT2669kdI74/tbFWbU21amusVWtjbaY8tbFWbY11mtpYq6mpcmtDraY01KilvkYtDTWqTSbK9jEDAAAAACqnqpOPkl6q3M1l/sPd1+Vr7O73mdlPJL05VfXnZrbc3deUIJb3KRjFmHbFSInHkC8oSFbOT5U/LJKPee3u6tc192bXenzP2UtZ6xFQsG7kUXNaddScIBk5NOzatrdXz+0ONrDZ2tmrnfv78iYkw/b0DGhPz4A2HrTlgeprEjnJyJb6GrXU16qlPqnGuho11ibVUJtQY21SjXVJNdQmU3VJNdYlcsuhNk11SRKbAAAAAFBB1Z58fF2kfHUB53xb2eSjJL1WQSKwlLFslHTbaI3dfdDMrpH0T6mqw8zsJe7+SAlimXS+ffc6dfUPSZLammr19tMXVzYgIKaSCdP8tkbNb2tUsMysNOyuzu4B7djXqx37+rRjf5/29Axob8+g9vYOqDv1f2ss+gaH1be/Xzv394/5WlEt9TVqa6rVtOY6tTXVaVpTbfBvc53aQ8dtTbVqb6pTe1Md61sCAAAAQIlUe/LxwtDxWndfW8A5d0vqVXaU4p9rjMlHM1si6ehQ1e3ufvBhRtKvlU0+pmMh+Rixa3+f/vu+DZnyX5+1VFMY9QgULGGmac1Bgu7IOQe+Pzg0rL29g+rqG1TPwJB6+ofUMzCk7v4h9abK3Zn6QfX0DwXJxsHhcYl/f9+g9vcNanNHT8Hn1Nck1N4USkg2B0nKtsagPKUhO0IzGLFZq+b6pKak/q1htCUAAAAASKri5KOZtUlaGKp6oJDz3L3fzH4v6YxU1fGjtS/QCZFyQbFIekjSoLJfx1LEMul86+51mZFZ7Yx6BEquJpnIJCeLMeyu/lQSsncglZAcGFJv6t++wWH1Dg5pcMjVPzSsgcFhDQwNa2DIU/8Gx/3p48FhDQy7BgaHD9hEp1h9g8PatrdX2/Ye2j5ejbXJ0PTx7HTyKfU1aq6PJi6Duik5U86D48bapMxsjB8NAAAAAFRO1SYflTvSUJKKWbdxrbLJx3Yzm+Pu28Y7FnfvNbOtyiZRV4whhklp5/4+XXtfdgW6y85Zppb6au72QHwkzNSQWqexlDvPu7uGhl0DQ66+wWAEZnf/kLr6B4PjvsFMuSdc3z+k/hKNxuwZCEZ/7tjXN6brmEkNNak1LGsSaqhLZsupNTDrU+tc1tUkVJdMqCZhqq1JqDaZUF3SVJsMjmtrEqpNWOY4/F5N0lSXbpdMqK4meC+ZMNUkgn/Tr5rIMclRAAAAAKOp5izM0kh5UxHnRtsulTSW5ONYY0knH6PXqXr/fd8G9QwEox6nN9fp0tMWVTgiAOVmZqpJmmqSUmNdUm1NhZ87ODSs7oF0wnJQ3X2paeOhBGX3wFB2ZGZ6xObgkAaGxjre8kDu2URmXCVMqkkklEho1ERlupxIJSvTSct06tIseAV1lq1TtkG4bfpcS9UnzGQ28r+JTNmUDMWZjq8maTmJ1pqc+BOh9w+sz9aFzs1Xn/r40x9rIhV7ULYRPx5JSiRC7RR8PEodj3id1OfNFfQhueRyuafrPDM6OL3ISxBbcH7ClBNrwkzJTNwkmwEAAFCcak4+tkbKu4s4tyNSnhKTWGrNrN7dxzbUZhJ537nBSMdv3rVOl52zVE111dzlARxMTTKh1mRCrYewLuzQcDCNvHcwO4U8naAM6odHTFr2DQxnjnsHgn8L2Fw8NoZd6h8aloYkaXzW8UTlRJORmeRuKlmZTvQmEqFjMyUSoeNQvYUSnenkZ/5EaKocOk6k2tsIx+lzw23zGW2lbR9lIYd0Qjd7jVSS17PnZpO+oevlSQRH63ISx56NJXzfzNcm8nUK6iynPNJ7kX9SbQ7848BIbdIJ85G+7tGvRbiPZL+GwR8E0sfpr1n485v5nB1Qp4LaKafdoV8n53KpymJiiCb8M21D10/H5zl12bY550ZiyNc/sudkzz9Y27xx5dwrG0e0v+Zcq5i4Ivca6WsY/qNUug+m6zXCH67Sf7TJHIf7f+QPWdnjkesVvVbe61vkGtk/FkWvW6himmc/99HvLz7C1+DAfhFtE/6elb5ubmy53y+in6fwe/m+N1n4ZEW/BnnOidw/G0/+9tn7HRjz6PfIvn/A98O83y8Pcq0ivqj5mmb/BxTavrzXH7ltnmsUHUuR1y/yOvlOGKm2mM/XzCn1umDF7Hx3rSrVnIlpiZSLWdgrumtB9FqVjqWg5KOZ/SnPW0etXbtWxxxzTBFhxNuwS1f9QPpycT/nx2xPz8D43hDApBBNQmR+gfARfjlIt0m/n7lG6JfFTL3nlLPHB/7CO+aFMwEAAIAq1liX0GHtRUzDirG1a9dK0oJDPb+ak48NkXJ/EedGk3uNkygWSRru6+vrWrVq1XMluFY1W5b6t5Bd1FFd6BvIh76BfOgbyIe+gXzoG8iHvoF86BslNCBp1fOVjqJkFkjqPtSTqzn5GB1dWMw2rfWRcnT0YSliKXT04yHH4u6TZ2hjDKVHlvJ5RhR9A/nQN5APfQP50DeQD30D+dA3kA99A+WSqHQAFbQ/Uo6OPhxNdHRh9FoTORYAAAAAAACgJKo5+bg3Um4v4ty2SHnf2EIpWSwDbDYDAAAAAACAuKjm5OP6SHlhEecuipTXxSSWscYBAAAAAAAAlEw1Jx9XRcrLizh3Wei4w923VSIWM2uQNG+U6wAAAAAAAAAVU7XJR3fvlLQpVHVaIeeZWZ2kk0JVT5QgnMci5YJikXSKcjcNKkUsAAAAAAAAQElU827XknSrpPemjpeZ2VJ3P9jU5bOUuyHMzWMNwt3Xm9nTko5KVZ1vZubufpBTL4iUxxwLSocdwpAPfQP50DeQD30D+dA3kA99A/nQN5APfQPlUrUjH1N+Hin/dQHnRNvcUJpQcmJZJOllozU2sxpJ7wxVbZH0SIliAQAAAAAAAMas2pOPt0t6MlT+oJktydfYzE6TdFGo6hZ3X52n7WIz89DrdweJ5euSwjtVf97Makdp//eS5ofKVxUwUhIAAAAAAAAYN1WdfHT3YUlXhqqaJd1kZguibc3seEk/UfZzNizpYyWM5TlJXw1VHS/p+2ZWP0Isfynpk6GqLZK+UqpYAAAAAAAAgFIwBstJZvZVSZeHqrokfV/So5JqJZ0q6Y2p47R/cPcvjHLNxZLWh6rudPdzDxJHk6Q7Jb0kVL1F0nWS1klql/QqSeeE3u+TdL673zPatQEAAAAAAIDxRvJRkpklJV0j6W0FNHdJn3X3K0drdCjJx9R5cyTdIunEAmLZJ+nt7h5duxIAAAAAAACouKqedp3m7kPufqmkNyt3DcioBxSMMhw18TjGWLYpGGn5CUnb8jTrV7BBzQkkHgEAAAAAABBXjHwcgZkdq2DNxXmShiRtlfSwu68b5ziSkk6XtFzSbAUjHTdLutvdd49nLAAAAAAAAECxSD4CIWZ2jHITz1skPeLu60c9sfRxJBQknpdJmitpTyqWu929YzxjQaDSfcPM6iQdLWmFpDmSmiTtlbQ9Fce4/nEEWZXuG4ivuPUNM2tV8LNlnqRZkvZLeiEV16Pu3lWJuKpRXPqGmS1TsNTPXElTJPVI2iXpcUlPuPvgeMaD+OBZFFE8iwIYC5KPgCQze6Okjyv4RWAk90n6mLv/rsxx1Ej6qIINkOaN0KRf0k2S/t7dN5QzFgQq2TfMbL6Cza5eJelMBQ95+ayR9DVJX3P3vlLHggPF5ftGPmb2fklfiVR/0t3/uQLhVJW49Q0zO0vBz5YLJNXlaTakYHmZj7n7neMRVzWKQ99Izax5n6T3SzpqlKY7Jf23pE8z46Z8Ukm+oxVsOJl+nSCpMdTsz8bx+wXPojERh77Bs2g8xaFvFIJnUYSRfERVSz2AXy3pHQU0H1bwAP7xMsUyW9LNyt3tPJ+9ki519xvLEQsq3zfM7GWSfinJijz1T5Le5O6rShULclW6bxTCzA6TtErBSKYwHvjKKG59w8yaFDz0v0OFfy/5B3f/QrliqlZx6RtmNkvBxoaFPGukvSDpDe5+T6njqXZm9j+SXi6p+SBNxyWJwLNofMShb/AsGk9x6BuF4FkUUTWVDgCosC8p9xeBbknfl/SogtEhKyW9QVKtgg2a/snMdrv7l0oZhJk1SrpRuQ97WyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxkLMirdN5qU+7A3LOkxSXdL2iipQ1K7gg2q/kLZ0UzHSLrDzM509zUligW5Kt03CvF1Hfiwh/KLTd8ws2YFSaZzQtU9kn6jYITjdklJBVPnXiTpPAU/W1AeFe8bqSmTv1buqMs+Sb9Q0Cd2S2qRdJyCkU7TUm1mSfpfM1tJMqHkTtLBEwjjgmfR2IlD3+BZNJ7i0DcKwbMocrk7L15V+ZJ0oSQPvf4kacEI7U5Q8PCVbjck6bgSx/LvkVh+Iql+hHYXK5jukm63SVJDpT+Xk+0Vh74h6bWpa65TMP1p7ihtFyqYqheO+a5Kfx4n4ysOfaOAGN8Suu+qSLz/XOnP4WR9xa1vSLo1Es+1kmaN0r5W0uskvaLSn8vJ9opL35B0RSSORyUtydN2iqQfRdr/utKfy8n2krQh9PntlfSQgl/Yr4t87s8dh1h4Fo3RKw59QzyLxvIVh75RQIw8i/I64JUQUIVS62R8OlTVLenV7v5ctK27PybpIgV/7ZOCEQmfjrYbQyyHSfpAqOpxSRf7CGuluPv1kj4RqlqgYM0mlEiM+sYLki6TdKS7f87dn8/X0N03KZh+8Uyo+iwzOzvPKTgEMeobo8U4XdKXU8VeSR8q9z0Rv75hZn+lYIRS2ufd/VJ3fyHfOe4+4O4/d/dfljKWahezvvH20HFPKo71IzV0932S3qrgmSTtpWY20hqAOHTXSnqPgpFMU9z9FHd/n4IRyuOGZ9FYikPf4Fk0nuLQN/LiWRT5kHxEtXqpcqcd/YePskObu9+n4C/AaX9uZstLFMv7JDWEyle4+8Ao7b+gYGRE2odLFAcCsegb7n6fu3/rIH0h3H6fpE9Gqv98rHEgRyz6xkF8ScEUSUn6lILF31F+sekbZjZFwc+JtAck/Z9SXBuHJBZ9w8waFOxQm3bzSAnQSCyDkr4dvozyb5SDQ+Dun3D3b7v7Hwr9eV8mPIvGTBz6Bs+i8RSHvnEQPItiRCQfUa1eFylfXcA5346UX1uaUHJi2SjpttEap34ZuCZUdZiZFbNwPEYXp75RrNsj5WUViWLyinXfSC0M/7ZUcZWkz5frXjhAnPrGJZLaQuUr3H04T1uUX1z6xvRIudBfBldHytNGbIWJjmdRlArPolWMZ1GMhuQjqtWFoeO17r62gHPuVjB0PG3Mf8kzsyWSjg5V3e4eLJRxEL+OlPmrYunEom8cov2R8kRYjHoiiW3fSG0u8s1U0SVdFtO/hk9Wceob7wkdP+Pud5foujg0cekbnQq+N6QV+vOhJVLOO3UfExPPoigxnkWrFM+iOBiSj6g6ZtamYFHktAcKOc/d+yX9PlRViqlHJ0TKBcWiYGHhwRLHUvVi1jcOxZJIeVtFopiEJkDf+JSkxanjq939njLdBxFx6htmNkPBztVpt471mjh0ceob7t6lYJfatPMKPPWloeP0xgaYXHgWRSnxLFq9eBbFqEg+ohodHSkXsw5FeMRCu5nNqUQs7t4raWuoakW+tihKnPrGoXh9pHx/BWKYrGLbN8zsFGUX896uYEdKjJ849Y1TIuX7pWDxdzP7iJndY2bPm1lf6t/7zOxTZnb4GO+LkcWpb0jSf4aOjzWzUTcJMbOTJb0rVPUtd99bgjgQLzyLopR4Fq1CPIuiECQfUY2WRsqbijg32jZ6rUrFMtY4EIhT3yiKmbVIujxU1S/pxvGMYZKLZd8ws1pJ31H25/lH3L2jVNdHQeLUN14cKT9tZm+Q9LSk/yfpDElzJNWl/j1N0sckPWVmXzOz+jHeH7ni1DekYI2+8M+F/0x93Y8KNzKzOWZ2haTfSkr3iYckXVmCGBA/PIuiJHgWrU48i6JQJB9RjVoj5d1FnBv9RjolJrHU8ktjScSpbxTri5LmhsrfcHemupROXPvGP0o6NnV8m7v/oITXRmHi1DdmRsrnKtg5eUaq7JJ2SHpe0lCoXVLBbre/MbPGMcaArDj1DaXW8XuTpKsUTJc1BV/3p8xsj5mtN7N0//icgrXaBiR9XdJLU1O3MfnwLIpS4Vm0OvEsioKQfEQ1ii6e3jtiq5H1HORaEzkWTNCvh5ldqtxNJjZJ+vh43b9KxK5vmNnRCkatpe/xvlJcF0WLU99oi5S/qCDB1CfpnyXNd/dZ7j5Pwe7Hlys30XCGgkQTSiNOfUNSsJ6ku39EwS+Kd4bealWwVteMUN0mSa9198vdPbqJBCaP2PVTTDw8i1YnnkVRDJKPqEYNkXJ/Eef2RcpjHSESp1gwAb8eZnaOpG+HqgYkvYV1uUouVn3DzEzB1z09yuRf3H3dWK+LQxKnvhH9xb9WwfeEV7n7J939+fQb7r7H3b8u6UxJu0LnvD211h/GLk59Q5JkZgkz+4ikuySdc5DmCyXdYma/NjOm1E5eseunmFh4Fq1OPIuiWCQfUY2if9GtK+Lc6HSS6F98J3IsmGBfDzM7SdIvlI3TJb3T3Vncu/Ti1jcuVzBKTZKeUDDCDZURp74x0oilL7r7HflOcPenJP1tpPrDY4wDgTj1DZlZg6SbFaz/OStVfbuk1yqYKlknqV1BUvLbyk7NP1/SI2Z24lhjQCzFqp9iYuFZtKrxLIqikHxENYpOHYr+xXc00b/ojnUaUpxiwQT6epjZcZJ+pdy1mi539++X875VLDZ9w8wWSPpMquiSLnP3gbFcE2MSm74haV+k7JL+o4DzrlewO2Xa+WOMA4E49Q1J+rKkV4bKV7r7Be5+o7tvc/cBd+9097vc/T2SXqZsYqpd0s9SG0pgcolbP8UEwbNo9eJZFIeC5COqUXQKQHsR57ZFytFf9IpVqlgG3D069QXFi1PfyCu1M+ntCtZsS/uwu3+jXPdErPrG15XdfOIbjC6ouDj1jWgsT4enWufj7oOS7glVzTKzw8YYC2LUN1Lrcv11qOoX7v6ZfO0lKTVi9mOhqkWSLhtLHIglnkVRNJ5Fqx7PoigayUdUo/WR8sIizl0UKY91XYtSxcL6GqURp74xIjM7XNIdyk6Zk6R/dPcvl+N+yIhF3zCz10i6MFXcJun/HOq1UDKx6BspayPlTUWcuzFSju6cjeLFqW+8RcHmQ2lfKfC8byp3DcDXjzEOxA/PoigKz6LVjWdRHKqaSgcAVMCqSHl5EecuCx13uPu2MsRy50gNw1LrNs0b5To4NHHqGwdILfh/h4K1udI+4e6fK/W9cIC49I3wpg9Nkn4frPedV/Tn/IfM7JJQ+VPu/t0xxIP49A1J+lOkXMyutdG2xUy9xMji1DeOj5QfKeQkd+8ys6dD5x8zxjgQPzyLomA8i0I8i+IQkXxE1XH3TjPbpOxfdk8r5Dwzq5N0UqjqiRKE81ikfJqk7xRw3inK/f9biliqXsz6RvQeiyT9VlJ4KuSn3P1fS30vHCimfaNVuessFaJduVPq2koWTZWKWd94UsEmIclUeVoR50bb7hqxFQoWs77RHCkXszZfV+iY3YwnH55FURCeRTECnkVRMKZdo1rdGjpelvor3sGcpdyRIDePNQh3Xy/p6VDV+XaQPx2lXBApjzkWZMSib4Sl1l67Q7lToT7n7h8v5X1wULHrG4iNWPQNd9+j3BFLx5tZoc96Lw4dD0jaPNZ4ICkmfUNSR6Q8p4hzwyOcSEpPMjyLohA8iwIYK5KPqFY/j5T/esRWo7e5oTSh5MSySMHuknmZWY2kd4aqtqjA6VMoSJz6hsxsroKHvfAvrP/P3f+xVPdAwSreN9z9Kne3Ql+SlkQu8clIm6vGEg8yKt43Qn4aOp6qg/xMkSQzWyLp5FDVA+7eXaJ4ql1c+saaSDmaOBpRam23xaGqZ0sQC+KHZ1HkxbMowngWxaEi+YhqdbuC6WlpH0z98jUiMztN0kWhqlvcfXWetovNzEOv3x0klq9LCu8O+Hkzqx2l/d9Lmh8qX+XufpB7oHCx6RtmNjMVz+Gh6v9w97872AeBsohN30DsxKlvXCdpe6j82dQ03tF8UbnPhP99kPYoXFz6xi8j5SvNbMqILXNF13H7VQHnoMJ4FkU+PIsiH55FUW4kH1GV3H1Y0pWhqmZJN5nZgmhbMzte0k+U/f8yLOljJYzlOUlfDVUdL+n7ZlY/Qix/KemToaotKnzHShQgLn3DzNol/VrSilD119z9b0pxfRQvLn0D8ROnvuHu+yX931DVCZJ+lvqeEo2l3sy+Kul1oepnJV1bqniqXVz6hrvfLenhUNUySbemplIewMyazOxq5faNvZK+XYp4EC88i2IkPIsCKCU2nEHVcvebzOxrki5PVR0j6Skz+76kRyXVSjpV0htTx2kfdffo4txj9XFJZ0t6Sap8kaTTzew6SesULMr7KknnhM7pk/QWdy9mN1MUICZ94wMKkgZhrzCz6NS50Wx293NLFA8Um76BGIpZ3/iWgp8Xf5kqXyhpjZn9WNLjkgYVjGJ5k4Iplmn7Jb3B3QdKHE9Vi1HfuEzSXZJaUuUzFfSLX0h6UMF6js0KEk9vkDQ9cv7fuPvOEsZT9czs9ZI+P8Jb0VGp3zeznhHaXeHuPytRODyLxkhM+gbPojEUk74BFI3kI6rdhxR8o35bqtws6T152rqkz7r7F0odhLt3m9mrJd0i6cRU9XxJ+dZS2Sfp7e5+T6ljQUal+0ZyhLpCNioI43t8eVS6byC+YtE33N3N7B0KRtC9OVU9TdJ7Rzlti6TXufuTo7TBoat433D3P5rZhZJ+qOwmMvUKkkwX5T1R6pX0EXf/binjgaRgl9hlBbSbN8r5JcGzaOzEoW/wLBpPcegbQNGYdo2q5u5D7n6pgl/ORvuF6wFJ57v7laO0GWss2xSMfPiEpG15mvUrWBT8BHePLmKPEopT30C80DeQT5z6hrv3u/tbFIxufHSUpnsUjKA4wd0fHqUdxiAufcPd75J0rKR/U/5njbRuSddIerG7f6Mc8SBeeBYFAJSLsTYwkGVmxyqYbjRP0pCkrZIedvd14xxHUtLpkpZLmq3gr8ubJd3t7rvHMxYE4tI3ED/0DeQTp75hZkdIenEqljoFU2xXSXrI3QfHO55qF4e+YWYm6WhJL5I0U8HIzB5JuxX0jUfdvS/vBTCp8SwKACglko8AAAAAAAAAyoJp1wAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACMwMwSZnaMmb3dzP7TzO43s24z89Dr3ErHmWZmGyKxHcrrd6WMqaaUFwMAAAAAAAAmAzP7H0kvl9Rc6VjGWWcpL0byEQAAAAAAADjQSZp4iccNkgaLPGeepMZQ+Qcli0YkHwEAAAAAAICD6ZP0uKTfS2qRdEllwxmZu59bTHszq5e0Rdnk4y5JN5QyJpKPAAAAAAAAwIGulfScgoTjE+4+IElm9g7FNPl4CF4raXqofJ2795XyBiQfAQAAAAAAgAh3/8R43cvMTNKJklZImiXJJG2X9Ad3/1MZb/3uSPk7pb4ByUcAAAAAAACgAsxsiqSPKkgCzs7TZrWk/+vuJV2L0cwWS3ppqOpBd3+ylPeQpESpLwgAAAAAAABgdGZ2qqTVkj6mPInHlMMlXW9mPzaz2hKG8C4FIyzTri7htTMY+QgAAAAAAACMIzP7M0k3S2oKVT+TqlurYMfqIyW9SdKC1PsXSXJJby7B/ROS3hGq6pL0o7FedyQkHwEAAAAAAIBxYmazJP1A2cRjr6T3S7rG3T3S9uOSviTpslTVm8zsZne/boxhvEzZpKYk/cjd943xmiNi2jUAAAAAAAAwfj6r7DTrYUmvc/f/iiYeJcnde9z9vZL+J1T9r6mRi2MR3WimLFOuJZKPAAAAAAAAwLgwszmS3hqqutrdf1nAqR+SNJA6XiTpVWOIYaak14SqVrn7/Yd6vYMh+QgAAAAAAACMjzdKqguVv1TISe6+VdLtoaoLxhDDpZLCG9d8ZwzXOiiSjwAAAAAAAMD4OCt0vM7dny7i3IdCxyvHEMO7Qsf9kq4dw7UOiuQjAAAAAAAAMD5OCB3/qchzt4eODzuUm5vZaZJWhKpudPedh3KtQrHbNQAAAAAAADA+poeOX21mB2wyU6D2Qzxv3DaaSWPkIwAAAAAAADA+2kp0naZiTzCzFklvClVtVO46kmXByEcAAAAAAABgfHRLak0dd0jaPY73foukllD5GncfLvdNST4CAAAAAAAA42OnssnHn7j7ZeN4778KHQ9LumY8bsq0awAAAAAAAGB8hHe3Pma8bmpmx0g6NVR1m7tvGo97k3wEAAAAAAAAxsdvQ8enmtmMcbrvX0XK3xmn+5J8BAAAAAAAAMbJTyUNpo6Tkv6h3Dc0szpJbwtV7ZB0Y7nvm0byEQAAAAAAABgH7r5B0g9CVX9rZi8r5hoWqCvilL+QFB5hea27DxRzz7Eg+QgAAAAAAACMnyskPZ86rpF0k5n9nZk1jHaSmc01sw8qWDfyxCLuV7Ep15Jk7j6e9wMAAAAAAABiz8xeL+nzI7w1RdKsUHmrpJ4R2l3h7j/Lc+3TJP1S2Z2vpWAn7F9JelTSbgXTstskHaEg2fhiSZZqe5q7P1DAx7BQ0nplByDe5+5nHOy8UqoZz5sBAAAAAAAAE0SrpGUFtJs3yvkjcvf7zexUSTcoSC5KwdTot6ZeBzNUQBtJeqdyZz5fXeB5JcO0awAAAAAAAGCcuftTko6V9F5Jqwo4ZZWkL0p6sbs/fLDGZmYKko9p+yT9+BBCHROmXQMAAAAAAAAVZmbzJZ0qabakdkn9kjokrZX0pLvvqGB4h4zkIwAAAAAAAICyYNo1AAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMri/wMZEmFf1HC9MAAAAABJRU5ErkJggg==\n",
|
217 |
+
"text/plain": [
|
218 |
+
"<Figure size 1500x750 with 1 Axes>"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"metadata": {
|
222 |
+
"needs_background": "light"
|
223 |
+
},
|
224 |
+
"output_type": "display_data"
|
225 |
+
}
|
226 |
+
],
|
227 |
+
"source": [
|
228 |
+
"gene_detection_counts = [i for i in gene_detection_counts_dict.values()]\n",
|
229 |
+
"import seaborn as sns\n",
|
230 |
+
"import matplotlib.pyplot as plt\n",
|
231 |
+
"plt.figure(figsize=(10,5), dpi=150)\n",
|
232 |
+
"plt.rcParams.update({'font.size': 18})\n",
|
233 |
+
"count_plot = sns.distplot(gene_detection_counts).set_title(f\"# Cells Expressing Each\\nProtein-Coding or miRNA Gene\")"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 47,
|
239 |
+
"id": "missing-bradley",
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [
|
242 |
+
{
|
243 |
+
"data": {
|
244 |
+
"text/plain": [
|
245 |
+
"27454"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
"execution_count": 47,
|
249 |
+
"metadata": {},
|
250 |
+
"output_type": "execute_result"
|
251 |
+
}
|
252 |
+
],
|
253 |
+
"source": [
|
254 |
+
"len(gene_detection_counts)"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 55,
|
260 |
+
"id": "perfect-signal",
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [
|
263 |
+
{
|
264 |
+
"data": {
|
265 |
+
"text/plain": [
|
266 |
+
"25424"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
"execution_count": 55,
|
270 |
+
"metadata": {},
|
271 |
+
"output_type": "execute_result"
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"len([i for i in gene_detection_counts if i > 0])"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 56,
|
281 |
+
"id": "faced-theory",
|
282 |
+
"metadata": {},
|
283 |
+
"outputs": [
|
284 |
+
{
|
285 |
+
"data": {
|
286 |
+
"text/plain": [
|
287 |
+
"22735"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
"execution_count": 56,
|
291 |
+
"metadata": {},
|
292 |
+
"output_type": "execute_result"
|
293 |
+
}
|
294 |
+
],
|
295 |
+
"source": [
|
296 |
+
"len([i for i in gene_detection_counts if i > 100])"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "code",
|
301 |
+
"execution_count": 57,
|
302 |
+
"id": "tough-workplace",
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [
|
305 |
+
{
|
306 |
+
"data": {
|
307 |
+
"text/plain": [
|
308 |
+
"21167"
|
309 |
+
]
|
310 |
+
},
|
311 |
+
"execution_count": 57,
|
312 |
+
"metadata": {},
|
313 |
+
"output_type": "execute_result"
|
314 |
+
}
|
315 |
+
],
|
316 |
+
"source": [
|
317 |
+
"len([i for i in gene_detection_counts if i > 1000])"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"execution_count": 49,
|
323 |
+
"id": "cooperative-camcorder",
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [
|
326 |
+
{
|
327 |
+
"data": {
|
328 |
+
"text/plain": [
|
329 |
+
"173152.0299000284"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"execution_count": 49,
|
333 |
+
"metadata": {},
|
334 |
+
"output_type": "execute_result"
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"source": [
|
338 |
+
"gene_detection_event_digest = crick.tdigest.TDigest()\n",
|
339 |
+
"gene_detection_event_digest.update(gene_detection_counts)\n",
|
340 |
+
"gene_detection_event_digest.quantile(0.5)"
|
341 |
+
]
|
342 |
+
}
|
343 |
+
],
|
344 |
+
"metadata": {
|
345 |
+
"kernelspec": {
|
346 |
+
"display_name": "Python 3 (ipykernel)",
|
347 |
+
"language": "python",
|
348 |
+
"name": "python3"
|
349 |
+
},
|
350 |
+
"language_info": {
|
351 |
+
"codemirror_mode": {
|
352 |
+
"name": "ipython",
|
353 |
+
"version": 3
|
354 |
+
},
|
355 |
+
"file_extension": ".py",
|
356 |
+
"mimetype": "text/x-python",
|
357 |
+
"name": "python",
|
358 |
+
"nbconvert_exporter": "python",
|
359 |
+
"pygments_lexer": "ipython3",
|
360 |
+
"version": "3.10.11"
|
361 |
+
}
|
362 |
+
},
|
363 |
+
"nbformat": 4,
|
364 |
+
"nbformat_minor": 5
|
365 |
+
}
|
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# run with:
|
5 |
+
# deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
# imports
|
10 |
+
import os
|
11 |
+
|
12 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
13 |
+
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
|
14 |
+
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
15 |
+
|
16 |
+
import pickle
|
17 |
+
import random
|
18 |
+
import subprocess
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import pytz
|
22 |
+
import torch
|
23 |
+
from datasets import load_from_disk
|
24 |
+
from transformers import BertConfig, BertForMaskedLM, TrainingArguments
|
25 |
+
|
26 |
+
from geneformer import GeneformerPretrainer
|
27 |
+
|
28 |
+
seed_num = 0
|
29 |
+
random.seed(seed_num)
|
30 |
+
np.random.seed(seed_num)
|
31 |
+
seed_val = 42
|
32 |
+
torch.manual_seed(seed_val)
|
33 |
+
torch.cuda.manual_seed_all(seed_val)
|
34 |
+
|
35 |
+
# set local time/directories
|
36 |
+
timezone = pytz.timezone("US/Eastern")
|
37 |
+
rootdir = "/parent_ouput_directory"
|
38 |
+
|
39 |
+
# set model parameters
|
40 |
+
# model type
|
41 |
+
model_type = "bert"
|
42 |
+
# max input size
|
43 |
+
max_input_size = 2**11 # 2048
|
44 |
+
# number of layers
|
45 |
+
num_layers = 6
|
46 |
+
# number of attention heads
|
47 |
+
num_attn_heads = 4
|
48 |
+
# number of embedding dimensions
|
49 |
+
num_embed_dim = 256
|
50 |
+
# intermediate size
|
51 |
+
intermed_size = num_embed_dim * 2
|
52 |
+
# activation function
|
53 |
+
activ_fn = "relu"
|
54 |
+
# initializer range, layer norm, dropout
|
55 |
+
initializer_range = 0.02
|
56 |
+
layer_norm_eps = 1e-12
|
57 |
+
attention_probs_dropout_prob = 0.02
|
58 |
+
hidden_dropout_prob = 0.02
|
59 |
+
|
60 |
+
|
61 |
+
# set training parameters
|
62 |
+
# total number of examples in Genecorpus-30M after QC filtering:
|
63 |
+
num_examples = 27_406_208
|
64 |
+
# number gpus
|
65 |
+
num_gpus = 12
|
66 |
+
# batch size for training and eval
|
67 |
+
geneformer_batch_size = 12
|
68 |
+
# max learning rate
|
69 |
+
max_lr = 1e-3
|
70 |
+
# learning schedule
|
71 |
+
lr_schedule_fn = "linear"
|
72 |
+
# warmup steps
|
73 |
+
warmup_steps = 10_000
|
74 |
+
# number of epochs
|
75 |
+
epochs = 3
|
76 |
+
# optimizer
|
77 |
+
optimizer = "adamw"
|
78 |
+
# weight_decay
|
79 |
+
weight_decay = 0.001
|
80 |
+
|
81 |
+
|
82 |
+
# output directories
|
83 |
+
current_date = datetime.datetime.now(tz=timezone)
|
84 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
|
85 |
+
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
|
86 |
+
training_output_dir = f"{rootdir}/models/{run_name}/"
|
87 |
+
logging_dir = f"{rootdir}/runs/{run_name}/"
|
88 |
+
model_output_dir = os.path.join(training_output_dir, "models/")
|
89 |
+
|
90 |
+
|
91 |
+
# ensure not overwriting previously saved model
|
92 |
+
model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
|
93 |
+
if os.path.isfile(model_output_file) is True:
|
94 |
+
raise Exception("Model already saved to this directory.")
|
95 |
+
|
96 |
+
|
97 |
+
# make training and model output directories
|
98 |
+
subprocess.call(f"mkdir {training_output_dir}", shell=True)
|
99 |
+
subprocess.call(f"mkdir {model_output_dir}", shell=True)
|
100 |
+
|
101 |
+
|
102 |
+
# load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/token_dictionary.pkl)
|
103 |
+
with open("token_dictionary.pkl", "rb") as fp:
|
104 |
+
token_dictionary = pickle.load(fp)
|
105 |
+
|
106 |
+
# model configuration
|
107 |
+
config = {
|
108 |
+
"hidden_size": num_embed_dim,
|
109 |
+
"num_hidden_layers": num_layers,
|
110 |
+
"initializer_range": initializer_range,
|
111 |
+
"layer_norm_eps": layer_norm_eps,
|
112 |
+
"attention_probs_dropout_prob": attention_probs_dropout_prob,
|
113 |
+
"hidden_dropout_prob": hidden_dropout_prob,
|
114 |
+
"intermediate_size": intermed_size,
|
115 |
+
"hidden_act": activ_fn,
|
116 |
+
"max_position_embeddings": max_input_size,
|
117 |
+
"model_type": model_type,
|
118 |
+
"num_attention_heads": num_attn_heads,
|
119 |
+
"pad_token_id": token_dictionary.get("<pad>"),
|
120 |
+
"vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
|
121 |
+
}
|
122 |
+
|
123 |
+
config = BertConfig(**config)
|
124 |
+
model = BertForMaskedLM(config)
|
125 |
+
model = model.train()
|
126 |
+
|
127 |
+
# define the training arguments
|
128 |
+
training_args = {
|
129 |
+
"learning_rate": max_lr,
|
130 |
+
"do_train": True,
|
131 |
+
"do_eval": False,
|
132 |
+
"group_by_length": True,
|
133 |
+
"length_column_name": "length",
|
134 |
+
"disable_tqdm": False,
|
135 |
+
"lr_scheduler_type": lr_schedule_fn,
|
136 |
+
"warmup_steps": warmup_steps,
|
137 |
+
"weight_decay": weight_decay,
|
138 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
+
"num_train_epochs": epochs,
|
140 |
+
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(
|
142 |
+
num_examples / geneformer_batch_size / 8
|
143 |
+
), # 8 saves per epoch
|
144 |
+
"logging_steps": 1000,
|
145 |
+
"output_dir": training_output_dir,
|
146 |
+
"logging_dir": logging_dir,
|
147 |
+
}
|
148 |
+
training_args = TrainingArguments(**training_args)
|
149 |
+
|
150 |
+
print("Starting training.")
|
151 |
+
|
152 |
+
# define the trainer
|
153 |
+
trainer = GeneformerPretrainer(
|
154 |
+
model=model,
|
155 |
+
args=training_args,
|
156 |
+
# pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
|
157 |
+
train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
|
158 |
+
# file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
|
159 |
+
example_lengths_file="genecorpus_30M_2048_lengths.pkl",
|
160 |
+
token_dictionary=token_dictionary,
|
161 |
+
)
|
162 |
+
|
163 |
+
# train
|
164 |
+
trainer.train()
|
165 |
+
|
166 |
+
# save model
|
167 |
+
trainer.save_model(model_output_dir)
|
examples/tokenizing_scRNAseq_data.ipynb
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
|
6 |
+
"metadata": {
|
7 |
+
"tags": []
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"## Tokenizing .loom or .h5ad single cell RNA-seq data to rank value encoding .dataset format"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"id": "1fe86f48-5578-47df-b373-58c21ec170ab",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
|
19 |
+
"\n",
|
20 |
+
"#### The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.\n",
|
21 |
+
"\n",
|
22 |
+
"#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
|
23 |
+
"\n",
|
24 |
+
"#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
|
25 |
+
"\n",
|
26 |
+
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
+
"\n",
|
28 |
+
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "markdown",
|
33 |
+
"id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
|
34 |
+
"metadata": {},
|
35 |
+
"source": [
|
36 |
+
"**********************************************************************************************************\n",
|
37 |
+
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
38 |
+
"#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
|
39 |
+
"\n",
|
40 |
+
"#### ADDITIONALLY:\n",
|
41 |
+
"#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
|
42 |
+
"#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "080fdd9c-0c48-4d5d-a254-52b6c53cdf78",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"from geneformer import TranscriptomeTokenizer"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"id": "37205758-aa52-4443-a383-0638519ee8a9",
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
|
63 |
+
"tk.tokenize_data(\"loom_data_directory\", \n",
|
64 |
+
" \"output_directory\", \n",
|
65 |
+
" \"output_prefix\", \n",
|
66 |
+
" file_format=\"loom\")"
|
67 |
+
]
|
68 |
+
}
|
69 |
+
],
|
70 |
+
"metadata": {
|
71 |
+
"kernelspec": {
|
72 |
+
"display_name": "Python 3 (ipykernel)",
|
73 |
+
"language": "python",
|
74 |
+
"name": "python3"
|
75 |
+
},
|
76 |
+
"language_info": {
|
77 |
+
"codemirror_mode": {
|
78 |
+
"name": "ipython",
|
79 |
+
"version": 3
|
80 |
+
},
|
81 |
+
"file_extension": ".py",
|
82 |
+
"mimetype": "text/x-python",
|
83 |
+
"name": "python",
|
84 |
+
"nbconvert_exporter": "python",
|
85 |
+
"pygments_lexer": "ipython3",
|
86 |
+
"version": "3.10.15"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"nbformat": 4,
|
90 |
+
"nbformat_minor": 5
|
91 |
+
}
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"classifier_dropout": null,
|
7 |
+
"hidden_act": "relu",
|
8 |
+
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 512,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 1024,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 4096,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 8,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"position_embedding_type": "absolute",
|
19 |
+
"torch_dtype": "float32",
|
20 |
+
"transformers_version": "4.37.2",
|
21 |
+
"type_vocab_size": 2,
|
22 |
+
"use_cache": true,
|
23 |
+
"vocab_size": 20275
|
24 |
+
}
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.02,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "relu",
|
9 |
+
"hidden_dropout_prob": 0.02,
|
10 |
+
"hidden_size": 256,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2"
|
15 |
+
},
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"intermediate_size": 512,
|
18 |
+
"label2id": {
|
19 |
+
"LABEL_0": 0,
|
20 |
+
"LABEL_1": 1,
|
21 |
+
"LABEL_2": 2
|
22 |
+
},
|
23 |
+
"layer_norm_eps": 1e-12,
|
24 |
+
"max_position_embeddings": 2048,
|
25 |
+
"model_type": "bert",
|
26 |
+
"num_attention_heads": 4,
|
27 |
+
"num_hidden_layers": 6,
|
28 |
+
"pad_token_id": 0,
|
29 |
+
"position_embedding_type": "absolute",
|
30 |
+
"problem_type": "single_label_classification",
|
31 |
+
"transformers_version": "4.6.0",
|
32 |
+
"type_vocab_size": 2,
|
33 |
+
"use_cache": true,
|
34 |
+
"vocab_size": 25426
|
35 |
+
}
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.39658036828041077,
|
3 |
+
"best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
|
4 |
+
"epoch": 0.9,
|
5 |
+
"global_step": 7020,
|
6 |
+
"is_hyper_param_search": true,
|
7 |
+
"is_local_process_zero": true,
|
8 |
+
"is_world_process_zero": true,
|
9 |
+
"log_history": [
|
10 |
+
{
|
11 |
+
"epoch": 0.1,
|
12 |
+
"learning_rate": 0.00034606438343856935,
|
13 |
+
"loss": 0.911,
|
14 |
+
"step": 780
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"epoch": 0.1,
|
18 |
+
"eval_accuracy": 0.4531576503366612,
|
19 |
+
"eval_loss": 1.4550466537475586,
|
20 |
+
"eval_runtime": 66.5164,
|
21 |
+
"eval_samples_per_second": 259.004,
|
22 |
+
"step": 780
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"epoch": 0.2,
|
26 |
+
"learning_rate": 0.0006921287668771387,
|
27 |
+
"loss": 0.6273,
|
28 |
+
"step": 1560
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"epoch": 0.2,
|
32 |
+
"eval_accuracy": 0.5953680055723242,
|
33 |
+
"eval_loss": 0.846651554107666,
|
34 |
+
"eval_runtime": 66.1267,
|
35 |
+
"eval_samples_per_second": 260.53,
|
36 |
+
"step": 1560
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"epoch": 0.3,
|
40 |
+
"learning_rate": 0.0007330550166223805,
|
41 |
+
"loss": 0.5592,
|
42 |
+
"step": 2340
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"epoch": 0.3,
|
46 |
+
"eval_accuracy": 0.5935105641978176,
|
47 |
+
"eval_loss": 1.0599186420440674,
|
48 |
+
"eval_runtime": 66.2608,
|
49 |
+
"eval_samples_per_second": 260.003,
|
50 |
+
"step": 2340
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"epoch": 0.4,
|
54 |
+
"learning_rate": 0.0006283471571048975,
|
55 |
+
"loss": 0.3714,
|
56 |
+
"step": 3120
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"epoch": 0.4,
|
60 |
+
"eval_accuracy": 0.686324587880195,
|
61 |
+
"eval_loss": 1.184874415397644,
|
62 |
+
"eval_runtime": 66.1411,
|
63 |
+
"eval_samples_per_second": 260.473,
|
64 |
+
"step": 3120
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"epoch": 0.5,
|
68 |
+
"learning_rate": 0.0005236392975874146,
|
69 |
+
"loss": 0.2976,
|
70 |
+
"step": 3900
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"epoch": 0.5,
|
74 |
+
"eval_accuracy": 0.7681100534014396,
|
75 |
+
"eval_loss": 0.6318939328193665,
|
76 |
+
"eval_runtime": 66.3309,
|
77 |
+
"eval_samples_per_second": 259.728,
|
78 |
+
"step": 3900
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"epoch": 0.6,
|
82 |
+
"learning_rate": 0.0004189314380699318,
|
83 |
+
"loss": 0.2564,
|
84 |
+
"step": 4680
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"epoch": 0.6,
|
88 |
+
"eval_accuracy": 0.7807058277223126,
|
89 |
+
"eval_loss": 0.7283642888069153,
|
90 |
+
"eval_runtime": 66.3416,
|
91 |
+
"eval_samples_per_second": 259.686,
|
92 |
+
"step": 4680
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"epoch": 0.7,
|
96 |
+
"learning_rate": 0.0003142235785524487,
|
97 |
+
"loss": 0.2336,
|
98 |
+
"step": 5460
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"epoch": 0.7,
|
102 |
+
"eval_accuracy": 0.8563965637334572,
|
103 |
+
"eval_loss": 0.5184123516082764,
|
104 |
+
"eval_runtime": 66.3416,
|
105 |
+
"eval_samples_per_second": 259.686,
|
106 |
+
"step": 5460
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"epoch": 0.8,
|
110 |
+
"learning_rate": 0.0002095157190349659,
|
111 |
+
"loss": 0.1731,
|
112 |
+
"step": 6240
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"epoch": 0.8,
|
116 |
+
"eval_accuracy": 0.8288832133735778,
|
117 |
+
"eval_loss": 0.5823884010314941,
|
118 |
+
"eval_runtime": 66.1535,
|
119 |
+
"eval_samples_per_second": 260.425,
|
120 |
+
"step": 6240
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"epoch": 0.9,
|
124 |
+
"learning_rate": 0.00010480785951748295,
|
125 |
+
"loss": 0.1451,
|
126 |
+
"step": 7020
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"epoch": 0.9,
|
130 |
+
"eval_accuracy": 0.886812166241003,
|
131 |
+
"eval_loss": 0.39658036828041077,
|
132 |
+
"eval_runtime": 66.3555,
|
133 |
+
"eval_samples_per_second": 259.632,
|
134 |
+
"step": 7020
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"max_steps": 7800,
|
138 |
+
"num_train_epochs": 1,
|
139 |
+
"total_flos": 0,
|
140 |
+
"trial_name": null,
|
141 |
+
"trial_params": {
|
142 |
+
"learning_rate": 0.0008039341830649843,
|
143 |
+
"lr_scheduler_type": "polynomial",
|
144 |
+
"num_train_epochs": 1,
|
145 |
+
"per_device_train_batch_size": 12,
|
146 |
+
"seed": 73.15243080311434,
|
147 |
+
"warmup_steps": 1812.6785581609881,
|
148 |
+
"weight_decay": 0.2588277764570262
|
149 |
+
}
|
150 |
+
}
|
geneformer/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: F401
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
6 |
+
|
7 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
8 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
|
9 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
|
10 |
+
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
|
11 |
+
|
12 |
+
from . import (
|
13 |
+
collator_for_classification,
|
14 |
+
emb_extractor,
|
15 |
+
in_silico_perturber,
|
16 |
+
in_silico_perturber_stats,
|
17 |
+
pretrainer,
|
18 |
+
tokenizer,
|
19 |
+
)
|
20 |
+
from .collator_for_classification import (
|
21 |
+
DataCollatorForCellClassification,
|
22 |
+
DataCollatorForGeneClassification,
|
23 |
+
)
|
24 |
+
from .emb_extractor import EmbExtractor, get_embs
|
25 |
+
from .in_silico_perturber import InSilicoPerturber
|
26 |
+
from .in_silico_perturber_stats import InSilicoPerturberStats
|
27 |
+
from .pretrainer import GeneformerPretrainer
|
28 |
+
from .tokenizer import TranscriptomeTokenizer
|
29 |
+
|
30 |
+
from . import classifier # noqa # isort:skip
|
31 |
+
from .classifier import Classifier # noqa # isort:skip
|
32 |
+
|
33 |
+
from . import mtl_classifier # noqa # isort:skip
|
34 |
+
from .mtl_classifier import MTLClassifier # noqa # isort:skip
|
geneformer/classifier.py
ADDED
@@ -0,0 +1,1563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer classifier.
|
3 |
+
|
4 |
+
**Input data:**
|
5 |
+
|
6 |
+
| Cell state classifier:
|
7 |
+
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
8 |
+
|
9 |
+
| Gene classifier:
|
10 |
+
| Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
11 |
+
|
12 |
+
**Usage:**
|
13 |
+
|
14 |
+
.. code-block :: python
|
15 |
+
|
16 |
+
>>> from geneformer import Classifier
|
17 |
+
>>> cc = Classifier(classifier="cell", # example of cell state classifier
|
18 |
+
... cell_state_dict={"state_key": "disease", "states": "all"},
|
19 |
+
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
|
20 |
+
... training_args=training_args,
|
21 |
+
... freeze_layers = 2,
|
22 |
+
... num_crossval_splits = 1,
|
23 |
+
... forward_batch_size=200,
|
24 |
+
... nproc=16)
|
25 |
+
>>> cc.prepare_data(input_data_file="path/to/input_data",
|
26 |
+
... output_directory="path/to/output_directory",
|
27 |
+
... output_prefix="output_prefix")
|
28 |
+
>>> all_metrics = cc.validate(model_directory="path/to/model",
|
29 |
+
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
|
30 |
+
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
31 |
+
... output_directory="path/to/output_directory",
|
32 |
+
... output_prefix="output_prefix",
|
33 |
+
... predict_eval=True)
|
34 |
+
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
35 |
+
... output_directory="path/to/output_directory",
|
36 |
+
... output_prefix="output_prefix",
|
37 |
+
... custom_class_order=["healthy","disease1","disease2"])
|
38 |
+
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
|
39 |
+
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
40 |
+
... title="disease",
|
41 |
+
... output_directory="path/to/output_directory",
|
42 |
+
... output_prefix="output_prefix",
|
43 |
+
... custom_class_order=["healthy","disease1","disease2"])
|
44 |
+
"""
|
45 |
+
|
46 |
+
import datetime
|
47 |
+
import logging
|
48 |
+
import os
|
49 |
+
import pickle
|
50 |
+
import subprocess
|
51 |
+
from pathlib import Path
|
52 |
+
|
53 |
+
import numpy as np
|
54 |
+
import pandas as pd
|
55 |
+
import seaborn as sns
|
56 |
+
from tqdm.auto import tqdm, trange
|
57 |
+
from transformers import Trainer
|
58 |
+
from transformers.training_args import TrainingArguments
|
59 |
+
|
60 |
+
from . import (
|
61 |
+
TOKEN_DICTIONARY_FILE,
|
62 |
+
DataCollatorForCellClassification,
|
63 |
+
DataCollatorForGeneClassification,
|
64 |
+
)
|
65 |
+
from . import classifier_utils as cu
|
66 |
+
from . import evaluation_utils as eu
|
67 |
+
from . import perturber_utils as pu
|
68 |
+
|
69 |
+
sns.set()
|
70 |
+
|
71 |
+
|
72 |
+
logger = logging.getLogger(__name__)
|
73 |
+
|
74 |
+
|
75 |
+
class Classifier:
|
76 |
+
valid_option_dict = {
|
77 |
+
"classifier": {"cell", "gene"},
|
78 |
+
"quantize": {bool, dict},
|
79 |
+
"cell_state_dict": {None, dict},
|
80 |
+
"gene_class_dict": {None, dict},
|
81 |
+
"filter_data": {None, dict},
|
82 |
+
"rare_threshold": {int, float},
|
83 |
+
"max_ncells": {None, int},
|
84 |
+
"max_ncells_per_class": {None, int},
|
85 |
+
"training_args": {None, dict},
|
86 |
+
"freeze_layers": {int},
|
87 |
+
"num_crossval_splits": {0, 1, 5},
|
88 |
+
"split_sizes": {None, dict},
|
89 |
+
"no_eval": {bool},
|
90 |
+
"stratify_splits_col": {None, str},
|
91 |
+
"forward_batch_size": {int},
|
92 |
+
"token_dictionary_file": {None, str},
|
93 |
+
"nproc": {int},
|
94 |
+
"ngpu": {int},
|
95 |
+
}
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
classifier=None,
|
100 |
+
quantize=False,
|
101 |
+
cell_state_dict=None,
|
102 |
+
gene_class_dict=None,
|
103 |
+
filter_data=None,
|
104 |
+
rare_threshold=0,
|
105 |
+
max_ncells=None,
|
106 |
+
max_ncells_per_class=None,
|
107 |
+
training_args=None,
|
108 |
+
ray_config=None,
|
109 |
+
freeze_layers=0,
|
110 |
+
num_crossval_splits=1,
|
111 |
+
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
|
112 |
+
stratify_splits_col=None,
|
113 |
+
no_eval=False,
|
114 |
+
forward_batch_size=100,
|
115 |
+
token_dictionary_file=None,
|
116 |
+
nproc=4,
|
117 |
+
ngpu=1,
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
Initialize Geneformer classifier.
|
121 |
+
|
122 |
+
**Parameters:**
|
123 |
+
|
124 |
+
classifier : {"cell", "gene"}
|
125 |
+
| Whether to fine-tune a cell state or gene classifier.
|
126 |
+
quantize : bool, dict
|
127 |
+
| Whether to fine-tune a quantized model.
|
128 |
+
| If True and no config provided, will use default.
|
129 |
+
| Will use custom config if provided.
|
130 |
+
| Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
|
131 |
+
| For example: {"bnb_config": BitsAndBytesConfig(...),
|
132 |
+
| "peft_config": LoraConfig(...)}
|
133 |
+
cell_state_dict : None, dict
|
134 |
+
| Cell states to fine-tune model to distinguish.
|
135 |
+
| Two-item dictionary with keys: state_key and states
|
136 |
+
| state_key: key specifying name of column in .dataset that defines the states to model
|
137 |
+
| states: list of values in the state_key column that specifies the states to model
|
138 |
+
| Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
|
139 |
+
| Of note, if using "all", states will be defined after data is filtered.
|
140 |
+
| Must have at least 2 states to model.
|
141 |
+
| For example: {"state_key": "disease",
|
142 |
+
| "states": ["nf", "hcm", "dcm"]}
|
143 |
+
| or
|
144 |
+
| {"state_key": "disease",
|
145 |
+
| "states": "all"}
|
146 |
+
gene_class_dict : None, dict
|
147 |
+
| Gene classes to fine-tune model to distinguish.
|
148 |
+
| Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
|
149 |
+
| Gene_label_B: list(geneB1, geneB2, ...)}
|
150 |
+
| Gene values should be Ensembl IDs.
|
151 |
+
filter_data : None, dict
|
152 |
+
| Default is to fine-tune with all input data.
|
153 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
154 |
+
rare_threshold : float
|
155 |
+
| Threshold below which rare cell states should be removed.
|
156 |
+
| For example, setting to 0.05 will remove cell states representing
|
157 |
+
| < 5% of the total cells from the cell state classifier's possible classes.
|
158 |
+
max_ncells : None, int
|
159 |
+
| Maximum number of cells to use for fine-tuning.
|
160 |
+
| Default is to fine-tune with all input data.
|
161 |
+
max_ncells_per_class : None, int
|
162 |
+
| Maximum number of cells per cell class to use for fine-tuning.
|
163 |
+
| Of note, will be applied after max_ncells above.
|
164 |
+
| (Only valid for cell classification.)
|
165 |
+
training_args : None, dict
|
166 |
+
| Training arguments for fine-tuning.
|
167 |
+
| If None, defaults will be inferred for 6 layer Geneformer.
|
168 |
+
| Otherwise, will use the Hugging Face defaults:
|
169 |
+
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
170 |
+
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
171 |
+
ray_config : None, dict
|
172 |
+
| Training argument ranges for tuning hyperparameters with Ray.
|
173 |
+
freeze_layers : int
|
174 |
+
| Number of layers to freeze from fine-tuning.
|
175 |
+
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
176 |
+
num_crossval_splits : {0, 1, 5}
|
177 |
+
| 0: train on all data without splitting
|
178 |
+
| 1: split data into train and eval sets by designated split_sizes["valid"]
|
179 |
+
| 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
|
180 |
+
split_sizes : None, dict
|
181 |
+
| Dictionary of proportion of data to hold out for train, validation, and test sets
|
182 |
+
| {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
|
183 |
+
stratify_splits_col : None, str
|
184 |
+
| Name of column in .dataset to be used for stratified splitting.
|
185 |
+
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
186 |
+
no_eval : bool
|
187 |
+
| If True, will skip eval step and use all data for training.
|
188 |
+
| Otherwise, will perform eval during training.
|
189 |
+
forward_batch_size : int
|
190 |
+
| Batch size for forward pass (for evaluation, not training).
|
191 |
+
token_dictionary_file : None, str
|
192 |
+
| Default is to use token dictionary file from Geneformer
|
193 |
+
| Otherwise, will load custom gene token dictionary.
|
194 |
+
nproc : int
|
195 |
+
| Number of CPU processes to use.
|
196 |
+
ngpu : int
|
197 |
+
| Number of GPUs available.
|
198 |
+
|
199 |
+
"""
|
200 |
+
|
201 |
+
self.classifier = classifier
|
202 |
+
if self.classifier == "cell":
|
203 |
+
self.model_type = "CellClassifier"
|
204 |
+
elif self.classifier == "gene":
|
205 |
+
self.model_type = "GeneClassifier"
|
206 |
+
self.quantize = quantize
|
207 |
+
self.cell_state_dict = cell_state_dict
|
208 |
+
self.gene_class_dict = gene_class_dict
|
209 |
+
self.filter_data = filter_data
|
210 |
+
self.rare_threshold = rare_threshold
|
211 |
+
self.max_ncells = max_ncells
|
212 |
+
self.max_ncells_per_class = max_ncells_per_class
|
213 |
+
self.training_args = training_args
|
214 |
+
self.ray_config = ray_config
|
215 |
+
self.freeze_layers = freeze_layers
|
216 |
+
self.num_crossval_splits = num_crossval_splits
|
217 |
+
self.split_sizes = split_sizes
|
218 |
+
self.train_size = self.split_sizes["train"]
|
219 |
+
self.valid_size = self.split_sizes["valid"]
|
220 |
+
self.oos_test_size = self.split_sizes["test"]
|
221 |
+
self.eval_size = self.valid_size / (self.train_size + self.valid_size)
|
222 |
+
self.stratify_splits_col = stratify_splits_col
|
223 |
+
self.no_eval = no_eval
|
224 |
+
self.forward_batch_size = forward_batch_size
|
225 |
+
self.token_dictionary_file = token_dictionary_file
|
226 |
+
self.nproc = nproc
|
227 |
+
self.ngpu = ngpu
|
228 |
+
|
229 |
+
if self.training_args is None:
|
230 |
+
logger.warning(
|
231 |
+
"Hyperparameter tuning is highly recommended for optimal results. "
|
232 |
+
"No training_args provided; using default hyperparameters."
|
233 |
+
)
|
234 |
+
|
235 |
+
self.validate_options()
|
236 |
+
|
237 |
+
if self.filter_data is None:
|
238 |
+
self.filter_data = dict()
|
239 |
+
|
240 |
+
if self.classifier == "cell":
|
241 |
+
if self.cell_state_dict["states"] != "all":
|
242 |
+
self.filter_data[
|
243 |
+
self.cell_state_dict["state_key"]
|
244 |
+
] = self.cell_state_dict["states"]
|
245 |
+
|
246 |
+
# load token dictionary (Ensembl IDs:token)
|
247 |
+
if self.token_dictionary_file is None:
|
248 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
249 |
+
with open(self.token_dictionary_file, "rb") as f:
|
250 |
+
self.gene_token_dict = pickle.load(f)
|
251 |
+
|
252 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
253 |
+
|
254 |
+
# filter genes for gene classification for those in token dictionary
|
255 |
+
if self.classifier == "gene":
|
256 |
+
all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
|
257 |
+
missing_genes = [
|
258 |
+
gene
|
259 |
+
for gene in all_gene_class_values
|
260 |
+
if gene not in self.gene_token_dict.keys()
|
261 |
+
]
|
262 |
+
if len(missing_genes) == len(all_gene_class_values):
|
263 |
+
logger.error(
|
264 |
+
"None of the provided genes to classify are in token dictionary."
|
265 |
+
)
|
266 |
+
raise
|
267 |
+
elif len(missing_genes) > 0:
|
268 |
+
logger.warning(
|
269 |
+
f"Genes to classify {missing_genes} are not in token dictionary."
|
270 |
+
)
|
271 |
+
self.gene_class_dict = {
|
272 |
+
k: list(set([self.gene_token_dict.get(gene) for gene in v]))
|
273 |
+
for k, v in self.gene_class_dict.items()
|
274 |
+
}
|
275 |
+
empty_classes = []
|
276 |
+
for k, v in self.gene_class_dict.items():
|
277 |
+
if len(v) == 0:
|
278 |
+
empty_classes += [k]
|
279 |
+
if len(empty_classes) > 0:
|
280 |
+
logger.error(
|
281 |
+
f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
|
282 |
+
)
|
283 |
+
raise
|
284 |
+
|
285 |
+
def validate_options(self):
|
286 |
+
# confirm arguments are within valid options and compatible with each other
|
287 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
288 |
+
attr_value = self.__dict__[attr_name]
|
289 |
+
if not isinstance(attr_value, (list, dict)):
|
290 |
+
if attr_value in valid_options:
|
291 |
+
continue
|
292 |
+
valid_type = False
|
293 |
+
for option in valid_options:
|
294 |
+
if (option in [int, float, list, dict, bool, str]) and isinstance(
|
295 |
+
attr_value, option
|
296 |
+
):
|
297 |
+
valid_type = True
|
298 |
+
break
|
299 |
+
if valid_type:
|
300 |
+
continue
|
301 |
+
logger.error(
|
302 |
+
f"Invalid option for {attr_name}. "
|
303 |
+
f"Valid options for {attr_name}: {valid_options}"
|
304 |
+
)
|
305 |
+
raise
|
306 |
+
|
307 |
+
if self.filter_data is not None:
|
308 |
+
for key, value in self.filter_data.items():
|
309 |
+
if not isinstance(value, list):
|
310 |
+
self.filter_data[key] = [value]
|
311 |
+
logger.warning(
|
312 |
+
"Values in filter_data dict must be lists. "
|
313 |
+
f"Changing {key} value to list ([{value}])."
|
314 |
+
)
|
315 |
+
|
316 |
+
if self.classifier == "cell":
|
317 |
+
if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
|
318 |
+
logger.error(
|
319 |
+
"Invalid keys for cell_state_dict. "
|
320 |
+
"The cell_state_dict should have only 2 keys: state_key and states"
|
321 |
+
)
|
322 |
+
raise
|
323 |
+
|
324 |
+
if self.cell_state_dict["states"] != "all":
|
325 |
+
if not isinstance(self.cell_state_dict["states"], list):
|
326 |
+
logger.error(
|
327 |
+
"States in cell_state_dict should be list of states to model."
|
328 |
+
)
|
329 |
+
raise
|
330 |
+
if len(self.cell_state_dict["states"]) < 2:
|
331 |
+
logger.error(
|
332 |
+
"States in cell_state_dict should contain at least 2 states to classify."
|
333 |
+
)
|
334 |
+
raise
|
335 |
+
|
336 |
+
if self.classifier == "gene":
|
337 |
+
if len(self.gene_class_dict.keys()) < 2:
|
338 |
+
logger.error(
|
339 |
+
"Gene_class_dict should contain at least 2 gene classes to classify."
|
340 |
+
)
|
341 |
+
raise
|
342 |
+
if sum(self.split_sizes.values()) != 1:
|
343 |
+
logger.error("Train, validation, and test proportions should sum to 1.")
|
344 |
+
raise
|
345 |
+
|
346 |
+
def prepare_data(
|
347 |
+
self,
|
348 |
+
input_data_file,
|
349 |
+
output_directory,
|
350 |
+
output_prefix,
|
351 |
+
split_id_dict=None,
|
352 |
+
test_size=None,
|
353 |
+
attr_to_split=None,
|
354 |
+
attr_to_balance=None,
|
355 |
+
max_trials=100,
|
356 |
+
pval_threshold=0.1,
|
357 |
+
):
|
358 |
+
"""
|
359 |
+
Prepare data for cell state or gene classification.
|
360 |
+
|
361 |
+
**Parameters**
|
362 |
+
|
363 |
+
input_data_file : Path
|
364 |
+
| Path to directory containing .dataset input
|
365 |
+
output_directory : Path
|
366 |
+
| Path to directory where prepared data will be saved
|
367 |
+
output_prefix : str
|
368 |
+
| Prefix for output file
|
369 |
+
split_id_dict : None, dict
|
370 |
+
| Dictionary of IDs for train and test splits
|
371 |
+
| Three-item dictionary with keys: attr_key, train, test
|
372 |
+
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
373 |
+
| train: list of IDs in the attr_key column to include in the train split
|
374 |
+
| test: list of IDs in the attr_key column to include in the test split
|
375 |
+
| For example: {"attr_key": "individual",
|
376 |
+
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
377 |
+
| "test": ["patient5", "patient6"]}
|
378 |
+
test_size : None, float
|
379 |
+
| Proportion of data to be saved separately and held out for test set
|
380 |
+
| (e.g. 0.2 if intending hold out 20%)
|
381 |
+
| If None, will inherit from split_sizes["test"] from Classifier
|
382 |
+
| The training set will be further split to train / validation in self.validate
|
383 |
+
| Note: only available for CellClassifiers
|
384 |
+
attr_to_split : None, str
|
385 |
+
| Key for attribute on which to split data while balancing potential confounders
|
386 |
+
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
387 |
+
| Note: only available for CellClassifiers
|
388 |
+
attr_to_balance : None, list
|
389 |
+
| List of attribute keys on which to balance data while splitting on attr_to_split
|
390 |
+
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
391 |
+
| Note: only available for CellClassifiers
|
392 |
+
max_trials : None, int
|
393 |
+
| Maximum number of trials of random splitting to try to achieve balanced other attributes
|
394 |
+
| If no split is found without significant (p<0.05) differences in other attributes, will select best
|
395 |
+
| Note: only available for CellClassifiers
|
396 |
+
pval_threshold : None, float
|
397 |
+
| P-value threshold to use for attribute balancing across splits
|
398 |
+
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
399 |
+
"""
|
400 |
+
|
401 |
+
if test_size is None:
|
402 |
+
test_size = self.oos_test_size
|
403 |
+
|
404 |
+
# prepare data and labels for classification
|
405 |
+
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
406 |
+
|
407 |
+
if self.classifier == "cell":
|
408 |
+
if "label" in data.features:
|
409 |
+
logger.error(
|
410 |
+
"Column name 'label' must be reserved for class IDs. Please rename column."
|
411 |
+
)
|
412 |
+
raise
|
413 |
+
elif self.classifier == "gene":
|
414 |
+
if "labels" in data.features:
|
415 |
+
logger.error(
|
416 |
+
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
417 |
+
)
|
418 |
+
raise
|
419 |
+
|
420 |
+
if (attr_to_split is not None) and (attr_to_balance is None):
|
421 |
+
logger.error(
|
422 |
+
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
423 |
+
)
|
424 |
+
raise
|
425 |
+
|
426 |
+
if not isinstance(attr_to_balance, list):
|
427 |
+
attr_to_balance = [attr_to_balance]
|
428 |
+
|
429 |
+
if self.classifier == "cell":
|
430 |
+
# remove cell states representing < rare_threshold of cells
|
431 |
+
data = cu.remove_rare(
|
432 |
+
data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
|
433 |
+
)
|
434 |
+
# downsample max cells and max per class
|
435 |
+
data = cu.downsample_and_shuffle(
|
436 |
+
data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
|
437 |
+
)
|
438 |
+
# rename cell state column to "label"
|
439 |
+
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
440 |
+
|
441 |
+
# convert classes to numerical labels and save as id_class_dict
|
442 |
+
# of note, will label all genes in gene_class_dict
|
443 |
+
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
444 |
+
# at the time of training with Classifier.validate
|
445 |
+
data, id_class_dict = cu.label_classes(
|
446 |
+
self.classifier, data, self.gene_class_dict, self.nproc
|
447 |
+
)
|
448 |
+
|
449 |
+
# save id_class_dict for future reference
|
450 |
+
id_class_output_path = (
|
451 |
+
Path(output_directory) / f"{output_prefix}_id_class_dict"
|
452 |
+
).with_suffix(".pkl")
|
453 |
+
with open(id_class_output_path, "wb") as f:
|
454 |
+
pickle.dump(id_class_dict, f)
|
455 |
+
|
456 |
+
if split_id_dict is not None:
|
457 |
+
data_dict = dict()
|
458 |
+
data_dict["train"] = pu.filter_by_dict(
|
459 |
+
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
|
460 |
+
)
|
461 |
+
data_dict["test"] = pu.filter_by_dict(
|
462 |
+
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
|
463 |
+
)
|
464 |
+
train_data_output_path = (
|
465 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
466 |
+
).with_suffix(".dataset")
|
467 |
+
test_data_output_path = (
|
468 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
469 |
+
).with_suffix(".dataset")
|
470 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
471 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
472 |
+
elif (test_size is not None) and (self.classifier == "cell"):
|
473 |
+
if 1 > test_size > 0:
|
474 |
+
if attr_to_split is None:
|
475 |
+
data_dict = data.train_test_split(
|
476 |
+
test_size=test_size,
|
477 |
+
stratify_by_column=self.stratify_splits_col,
|
478 |
+
seed=42,
|
479 |
+
)
|
480 |
+
train_data_output_path = (
|
481 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
482 |
+
).with_suffix(".dataset")
|
483 |
+
test_data_output_path = (
|
484 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
485 |
+
).with_suffix(".dataset")
|
486 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
487 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
488 |
+
else:
|
489 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
490 |
+
data,
|
491 |
+
attr_to_split,
|
492 |
+
attr_to_balance,
|
493 |
+
test_size,
|
494 |
+
max_trials,
|
495 |
+
pval_threshold,
|
496 |
+
self.cell_state_dict["state_key"],
|
497 |
+
self.nproc,
|
498 |
+
)
|
499 |
+
balance_df.to_csv(
|
500 |
+
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
501 |
+
)
|
502 |
+
train_data_output_path = (
|
503 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
504 |
+
).with_suffix(".dataset")
|
505 |
+
test_data_output_path = (
|
506 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
507 |
+
).with_suffix(".dataset")
|
508 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
509 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
510 |
+
else:
|
511 |
+
data_output_path = (
|
512 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
513 |
+
).with_suffix(".dataset")
|
514 |
+
data.save_to_disk(str(data_output_path))
|
515 |
+
print(data_output_path)
|
516 |
+
else:
|
517 |
+
data_output_path = (
|
518 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
519 |
+
).with_suffix(".dataset")
|
520 |
+
data.save_to_disk(str(data_output_path))
|
521 |
+
|
522 |
+
def train_all_data(
|
523 |
+
self,
|
524 |
+
model_directory,
|
525 |
+
prepared_input_data_file,
|
526 |
+
id_class_dict_file,
|
527 |
+
output_directory,
|
528 |
+
output_prefix,
|
529 |
+
save_eval_output=True,
|
530 |
+
gene_balance=False,
|
531 |
+
):
|
532 |
+
"""
|
533 |
+
Train cell state or gene classifier using all data.
|
534 |
+
|
535 |
+
**Parameters**
|
536 |
+
|
537 |
+
model_directory : Path
|
538 |
+
| Path to directory containing model
|
539 |
+
prepared_input_data_file : Path
|
540 |
+
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
541 |
+
id_class_dict_file : Path
|
542 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
543 |
+
| (dictionary of format: numerical IDs: class_labels)
|
544 |
+
output_directory : Path
|
545 |
+
| Path to directory where model and eval data will be saved
|
546 |
+
output_prefix : str
|
547 |
+
| Prefix for output files
|
548 |
+
save_eval_output : bool
|
549 |
+
| Whether to save cross-fold eval output
|
550 |
+
| Saves as pickle file of dictionary of eval metrics
|
551 |
+
gene_balance : None, bool
|
552 |
+
| Whether to automatically balance genes in training set.
|
553 |
+
| Only available for binary gene classifications.
|
554 |
+
|
555 |
+
**Output**
|
556 |
+
|
557 |
+
Returns trainer after fine-tuning with all data.
|
558 |
+
|
559 |
+
"""
|
560 |
+
|
561 |
+
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
562 |
+
logger.error(
|
563 |
+
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
564 |
+
)
|
565 |
+
raise
|
566 |
+
|
567 |
+
##### Load data and prepare output directory #####
|
568 |
+
# load numerical id to class dictionary (id:class)
|
569 |
+
with open(id_class_dict_file, "rb") as f:
|
570 |
+
id_class_dict = pickle.load(f)
|
571 |
+
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
572 |
+
|
573 |
+
# load previously filtered and prepared data
|
574 |
+
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
575 |
+
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
576 |
+
|
577 |
+
# define output directory path
|
578 |
+
current_date = datetime.datetime.now()
|
579 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
580 |
+
if output_directory[-1:] != "/": # add slash for dir if not present
|
581 |
+
output_directory = output_directory + "/"
|
582 |
+
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
583 |
+
subprocess.call(f"mkdir {output_dir}", shell=True)
|
584 |
+
|
585 |
+
# get number of classes for classifier
|
586 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
587 |
+
|
588 |
+
if self.classifier == "gene":
|
589 |
+
targets = pu.flatten_list(self.gene_class_dict.values())
|
590 |
+
labels = pu.flatten_list(
|
591 |
+
[
|
592 |
+
[class_id_dict[label]] * len(targets)
|
593 |
+
for label, targets in self.gene_class_dict.items()
|
594 |
+
]
|
595 |
+
)
|
596 |
+
assert len(targets) == len(labels)
|
597 |
+
data = cu.prep_gene_classifier_all_data(
|
598 |
+
data, targets, labels, self.max_ncells, self.nproc, gene_balance
|
599 |
+
)
|
600 |
+
|
601 |
+
trainer = self.train_classifier(
|
602 |
+
model_directory, num_classes, data, None, output_dir
|
603 |
+
)
|
604 |
+
|
605 |
+
return trainer
|
606 |
+
|
607 |
+
def validate(
|
608 |
+
self,
|
609 |
+
model_directory,
|
610 |
+
prepared_input_data_file,
|
611 |
+
id_class_dict_file,
|
612 |
+
output_directory,
|
613 |
+
output_prefix,
|
614 |
+
split_id_dict=None,
|
615 |
+
attr_to_split=None,
|
616 |
+
attr_to_balance=None,
|
617 |
+
gene_balance=False,
|
618 |
+
max_trials=100,
|
619 |
+
pval_threshold=0.1,
|
620 |
+
save_eval_output=True,
|
621 |
+
predict_eval=True,
|
622 |
+
predict_trainer=False,
|
623 |
+
n_hyperopt_trials=0,
|
624 |
+
save_gene_split_datasets=True,
|
625 |
+
debug_gene_split_datasets=False,
|
626 |
+
):
|
627 |
+
"""
|
628 |
+
(Cross-)validate cell state or gene classifier.
|
629 |
+
|
630 |
+
**Parameters**
|
631 |
+
|
632 |
+
model_directory : Path
|
633 |
+
| Path to directory containing model
|
634 |
+
prepared_input_data_file : Path
|
635 |
+
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
636 |
+
id_class_dict_file : Path
|
637 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
638 |
+
| (dictionary of format: numerical IDs: class_labels)
|
639 |
+
output_directory : Path
|
640 |
+
| Path to directory where model and eval data will be saved
|
641 |
+
output_prefix : str
|
642 |
+
| Prefix for output files
|
643 |
+
split_id_dict : None, dict
|
644 |
+
| Dictionary of IDs for train and eval splits
|
645 |
+
| Three-item dictionary with keys: attr_key, train, eval
|
646 |
+
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
647 |
+
| train: list of IDs in the attr_key column to include in the train split
|
648 |
+
| eval: list of IDs in the attr_key column to include in the eval split
|
649 |
+
| For example: {"attr_key": "individual",
|
650 |
+
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
651 |
+
| "eval": ["patient5", "patient6"]}
|
652 |
+
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
653 |
+
attr_to_split : None, str
|
654 |
+
| Key for attribute on which to split data while balancing potential confounders
|
655 |
+
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
656 |
+
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
657 |
+
attr_to_balance : None, list
|
658 |
+
| List of attribute keys on which to balance data while splitting on attr_to_split
|
659 |
+
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
660 |
+
gene_balance : None, bool
|
661 |
+
| Whether to automatically balance genes in training set.
|
662 |
+
| Only available for binary gene classifications.
|
663 |
+
max_trials : None, int
|
664 |
+
| Maximum number of trials of random splitting to try to achieve balanced other attribute
|
665 |
+
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
|
666 |
+
pval_threshold : None, float
|
667 |
+
| P-value threshold to use for attribute balancing across splits
|
668 |
+
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
669 |
+
save_eval_output : bool
|
670 |
+
| Whether to save cross-fold eval output
|
671 |
+
| Saves as pickle file of dictionary of eval metrics
|
672 |
+
predict_eval : bool
|
673 |
+
| Whether or not to save eval predictions
|
674 |
+
| Saves as a pickle file of self.evaluate predictions
|
675 |
+
predict_trainer : bool
|
676 |
+
| Whether or not to save eval predictions from trainer
|
677 |
+
| Saves as a pickle file of trainer predictions
|
678 |
+
n_hyperopt_trials : int
|
679 |
+
| Number of trials to run for hyperparameter optimization
|
680 |
+
| If 0, will not optimize hyperparameters
|
681 |
+
save_gene_split_datasets : bool
|
682 |
+
| Whether or not to save train, valid, and test gene-labeled datasets
|
683 |
+
"""
|
684 |
+
if self.num_crossval_splits == 0:
|
685 |
+
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
686 |
+
raise
|
687 |
+
|
688 |
+
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
689 |
+
logger.error(
|
690 |
+
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
691 |
+
)
|
692 |
+
raise
|
693 |
+
|
694 |
+
# ensure number of genes in each class is > 5 if validating model
|
695 |
+
if self.classifier == "gene":
|
696 |
+
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
697 |
+
if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
|
698 |
+
logger.error(
|
699 |
+
f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
|
700 |
+
)
|
701 |
+
raise
|
702 |
+
|
703 |
+
##### Load data and prepare output directory #####
|
704 |
+
# load numerical id to class dictionary (id:class)
|
705 |
+
with open(id_class_dict_file, "rb") as f:
|
706 |
+
id_class_dict = pickle.load(f)
|
707 |
+
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
708 |
+
|
709 |
+
# load previously filtered and prepared data
|
710 |
+
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
711 |
+
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
712 |
+
|
713 |
+
# define output directory path
|
714 |
+
current_date = datetime.datetime.now()
|
715 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
716 |
+
if output_directory[-1:] != "/": # add slash for dir if not present
|
717 |
+
output_directory = output_directory + "/"
|
718 |
+
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
719 |
+
subprocess.call(f"mkdir {output_dir}", shell=True)
|
720 |
+
|
721 |
+
# get number of classes for classifier
|
722 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
723 |
+
|
724 |
+
##### (Cross-)validate the model #####
|
725 |
+
results = []
|
726 |
+
all_conf_mat = np.zeros((num_classes, num_classes))
|
727 |
+
iteration_num = 1
|
728 |
+
if self.classifier == "cell":
|
729 |
+
for i in trange(self.num_crossval_splits):
|
730 |
+
print(
|
731 |
+
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
732 |
+
)
|
733 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
734 |
+
if self.num_crossval_splits == 1:
|
735 |
+
# single 1-eval_size:eval_size split
|
736 |
+
if split_id_dict is not None:
|
737 |
+
data_dict = dict()
|
738 |
+
data_dict["train"] = pu.filter_by_dict(
|
739 |
+
data,
|
740 |
+
{split_id_dict["attr_key"]: split_id_dict["train"]},
|
741 |
+
self.nproc,
|
742 |
+
)
|
743 |
+
data_dict["test"] = pu.filter_by_dict(
|
744 |
+
data,
|
745 |
+
{split_id_dict["attr_key"]: split_id_dict["eval"]},
|
746 |
+
self.nproc,
|
747 |
+
)
|
748 |
+
elif attr_to_split is not None:
|
749 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
750 |
+
data,
|
751 |
+
attr_to_split,
|
752 |
+
attr_to_balance,
|
753 |
+
self.eval_size,
|
754 |
+
max_trials,
|
755 |
+
pval_threshold,
|
756 |
+
self.cell_state_dict["state_key"],
|
757 |
+
self.nproc,
|
758 |
+
)
|
759 |
+
|
760 |
+
balance_df.to_csv(
|
761 |
+
f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
|
762 |
+
)
|
763 |
+
else:
|
764 |
+
data_dict = data.train_test_split(
|
765 |
+
test_size=self.eval_size,
|
766 |
+
stratify_by_column=self.stratify_splits_col,
|
767 |
+
seed=42,
|
768 |
+
)
|
769 |
+
train_data = data_dict["train"]
|
770 |
+
eval_data = data_dict["test"]
|
771 |
+
else:
|
772 |
+
# 5-fold cross-validate
|
773 |
+
num_cells = len(data)
|
774 |
+
fifth_cells = int(np.floor(num_cells * 0.2))
|
775 |
+
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
776 |
+
start = i * fifth_cells
|
777 |
+
end = start + num_eval
|
778 |
+
eval_indices = [j for j in range(start, end)]
|
779 |
+
train_indices = [
|
780 |
+
j for j in range(num_cells) if j not in eval_indices
|
781 |
+
]
|
782 |
+
eval_data = data.select(eval_indices)
|
783 |
+
train_data = data.select(train_indices)
|
784 |
+
if n_hyperopt_trials == 0:
|
785 |
+
trainer = self.train_classifier(
|
786 |
+
model_directory,
|
787 |
+
num_classes,
|
788 |
+
train_data,
|
789 |
+
eval_data,
|
790 |
+
ksplit_output_dir,
|
791 |
+
predict_trainer,
|
792 |
+
)
|
793 |
+
else:
|
794 |
+
trainer = self.hyperopt_classifier(
|
795 |
+
model_directory,
|
796 |
+
num_classes,
|
797 |
+
train_data,
|
798 |
+
eval_data,
|
799 |
+
ksplit_output_dir,
|
800 |
+
n_trials=n_hyperopt_trials,
|
801 |
+
)
|
802 |
+
if iteration_num == self.num_crossval_splits:
|
803 |
+
return
|
804 |
+
else:
|
805 |
+
iteration_num = iteration_num + 1
|
806 |
+
continue
|
807 |
+
|
808 |
+
result = self.evaluate_model(
|
809 |
+
trainer.model,
|
810 |
+
num_classes,
|
811 |
+
id_class_dict,
|
812 |
+
eval_data,
|
813 |
+
predict_eval,
|
814 |
+
ksplit_output_dir,
|
815 |
+
output_prefix,
|
816 |
+
)
|
817 |
+
results += [result]
|
818 |
+
all_conf_mat = all_conf_mat + result["conf_mat"]
|
819 |
+
iteration_num = iteration_num + 1
|
820 |
+
|
821 |
+
elif self.classifier == "gene":
|
822 |
+
# set up (cross-)validation splits
|
823 |
+
targets = pu.flatten_list(self.gene_class_dict.values())
|
824 |
+
labels = pu.flatten_list(
|
825 |
+
[
|
826 |
+
[class_id_dict[label]] * len(targets)
|
827 |
+
for label, targets in self.gene_class_dict.items()
|
828 |
+
]
|
829 |
+
)
|
830 |
+
assert len(targets) == len(labels)
|
831 |
+
n_splits = int(1 / (1 - self.train_size))
|
832 |
+
skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
|
833 |
+
# (Cross-)validate
|
834 |
+
test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
|
835 |
+
for train_index, eval_index, test_index in tqdm(
|
836 |
+
skf.split(targets, labels, test_ratio)
|
837 |
+
):
|
838 |
+
print(
|
839 |
+
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
840 |
+
)
|
841 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
842 |
+
# filter data for examples containing classes for this split
|
843 |
+
# subsample to max_ncells and relabel data in column "labels"
|
844 |
+
train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
|
845 |
+
data,
|
846 |
+
targets,
|
847 |
+
labels,
|
848 |
+
train_index,
|
849 |
+
eval_index,
|
850 |
+
self.max_ncells,
|
851 |
+
iteration_num,
|
852 |
+
self.nproc,
|
853 |
+
gene_balance,
|
854 |
+
)
|
855 |
+
|
856 |
+
if save_gene_split_datasets is True:
|
857 |
+
for split_name in ["train", "valid"]:
|
858 |
+
labeled_dataset_output_path = (
|
859 |
+
Path(output_dir)
|
860 |
+
/ f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
|
861 |
+
).with_suffix(".dataset")
|
862 |
+
if split_name == "train":
|
863 |
+
train_data.save_to_disk(str(labeled_dataset_output_path))
|
864 |
+
elif split_name == "valid":
|
865 |
+
eval_data.save_to_disk(str(labeled_dataset_output_path))
|
866 |
+
|
867 |
+
if self.oos_test_size > 0:
|
868 |
+
test_data = cu.prep_gene_classifier_split(
|
869 |
+
data,
|
870 |
+
targets,
|
871 |
+
labels,
|
872 |
+
test_index,
|
873 |
+
"test",
|
874 |
+
self.max_ncells,
|
875 |
+
iteration_num,
|
876 |
+
self.nproc,
|
877 |
+
)
|
878 |
+
if save_gene_split_datasets is True:
|
879 |
+
test_labeled_dataset_output_path = (
|
880 |
+
Path(output_dir)
|
881 |
+
/ f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
|
882 |
+
).with_suffix(".dataset")
|
883 |
+
test_data.save_to_disk(str(test_labeled_dataset_output_path))
|
884 |
+
if debug_gene_split_datasets is True:
|
885 |
+
logger.error(
|
886 |
+
"Exiting after saving gene split datasets given debug_gene_split_datasets = True."
|
887 |
+
)
|
888 |
+
raise
|
889 |
+
if n_hyperopt_trials == 0:
|
890 |
+
trainer = self.train_classifier(
|
891 |
+
model_directory,
|
892 |
+
num_classes,
|
893 |
+
train_data,
|
894 |
+
eval_data,
|
895 |
+
ksplit_output_dir,
|
896 |
+
predict_trainer,
|
897 |
+
)
|
898 |
+
result = self.evaluate_model(
|
899 |
+
trainer.model,
|
900 |
+
num_classes,
|
901 |
+
id_class_dict,
|
902 |
+
eval_data,
|
903 |
+
predict_eval,
|
904 |
+
ksplit_output_dir,
|
905 |
+
output_prefix,
|
906 |
+
)
|
907 |
+
else:
|
908 |
+
trainer = self.hyperopt_classifier(
|
909 |
+
model_directory,
|
910 |
+
num_classes,
|
911 |
+
train_data,
|
912 |
+
eval_data,
|
913 |
+
ksplit_output_dir,
|
914 |
+
n_trials=n_hyperopt_trials,
|
915 |
+
)
|
916 |
+
|
917 |
+
model = cu.load_best_model(
|
918 |
+
ksplit_output_dir, self.model_type, num_classes
|
919 |
+
)
|
920 |
+
|
921 |
+
if self.oos_test_size > 0:
|
922 |
+
result = self.evaluate_model(
|
923 |
+
model,
|
924 |
+
num_classes,
|
925 |
+
id_class_dict,
|
926 |
+
test_data,
|
927 |
+
predict_eval,
|
928 |
+
ksplit_output_dir,
|
929 |
+
output_prefix,
|
930 |
+
)
|
931 |
+
else:
|
932 |
+
if iteration_num == self.num_crossval_splits:
|
933 |
+
return
|
934 |
+
else:
|
935 |
+
iteration_num = iteration_num + 1
|
936 |
+
continue
|
937 |
+
results += [result]
|
938 |
+
all_conf_mat = all_conf_mat + result["conf_mat"]
|
939 |
+
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
940 |
+
if iteration_num == self.num_crossval_splits:
|
941 |
+
break
|
942 |
+
iteration_num = iteration_num + 1
|
943 |
+
|
944 |
+
all_conf_mat_df = pd.DataFrame(
|
945 |
+
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
|
946 |
+
)
|
947 |
+
all_metrics = {
|
948 |
+
"conf_matrix": all_conf_mat_df,
|
949 |
+
"macro_f1": [result["macro_f1"] for result in results],
|
950 |
+
"acc": [result["acc"] for result in results],
|
951 |
+
}
|
952 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
953 |
+
if num_classes == 2:
|
954 |
+
mean_fpr = np.linspace(0, 1, 100)
|
955 |
+
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
|
956 |
+
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
|
957 |
+
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
958 |
+
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
959 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
960 |
+
)
|
961 |
+
all_roc_metrics = {
|
962 |
+
"mean_tpr": mean_tpr,
|
963 |
+
"mean_fpr": mean_fpr,
|
964 |
+
"all_roc_auc": all_roc_auc,
|
965 |
+
"roc_auc": roc_auc,
|
966 |
+
"roc_auc_sd": roc_auc_sd,
|
967 |
+
}
|
968 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
969 |
+
if save_eval_output is True:
|
970 |
+
eval_metrics_output_path = (
|
971 |
+
Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
|
972 |
+
).with_suffix(".pkl")
|
973 |
+
with open(eval_metrics_output_path, "wb") as f:
|
974 |
+
pickle.dump(all_metrics, f)
|
975 |
+
|
976 |
+
return all_metrics
|
977 |
+
|
978 |
+
def hyperopt_classifier(
|
979 |
+
self,
|
980 |
+
model_directory,
|
981 |
+
num_classes,
|
982 |
+
train_data,
|
983 |
+
eval_data,
|
984 |
+
output_directory,
|
985 |
+
n_trials=100,
|
986 |
+
):
|
987 |
+
"""
|
988 |
+
Fine-tune model for cell state or gene classification.
|
989 |
+
|
990 |
+
**Parameters**
|
991 |
+
|
992 |
+
model_directory : Path
|
993 |
+
| Path to directory containing model
|
994 |
+
num_classes : int
|
995 |
+
| Number of classes for classifier
|
996 |
+
train_data : Dataset
|
997 |
+
| Loaded training .dataset input
|
998 |
+
| For cell classifier, labels in column "label".
|
999 |
+
| For gene classifier, labels in column "labels".
|
1000 |
+
eval_data : None, Dataset
|
1001 |
+
| (Optional) Loaded evaluation .dataset input
|
1002 |
+
| For cell classifier, labels in column "label".
|
1003 |
+
| For gene classifier, labels in column "labels".
|
1004 |
+
output_directory : Path
|
1005 |
+
| Path to directory where fine-tuned model will be saved
|
1006 |
+
n_trials : int
|
1007 |
+
| Number of trials to run for hyperparameter optimization
|
1008 |
+
"""
|
1009 |
+
|
1010 |
+
# initiate runtime environment for raytune
|
1011 |
+
import ray
|
1012 |
+
from ray import tune
|
1013 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
1014 |
+
|
1015 |
+
ray.shutdown() # engage new ray session
|
1016 |
+
ray.init()
|
1017 |
+
|
1018 |
+
##### Validate and prepare data #####
|
1019 |
+
train_data, eval_data = cu.validate_and_clean_cols(
|
1020 |
+
train_data, eval_data, self.classifier
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
if (self.no_eval is True) and (eval_data is not None):
|
1024 |
+
logger.warning(
|
1025 |
+
"no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
# ensure not overwriting previously saved model
|
1029 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
1030 |
+
if os.path.isfile(saved_model_test) is True:
|
1031 |
+
logger.error("Model already saved to this designated output directory.")
|
1032 |
+
raise
|
1033 |
+
# make output directory
|
1034 |
+
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1035 |
+
|
1036 |
+
##### Load model and training args #####
|
1037 |
+
model = pu.load_model(
|
1038 |
+
self.model_type,
|
1039 |
+
num_classes,
|
1040 |
+
model_directory,
|
1041 |
+
"train",
|
1042 |
+
quantize=self.quantize,
|
1043 |
+
)
|
1044 |
+
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1045 |
+
model, self.classifier, train_data, output_directory
|
1046 |
+
)
|
1047 |
+
del model
|
1048 |
+
|
1049 |
+
if self.training_args is not None:
|
1050 |
+
def_training_args.update(self.training_args)
|
1051 |
+
logging_steps = round(
|
1052 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
1053 |
+
)
|
1054 |
+
def_training_args["logging_steps"] = logging_steps
|
1055 |
+
def_training_args["output_dir"] = output_directory
|
1056 |
+
if eval_data is None:
|
1057 |
+
def_training_args["evaluation_strategy"] = "no"
|
1058 |
+
def_training_args["load_best_model_at_end"] = False
|
1059 |
+
def_training_args.update(
|
1060 |
+
{"save_strategy": "epoch", "save_total_limit": 1}
|
1061 |
+
) # only save last model for each run
|
1062 |
+
training_args_init = TrainingArguments(**def_training_args)
|
1063 |
+
|
1064 |
+
##### Fine-tune the model #####
|
1065 |
+
# define the data collator
|
1066 |
+
if self.classifier == "cell":
|
1067 |
+
data_collator = DataCollatorForCellClassification(
|
1068 |
+
token_dictionary=self.gene_token_dict
|
1069 |
+
)
|
1070 |
+
elif self.classifier == "gene":
|
1071 |
+
data_collator = DataCollatorForGeneClassification(
|
1072 |
+
token_dictionary=self.gene_token_dict
|
1073 |
+
)
|
1074 |
+
|
1075 |
+
# define function to initiate model
|
1076 |
+
def model_init():
|
1077 |
+
model = pu.load_model(
|
1078 |
+
self.model_type,
|
1079 |
+
num_classes,
|
1080 |
+
model_directory,
|
1081 |
+
"train",
|
1082 |
+
quantize=self.quantize,
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
if self.freeze_layers is not None:
|
1086 |
+
def_freeze_layers = self.freeze_layers
|
1087 |
+
|
1088 |
+
if def_freeze_layers > 0:
|
1089 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
1090 |
+
for module in modules_to_freeze:
|
1091 |
+
for param in module.parameters():
|
1092 |
+
param.requires_grad = False
|
1093 |
+
|
1094 |
+
if self.quantize is False:
|
1095 |
+
model = model.to("cuda:0")
|
1096 |
+
return model
|
1097 |
+
|
1098 |
+
# create the trainer
|
1099 |
+
trainer = Trainer(
|
1100 |
+
model_init=model_init,
|
1101 |
+
args=training_args_init,
|
1102 |
+
data_collator=data_collator,
|
1103 |
+
train_dataset=train_data,
|
1104 |
+
eval_dataset=eval_data,
|
1105 |
+
compute_metrics=cu.compute_metrics,
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
# specify raytune hyperparameter search space
|
1109 |
+
if self.ray_config is None:
|
1110 |
+
logger.warning(
|
1111 |
+
"No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
|
1112 |
+
)
|
1113 |
+
def_ray_config = {
|
1114 |
+
"num_train_epochs": tune.choice([1]),
|
1115 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
1116 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
1117 |
+
"lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
|
1118 |
+
"warmup_steps": tune.uniform(100, 2000),
|
1119 |
+
"seed": tune.uniform(0, 100),
|
1120 |
+
"per_device_train_batch_size": tune.choice(
|
1121 |
+
[def_training_args["per_device_train_batch_size"]]
|
1122 |
+
),
|
1123 |
+
}
|
1124 |
+
|
1125 |
+
hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
|
1126 |
+
|
1127 |
+
# optimize hyperparameters
|
1128 |
+
trainer.hyperparameter_search(
|
1129 |
+
direction="maximize",
|
1130 |
+
backend="ray",
|
1131 |
+
resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
|
1132 |
+
hp_space=lambda _: def_ray_config
|
1133 |
+
if self.ray_config is None
|
1134 |
+
else self.ray_config,
|
1135 |
+
search_alg=hyperopt_search,
|
1136 |
+
n_trials=n_trials, # number of trials
|
1137 |
+
progress_reporter=tune.CLIReporter(
|
1138 |
+
max_report_frequency=600,
|
1139 |
+
sort_by_metric=True,
|
1140 |
+
max_progress_rows=n_trials,
|
1141 |
+
mode="max",
|
1142 |
+
metric="eval_macro_f1",
|
1143 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1144 |
+
),
|
1145 |
+
storage_path=output_directory,
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
return trainer
|
1149 |
+
|
1150 |
+
def train_classifier(
|
1151 |
+
self,
|
1152 |
+
model_directory,
|
1153 |
+
num_classes,
|
1154 |
+
train_data,
|
1155 |
+
eval_data,
|
1156 |
+
output_directory,
|
1157 |
+
predict=False,
|
1158 |
+
):
|
1159 |
+
"""
|
1160 |
+
Fine-tune model for cell state or gene classification.
|
1161 |
+
|
1162 |
+
**Parameters**
|
1163 |
+
|
1164 |
+
model_directory : Path
|
1165 |
+
| Path to directory containing model
|
1166 |
+
num_classes : int
|
1167 |
+
| Number of classes for classifier
|
1168 |
+
train_data : Dataset
|
1169 |
+
| Loaded training .dataset input
|
1170 |
+
| For cell classifier, labels in column "label".
|
1171 |
+
| For gene classifier, labels in column "labels".
|
1172 |
+
eval_data : None, Dataset
|
1173 |
+
| (Optional) Loaded evaluation .dataset input
|
1174 |
+
| For cell classifier, labels in column "label".
|
1175 |
+
| For gene classifier, labels in column "labels".
|
1176 |
+
output_directory : Path
|
1177 |
+
| Path to directory where fine-tuned model will be saved
|
1178 |
+
predict : bool
|
1179 |
+
| Whether or not to save eval predictions from trainer
|
1180 |
+
"""
|
1181 |
+
|
1182 |
+
##### Validate and prepare data #####
|
1183 |
+
train_data, eval_data = cu.validate_and_clean_cols(
|
1184 |
+
train_data, eval_data, self.classifier
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
if (self.no_eval is True) and (eval_data is not None):
|
1188 |
+
logger.warning(
|
1189 |
+
"no_eval set to True; model will be trained without evaluation."
|
1190 |
+
)
|
1191 |
+
eval_data = None
|
1192 |
+
|
1193 |
+
if (self.classifier == "gene") and (predict is True):
|
1194 |
+
logger.warning(
|
1195 |
+
"Predictions during training not currently available for gene classifiers; setting predict to False."
|
1196 |
+
)
|
1197 |
+
predict = False
|
1198 |
+
|
1199 |
+
# ensure not overwriting previously saved model
|
1200 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
1201 |
+
if os.path.isfile(saved_model_test) is True:
|
1202 |
+
logger.error("Model already saved to this designated output directory.")
|
1203 |
+
raise
|
1204 |
+
# make output directory
|
1205 |
+
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1206 |
+
|
1207 |
+
##### Load model and training args #####
|
1208 |
+
model = pu.load_model(
|
1209 |
+
self.model_type,
|
1210 |
+
num_classes,
|
1211 |
+
model_directory,
|
1212 |
+
"train",
|
1213 |
+
quantize=self.quantize,
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1217 |
+
model, self.classifier, train_data, output_directory
|
1218 |
+
)
|
1219 |
+
|
1220 |
+
if self.training_args is not None:
|
1221 |
+
def_training_args.update(self.training_args)
|
1222 |
+
logging_steps = round(
|
1223 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
1224 |
+
)
|
1225 |
+
def_training_args["logging_steps"] = logging_steps
|
1226 |
+
def_training_args["output_dir"] = output_directory
|
1227 |
+
if eval_data is None:
|
1228 |
+
def_training_args["evaluation_strategy"] = "no"
|
1229 |
+
def_training_args["load_best_model_at_end"] = False
|
1230 |
+
training_args_init = TrainingArguments(**def_training_args)
|
1231 |
+
|
1232 |
+
if self.freeze_layers is not None:
|
1233 |
+
def_freeze_layers = self.freeze_layers
|
1234 |
+
|
1235 |
+
if def_freeze_layers > 0:
|
1236 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
1237 |
+
for module in modules_to_freeze:
|
1238 |
+
for param in module.parameters():
|
1239 |
+
param.requires_grad = False
|
1240 |
+
|
1241 |
+
##### Fine-tune the model #####
|
1242 |
+
# define the data collator
|
1243 |
+
if self.classifier == "cell":
|
1244 |
+
data_collator = DataCollatorForCellClassification(
|
1245 |
+
token_dictionary=self.gene_token_dict
|
1246 |
+
)
|
1247 |
+
elif self.classifier == "gene":
|
1248 |
+
data_collator = DataCollatorForGeneClassification(
|
1249 |
+
token_dictionary=self.gene_token_dict
|
1250 |
+
)
|
1251 |
+
|
1252 |
+
# create the trainer
|
1253 |
+
trainer = Trainer(
|
1254 |
+
model=model,
|
1255 |
+
args=training_args_init,
|
1256 |
+
data_collator=data_collator,
|
1257 |
+
train_dataset=train_data,
|
1258 |
+
eval_dataset=eval_data,
|
1259 |
+
compute_metrics=cu.compute_metrics,
|
1260 |
+
)
|
1261 |
+
|
1262 |
+
# train the classifier
|
1263 |
+
trainer.train()
|
1264 |
+
trainer.save_model(output_directory)
|
1265 |
+
if predict is True:
|
1266 |
+
# make eval predictions and save predictions and metrics
|
1267 |
+
predictions = trainer.predict(eval_data)
|
1268 |
+
prediction_output_path = f"{output_directory}/predictions.pkl"
|
1269 |
+
with open(prediction_output_path, "wb") as f:
|
1270 |
+
pickle.dump(predictions, f)
|
1271 |
+
trainer.save_metrics("eval", predictions.metrics)
|
1272 |
+
return trainer
|
1273 |
+
|
1274 |
+
def evaluate_model(
|
1275 |
+
self,
|
1276 |
+
model,
|
1277 |
+
num_classes,
|
1278 |
+
id_class_dict,
|
1279 |
+
eval_data,
|
1280 |
+
predict=False,
|
1281 |
+
output_directory=None,
|
1282 |
+
output_prefix=None,
|
1283 |
+
):
|
1284 |
+
"""
|
1285 |
+
Evaluate the fine-tuned model.
|
1286 |
+
|
1287 |
+
**Parameters**
|
1288 |
+
|
1289 |
+
model : nn.Module
|
1290 |
+
| Loaded fine-tuned model (e.g. trainer.model)
|
1291 |
+
num_classes : int
|
1292 |
+
| Number of classes for classifier
|
1293 |
+
id_class_dict : dict
|
1294 |
+
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1295 |
+
| (dictionary of format: numerical IDs: class_labels)
|
1296 |
+
eval_data : Dataset
|
1297 |
+
| Loaded evaluation .dataset input
|
1298 |
+
predict : bool
|
1299 |
+
| Whether or not to save eval predictions
|
1300 |
+
output_directory : Path
|
1301 |
+
| Path to directory where eval data will be saved
|
1302 |
+
output_prefix : str
|
1303 |
+
| Prefix for output files
|
1304 |
+
"""
|
1305 |
+
|
1306 |
+
##### Evaluate the model #####
|
1307 |
+
labels = id_class_dict.keys()
|
1308 |
+
y_pred, y_true, logits_list = eu.classifier_predict(
|
1309 |
+
model, self.classifier, eval_data, self.forward_batch_size
|
1310 |
+
)
|
1311 |
+
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
1312 |
+
y_pred, y_true, logits_list, num_classes, labels
|
1313 |
+
)
|
1314 |
+
if predict is True:
|
1315 |
+
pred_dict = {
|
1316 |
+
"pred_ids": y_pred,
|
1317 |
+
"label_ids": y_true,
|
1318 |
+
"predictions": logits_list,
|
1319 |
+
}
|
1320 |
+
pred_dict_output_path = (
|
1321 |
+
Path(output_directory) / f"{output_prefix}_pred_dict"
|
1322 |
+
).with_suffix(".pkl")
|
1323 |
+
with open(pred_dict_output_path, "wb") as f:
|
1324 |
+
pickle.dump(pred_dict, f)
|
1325 |
+
return {
|
1326 |
+
"conf_mat": conf_mat,
|
1327 |
+
"macro_f1": macro_f1,
|
1328 |
+
"acc": acc,
|
1329 |
+
"roc_metrics": roc_metrics,
|
1330 |
+
}
|
1331 |
+
|
1332 |
+
def evaluate_saved_model(
|
1333 |
+
self,
|
1334 |
+
model_directory,
|
1335 |
+
id_class_dict_file,
|
1336 |
+
test_data_file,
|
1337 |
+
output_directory,
|
1338 |
+
output_prefix,
|
1339 |
+
predict=True,
|
1340 |
+
):
|
1341 |
+
"""
|
1342 |
+
Evaluate the fine-tuned model.
|
1343 |
+
|
1344 |
+
**Parameters**
|
1345 |
+
|
1346 |
+
model_directory : Path
|
1347 |
+
| Path to directory containing model
|
1348 |
+
id_class_dict_file : Path
|
1349 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1350 |
+
| (dictionary of format: numerical IDs: class_labels)
|
1351 |
+
test_data_file : Path
|
1352 |
+
| Path to directory containing test .dataset
|
1353 |
+
output_directory : Path
|
1354 |
+
| Path to directory where eval data will be saved
|
1355 |
+
output_prefix : str
|
1356 |
+
| Prefix for output files
|
1357 |
+
predict : bool
|
1358 |
+
| Whether or not to save eval predictions
|
1359 |
+
"""
|
1360 |
+
|
1361 |
+
# load numerical id to class dictionary (id:class)
|
1362 |
+
with open(id_class_dict_file, "rb") as f:
|
1363 |
+
id_class_dict = pickle.load(f)
|
1364 |
+
|
1365 |
+
# get number of classes for classifier
|
1366 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
1367 |
+
|
1368 |
+
# load previously filtered and prepared data
|
1369 |
+
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1370 |
+
|
1371 |
+
# load previously fine-tuned model
|
1372 |
+
model = pu.load_model(
|
1373 |
+
self.model_type,
|
1374 |
+
num_classes,
|
1375 |
+
model_directory,
|
1376 |
+
"eval",
|
1377 |
+
quantize=self.quantize,
|
1378 |
+
)
|
1379 |
+
|
1380 |
+
# evaluate the model
|
1381 |
+
result = self.evaluate_model(
|
1382 |
+
model,
|
1383 |
+
num_classes,
|
1384 |
+
id_class_dict,
|
1385 |
+
test_data,
|
1386 |
+
predict=predict,
|
1387 |
+
output_directory=output_directory,
|
1388 |
+
output_prefix=output_prefix,
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
all_conf_mat_df = pd.DataFrame(
|
1392 |
+
result["conf_mat"],
|
1393 |
+
columns=id_class_dict.values(),
|
1394 |
+
index=id_class_dict.values(),
|
1395 |
+
)
|
1396 |
+
all_metrics = {
|
1397 |
+
"conf_matrix": all_conf_mat_df,
|
1398 |
+
"macro_f1": result["macro_f1"],
|
1399 |
+
"acc": result["acc"],
|
1400 |
+
}
|
1401 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
1402 |
+
|
1403 |
+
if num_classes == 2:
|
1404 |
+
mean_fpr = np.linspace(0, 1, 100)
|
1405 |
+
mean_tpr = result["roc_metrics"]["interp_tpr"]
|
1406 |
+
all_roc_auc = result["roc_metrics"]["auc"]
|
1407 |
+
all_roc_metrics = {
|
1408 |
+
"mean_tpr": mean_tpr,
|
1409 |
+
"mean_fpr": mean_fpr,
|
1410 |
+
"all_roc_auc": all_roc_auc,
|
1411 |
+
}
|
1412 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
1413 |
+
test_metrics_output_path = (
|
1414 |
+
Path(output_directory) / f"{output_prefix}_test_metrics_dict"
|
1415 |
+
).with_suffix(".pkl")
|
1416 |
+
with open(test_metrics_output_path, "wb") as f:
|
1417 |
+
pickle.dump(all_metrics, f)
|
1418 |
+
|
1419 |
+
return all_metrics
|
1420 |
+
|
1421 |
+
def plot_conf_mat(
|
1422 |
+
self,
|
1423 |
+
conf_mat_dict,
|
1424 |
+
output_directory,
|
1425 |
+
output_prefix,
|
1426 |
+
custom_class_order=None,
|
1427 |
+
):
|
1428 |
+
"""
|
1429 |
+
Plot confusion matrix results of evaluating the fine-tuned model.
|
1430 |
+
|
1431 |
+
**Parameters**
|
1432 |
+
|
1433 |
+
conf_mat_dict : dict
|
1434 |
+
| Dictionary of model_name : confusion_matrix_DataFrame
|
1435 |
+
| (all_metrics["conf_matrix"] from self.validate)
|
1436 |
+
output_directory : Path
|
1437 |
+
| Path to directory where plots will be saved
|
1438 |
+
output_prefix : str
|
1439 |
+
| Prefix for output file
|
1440 |
+
custom_class_order : None, list
|
1441 |
+
| List of classes in custom order for plots.
|
1442 |
+
| Same order will be used for all models.
|
1443 |
+
"""
|
1444 |
+
|
1445 |
+
for model_name in conf_mat_dict.keys():
|
1446 |
+
eu.plot_confusion_matrix(
|
1447 |
+
conf_mat_dict[model_name],
|
1448 |
+
model_name,
|
1449 |
+
output_directory,
|
1450 |
+
output_prefix,
|
1451 |
+
custom_class_order,
|
1452 |
+
)
|
1453 |
+
|
1454 |
+
def plot_roc(
|
1455 |
+
self,
|
1456 |
+
roc_metric_dict,
|
1457 |
+
model_style_dict,
|
1458 |
+
title,
|
1459 |
+
output_directory,
|
1460 |
+
output_prefix,
|
1461 |
+
):
|
1462 |
+
"""
|
1463 |
+
Plot ROC curve results of evaluating the fine-tuned model.
|
1464 |
+
|
1465 |
+
**Parameters**
|
1466 |
+
|
1467 |
+
roc_metric_dict : dict
|
1468 |
+
| Dictionary of model_name : roc_metrics
|
1469 |
+
| (all_metrics["all_roc_metrics"] from self.validate)
|
1470 |
+
model_style_dict : dict[dict]
|
1471 |
+
| Dictionary of model_name : dictionary of style_attribute : style
|
1472 |
+
| where style includes color and linestyle
|
1473 |
+
| e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
|
1474 |
+
title : str
|
1475 |
+
| Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
|
1476 |
+
output_directory : Path
|
1477 |
+
| Path to directory where plots will be saved
|
1478 |
+
output_prefix : str
|
1479 |
+
| Prefix for output file
|
1480 |
+
"""
|
1481 |
+
|
1482 |
+
eu.plot_ROC(
|
1483 |
+
roc_metric_dict, model_style_dict, title, output_directory, output_prefix
|
1484 |
+
)
|
1485 |
+
|
1486 |
+
def plot_predictions(
|
1487 |
+
self,
|
1488 |
+
predictions_file,
|
1489 |
+
id_class_dict_file,
|
1490 |
+
title,
|
1491 |
+
output_directory,
|
1492 |
+
output_prefix,
|
1493 |
+
custom_class_order=None,
|
1494 |
+
kwargs_dict=None,
|
1495 |
+
):
|
1496 |
+
"""
|
1497 |
+
Plot prediction results of evaluating the fine-tuned model.
|
1498 |
+
|
1499 |
+
**Parameters**
|
1500 |
+
|
1501 |
+
predictions_file : path
|
1502 |
+
| Path of model predictions output to plot
|
1503 |
+
| (saved output from self.validate if predict_eval=True)
|
1504 |
+
| (or saved output from self.evaluate_saved_model)
|
1505 |
+
id_class_dict_file : Path
|
1506 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1507 |
+
| (dictionary of format: numerical IDs: class_labels)
|
1508 |
+
title : str
|
1509 |
+
| Title for legend containing class labels.
|
1510 |
+
output_directory : Path
|
1511 |
+
| Path to directory where plots will be saved
|
1512 |
+
output_prefix : str
|
1513 |
+
| Prefix for output file
|
1514 |
+
custom_class_order : None, list
|
1515 |
+
| List of classes in custom order for plots.
|
1516 |
+
| Same order will be used for all models.
|
1517 |
+
kwargs_dict : None, dict
|
1518 |
+
| Dictionary of kwargs to pass to plotting function.
|
1519 |
+
"""
|
1520 |
+
# load predictions
|
1521 |
+
with open(predictions_file, "rb") as f:
|
1522 |
+
predictions = pickle.load(f)
|
1523 |
+
|
1524 |
+
# load numerical id to class dictionary (id:class)
|
1525 |
+
with open(id_class_dict_file, "rb") as f:
|
1526 |
+
id_class_dict = pickle.load(f)
|
1527 |
+
|
1528 |
+
if isinstance(predictions, dict):
|
1529 |
+
if all(
|
1530 |
+
[
|
1531 |
+
key in predictions.keys()
|
1532 |
+
for key in ["pred_ids", "label_ids", "predictions"]
|
1533 |
+
]
|
1534 |
+
):
|
1535 |
+
# format is output from self.evaluate_saved_model
|
1536 |
+
predictions_logits = np.array(predictions["predictions"])
|
1537 |
+
true_ids = predictions["label_ids"]
|
1538 |
+
else:
|
1539 |
+
# format is output from self.validate if predict_eval=True
|
1540 |
+
predictions_logits = predictions.predictions
|
1541 |
+
true_ids = predictions.label_ids
|
1542 |
+
|
1543 |
+
num_classes = len(id_class_dict.keys())
|
1544 |
+
num_predict_classes = predictions_logits.shape[1]
|
1545 |
+
assert num_classes == num_predict_classes
|
1546 |
+
classes = id_class_dict.values()
|
1547 |
+
true_labels = [id_class_dict[idx] for idx in true_ids]
|
1548 |
+
predictions_df = pd.DataFrame(predictions_logits, columns=classes)
|
1549 |
+
if custom_class_order is not None:
|
1550 |
+
predictions_df = predictions_df.reindex(columns=custom_class_order)
|
1551 |
+
predictions_df["true"] = true_labels
|
1552 |
+
custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
|
1553 |
+
if custom_class_order is not None:
|
1554 |
+
custom_dict = dict(
|
1555 |
+
zip(custom_class_order, [i for i in range(len(custom_class_order))])
|
1556 |
+
)
|
1557 |
+
predictions_df = predictions_df.sort_values(
|
1558 |
+
by=["true"], key=lambda x: x.map(custom_dict)
|
1559 |
+
)
|
1560 |
+
|
1561 |
+
eu.plot_predictions(
|
1562 |
+
predictions_df, title, output_directory, output_prefix, kwargs_dict
|
1563 |
+
)
|
geneformer/classifier_utils.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from collections import Counter, defaultdict
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from scipy.stats import chisquare, ranksums
|
10 |
+
from sklearn.metrics import accuracy_score, f1_score
|
11 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
12 |
+
|
13 |
+
from . import perturber_utils as pu
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
|
19 |
+
data = data.shuffle(seed=42)
|
20 |
+
num_cells = len(data)
|
21 |
+
# if max number of cells is defined, then subsample to this max number
|
22 |
+
if max_ncells is not None:
|
23 |
+
if num_cells > max_ncells:
|
24 |
+
data = data.select([i for i in range(max_ncells)])
|
25 |
+
if max_ncells_per_class is not None:
|
26 |
+
class_labels = data[cell_state_dict["state_key"]]
|
27 |
+
random.seed(42)
|
28 |
+
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
|
29 |
+
data = data.select(subsample_indices)
|
30 |
+
return data
|
31 |
+
|
32 |
+
|
33 |
+
# subsample labels to maximum number N per class and return indices
|
34 |
+
def subsample_by_class(labels, N):
|
35 |
+
label_indices = defaultdict(list)
|
36 |
+
# Gather indices for each label
|
37 |
+
for idx, label in enumerate(labels):
|
38 |
+
label_indices[label].append(idx)
|
39 |
+
selected_indices = []
|
40 |
+
# Select up to N indices for each label
|
41 |
+
for label, indices in label_indices.items():
|
42 |
+
if len(indices) > N:
|
43 |
+
selected_indices.extend(random.sample(indices, N))
|
44 |
+
else:
|
45 |
+
selected_indices.extend(indices)
|
46 |
+
return selected_indices
|
47 |
+
|
48 |
+
|
49 |
+
def rename_cols(data, state_key):
|
50 |
+
data = data.rename_column(state_key, "label")
|
51 |
+
return data
|
52 |
+
|
53 |
+
|
54 |
+
def validate_and_clean_cols(train_data, eval_data, classifier):
|
55 |
+
# validate that data has expected label column and remove others
|
56 |
+
if classifier == "cell":
|
57 |
+
label_col = "label"
|
58 |
+
elif classifier == "gene":
|
59 |
+
label_col = "labels"
|
60 |
+
|
61 |
+
cols_to_keep = [label_col] + ["input_ids", "length"]
|
62 |
+
if label_col not in train_data.column_names:
|
63 |
+
logger.error(f"train_data must contain column {label_col} with class labels.")
|
64 |
+
raise
|
65 |
+
else:
|
66 |
+
train_data = remove_cols(train_data, cols_to_keep)
|
67 |
+
|
68 |
+
if eval_data is not None:
|
69 |
+
if label_col not in eval_data.column_names:
|
70 |
+
logger.error(
|
71 |
+
f"eval_data must contain column {label_col} with class labels."
|
72 |
+
)
|
73 |
+
raise
|
74 |
+
else:
|
75 |
+
eval_data = remove_cols(eval_data, cols_to_keep)
|
76 |
+
return train_data, eval_data
|
77 |
+
|
78 |
+
|
79 |
+
def remove_cols(data, cols_to_keep):
|
80 |
+
other_cols = list(data.features.keys())
|
81 |
+
other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
|
82 |
+
data = data.remove_columns(other_cols)
|
83 |
+
return data
|
84 |
+
|
85 |
+
|
86 |
+
def remove_rare(data, rare_threshold, label, nproc):
|
87 |
+
if rare_threshold > 0:
|
88 |
+
total_cells = len(data)
|
89 |
+
label_counter = Counter(data[label])
|
90 |
+
nonrare_label_dict = {
|
91 |
+
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
|
92 |
+
}
|
93 |
+
data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
|
94 |
+
return data
|
95 |
+
|
96 |
+
|
97 |
+
def label_classes(classifier, data, gene_class_dict, nproc):
|
98 |
+
if classifier == "cell":
|
99 |
+
label_set = set(data["label"])
|
100 |
+
elif classifier == "gene":
|
101 |
+
# remove cells without any of the target genes
|
102 |
+
def if_contains_label(example):
|
103 |
+
a = pu.flatten_list(gene_class_dict.values())
|
104 |
+
b = example["input_ids"]
|
105 |
+
return not set(a).isdisjoint(b)
|
106 |
+
|
107 |
+
data = data.filter(if_contains_label, num_proc=nproc)
|
108 |
+
label_set = gene_class_dict.keys()
|
109 |
+
|
110 |
+
if len(data) == 0:
|
111 |
+
logger.error(
|
112 |
+
"No cells remain after filtering for target genes. Check target gene list."
|
113 |
+
)
|
114 |
+
raise
|
115 |
+
|
116 |
+
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
117 |
+
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
118 |
+
|
119 |
+
def classes_to_ids(example):
|
120 |
+
if classifier == "cell":
|
121 |
+
example["label"] = class_id_dict[example["label"]]
|
122 |
+
elif classifier == "gene":
|
123 |
+
example["labels"] = label_gene_classes(
|
124 |
+
example, class_id_dict, gene_class_dict
|
125 |
+
)
|
126 |
+
return example
|
127 |
+
|
128 |
+
data = data.map(classes_to_ids, num_proc=nproc)
|
129 |
+
return data, id_class_dict
|
130 |
+
|
131 |
+
|
132 |
+
def label_gene_classes(example, class_id_dict, gene_class_dict):
|
133 |
+
return [
|
134 |
+
class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
|
135 |
+
for token_id in example["input_ids"]
|
136 |
+
]
|
137 |
+
|
138 |
+
|
139 |
+
def prep_gene_classifier_train_eval_split(
|
140 |
+
data,
|
141 |
+
targets,
|
142 |
+
labels,
|
143 |
+
train_index,
|
144 |
+
eval_index,
|
145 |
+
max_ncells,
|
146 |
+
iteration_num,
|
147 |
+
num_proc,
|
148 |
+
balance=False,
|
149 |
+
):
|
150 |
+
# generate cross-validation splits
|
151 |
+
train_data = prep_gene_classifier_split(
|
152 |
+
data,
|
153 |
+
targets,
|
154 |
+
labels,
|
155 |
+
train_index,
|
156 |
+
"train",
|
157 |
+
max_ncells,
|
158 |
+
iteration_num,
|
159 |
+
num_proc,
|
160 |
+
balance,
|
161 |
+
)
|
162 |
+
eval_data = prep_gene_classifier_split(
|
163 |
+
data,
|
164 |
+
targets,
|
165 |
+
labels,
|
166 |
+
eval_index,
|
167 |
+
"eval",
|
168 |
+
max_ncells,
|
169 |
+
iteration_num,
|
170 |
+
num_proc,
|
171 |
+
balance,
|
172 |
+
)
|
173 |
+
return train_data, eval_data
|
174 |
+
|
175 |
+
|
176 |
+
def prep_gene_classifier_split(
|
177 |
+
data,
|
178 |
+
targets,
|
179 |
+
labels,
|
180 |
+
index,
|
181 |
+
subset_name,
|
182 |
+
max_ncells,
|
183 |
+
iteration_num,
|
184 |
+
num_proc,
|
185 |
+
balance=False,
|
186 |
+
):
|
187 |
+
# generate cross-validation splits
|
188 |
+
targets = np.array(targets)
|
189 |
+
labels = np.array(labels)
|
190 |
+
targets_subset = targets[index]
|
191 |
+
labels_subset = labels[index]
|
192 |
+
label_dict_subset = dict(zip(targets_subset, labels_subset))
|
193 |
+
|
194 |
+
# function to filter by whether contains train or eval labels
|
195 |
+
def if_contains_subset_label(example):
|
196 |
+
a = targets_subset
|
197 |
+
b = example["input_ids"]
|
198 |
+
return not set(a).isdisjoint(b)
|
199 |
+
|
200 |
+
# filter dataset for examples containing classes for this split
|
201 |
+
logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
|
202 |
+
subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
|
203 |
+
logger.info(
|
204 |
+
f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
|
205 |
+
)
|
206 |
+
|
207 |
+
# balance gene subsets if train
|
208 |
+
if (subset_name == "train") and (balance is True):
|
209 |
+
subset_data, label_dict_subset = balance_gene_split(
|
210 |
+
subset_data, label_dict_subset, num_proc
|
211 |
+
)
|
212 |
+
|
213 |
+
# subsample to max_ncells
|
214 |
+
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
215 |
+
|
216 |
+
# relabel genes for this split
|
217 |
+
def subset_classes_to_ids(example):
|
218 |
+
example["labels"] = [
|
219 |
+
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
220 |
+
]
|
221 |
+
return example
|
222 |
+
|
223 |
+
subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
|
224 |
+
|
225 |
+
return subset_data
|
226 |
+
|
227 |
+
|
228 |
+
def prep_gene_classifier_all_data(
|
229 |
+
data, targets, labels, max_ncells, num_proc, balance=False
|
230 |
+
):
|
231 |
+
targets = np.array(targets)
|
232 |
+
labels = np.array(labels)
|
233 |
+
label_dict_train = dict(zip(targets, labels))
|
234 |
+
|
235 |
+
# function to filter by whether contains train labels
|
236 |
+
def if_contains_train_label(example):
|
237 |
+
a = targets
|
238 |
+
b = example["input_ids"]
|
239 |
+
return not set(a).isdisjoint(b)
|
240 |
+
|
241 |
+
# filter dataset for examples containing classes for this split
|
242 |
+
logger.info("Filtering training data for genes to classify.")
|
243 |
+
train_data = data.filter(if_contains_train_label, num_proc=num_proc)
|
244 |
+
logger.info(
|
245 |
+
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
246 |
+
)
|
247 |
+
|
248 |
+
if balance is True:
|
249 |
+
train_data, label_dict_train = balance_gene_split(
|
250 |
+
train_data, label_dict_train, num_proc
|
251 |
+
)
|
252 |
+
|
253 |
+
# subsample to max_ncells
|
254 |
+
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
255 |
+
|
256 |
+
# relabel genes for this split
|
257 |
+
def train_classes_to_ids(example):
|
258 |
+
example["labels"] = [
|
259 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
260 |
+
]
|
261 |
+
return example
|
262 |
+
|
263 |
+
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
|
264 |
+
|
265 |
+
return train_data
|
266 |
+
|
267 |
+
|
268 |
+
def balance_gene_split(subset_data, label_dict_subset, num_proc):
|
269 |
+
# count occurrence of genes in each label category
|
270 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
271 |
+
subset_data, label_dict_subset, num_proc
|
272 |
+
)
|
273 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
274 |
+
|
275 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
276 |
+
# gene sets already balanced
|
277 |
+
logger.info(
|
278 |
+
"Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
|
279 |
+
)
|
280 |
+
return subset_data, label_dict_subset
|
281 |
+
else:
|
282 |
+
label_ratio_0to1_orig = label_ratio_0to1 + 0
|
283 |
+
label_dict_subset_orig = label_dict_subset.copy()
|
284 |
+
# balance gene sets
|
285 |
+
max_ntrials = 25
|
286 |
+
boost = 1
|
287 |
+
if label_ratio_0to1 > 10 / 8:
|
288 |
+
# downsample label 0
|
289 |
+
for i in range(max_ntrials):
|
290 |
+
label0 = 0
|
291 |
+
label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
|
292 |
+
label0_ngenes = len(label0_genes)
|
293 |
+
label0_nremove = max(
|
294 |
+
1,
|
295 |
+
int(
|
296 |
+
np.floor(
|
297 |
+
label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
|
298 |
+
)
|
299 |
+
),
|
300 |
+
)
|
301 |
+
random.seed(i)
|
302 |
+
label0_remove_genes = random.sample(label0_genes, label0_nremove)
|
303 |
+
label_dict_subset_new = {
|
304 |
+
k: v
|
305 |
+
for k, v in label_dict_subset.items()
|
306 |
+
if k not in label0_remove_genes
|
307 |
+
}
|
308 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
309 |
+
subset_data, label_dict_subset_new, num_proc
|
310 |
+
)
|
311 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
312 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
313 |
+
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
314 |
+
return filter_data_balanced_genes(
|
315 |
+
subset_data, label_dict_subset_new, num_proc
|
316 |
+
)
|
317 |
+
elif label_ratio_0to1 > 10 / 8:
|
318 |
+
boost = boost * 1.1
|
319 |
+
elif label_ratio_0to1 < 8 / 10:
|
320 |
+
boost = boost * 0.9
|
321 |
+
else:
|
322 |
+
# downsample label 1
|
323 |
+
for i in range(max_ntrials):
|
324 |
+
label1 = 1
|
325 |
+
label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
|
326 |
+
label1_ngenes = len(label1_genes)
|
327 |
+
label1_nremove = max(
|
328 |
+
1,
|
329 |
+
int(
|
330 |
+
np.floor(
|
331 |
+
label1_ngenes
|
332 |
+
- label1_ngenes / ((1 / label_ratio_0to1) * boost)
|
333 |
+
)
|
334 |
+
),
|
335 |
+
)
|
336 |
+
random.seed(i)
|
337 |
+
label1_remove_genes = random.sample(label1_genes, label1_nremove)
|
338 |
+
label_dict_subset_new = {
|
339 |
+
k: v
|
340 |
+
for k, v in label_dict_subset.items()
|
341 |
+
if k not in label1_remove_genes
|
342 |
+
}
|
343 |
+
label0_counts, label1_counts = count_genes_for_balancing(
|
344 |
+
subset_data, label_dict_subset_new, num_proc
|
345 |
+
)
|
346 |
+
label_ratio_0to1 = label0_counts / label1_counts
|
347 |
+
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
348 |
+
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
349 |
+
return filter_data_balanced_genes(
|
350 |
+
subset_data, label_dict_subset_new, num_proc
|
351 |
+
)
|
352 |
+
elif label_ratio_0to1 < 8 / 10:
|
353 |
+
boost = boost * 1.1
|
354 |
+
elif label_ratio_0to1 > 10 / 8:
|
355 |
+
boost = boost * 0.9
|
356 |
+
|
357 |
+
assert i + 1 == max_ntrials
|
358 |
+
if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
|
359 |
+
10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
|
360 |
+
):
|
361 |
+
label_ratio_0to1 = label_ratio_0to1_orig
|
362 |
+
label_dict_subset_new = label_dict_subset_orig
|
363 |
+
logger.warning(
|
364 |
+
f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
|
365 |
+
)
|
366 |
+
return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
|
367 |
+
|
368 |
+
|
369 |
+
def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
|
370 |
+
def count_targets(example):
|
371 |
+
labels = [
|
372 |
+
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
373 |
+
]
|
374 |
+
counter_labels = Counter(labels)
|
375 |
+
# get count of labels 0 or 1, or if absent, return 0
|
376 |
+
example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
|
377 |
+
return example
|
378 |
+
|
379 |
+
subset_data = subset_data.map(count_targets, num_proc=num_proc)
|
380 |
+
|
381 |
+
label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
|
382 |
+
label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
|
383 |
+
|
384 |
+
subset_data = subset_data.remove_columns("labels_counts")
|
385 |
+
|
386 |
+
return label0_counts, label1_counts
|
387 |
+
|
388 |
+
|
389 |
+
def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
|
390 |
+
# function to filter by whether contains labels
|
391 |
+
def if_contains_subset_label(example):
|
392 |
+
a = list(label_dict_subset.keys())
|
393 |
+
b = example["input_ids"]
|
394 |
+
return not set(a).isdisjoint(b)
|
395 |
+
|
396 |
+
# filter dataset for examples containing classes for this split
|
397 |
+
logger.info("Filtering data for balanced genes")
|
398 |
+
subset_data_len_orig = len(subset_data)
|
399 |
+
subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
|
400 |
+
logger.info(
|
401 |
+
f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
|
402 |
+
)
|
403 |
+
|
404 |
+
return subset_data, label_dict_subset
|
405 |
+
|
406 |
+
|
407 |
+
def balance_attr_splits(
|
408 |
+
data,
|
409 |
+
attr_to_split,
|
410 |
+
attr_to_balance,
|
411 |
+
eval_size,
|
412 |
+
max_trials,
|
413 |
+
pval_threshold,
|
414 |
+
state_key,
|
415 |
+
nproc,
|
416 |
+
):
|
417 |
+
metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
|
418 |
+
for attr in attr_to_balance:
|
419 |
+
if attr == state_key:
|
420 |
+
metadata_df[attr] = data["label"]
|
421 |
+
else:
|
422 |
+
metadata_df[attr] = data[attr]
|
423 |
+
metadata_df = metadata_df.drop_duplicates()
|
424 |
+
|
425 |
+
split_attr_ids = list(metadata_df["split_attr_ids"])
|
426 |
+
assert len(split_attr_ids) == len(set(split_attr_ids))
|
427 |
+
eval_num = round(len(split_attr_ids) * eval_size)
|
428 |
+
colnames = (
|
429 |
+
["trial_num", "train_ids", "eval_ids"]
|
430 |
+
+ pu.flatten_list(
|
431 |
+
[
|
432 |
+
[
|
433 |
+
f"{attr}_train_mean_or_counts",
|
434 |
+
f"{attr}_eval_mean_or_counts",
|
435 |
+
f"{attr}_pval",
|
436 |
+
]
|
437 |
+
for attr in attr_to_balance
|
438 |
+
]
|
439 |
+
)
|
440 |
+
+ ["mean_pval"]
|
441 |
+
)
|
442 |
+
balance_df = pd.DataFrame(columns=colnames)
|
443 |
+
data_dict = dict()
|
444 |
+
trial_num = 1
|
445 |
+
for i in range(max_trials):
|
446 |
+
if not all(
|
447 |
+
count > 1 for count in list(Counter(metadata_df[state_key]).values())
|
448 |
+
):
|
449 |
+
logger.error(
|
450 |
+
f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
|
451 |
+
)
|
452 |
+
raise
|
453 |
+
eval_base = []
|
454 |
+
for state in set(metadata_df[state_key]):
|
455 |
+
eval_base += list(
|
456 |
+
metadata_df.loc[
|
457 |
+
metadata_df[state_key][metadata_df[state_key].eq(state)]
|
458 |
+
.sample(1, random_state=i)
|
459 |
+
.index
|
460 |
+
]["split_attr_ids"]
|
461 |
+
)
|
462 |
+
non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
|
463 |
+
random.seed(i)
|
464 |
+
eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
|
465 |
+
train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
|
466 |
+
df_vals = [trial_num, train_ids, eval_ids]
|
467 |
+
pvals = []
|
468 |
+
for attr in attr_to_balance:
|
469 |
+
train_attr = list(
|
470 |
+
metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
|
471 |
+
)
|
472 |
+
eval_attr = list(
|
473 |
+
metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
|
474 |
+
)
|
475 |
+
if attr == state_key:
|
476 |
+
# ensure IDs are interpreted as categorical
|
477 |
+
train_attr = [str(item) for item in train_attr]
|
478 |
+
eval_attr = [str(item) for item in eval_attr]
|
479 |
+
if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
|
480 |
+
train_attr_mean = np.nanmean(train_attr)
|
481 |
+
eval_attr_mean = np.nanmean(eval_attr)
|
482 |
+
pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
|
483 |
+
df_vals += [train_attr_mean, eval_attr_mean, pval]
|
484 |
+
elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
|
485 |
+
obs_counts = Counter(train_attr)
|
486 |
+
exp_counts = Counter(eval_attr)
|
487 |
+
all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
|
488 |
+
obs = [obs_counts[cat] for cat in all_categ]
|
489 |
+
exp = [
|
490 |
+
exp_counts[cat] * sum(obs) / sum(exp_counts.values())
|
491 |
+
for cat in all_categ
|
492 |
+
]
|
493 |
+
pval = chisquare(f_obs=obs, f_exp=exp).pvalue
|
494 |
+
train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
|
495 |
+
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
|
496 |
+
df_vals += [train_attr_counts, eval_attr_counts, pval]
|
497 |
+
else:
|
498 |
+
logger.error(
|
499 |
+
f"Inconsistent data types in attribute {attr}. "
|
500 |
+
"Cannot infer if continuous or categorical. "
|
501 |
+
"Must be all numeric (continuous) or all strings (categorical) to balance."
|
502 |
+
)
|
503 |
+
raise
|
504 |
+
pvals += [pval]
|
505 |
+
|
506 |
+
df_vals += [np.nanmean(pvals)]
|
507 |
+
balance_df_i = pd.DataFrame(df_vals, index=colnames).T
|
508 |
+
balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
|
509 |
+
valid_pvals = [
|
510 |
+
pval_i
|
511 |
+
for pval_i in pvals
|
512 |
+
if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
|
513 |
+
]
|
514 |
+
if all(i >= pval_threshold for i in valid_pvals):
|
515 |
+
data_dict["train"] = pu.filter_by_dict(
|
516 |
+
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
517 |
+
)
|
518 |
+
data_dict["test"] = pu.filter_by_dict(
|
519 |
+
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
520 |
+
)
|
521 |
+
return data_dict, balance_df
|
522 |
+
trial_num = trial_num + 1
|
523 |
+
balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
|
524 |
+
data_dict["train"] = pu.filter_by_dict(
|
525 |
+
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
526 |
+
)
|
527 |
+
data_dict["test"] = pu.filter_by_dict(
|
528 |
+
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
529 |
+
)
|
530 |
+
logger.warning(
|
531 |
+
f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
|
532 |
+
f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
|
533 |
+
)
|
534 |
+
return data_dict, balance_df
|
535 |
+
|
536 |
+
|
537 |
+
def get_num_classes(id_class_dict):
|
538 |
+
return len(set(id_class_dict.values()))
|
539 |
+
|
540 |
+
|
541 |
+
def compute_metrics(pred):
|
542 |
+
labels = pred.label_ids
|
543 |
+
preds = pred.predictions.argmax(-1)
|
544 |
+
|
545 |
+
# calculate accuracy and macro f1 using sklearn's function
|
546 |
+
if len(labels.shape) == 1:
|
547 |
+
acc = accuracy_score(labels, preds)
|
548 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
549 |
+
else:
|
550 |
+
flat_labels = labels.flatten().tolist()
|
551 |
+
flat_preds = preds.flatten().tolist()
|
552 |
+
logit_label_paired = [
|
553 |
+
item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
|
554 |
+
]
|
555 |
+
y_pred = [item[0] for item in logit_label_paired]
|
556 |
+
y_true = [item[1] for item in logit_label_paired]
|
557 |
+
|
558 |
+
acc = accuracy_score(y_true, y_pred)
|
559 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
560 |
+
|
561 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
562 |
+
|
563 |
+
|
564 |
+
def get_default_train_args(model, classifier, data, output_dir):
|
565 |
+
num_layers = pu.quant_layers(model)
|
566 |
+
freeze_layers = 0
|
567 |
+
batch_size = 12
|
568 |
+
if classifier == "cell":
|
569 |
+
epochs = 10
|
570 |
+
evaluation_strategy = "epoch"
|
571 |
+
load_best_model_at_end = True
|
572 |
+
else:
|
573 |
+
epochs = 1
|
574 |
+
evaluation_strategy = "no"
|
575 |
+
load_best_model_at_end = False
|
576 |
+
|
577 |
+
if num_layers == 6:
|
578 |
+
default_training_args = {
|
579 |
+
"learning_rate": 5e-5,
|
580 |
+
"lr_scheduler_type": "linear",
|
581 |
+
"warmup_steps": 500,
|
582 |
+
"per_device_train_batch_size": batch_size,
|
583 |
+
"per_device_eval_batch_size": batch_size,
|
584 |
+
}
|
585 |
+
else:
|
586 |
+
default_training_args = {
|
587 |
+
"per_device_train_batch_size": batch_size,
|
588 |
+
"per_device_eval_batch_size": batch_size,
|
589 |
+
}
|
590 |
+
|
591 |
+
training_args = {
|
592 |
+
"num_train_epochs": epochs,
|
593 |
+
"do_train": True,
|
594 |
+
"do_eval": True,
|
595 |
+
"evaluation_strategy": evaluation_strategy,
|
596 |
+
"logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
|
597 |
+
"save_strategy": "epoch",
|
598 |
+
"group_by_length": False,
|
599 |
+
"length_column_name": "length",
|
600 |
+
"disable_tqdm": False,
|
601 |
+
"weight_decay": 0.001,
|
602 |
+
"load_best_model_at_end": load_best_model_at_end,
|
603 |
+
}
|
604 |
+
training_args.update(default_training_args)
|
605 |
+
|
606 |
+
return training_args, freeze_layers
|
607 |
+
|
608 |
+
|
609 |
+
def load_best_model(directory, model_type, num_classes, mode="eval"):
|
610 |
+
file_dict = dict()
|
611 |
+
for subdir, dirs, files in os.walk(directory):
|
612 |
+
for file in files:
|
613 |
+
if file.endswith("result.json"):
|
614 |
+
with open(f"{subdir}/{file}", "rb") as fp:
|
615 |
+
result_json = json.load(fp)
|
616 |
+
file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
|
617 |
+
file_df = pd.DataFrame(
|
618 |
+
{"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
|
619 |
+
)
|
620 |
+
model_superdir = (
|
621 |
+
"run-"
|
622 |
+
+ file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
|
623 |
+
.split("_objective_")[2]
|
624 |
+
.split("_")[0]
|
625 |
+
)
|
626 |
+
|
627 |
+
for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
|
628 |
+
for file in files:
|
629 |
+
if file.endswith("model.safetensors"):
|
630 |
+
model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
|
631 |
+
return model
|
632 |
+
|
633 |
+
|
634 |
+
class StratifiedKFold3(StratifiedKFold):
|
635 |
+
def split(self, targets, labels, test_ratio=0.5, groups=None):
|
636 |
+
s = super().split(targets, labels, groups)
|
637 |
+
for train_indxs, test_indxs in s:
|
638 |
+
if test_ratio == 0:
|
639 |
+
yield train_indxs, test_indxs, None
|
640 |
+
else:
|
641 |
+
labels_test = np.array(labels)[test_indxs]
|
642 |
+
valid_indxs, test_indxs = train_test_split(
|
643 |
+
test_indxs,
|
644 |
+
stratify=labels_test,
|
645 |
+
test_size=test_ratio,
|
646 |
+
random_state=0,
|
647 |
+
)
|
648 |
+
yield train_indxs, valid_indxs, test_indxs
|
geneformer/collator_for_classification.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer collator for gene and cell classification.
|
3 |
+
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Dict, List, Optional, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from transformers import (
|
13 |
+
BatchEncoding,
|
14 |
+
DataCollatorForTokenClassification,
|
15 |
+
SpecialTokensMixin,
|
16 |
+
)
|
17 |
+
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
+
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
+
|
20 |
+
EncodedInput = List[int]
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
VERY_LARGE_INTEGER = int(
|
23 |
+
1e30
|
24 |
+
) # This is used to set the max input length for a model with infinite size input
|
25 |
+
LARGE_INTEGER = int(
|
26 |
+
1e20
|
27 |
+
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
28 |
+
|
29 |
+
# precollator functions
|
30 |
+
|
31 |
+
|
32 |
+
class ExplicitEnum(Enum):
|
33 |
+
"""
|
34 |
+
Enum with more explicit error message for missing values.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def _missing_(cls, value):
|
39 |
+
raise ValueError(
|
40 |
+
"%r is not a valid %s, please select one of %s"
|
41 |
+
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class TruncationStrategy(ExplicitEnum):
|
46 |
+
"""
|
47 |
+
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
48 |
+
tab-completion in an IDE.
|
49 |
+
"""
|
50 |
+
|
51 |
+
ONLY_FIRST = "only_first"
|
52 |
+
ONLY_SECOND = "only_second"
|
53 |
+
LONGEST_FIRST = "longest_first"
|
54 |
+
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
+
|
56 |
+
|
57 |
+
class PaddingStrategy(ExplicitEnum):
|
58 |
+
"""
|
59 |
+
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
60 |
+
in an IDE.
|
61 |
+
"""
|
62 |
+
|
63 |
+
LONGEST = "longest"
|
64 |
+
MAX_LENGTH = "max_length"
|
65 |
+
DO_NOT_PAD = "do_not_pad"
|
66 |
+
|
67 |
+
|
68 |
+
class TensorType(ExplicitEnum):
|
69 |
+
"""
|
70 |
+
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
71 |
+
tab-completion in an IDE.
|
72 |
+
"""
|
73 |
+
|
74 |
+
PYTORCH = "pt"
|
75 |
+
TENSORFLOW = "tf"
|
76 |
+
NUMPY = "np"
|
77 |
+
JAX = "jax"
|
78 |
+
|
79 |
+
|
80 |
+
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
81 |
+
def __init__(self, *args, **kwargs) -> None:
|
82 |
+
super().__init__(mask_token="<mask>", pad_token="<pad>")
|
83 |
+
|
84 |
+
self.token_dictionary = kwargs.get("token_dictionary")
|
85 |
+
self.padding_side = "right"
|
86 |
+
self.model_input_names = ["input_ids"]
|
87 |
+
self._mask_token_id = self.token_dictionary.get("<mask>")
|
88 |
+
self._pad_token_id = self.token_dictionary.get("<pad>")
|
89 |
+
self._all_special_ids = [
|
90 |
+
self.token_dictionary.get("<mask>"),
|
91 |
+
self.token_dictionary.get("<pad>"),
|
92 |
+
]
|
93 |
+
|
94 |
+
@property
|
95 |
+
def all_special_ids(self):
|
96 |
+
return self._all_special_ids
|
97 |
+
|
98 |
+
@property
|
99 |
+
def mask_token_id(self):
|
100 |
+
return self._mask_token_id
|
101 |
+
|
102 |
+
@property
|
103 |
+
def pad_token_id(self):
|
104 |
+
return self._pad_token_id
|
105 |
+
|
106 |
+
def _get_padding_truncation_strategies(
|
107 |
+
self,
|
108 |
+
padding=True,
|
109 |
+
truncation=False,
|
110 |
+
max_length=None,
|
111 |
+
pad_to_multiple_of=None,
|
112 |
+
verbose=True,
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
117 |
+
and pad_to_max_length) and behaviors.
|
118 |
+
"""
|
119 |
+
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
|
120 |
+
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
|
121 |
+
|
122 |
+
# Backward compatibility for previous behavior, maybe we should deprecate it:
|
123 |
+
# If you only set max_length, it activates truncation for max_length
|
124 |
+
if max_length is not None and padding is False and truncation is False:
|
125 |
+
if verbose:
|
126 |
+
if not self.deprecation_warnings.get(
|
127 |
+
"Truncation-not-explicitly-activated", False
|
128 |
+
):
|
129 |
+
logger.warning(
|
130 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
131 |
+
"please use `truncation=True` to explicitly truncate examples to max length. "
|
132 |
+
"Defaulting to 'longest_first' truncation strategy. "
|
133 |
+
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
134 |
+
"more precisely by providing a specific strategy to `truncation`."
|
135 |
+
)
|
136 |
+
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
|
137 |
+
truncation = "longest_first"
|
138 |
+
|
139 |
+
# Get padding strategy
|
140 |
+
if padding is False and old_pad_to_max_length:
|
141 |
+
if verbose:
|
142 |
+
warnings.warn(
|
143 |
+
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
|
144 |
+
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
|
145 |
+
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
|
146 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
|
147 |
+
"maximal input size of the model (e.g. 512 for Bert).",
|
148 |
+
FutureWarning,
|
149 |
+
)
|
150 |
+
if max_length is None:
|
151 |
+
padding_strategy = PaddingStrategy.LONGEST
|
152 |
+
else:
|
153 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
154 |
+
elif padding is not False:
|
155 |
+
if padding is True:
|
156 |
+
padding_strategy = (
|
157 |
+
PaddingStrategy.LONGEST
|
158 |
+
) # Default to pad to the longest sequence in the batch
|
159 |
+
elif not isinstance(padding, PaddingStrategy):
|
160 |
+
padding_strategy = PaddingStrategy(padding)
|
161 |
+
elif isinstance(padding, PaddingStrategy):
|
162 |
+
padding_strategy = padding
|
163 |
+
else:
|
164 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
165 |
+
|
166 |
+
# Get truncation strategy
|
167 |
+
if truncation is False and old_truncation_strategy != "do_not_truncate":
|
168 |
+
if verbose:
|
169 |
+
warnings.warn(
|
170 |
+
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
|
171 |
+
"use `truncation=True` to truncate examples to a max length. You can give a specific "
|
172 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
|
173 |
+
"maximal input size of the model (e.g. 512 for Bert). "
|
174 |
+
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
|
175 |
+
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
|
176 |
+
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
|
177 |
+
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
|
178 |
+
FutureWarning,
|
179 |
+
)
|
180 |
+
truncation_strategy = TruncationStrategy(old_truncation_strategy)
|
181 |
+
elif truncation is not False:
|
182 |
+
if truncation is True:
|
183 |
+
truncation_strategy = (
|
184 |
+
TruncationStrategy.LONGEST_FIRST
|
185 |
+
) # Default to truncate the longest sequences in pairs of inputs
|
186 |
+
elif not isinstance(truncation, TruncationStrategy):
|
187 |
+
truncation_strategy = TruncationStrategy(truncation)
|
188 |
+
elif isinstance(truncation, TruncationStrategy):
|
189 |
+
truncation_strategy = truncation
|
190 |
+
else:
|
191 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
192 |
+
|
193 |
+
# Set max length if needed
|
194 |
+
if max_length is None:
|
195 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
196 |
+
if self.model_max_length > LARGE_INTEGER:
|
197 |
+
if verbose:
|
198 |
+
if not self.deprecation_warnings.get(
|
199 |
+
"Asking-to-pad-to-max_length", False
|
200 |
+
):
|
201 |
+
logger.warning(
|
202 |
+
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
203 |
+
"Default to no padding."
|
204 |
+
)
|
205 |
+
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
|
206 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
207 |
+
else:
|
208 |
+
max_length = self.model_max_length
|
209 |
+
|
210 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
211 |
+
if self.model_max_length > LARGE_INTEGER:
|
212 |
+
if verbose:
|
213 |
+
if not self.deprecation_warnings.get(
|
214 |
+
"Asking-to-truncate-to-max_length", False
|
215 |
+
):
|
216 |
+
logger.warning(
|
217 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
218 |
+
"Default to no truncation."
|
219 |
+
)
|
220 |
+
self.deprecation_warnings[
|
221 |
+
"Asking-to-truncate-to-max_length"
|
222 |
+
] = True
|
223 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
224 |
+
else:
|
225 |
+
max_length = self.model_max_length
|
226 |
+
|
227 |
+
# Test if we have a padding token
|
228 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
229 |
+
not self.pad_token or self.pad_token_id < 0
|
230 |
+
):
|
231 |
+
raise ValueError(
|
232 |
+
"Asking to pad but the tokenizer does not have a padding token. "
|
233 |
+
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
234 |
+
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
|
235 |
+
)
|
236 |
+
|
237 |
+
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
|
238 |
+
if (
|
239 |
+
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
|
240 |
+
and padding_strategy != PaddingStrategy.DO_NOT_PAD
|
241 |
+
and pad_to_multiple_of is not None
|
242 |
+
and max_length is not None
|
243 |
+
and (max_length % pad_to_multiple_of != 0)
|
244 |
+
):
|
245 |
+
raise ValueError(
|
246 |
+
f"Truncation and padding are both activated but "
|
247 |
+
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
|
248 |
+
)
|
249 |
+
|
250 |
+
return padding_strategy, truncation_strategy, max_length, kwargs
|
251 |
+
|
252 |
+
def pad(
|
253 |
+
self,
|
254 |
+
encoded_inputs: Union[
|
255 |
+
BatchEncoding,
|
256 |
+
List[BatchEncoding],
|
257 |
+
Dict[str, EncodedInput],
|
258 |
+
Dict[str, List[EncodedInput]],
|
259 |
+
List[Dict[str, EncodedInput]],
|
260 |
+
],
|
261 |
+
class_type, # options: "gene" or "cell"
|
262 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
263 |
+
max_length: Optional[int] = None,
|
264 |
+
pad_to_multiple_of: Optional[int] = None,
|
265 |
+
return_attention_mask: Optional[bool] = True,
|
266 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
267 |
+
verbose: bool = True,
|
268 |
+
) -> BatchEncoding:
|
269 |
+
"""
|
270 |
+
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
271 |
+
in the batch.
|
272 |
+
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
273 |
+
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
274 |
+
.. note::
|
275 |
+
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
276 |
+
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
277 |
+
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
278 |
+
Args:
|
279 |
+
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
280 |
+
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
281 |
+
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
282 |
+
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
283 |
+
well as in a PyTorch Dataloader collate function.
|
284 |
+
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
285 |
+
see the note above for the return type.
|
286 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
287 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
288 |
+
index) among:
|
289 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
290 |
+
single sequence if provided).
|
291 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
292 |
+
maximum acceptable input length for the model if that argument is not provided.
|
293 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
294 |
+
different lengths).
|
295 |
+
max_length (:obj:`int`, `optional`):
|
296 |
+
Maximum length of the returned list and optionally padding length (see above).
|
297 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
298 |
+
If set will pad the sequence to a multiple of the provided value.
|
299 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
300 |
+
>= 7.5 (Volta).
|
301 |
+
return_attention_mask (:obj:`bool`, `optional`):
|
302 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
303 |
+
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
304 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
305 |
+
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
306 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
307 |
+
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
308 |
+
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
309 |
+
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
310 |
+
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
311 |
+
Whether or not to print more information and warnings.
|
312 |
+
"""
|
313 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
314 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
315 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
316 |
+
encoded_inputs[0], (dict, BatchEncoding)
|
317 |
+
):
|
318 |
+
encoded_inputs = {
|
319 |
+
key: [example[key] for example in encoded_inputs]
|
320 |
+
for key in encoded_inputs[0].keys()
|
321 |
+
}
|
322 |
+
|
323 |
+
# The model's main input name, usually `input_ids`, has be passed for padding
|
324 |
+
if self.model_input_names[0] not in encoded_inputs:
|
325 |
+
raise ValueError(
|
326 |
+
"You should supply an encoding or a list of encodings to this method"
|
327 |
+
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
328 |
+
)
|
329 |
+
|
330 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
331 |
+
|
332 |
+
if not required_input:
|
333 |
+
if return_attention_mask:
|
334 |
+
encoded_inputs["attention_mask"] = []
|
335 |
+
return encoded_inputs
|
336 |
+
|
337 |
+
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
|
338 |
+
# and rebuild them afterwards if no return_tensors is specified
|
339 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
340 |
+
|
341 |
+
first_element = required_input[0]
|
342 |
+
if isinstance(first_element, (list, tuple)):
|
343 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
344 |
+
index = 0
|
345 |
+
while len(required_input[index]) == 0:
|
346 |
+
index += 1
|
347 |
+
if index < len(required_input):
|
348 |
+
first_element = required_input[index][0]
|
349 |
+
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
350 |
+
if not isinstance(first_element, (int, list, tuple)):
|
351 |
+
if is_tf_available() and _is_tensorflow(first_element):
|
352 |
+
return_tensors = "tf" if return_tensors is None else return_tensors
|
353 |
+
elif is_torch_available() and _is_torch(first_element):
|
354 |
+
return_tensors = "pt" if return_tensors is None else return_tensors
|
355 |
+
elif isinstance(first_element, np.ndarray):
|
356 |
+
return_tensors = "np" if return_tensors is None else return_tensors
|
357 |
+
else:
|
358 |
+
raise ValueError(
|
359 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
360 |
+
f"Should be one of a python, numpy, pytorch or tensorflow object."
|
361 |
+
)
|
362 |
+
|
363 |
+
for key, value in encoded_inputs.items():
|
364 |
+
encoded_inputs[key] = to_py_obj(value)
|
365 |
+
|
366 |
+
# Convert padding_strategy in PaddingStrategy
|
367 |
+
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
368 |
+
padding=padding, max_length=max_length, verbose=verbose
|
369 |
+
)
|
370 |
+
|
371 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
372 |
+
if required_input and not isinstance(required_input[0], (list, tuple)):
|
373 |
+
encoded_inputs = self._pad(
|
374 |
+
encoded_inputs,
|
375 |
+
class_type=class_type,
|
376 |
+
max_length=max_length,
|
377 |
+
padding_strategy=padding_strategy,
|
378 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
379 |
+
return_attention_mask=return_attention_mask,
|
380 |
+
)
|
381 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
382 |
+
|
383 |
+
batch_size = len(required_input)
|
384 |
+
assert all(
|
385 |
+
len(v) == batch_size for v in encoded_inputs.values()
|
386 |
+
), "Some items in the output dictionary have a different batch size than others."
|
387 |
+
|
388 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
389 |
+
max_length = max(len(inputs) for inputs in required_input)
|
390 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
391 |
+
|
392 |
+
batch_outputs = {}
|
393 |
+
for i in range(batch_size):
|
394 |
+
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
395 |
+
outputs = self._pad(
|
396 |
+
inputs,
|
397 |
+
class_type=class_type,
|
398 |
+
max_length=max_length,
|
399 |
+
padding_strategy=padding_strategy,
|
400 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
401 |
+
return_attention_mask=return_attention_mask,
|
402 |
+
)
|
403 |
+
|
404 |
+
for key, value in outputs.items():
|
405 |
+
if key not in batch_outputs:
|
406 |
+
batch_outputs[key] = []
|
407 |
+
batch_outputs[key].append(value)
|
408 |
+
if class_type == "cell":
|
409 |
+
del batch_outputs["label"]
|
410 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
411 |
+
|
412 |
+
def _pad(
|
413 |
+
self,
|
414 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
415 |
+
class_type, # options: "gene" or "cell"
|
416 |
+
max_length: Optional[int] = None,
|
417 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
418 |
+
pad_to_multiple_of: Optional[int] = None,
|
419 |
+
return_attention_mask: Optional[bool] = True,
|
420 |
+
) -> dict:
|
421 |
+
"""
|
422 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
423 |
+
Args:
|
424 |
+
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
425 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
426 |
+
Will truncate by taking into account the special tokens.
|
427 |
+
padding_strategy: PaddingStrategy to use for padding.
|
428 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
429 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
430 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
431 |
+
The tokenizer padding sides are defined in self.padding_side:
|
432 |
+
- 'left': pads on the left of the sequences
|
433 |
+
- 'right': pads on the right of the sequences
|
434 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
435 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
436 |
+
>= 7.5 (Volta).
|
437 |
+
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
438 |
+
"""
|
439 |
+
# Load from model defaults
|
440 |
+
if return_attention_mask is None:
|
441 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
442 |
+
|
443 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
444 |
+
|
445 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
446 |
+
max_length = len(required_input)
|
447 |
+
|
448 |
+
if (
|
449 |
+
max_length is not None
|
450 |
+
and pad_to_multiple_of is not None
|
451 |
+
and (max_length % pad_to_multiple_of != 0)
|
452 |
+
):
|
453 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
454 |
+
|
455 |
+
needs_to_be_padded = (
|
456 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
457 |
+
and len(required_input) != max_length
|
458 |
+
)
|
459 |
+
|
460 |
+
if needs_to_be_padded:
|
461 |
+
difference = max_length - len(required_input)
|
462 |
+
if self.padding_side == "right":
|
463 |
+
if return_attention_mask:
|
464 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
465 |
+
0
|
466 |
+
] * difference
|
467 |
+
if "token_type_ids" in encoded_inputs:
|
468 |
+
encoded_inputs["token_type_ids"] = (
|
469 |
+
encoded_inputs["token_type_ids"]
|
470 |
+
+ [self.pad_token_type_id] * difference
|
471 |
+
)
|
472 |
+
if "special_tokens_mask" in encoded_inputs:
|
473 |
+
encoded_inputs["special_tokens_mask"] = (
|
474 |
+
encoded_inputs["special_tokens_mask"] + [1] * difference
|
475 |
+
)
|
476 |
+
encoded_inputs[self.model_input_names[0]] = (
|
477 |
+
required_input + [self.pad_token_id] * difference
|
478 |
+
)
|
479 |
+
if class_type == "gene":
|
480 |
+
encoded_inputs["labels"] = (
|
481 |
+
encoded_inputs["labels"] + [-100] * difference
|
482 |
+
)
|
483 |
+
elif self.padding_side == "left":
|
484 |
+
if return_attention_mask:
|
485 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
486 |
+
required_input
|
487 |
+
)
|
488 |
+
if "token_type_ids" in encoded_inputs:
|
489 |
+
encoded_inputs["token_type_ids"] = [
|
490 |
+
self.pad_token_type_id
|
491 |
+
] * difference + encoded_inputs["token_type_ids"]
|
492 |
+
if "special_tokens_mask" in encoded_inputs:
|
493 |
+
encoded_inputs["special_tokens_mask"] = [
|
494 |
+
1
|
495 |
+
] * difference + encoded_inputs["special_tokens_mask"]
|
496 |
+
encoded_inputs[self.model_input_names[0]] = [
|
497 |
+
self.pad_token_id
|
498 |
+
] * difference + required_input
|
499 |
+
if class_type == "gene":
|
500 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
|
501 |
+
"labels"
|
502 |
+
]
|
503 |
+
else:
|
504 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
505 |
+
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
506 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
507 |
+
|
508 |
+
return encoded_inputs
|
509 |
+
|
510 |
+
def get_special_tokens_mask(
|
511 |
+
self,
|
512 |
+
token_ids_0: List[int],
|
513 |
+
token_ids_1: Optional[List[int]] = None,
|
514 |
+
already_has_special_tokens: bool = False,
|
515 |
+
) -> List[int]:
|
516 |
+
"""
|
517 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
518 |
+
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
519 |
+
Args:
|
520 |
+
token_ids_0 (:obj:`List[int]`):
|
521 |
+
List of ids of the first sequence.
|
522 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
523 |
+
List of ids of the second sequence.
|
524 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
525 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
526 |
+
Returns:
|
527 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
528 |
+
"""
|
529 |
+
assert already_has_special_tokens and token_ids_1 is None, (
|
530 |
+
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
531 |
+
"Please use a slow (full python) tokenizer to activate this argument."
|
532 |
+
"Or set `return_special_tokens_mask=True` when calling the encoding method "
|
533 |
+
"to get the special tokens mask in any tokenizer. "
|
534 |
+
)
|
535 |
+
|
536 |
+
all_special_ids = self.all_special_ids # cache the property
|
537 |
+
|
538 |
+
special_tokens_mask = [
|
539 |
+
1 if token in all_special_ids else 0 for token in token_ids_0
|
540 |
+
]
|
541 |
+
|
542 |
+
return special_tokens_mask
|
543 |
+
|
544 |
+
def convert_tokens_to_ids(
|
545 |
+
self, tokens: Union[str, List[str]]
|
546 |
+
) -> Union[int, List[int]]:
|
547 |
+
"""
|
548 |
+
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
549 |
+
vocabulary.
|
550 |
+
Args:
|
551 |
+
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
|
552 |
+
Returns:
|
553 |
+
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
|
554 |
+
"""
|
555 |
+
if tokens is None:
|
556 |
+
return None
|
557 |
+
|
558 |
+
if isinstance(tokens, str):
|
559 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
560 |
+
|
561 |
+
ids = []
|
562 |
+
for token in tokens:
|
563 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
564 |
+
return ids
|
565 |
+
|
566 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
567 |
+
if token is None:
|
568 |
+
return None
|
569 |
+
|
570 |
+
return self.token_dictionary.get(token)
|
571 |
+
|
572 |
+
def __len__(self):
|
573 |
+
return len(self.token_dictionary)
|
574 |
+
|
575 |
+
|
576 |
+
# collator functions
|
577 |
+
|
578 |
+
|
579 |
+
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
580 |
+
"""
|
581 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
582 |
+
Args:
|
583 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
584 |
+
The tokenizer used for encoding the data.
|
585 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
586 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
587 |
+
among:
|
588 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
589 |
+
sequence if provided).
|
590 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
591 |
+
maximum acceptable input length for the model if that argument is not provided.
|
592 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
593 |
+
different lengths).
|
594 |
+
max_length (:obj:`int`, `optional`):
|
595 |
+
Maximum length of the returned list and optionally padding length (see above).
|
596 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
597 |
+
If set will pad the sequence to a multiple of the provided value.
|
598 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
599 |
+
7.5 (Volta).
|
600 |
+
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
601 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
602 |
+
"""
|
603 |
+
|
604 |
+
class_type = "gene"
|
605 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
606 |
+
max_length: Optional[int] = None
|
607 |
+
pad_to_multiple_of: Optional[int] = None
|
608 |
+
label_pad_token_id: int = -100
|
609 |
+
|
610 |
+
def __init__(self, *args, **kwargs) -> None:
|
611 |
+
self.token_dictionary = kwargs.pop("token_dictionary")
|
612 |
+
super().__init__(
|
613 |
+
tokenizer=PrecollatorForGeneAndCellClassification(
|
614 |
+
token_dictionary=self.token_dictionary
|
615 |
+
),
|
616 |
+
padding=self.padding,
|
617 |
+
max_length=self.max_length,
|
618 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
619 |
+
label_pad_token_id=self.label_pad_token_id,
|
620 |
+
*args,
|
621 |
+
**kwargs,
|
622 |
+
)
|
623 |
+
|
624 |
+
def _prepare_batch(self, features):
|
625 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
626 |
+
labels = (
|
627 |
+
[feature[label_name] for feature in features]
|
628 |
+
if label_name in features[0].keys()
|
629 |
+
else None
|
630 |
+
)
|
631 |
+
batch = self.tokenizer.pad(
|
632 |
+
features,
|
633 |
+
class_type=self.class_type,
|
634 |
+
padding=self.padding,
|
635 |
+
max_length=self.max_length,
|
636 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
637 |
+
return_tensors="pt",
|
638 |
+
)
|
639 |
+
return batch
|
640 |
+
|
641 |
+
def __call__(self, features):
|
642 |
+
batch = self._prepare_batch(features)
|
643 |
+
|
644 |
+
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
645 |
+
return batch
|
646 |
+
|
647 |
+
|
648 |
+
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
649 |
+
class_type = "cell"
|
650 |
+
|
651 |
+
def _prepare_batch(self, features):
|
652 |
+
batch = super()._prepare_batch(features)
|
653 |
+
|
654 |
+
# Special handling for labels.
|
655 |
+
# Ensure that tensor is created with the correct type
|
656 |
+
# (it should be automatically the case, but let's make sure of it.)
|
657 |
+
first = features[0]
|
658 |
+
if "label" in first and first["label"] is not None:
|
659 |
+
label = (
|
660 |
+
first["label"].item()
|
661 |
+
if isinstance(first["label"], torch.Tensor)
|
662 |
+
else first["label"]
|
663 |
+
)
|
664 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
665 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
666 |
+
|
667 |
+
return batch
|
geneformer/emb_extractor.py
ADDED
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer embedding extractor.
|
3 |
+
|
4 |
+
**Description:**
|
5 |
+
|
6 |
+
| Extracts gene or cell embeddings.
|
7 |
+
| Plots cell embeddings as heatmaps or UMAPs.
|
8 |
+
| Generates cell state embedding dictionary for use with InSilicoPerturber.
|
9 |
+
|
10 |
+
"""
|
11 |
+
|
12 |
+
# imports
|
13 |
+
import logging
|
14 |
+
import pickle
|
15 |
+
from collections import Counter
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import anndata
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import pandas as pd
|
21 |
+
import scanpy as sc
|
22 |
+
import seaborn as sns
|
23 |
+
import torch
|
24 |
+
from tdigest import TDigest
|
25 |
+
from tqdm.auto import trange
|
26 |
+
|
27 |
+
from . import TOKEN_DICTIONARY_FILE
|
28 |
+
from . import perturber_utils as pu
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
# extract embeddings
|
34 |
+
def get_embs(
|
35 |
+
model,
|
36 |
+
filtered_input_data,
|
37 |
+
emb_mode,
|
38 |
+
layer_to_quant,
|
39 |
+
pad_token_id,
|
40 |
+
forward_batch_size,
|
41 |
+
token_gene_dict,
|
42 |
+
special_token=False,
|
43 |
+
summary_stat=None,
|
44 |
+
silent=False,
|
45 |
+
):
|
46 |
+
model_input_size = pu.get_model_input_size(model)
|
47 |
+
total_batch_length = len(filtered_input_data)
|
48 |
+
|
49 |
+
if summary_stat is None:
|
50 |
+
embs_list = []
|
51 |
+
elif summary_stat is not None:
|
52 |
+
# get # of emb dims
|
53 |
+
emb_dims = pu.get_model_emb_dims(model)
|
54 |
+
if emb_mode == "cell":
|
55 |
+
# initiate tdigests for # of emb dims
|
56 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
57 |
+
if emb_mode == "gene":
|
58 |
+
gene_set = list(
|
59 |
+
{
|
60 |
+
element
|
61 |
+
for sublist in filtered_input_data["input_ids"]
|
62 |
+
for element in sublist
|
63 |
+
}
|
64 |
+
)
|
65 |
+
# initiate dict with genes as keys and tdigests for # of emb dims as values
|
66 |
+
embs_tdigests_dict = {
|
67 |
+
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
+
}
|
69 |
+
|
70 |
+
# Check if CLS and EOS token is present in the token dictionary
|
71 |
+
cls_present = any("<cls>" in value for value in token_gene_dict.values())
|
72 |
+
eos_present = any("<eos>" in value for value in token_gene_dict.values())
|
73 |
+
if emb_mode == "cls":
|
74 |
+
assert cls_present, "<cls> token missing in token dictionary"
|
75 |
+
# Check to make sure that the first token of the filtered input data is cls token
|
76 |
+
gene_token_dict = {v: k for k, v in token_gene_dict.items()}
|
77 |
+
cls_token_id = gene_token_dict["<cls>"]
|
78 |
+
assert (
|
79 |
+
filtered_input_data["input_ids"][0][0] == cls_token_id
|
80 |
+
), "First token is not <cls> token value"
|
81 |
+
elif emb_mode == "cell":
|
82 |
+
if cls_present:
|
83 |
+
logger.warning(
|
84 |
+
"CLS token present in token dictionary, excluding from average."
|
85 |
+
)
|
86 |
+
if eos_present:
|
87 |
+
logger.warning(
|
88 |
+
"EOS token present in token dictionary, excluding from average."
|
89 |
+
)
|
90 |
+
|
91 |
+
overall_max_len = 0
|
92 |
+
|
93 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
94 |
+
max_range = min(i + forward_batch_size, total_batch_length)
|
95 |
+
|
96 |
+
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
97 |
+
|
98 |
+
max_len = int(max(minibatch["length"]))
|
99 |
+
original_lens = torch.tensor(minibatch["length"], device="cuda")
|
100 |
+
minibatch.set_format(type="torch")
|
101 |
+
|
102 |
+
input_data_minibatch = minibatch["input_ids"]
|
103 |
+
input_data_minibatch = pu.pad_tensor_list(
|
104 |
+
input_data_minibatch, max_len, pad_token_id, model_input_size
|
105 |
+
)
|
106 |
+
|
107 |
+
with torch.no_grad():
|
108 |
+
outputs = model(
|
109 |
+
input_ids=input_data_minibatch.to("cuda"),
|
110 |
+
attention_mask=pu.gen_attention_mask(minibatch),
|
111 |
+
)
|
112 |
+
|
113 |
+
embs_i = outputs.hidden_states[layer_to_quant]
|
114 |
+
|
115 |
+
if emb_mode == "cell":
|
116 |
+
if cls_present:
|
117 |
+
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
118 |
+
if eos_present:
|
119 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
120 |
+
else:
|
121 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
|
122 |
+
else:
|
123 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
124 |
+
if summary_stat is None:
|
125 |
+
embs_list.append(mean_embs)
|
126 |
+
elif summary_stat is not None:
|
127 |
+
# update tdigests with current batch for each emb dim
|
128 |
+
accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
|
129 |
+
del mean_embs
|
130 |
+
elif emb_mode == "gene":
|
131 |
+
if summary_stat is None:
|
132 |
+
embs_list.append(embs_i)
|
133 |
+
elif summary_stat is not None:
|
134 |
+
for h in trange(len(minibatch)):
|
135 |
+
length_h = minibatch[h]["length"]
|
136 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
137 |
+
|
138 |
+
# double check dimensions before unsqueezing
|
139 |
+
embs_i_dim = embs_i.dim()
|
140 |
+
if embs_i_dim != 3:
|
141 |
+
logger.error(
|
142 |
+
f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
|
143 |
+
)
|
144 |
+
raise
|
145 |
+
|
146 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
147 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
148 |
+
for k in dict_h.keys():
|
149 |
+
accumulate_tdigests(
|
150 |
+
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
151 |
+
)
|
152 |
+
del embs_h
|
153 |
+
del dict_h
|
154 |
+
elif emb_mode == "cls":
|
155 |
+
cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
|
156 |
+
embs_list.append(cls_embs)
|
157 |
+
del cls_embs
|
158 |
+
|
159 |
+
overall_max_len = max(overall_max_len, max_len)
|
160 |
+
del outputs
|
161 |
+
del minibatch
|
162 |
+
del input_data_minibatch
|
163 |
+
del embs_i
|
164 |
+
|
165 |
+
torch.cuda.empty_cache()
|
166 |
+
|
167 |
+
if summary_stat is None:
|
168 |
+
if (emb_mode == "cell") or (emb_mode == "cls"):
|
169 |
+
embs_stack = torch.cat(embs_list, dim=0)
|
170 |
+
elif emb_mode == "gene":
|
171 |
+
embs_stack = pu.pad_tensor_list(
|
172 |
+
embs_list,
|
173 |
+
overall_max_len,
|
174 |
+
pad_token_id,
|
175 |
+
model_input_size,
|
176 |
+
1,
|
177 |
+
pu.pad_3d_tensor,
|
178 |
+
)
|
179 |
+
|
180 |
+
# calculate summary stat embs from approximated tdigests
|
181 |
+
elif summary_stat is not None:
|
182 |
+
if emb_mode == "cell":
|
183 |
+
if summary_stat == "mean":
|
184 |
+
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
185 |
+
elif summary_stat == "median":
|
186 |
+
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
187 |
+
embs_stack = torch.tensor(summary_emb_list)
|
188 |
+
elif emb_mode == "gene":
|
189 |
+
if summary_stat == "mean":
|
190 |
+
[
|
191 |
+
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
192 |
+
for gene in embs_tdigests_dict.keys()
|
193 |
+
]
|
194 |
+
elif summary_stat == "median":
|
195 |
+
[
|
196 |
+
update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
|
197 |
+
for gene in embs_tdigests_dict.keys()
|
198 |
+
]
|
199 |
+
return embs_tdigests_dict
|
200 |
+
|
201 |
+
return embs_stack
|
202 |
+
|
203 |
+
|
204 |
+
def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
|
205 |
+
# note: tdigest batch update known to be slow so updating serially
|
206 |
+
[
|
207 |
+
embs_tdigests[j].update(mean_embs[i, j].item())
|
208 |
+
for i in range(mean_embs.size(0))
|
209 |
+
for j in range(emb_dims)
|
210 |
+
]
|
211 |
+
|
212 |
+
|
213 |
+
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
214 |
+
embs_tdigests_dict[gene] = accumulate_tdigests(
|
215 |
+
embs_tdigests_dict[gene], gene_embs, emb_dims
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
|
220 |
+
embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
|
221 |
+
|
222 |
+
|
223 |
+
def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
|
224 |
+
embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
|
225 |
+
|
226 |
+
|
227 |
+
def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
|
228 |
+
length_h = minibatch[h]["length"]
|
229 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
230 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
231 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
232 |
+
[
|
233 |
+
update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
|
234 |
+
for k in dict_h.keys()
|
235 |
+
]
|
236 |
+
|
237 |
+
|
238 |
+
def tdigest_mean(embs_tdigests, emb_dims):
|
239 |
+
return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
|
240 |
+
|
241 |
+
|
242 |
+
def tdigest_median(embs_tdigests, emb_dims):
|
243 |
+
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
244 |
+
|
245 |
+
|
246 |
+
def label_cell_embs(embs, downsampled_data, emb_labels):
|
247 |
+
embs_df = pd.DataFrame(embs.cpu().numpy())
|
248 |
+
if emb_labels is not None:
|
249 |
+
for label in emb_labels:
|
250 |
+
emb_label = downsampled_data[label]
|
251 |
+
embs_df[label] = emb_label
|
252 |
+
return embs_df
|
253 |
+
|
254 |
+
|
255 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
256 |
+
gene_set = {
|
257 |
+
element for sublist in downsampled_data["input_ids"] for element in sublist
|
258 |
+
}
|
259 |
+
gene_emb_dict = {k: [] for k in gene_set}
|
260 |
+
for i in range(embs.size()[0]):
|
261 |
+
length = downsampled_data[i]["length"]
|
262 |
+
dict_i = dict(
|
263 |
+
zip(
|
264 |
+
downsampled_data[i]["input_ids"][0:length],
|
265 |
+
embs[i, :, :].unsqueeze(dim=1),
|
266 |
+
)
|
267 |
+
)
|
268 |
+
for k in dict_i.keys():
|
269 |
+
gene_emb_dict[k].append(dict_i[k])
|
270 |
+
for k in gene_emb_dict.keys():
|
271 |
+
gene_emb_dict[k] = (
|
272 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
273 |
+
.cpu()
|
274 |
+
.numpy()
|
275 |
+
)
|
276 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
277 |
+
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
278 |
+
return embs_df
|
279 |
+
|
280 |
+
|
281 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
282 |
+
only_embs_df = embs_df.iloc[:, :emb_dims]
|
283 |
+
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
284 |
+
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
285 |
+
str
|
286 |
+
)
|
287 |
+
vars_dict = {"embs": only_embs_df.columns}
|
288 |
+
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
289 |
+
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
290 |
+
sc.tl.pca(adata, svd_solver="arpack")
|
291 |
+
sc.pp.neighbors(adata, random_state=seed)
|
292 |
+
sc.tl.umap(adata, random_state=seed)
|
293 |
+
sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
|
294 |
+
sns.set_style("white")
|
295 |
+
default_kwargs_dict = {"size": 200}
|
296 |
+
if kwargs_dict is not None:
|
297 |
+
default_kwargs_dict.update(kwargs_dict)
|
298 |
+
|
299 |
+
cats = set(embs_df[label])
|
300 |
+
|
301 |
+
with plt.rc_context():
|
302 |
+
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
303 |
+
ax.legend(
|
304 |
+
markerscale=2,
|
305 |
+
frameon=False,
|
306 |
+
loc="center left",
|
307 |
+
bbox_to_anchor=(1, 0.5),
|
308 |
+
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
309 |
+
)
|
310 |
+
plt.show()
|
311 |
+
plt.savefig(output_file, bbox_inches="tight")
|
312 |
+
|
313 |
+
|
314 |
+
def gen_heatmap_class_colors(labels, df):
|
315 |
+
pal = sns.cubehelix_palette(
|
316 |
+
len(Counter(labels).keys()),
|
317 |
+
light=0.9,
|
318 |
+
dark=0.1,
|
319 |
+
hue=1,
|
320 |
+
reverse=True,
|
321 |
+
start=1,
|
322 |
+
rot=-2,
|
323 |
+
)
|
324 |
+
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
325 |
+
colors = pd.Series(labels, index=df.index).map(lut)
|
326 |
+
return colors
|
327 |
+
|
328 |
+
|
329 |
+
def gen_heatmap_class_dict(classes, label_colors_series):
|
330 |
+
class_color_dict_df = pd.DataFrame(
|
331 |
+
{"classes": classes, "color": label_colors_series}
|
332 |
+
)
|
333 |
+
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
334 |
+
return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
|
335 |
+
|
336 |
+
|
337 |
+
def make_colorbar(embs_df, label):
|
338 |
+
labels = list(embs_df[label])
|
339 |
+
|
340 |
+
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
341 |
+
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
342 |
+
|
343 |
+
# create dictionary for colors and classes
|
344 |
+
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
345 |
+
return label_colors, label_color_dict
|
346 |
+
|
347 |
+
|
348 |
+
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
349 |
+
sns.set_style("white")
|
350 |
+
sns.set(font_scale=2)
|
351 |
+
plt.figure(figsize=(15, 15), dpi=150)
|
352 |
+
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
353 |
+
|
354 |
+
default_kwargs_dict = {
|
355 |
+
"row_cluster": True,
|
356 |
+
"col_cluster": True,
|
357 |
+
"row_colors": label_colors,
|
358 |
+
"standard_scale": 1,
|
359 |
+
"linewidths": 0,
|
360 |
+
"xticklabels": False,
|
361 |
+
"yticklabels": False,
|
362 |
+
"figsize": (15, 15),
|
363 |
+
"center": 0,
|
364 |
+
"cmap": "magma",
|
365 |
+
}
|
366 |
+
|
367 |
+
if kwargs_dict is not None:
|
368 |
+
default_kwargs_dict.update(kwargs_dict)
|
369 |
+
g = sns.clustermap(
|
370 |
+
embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
|
371 |
+
)
|
372 |
+
|
373 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
374 |
+
|
375 |
+
for label_color in list(label_color_dict.keys()):
|
376 |
+
g.ax_col_dendrogram.bar(
|
377 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
378 |
+
)
|
379 |
+
|
380 |
+
g.ax_col_dendrogram.legend(
|
381 |
+
title=f"{label}",
|
382 |
+
loc="lower center",
|
383 |
+
ncol=4,
|
384 |
+
bbox_to_anchor=(0.5, 1),
|
385 |
+
facecolor="white",
|
386 |
+
)
|
387 |
+
plt.show()
|
388 |
+
logger.info(f"Output file: {output_file}")
|
389 |
+
plt.savefig(output_file, bbox_inches="tight")
|
390 |
+
|
391 |
+
|
392 |
+
class EmbExtractor:
|
393 |
+
valid_option_dict = {
|
394 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
395 |
+
"num_classes": {int},
|
396 |
+
"emb_mode": {"cls", "cell", "gene"},
|
397 |
+
"cell_emb_style": {"mean_pool"},
|
398 |
+
"gene_emb_style": {"mean_pool"},
|
399 |
+
"filter_data": {None, dict},
|
400 |
+
"max_ncells": {None, int},
|
401 |
+
"emb_layer": {-1, 0},
|
402 |
+
"emb_label": {None, list},
|
403 |
+
"labels_to_plot": {None, list},
|
404 |
+
"forward_batch_size": {int},
|
405 |
+
"token_dictionary_file": {None, str},
|
406 |
+
"nproc": {int},
|
407 |
+
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
408 |
+
}
|
409 |
+
|
410 |
+
def __init__(
|
411 |
+
self,
|
412 |
+
model_type="Pretrained",
|
413 |
+
num_classes=0,
|
414 |
+
emb_mode="cls",
|
415 |
+
cell_emb_style="mean_pool",
|
416 |
+
gene_emb_style="mean_pool",
|
417 |
+
filter_data=None,
|
418 |
+
max_ncells=1000,
|
419 |
+
emb_layer=-1,
|
420 |
+
emb_label=None,
|
421 |
+
labels_to_plot=None,
|
422 |
+
forward_batch_size=100,
|
423 |
+
nproc=4,
|
424 |
+
summary_stat=None,
|
425 |
+
token_dictionary_file=None,
|
426 |
+
):
|
427 |
+
"""
|
428 |
+
Initialize embedding extractor.
|
429 |
+
|
430 |
+
**Parameters:**
|
431 |
+
|
432 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
433 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
434 |
+
num_classes : int
|
435 |
+
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
436 |
+
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
437 |
+
emb_mode : {"cls", "cell", "gene"}
|
438 |
+
| Whether to output CLS, cell, or gene embeddings.
|
439 |
+
| CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
|
440 |
+
cell_emb_style : {"mean_pool"}
|
441 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
442 |
+
| Currently only option is mean pooling of gene embeddings for given cell.
|
443 |
+
gene_emb_style : "mean_pool"
|
444 |
+
| Method for summarizing gene embeddings.
|
445 |
+
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
446 |
+
filter_data : None, dict
|
447 |
+
| Default is to extract embeddings from all input data.
|
448 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
449 |
+
max_ncells : None, int
|
450 |
+
| Maximum number of cells to extract embeddings from.
|
451 |
+
| Default is 1000 cells randomly sampled from input data.
|
452 |
+
| If None, will extract embeddings from all cells.
|
453 |
+
emb_layer : {-1, 0}
|
454 |
+
| Embedding layer to extract.
|
455 |
+
| The last layer is most specifically weighted to optimize the given learning objective.
|
456 |
+
| Generally, it is best to extract the 2nd to last layer to get a more general representation.
|
457 |
+
| -1: 2nd to last layer
|
458 |
+
| 0: last layer
|
459 |
+
emb_label : None, list
|
460 |
+
| List of column name(s) in .dataset to add as labels to embedding output.
|
461 |
+
labels_to_plot : None, list
|
462 |
+
| Cell labels to plot.
|
463 |
+
| Shown as color bar in heatmap.
|
464 |
+
| Shown as cell color in umap.
|
465 |
+
| Plotting umap requires labels to plot.
|
466 |
+
forward_batch_size : int
|
467 |
+
| Batch size for forward pass.
|
468 |
+
nproc : int
|
469 |
+
| Number of CPU processes to use.
|
470 |
+
summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
|
471 |
+
| If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
|
472 |
+
| If mean or median, outputs only approximated mean or median embedding of input data.
|
473 |
+
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
474 |
+
| Non-exact is slower but more memory-efficient.
|
475 |
+
token_dictionary_file : Path
|
476 |
+
| Default is the Geneformer token dictionary
|
477 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
478 |
+
|
479 |
+
**Examples:**
|
480 |
+
|
481 |
+
.. code-block :: python
|
482 |
+
|
483 |
+
>>> from geneformer import EmbExtractor
|
484 |
+
>>> embex = EmbExtractor(model_type="CellClassifier",
|
485 |
+
... num_classes=3,
|
486 |
+
... emb_mode="cell",
|
487 |
+
... filter_data={"cell_type":["cardiomyocyte"]},
|
488 |
+
... max_ncells=1000,
|
489 |
+
... emb_layer=-1,
|
490 |
+
... emb_label=["disease", "cell_type"],
|
491 |
+
... labels_to_plot=["disease", "cell_type"])
|
492 |
+
|
493 |
+
"""
|
494 |
+
|
495 |
+
self.model_type = model_type
|
496 |
+
self.num_classes = num_classes
|
497 |
+
self.emb_mode = emb_mode
|
498 |
+
self.cell_emb_style = cell_emb_style
|
499 |
+
self.gene_emb_style = gene_emb_style
|
500 |
+
self.filter_data = filter_data
|
501 |
+
self.max_ncells = max_ncells
|
502 |
+
self.emb_layer = emb_layer
|
503 |
+
self.emb_label = emb_label
|
504 |
+
self.labels_to_plot = labels_to_plot
|
505 |
+
self.token_dictionary_file = token_dictionary_file
|
506 |
+
self.forward_batch_size = forward_batch_size
|
507 |
+
self.nproc = nproc
|
508 |
+
if (summary_stat is not None) and ("exact" in summary_stat):
|
509 |
+
self.summary_stat = None
|
510 |
+
self.exact_summary_stat = summary_stat
|
511 |
+
else:
|
512 |
+
self.summary_stat = summary_stat
|
513 |
+
self.exact_summary_stat = None
|
514 |
+
|
515 |
+
self.validate_options()
|
516 |
+
|
517 |
+
# load token dictionary (Ensembl IDs:token)
|
518 |
+
if self.token_dictionary_file is None:
|
519 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
520 |
+
with open(token_dictionary_file, "rb") as f:
|
521 |
+
self.gene_token_dict = pickle.load(f)
|
522 |
+
|
523 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
524 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
525 |
+
|
526 |
+
def validate_options(self):
|
527 |
+
# confirm arguments are within valid options and compatible with each other
|
528 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
529 |
+
attr_value = self.__dict__[attr_name]
|
530 |
+
if not isinstance(attr_value, (list, dict)):
|
531 |
+
if attr_value in valid_options:
|
532 |
+
continue
|
533 |
+
valid_type = False
|
534 |
+
for option in valid_options:
|
535 |
+
if (option in [int, list, dict, bool, str]) and isinstance(
|
536 |
+
attr_value, option
|
537 |
+
):
|
538 |
+
valid_type = True
|
539 |
+
break
|
540 |
+
if valid_type:
|
541 |
+
continue
|
542 |
+
logger.error(
|
543 |
+
f"Invalid option for {attr_name}. "
|
544 |
+
f"Valid options for {attr_name}: {valid_options}"
|
545 |
+
)
|
546 |
+
raise
|
547 |
+
|
548 |
+
if self.filter_data is not None:
|
549 |
+
for key, value in self.filter_data.items():
|
550 |
+
if not isinstance(value, list):
|
551 |
+
self.filter_data[key] = [value]
|
552 |
+
logger.warning(
|
553 |
+
"Values in filter_data dict must be lists. "
|
554 |
+
f"Changing {key} value to list ([{value}])."
|
555 |
+
)
|
556 |
+
|
557 |
+
def extract_embs(
|
558 |
+
self,
|
559 |
+
model_directory,
|
560 |
+
input_data_file,
|
561 |
+
output_directory,
|
562 |
+
output_prefix,
|
563 |
+
output_torch_embs=False,
|
564 |
+
cell_state=None,
|
565 |
+
):
|
566 |
+
"""
|
567 |
+
Extract embeddings from input data and save as results in output_directory.
|
568 |
+
|
569 |
+
**Parameters:**
|
570 |
+
|
571 |
+
model_directory : Path
|
572 |
+
| Path to directory containing model
|
573 |
+
input_data_file : Path
|
574 |
+
| Path to directory containing .dataset inputs
|
575 |
+
output_directory : Path
|
576 |
+
| Path to directory where embedding data will be saved as csv
|
577 |
+
output_prefix : str
|
578 |
+
| Prefix for output file
|
579 |
+
output_torch_embs : bool
|
580 |
+
| Whether or not to also output the embeddings as a tensor.
|
581 |
+
| Note, if true, will output embeddings as both dataframe and tensor.
|
582 |
+
cell_state : dict
|
583 |
+
| Cell state key and value for state embedding extraction.
|
584 |
+
|
585 |
+
**Examples:**
|
586 |
+
|
587 |
+
.. code-block :: python
|
588 |
+
|
589 |
+
>>> embs = embex.extract_embs("path/to/model",
|
590 |
+
... "path/to/input_data",
|
591 |
+
... "path/to/output_directory",
|
592 |
+
... "output_prefix")
|
593 |
+
|
594 |
+
"""
|
595 |
+
|
596 |
+
filtered_input_data = pu.load_and_filter(
|
597 |
+
self.filter_data, self.nproc, input_data_file
|
598 |
+
)
|
599 |
+
|
600 |
+
# Check to make sure that all the labels exist in the tokenized data:
|
601 |
+
if self.emb_label is not None:
|
602 |
+
for label in self.emb_label:
|
603 |
+
assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
|
604 |
+
|
605 |
+
if cell_state is not None:
|
606 |
+
filtered_input_data = pu.filter_by_dict(
|
607 |
+
filtered_input_data, cell_state, self.nproc
|
608 |
+
)
|
609 |
+
downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
|
610 |
+
model = pu.load_model(
|
611 |
+
self.model_type, self.num_classes, model_directory, mode="eval"
|
612 |
+
)
|
613 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
614 |
+
embs = get_embs(
|
615 |
+
model=model,
|
616 |
+
filtered_input_data=downsampled_data,
|
617 |
+
emb_mode=self.emb_mode,
|
618 |
+
layer_to_quant=layer_to_quant,
|
619 |
+
pad_token_id=self.pad_token_id,
|
620 |
+
forward_batch_size=self.forward_batch_size,
|
621 |
+
token_gene_dict=self.token_gene_dict,
|
622 |
+
summary_stat=self.summary_stat,
|
623 |
+
)
|
624 |
+
|
625 |
+
if self.emb_mode == "cell":
|
626 |
+
if self.summary_stat is None:
|
627 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
628 |
+
elif self.summary_stat is not None:
|
629 |
+
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
630 |
+
elif self.emb_mode == "gene":
|
631 |
+
if self.summary_stat is None:
|
632 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
633 |
+
elif self.summary_stat is not None:
|
634 |
+
embs_df = pd.DataFrame(embs).T
|
635 |
+
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
636 |
+
elif self.emb_mode == "cls":
|
637 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
638 |
+
|
639 |
+
# save embeddings to output_path
|
640 |
+
if cell_state is None:
|
641 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
642 |
+
embs_df.to_csv(output_path)
|
643 |
+
|
644 |
+
if self.exact_summary_stat == "exact_mean":
|
645 |
+
embs = embs.mean(dim=0)
|
646 |
+
emb_dims = pu.get_model_emb_dims(model)
|
647 |
+
embs_df = pd.DataFrame(
|
648 |
+
embs_df[0 : emb_dims - 1].mean(axis="rows"),
|
649 |
+
columns=[self.exact_summary_stat],
|
650 |
+
).T
|
651 |
+
elif self.exact_summary_stat == "exact_median":
|
652 |
+
embs = torch.median(embs, dim=0)[0]
|
653 |
+
emb_dims = pu.get_model_emb_dims(model)
|
654 |
+
embs_df = pd.DataFrame(
|
655 |
+
embs_df[0 : emb_dims - 1].median(axis="rows"),
|
656 |
+
columns=[self.exact_summary_stat],
|
657 |
+
).T
|
658 |
+
|
659 |
+
if cell_state is not None:
|
660 |
+
return embs
|
661 |
+
else:
|
662 |
+
if output_torch_embs:
|
663 |
+
return embs_df, embs
|
664 |
+
else:
|
665 |
+
return embs_df
|
666 |
+
|
667 |
+
def get_state_embs(
|
668 |
+
self,
|
669 |
+
cell_states_to_model,
|
670 |
+
model_directory,
|
671 |
+
input_data_file,
|
672 |
+
output_directory,
|
673 |
+
output_prefix,
|
674 |
+
output_torch_embs=True,
|
675 |
+
):
|
676 |
+
"""
|
677 |
+
Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
|
678 |
+
|
679 |
+
**Parameters:**
|
680 |
+
|
681 |
+
cell_states_to_model : None, dict
|
682 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
683 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
684 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
685 |
+
| start_state: value in the state_key column that specifies the start state
|
686 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
687 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
688 |
+
| For example:
|
689 |
+
| {"state_key": "disease",
|
690 |
+
| "start_state": "dcm",
|
691 |
+
| "goal_state": "nf",
|
692 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
693 |
+
model_directory : Path
|
694 |
+
| Path to directory containing model
|
695 |
+
input_data_file : Path
|
696 |
+
| Path to directory containing .dataset inputs
|
697 |
+
output_directory : Path
|
698 |
+
| Path to directory where embedding data will be saved as csv
|
699 |
+
output_prefix : str
|
700 |
+
| Prefix for output file
|
701 |
+
output_torch_embs : bool
|
702 |
+
| Whether or not to also output the embeddings as a tensor.
|
703 |
+
| Note, if true, will output embeddings as both dataframe and tensor.
|
704 |
+
|
705 |
+
**Outputs**
|
706 |
+
|
707 |
+
| Outputs state_embs_dict for use with in silico perturber.
|
708 |
+
| Format is dictionary of embedding positions of each cell state to model shifts from/towards.
|
709 |
+
| Keys specify each possible cell state to model.
|
710 |
+
| Values are target embedding positions as torch.tensor.
|
711 |
+
| For example:
|
712 |
+
| {"nf": emb_nf,
|
713 |
+
| "hcm": emb_hcm,
|
714 |
+
| "dcm": emb_dcm,
|
715 |
+
| "other1": emb_other1,
|
716 |
+
| "other2": emb_other2}
|
717 |
+
"""
|
718 |
+
|
719 |
+
pu.validate_cell_states_to_model(cell_states_to_model)
|
720 |
+
valid_summary_stats = ["exact_mean", "exact_median"]
|
721 |
+
if self.exact_summary_stat not in valid_summary_stats:
|
722 |
+
logger.error(
|
723 |
+
"For extracting state embs, summary_stat in EmbExtractor "
|
724 |
+
f"must be set to option in {valid_summary_stats}"
|
725 |
+
)
|
726 |
+
raise
|
727 |
+
|
728 |
+
if self.emb_label is not None:
|
729 |
+
logger.error(
|
730 |
+
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
731 |
+
)
|
732 |
+
raise
|
733 |
+
|
734 |
+
state_embs_dict = dict()
|
735 |
+
state_key = cell_states_to_model["state_key"]
|
736 |
+
for k, v in cell_states_to_model.items():
|
737 |
+
if k == "state_key":
|
738 |
+
continue
|
739 |
+
elif (k == "start_state") or (k == "goal_state"):
|
740 |
+
state_embs_dict[v] = self.extract_embs(
|
741 |
+
model_directory,
|
742 |
+
input_data_file,
|
743 |
+
output_directory,
|
744 |
+
output_prefix,
|
745 |
+
output_torch_embs,
|
746 |
+
cell_state={state_key: v},
|
747 |
+
)
|
748 |
+
else: # k == "alt_states"
|
749 |
+
for alt_state in v:
|
750 |
+
state_embs_dict[alt_state] = self.extract_embs(
|
751 |
+
model_directory,
|
752 |
+
input_data_file,
|
753 |
+
output_directory,
|
754 |
+
output_prefix,
|
755 |
+
output_torch_embs,
|
756 |
+
cell_state={state_key: alt_state},
|
757 |
+
)
|
758 |
+
|
759 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
|
760 |
+
with open(output_path, "wb") as fp:
|
761 |
+
pickle.dump(state_embs_dict, fp)
|
762 |
+
|
763 |
+
return state_embs_dict
|
764 |
+
|
765 |
+
def plot_embs(
|
766 |
+
self,
|
767 |
+
embs,
|
768 |
+
plot_style,
|
769 |
+
output_directory,
|
770 |
+
output_prefix,
|
771 |
+
max_ncells_to_plot=1000,
|
772 |
+
kwargs_dict=None,
|
773 |
+
):
|
774 |
+
"""
|
775 |
+
Plot embeddings, coloring by provided labels.
|
776 |
+
|
777 |
+
**Parameters:**
|
778 |
+
|
779 |
+
embs : pandas.core.frame.DataFrame
|
780 |
+
| Pandas dataframe containing embeddings output from extract_embs
|
781 |
+
plot_style : str
|
782 |
+
| Style of plot: "heatmap" or "umap"
|
783 |
+
output_directory : Path
|
784 |
+
| Path to directory where plots will be saved as pdf
|
785 |
+
output_prefix : str
|
786 |
+
| Prefix for output file
|
787 |
+
max_ncells_to_plot : None, int
|
788 |
+
| Maximum number of cells to plot.
|
789 |
+
| Default is 1000 cells randomly sampled from embeddings.
|
790 |
+
| If None, will plot embeddings from all cells.
|
791 |
+
kwargs_dict : dict
|
792 |
+
| Dictionary of kwargs to pass to plotting function.
|
793 |
+
|
794 |
+
**Examples:**
|
795 |
+
|
796 |
+
.. code-block :: python
|
797 |
+
|
798 |
+
>>> embex.plot_embs(embs=embs,
|
799 |
+
... plot_style="heatmap",
|
800 |
+
... output_directory="path/to/output_directory",
|
801 |
+
... output_prefix="output_prefix")
|
802 |
+
|
803 |
+
"""
|
804 |
+
|
805 |
+
if plot_style not in ["heatmap", "umap"]:
|
806 |
+
logger.error(
|
807 |
+
"Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
|
808 |
+
)
|
809 |
+
raise
|
810 |
+
|
811 |
+
if (plot_style == "umap") and (self.labels_to_plot is None):
|
812 |
+
logger.error("Plotting UMAP requires 'labels_to_plot'. ")
|
813 |
+
raise
|
814 |
+
|
815 |
+
if max_ncells_to_plot is not None:
|
816 |
+
if max_ncells_to_plot > self.max_ncells:
|
817 |
+
max_ncells_to_plot = self.max_ncells
|
818 |
+
logger.warning(
|
819 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
820 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
821 |
+
)
|
822 |
+
elif max_ncells_to_plot < self.max_ncells:
|
823 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
824 |
+
|
825 |
+
if self.emb_label is None:
|
826 |
+
label_len = 0
|
827 |
+
else:
|
828 |
+
label_len = len(self.emb_label)
|
829 |
+
|
830 |
+
emb_dims = embs.shape[1] - label_len
|
831 |
+
|
832 |
+
if self.emb_label is None:
|
833 |
+
emb_labels = None
|
834 |
+
else:
|
835 |
+
emb_labels = embs.columns[emb_dims:]
|
836 |
+
|
837 |
+
if plot_style == "umap":
|
838 |
+
for label in self.labels_to_plot:
|
839 |
+
if label not in emb_labels:
|
840 |
+
logger.warning(
|
841 |
+
f"Label {label} from labels_to_plot "
|
842 |
+
f"not present in provided embeddings dataframe."
|
843 |
+
)
|
844 |
+
continue
|
845 |
+
output_prefix_label = output_prefix + f"_umap_{label}"
|
846 |
+
output_file = (
|
847 |
+
Path(output_directory) / output_prefix_label
|
848 |
+
).with_suffix(".pdf")
|
849 |
+
plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
|
850 |
+
|
851 |
+
if plot_style == "heatmap":
|
852 |
+
for label in self.labels_to_plot:
|
853 |
+
if label not in emb_labels:
|
854 |
+
logger.warning(
|
855 |
+
f"Label {label} from labels_to_plot "
|
856 |
+
f"not present in provided embeddings dataframe."
|
857 |
+
)
|
858 |
+
continue
|
859 |
+
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
860 |
+
output_file = (
|
861 |
+
Path(output_directory) / output_prefix_label
|
862 |
+
).with_suffix(".pdf")
|
863 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
geneformer/evaluation_utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import pickle
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import seaborn as sns
|
10 |
+
import torch
|
11 |
+
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
12 |
+
from sklearn import preprocessing
|
13 |
+
from sklearn.metrics import (
|
14 |
+
ConfusionMatrixDisplay,
|
15 |
+
accuracy_score,
|
16 |
+
auc,
|
17 |
+
confusion_matrix,
|
18 |
+
f1_score,
|
19 |
+
roc_curve,
|
20 |
+
)
|
21 |
+
from tqdm.auto import trange
|
22 |
+
|
23 |
+
from . import TOKEN_DICTIONARY_FILE
|
24 |
+
from .emb_extractor import make_colorbar
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
30 |
+
if max_len is None:
|
31 |
+
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
32 |
+
|
33 |
+
# load token dictionary (Ensembl IDs:token)
|
34 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
35 |
+
gene_token_dict = pickle.load(f)
|
36 |
+
|
37 |
+
def pad_label_example(example):
|
38 |
+
example[label_name] = np.pad(
|
39 |
+
example[label_name],
|
40 |
+
(0, max_len - len(example["input_ids"])),
|
41 |
+
mode="constant",
|
42 |
+
constant_values=-100,
|
43 |
+
)
|
44 |
+
example["input_ids"] = np.pad(
|
45 |
+
example["input_ids"],
|
46 |
+
(0, max_len - len(example["input_ids"])),
|
47 |
+
mode="constant",
|
48 |
+
constant_values=gene_token_dict.get("<pad>"),
|
49 |
+
)
|
50 |
+
example["attention_mask"] = (
|
51 |
+
example["input_ids"] != gene_token_dict.get("<pad>")
|
52 |
+
).astype(int)
|
53 |
+
return example
|
54 |
+
|
55 |
+
padded_batch = cell_batch.map(pad_label_example)
|
56 |
+
return padded_batch
|
57 |
+
|
58 |
+
|
59 |
+
# Function to find the largest number smaller
|
60 |
+
# than or equal to N that is divisible by k
|
61 |
+
def find_largest_div(N, K):
|
62 |
+
rem = N % K
|
63 |
+
if rem == 0:
|
64 |
+
return N
|
65 |
+
else:
|
66 |
+
return N - rem
|
67 |
+
|
68 |
+
|
69 |
+
def vote(logit_list):
|
70 |
+
m = max(logit_list)
|
71 |
+
logit_list.index(m)
|
72 |
+
indices = [i for i, x in enumerate(logit_list) if x == m]
|
73 |
+
if len(indices) > 1:
|
74 |
+
return "tie"
|
75 |
+
else:
|
76 |
+
return indices[0]
|
77 |
+
|
78 |
+
|
79 |
+
def py_softmax(vector):
|
80 |
+
e = np.exp(vector)
|
81 |
+
return e / e.sum()
|
82 |
+
|
83 |
+
|
84 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
85 |
+
if classifier_type == "gene":
|
86 |
+
label_name = "labels"
|
87 |
+
elif classifier_type == "cell":
|
88 |
+
label_name = "label"
|
89 |
+
|
90 |
+
predict_logits = []
|
91 |
+
predict_labels = []
|
92 |
+
model.eval()
|
93 |
+
|
94 |
+
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
95 |
+
evalset_len = len(evalset)
|
96 |
+
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
97 |
+
if len(evalset) - max_divisible == 1:
|
98 |
+
evalset_len = max_divisible
|
99 |
+
|
100 |
+
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
101 |
+
|
102 |
+
disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
|
103 |
+
for i in trange(0, evalset_len, forward_batch_size):
|
104 |
+
max_range = min(i + forward_batch_size, evalset_len)
|
105 |
+
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
106 |
+
padded_batch = preprocess_classifier_batch(
|
107 |
+
batch_evalset, max_evalset_len, label_name
|
108 |
+
)
|
109 |
+
padded_batch.set_format(type="torch")
|
110 |
+
|
111 |
+
input_data_batch = padded_batch["input_ids"]
|
112 |
+
attn_msk_batch = padded_batch["attention_mask"]
|
113 |
+
label_batch = padded_batch[label_name]
|
114 |
+
with torch.no_grad():
|
115 |
+
outputs = model(
|
116 |
+
input_ids=input_data_batch.to("cuda"),
|
117 |
+
attention_mask=attn_msk_batch.to("cuda"),
|
118 |
+
labels=label_batch.to("cuda"),
|
119 |
+
)
|
120 |
+
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
121 |
+
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
122 |
+
|
123 |
+
enable_progress_bar()
|
124 |
+
logits_by_cell = torch.cat(predict_logits)
|
125 |
+
last_dim = len(logits_by_cell.shape) - 1
|
126 |
+
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
|
127 |
+
labels_by_cell = torch.cat(predict_labels)
|
128 |
+
all_labels = torch.flatten(labels_by_cell)
|
129 |
+
logit_label_paired = [
|
130 |
+
item
|
131 |
+
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
132 |
+
if item[1] != -100
|
133 |
+
]
|
134 |
+
y_pred = [vote(item[0]) for item in logit_label_paired]
|
135 |
+
y_true = [item[1] for item in logit_label_paired]
|
136 |
+
logits_list = [item[0] for item in logit_label_paired]
|
137 |
+
return y_pred, y_true, logits_list
|
138 |
+
|
139 |
+
|
140 |
+
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
141 |
+
conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
|
142 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
143 |
+
acc = accuracy_score(y_true, y_pred)
|
144 |
+
roc_metrics = None # roc metrics not reported for multiclass
|
145 |
+
if num_classes == 2:
|
146 |
+
y_score = [py_softmax(item)[1] for item in logits_list]
|
147 |
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
148 |
+
mean_fpr = np.linspace(0, 1, 100)
|
149 |
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
150 |
+
interp_tpr[0] = 0.0
|
151 |
+
tpr_wt = len(tpr)
|
152 |
+
roc_auc = auc(fpr, tpr)
|
153 |
+
roc_metrics = {
|
154 |
+
"fpr": fpr,
|
155 |
+
"tpr": tpr,
|
156 |
+
"interp_tpr": interp_tpr,
|
157 |
+
"auc": roc_auc,
|
158 |
+
"tpr_wt": tpr_wt,
|
159 |
+
}
|
160 |
+
return conf_mat, macro_f1, acc, roc_metrics
|
161 |
+
|
162 |
+
|
163 |
+
# get cross-validated mean and sd metrics
|
164 |
+
def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
165 |
+
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
|
166 |
+
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
|
167 |
+
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
168 |
+
mean_tpr[-1] = 1.0
|
169 |
+
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
|
170 |
+
roc_auc = np.sum(all_weighted_roc_auc)
|
171 |
+
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
|
172 |
+
return mean_tpr, roc_auc, roc_auc_sd
|
173 |
+
|
174 |
+
|
175 |
+
# plot ROC curve
|
176 |
+
def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
|
177 |
+
fig = plt.figure()
|
178 |
+
fig.set_size_inches(10, 8)
|
179 |
+
sns.set(font_scale=2)
|
180 |
+
sns.set_style("white")
|
181 |
+
lw = 3
|
182 |
+
for model_name in roc_metric_dict.keys():
|
183 |
+
mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
|
184 |
+
mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
|
185 |
+
roc_auc = roc_metric_dict[model_name]["roc_auc"]
|
186 |
+
roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
|
187 |
+
color = model_style_dict[model_name]["color"]
|
188 |
+
linestyle = model_style_dict[model_name]["linestyle"]
|
189 |
+
if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
|
190 |
+
label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
|
191 |
+
else:
|
192 |
+
label = f"{model_name} (AUC {roc_auc:0.2f})"
|
193 |
+
plt.plot(
|
194 |
+
mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
|
195 |
+
)
|
196 |
+
|
197 |
+
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
|
198 |
+
plt.xlim([0.0, 1.0])
|
199 |
+
plt.ylim([0.0, 1.05])
|
200 |
+
plt.xlabel("False Positive Rate")
|
201 |
+
plt.ylabel("True Positive Rate")
|
202 |
+
plt.title(title)
|
203 |
+
plt.legend(loc="lower right")
|
204 |
+
|
205 |
+
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
206 |
+
plt.savefig(output_file, bbox_inches="tight")
|
207 |
+
plt.show()
|
208 |
+
|
209 |
+
|
210 |
+
# plot confusion matrix
|
211 |
+
def plot_confusion_matrix(
|
212 |
+
conf_mat_df, title, output_dir, output_prefix, custom_class_order
|
213 |
+
):
|
214 |
+
fig = plt.figure()
|
215 |
+
fig.set_size_inches(10, 10)
|
216 |
+
sns.set(font_scale=1)
|
217 |
+
sns.set_style("whitegrid", {"axes.grid": False})
|
218 |
+
if custom_class_order is not None:
|
219 |
+
conf_mat_df = conf_mat_df.reindex(
|
220 |
+
index=custom_class_order, columns=custom_class_order
|
221 |
+
)
|
222 |
+
display_labels = generate_display_labels(conf_mat_df)
|
223 |
+
conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
|
224 |
+
display = ConfusionMatrixDisplay(
|
225 |
+
confusion_matrix=conf_mat, display_labels=display_labels
|
226 |
+
)
|
227 |
+
display.plot(cmap="Blues", values_format=".2g")
|
228 |
+
plt.title(title)
|
229 |
+
plt.show()
|
230 |
+
|
231 |
+
output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
|
232 |
+
display.figure_.savefig(output_file, bbox_inches="tight")
|
233 |
+
|
234 |
+
|
235 |
+
def generate_display_labels(conf_mat_df):
|
236 |
+
display_labels = []
|
237 |
+
i = 0
|
238 |
+
for label in conf_mat_df.index:
|
239 |
+
display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
|
240 |
+
i = i + 1
|
241 |
+
return display_labels
|
242 |
+
|
243 |
+
|
244 |
+
def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
|
245 |
+
sns.set(font_scale=2)
|
246 |
+
plt.figure(figsize=(10, 10), dpi=150)
|
247 |
+
label_colors, label_color_dict = make_colorbar(predictions_df, "true")
|
248 |
+
predictions_df = predictions_df.drop(columns=["true"])
|
249 |
+
predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
|
250 |
+
predict_label_list = [label for label in predictions_df.columns]
|
251 |
+
predict_colors = pd.DataFrame(
|
252 |
+
pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
|
253 |
+
)
|
254 |
+
|
255 |
+
default_kwargs_dict = {
|
256 |
+
"row_cluster": False,
|
257 |
+
"col_cluster": False,
|
258 |
+
"row_colors": label_colors,
|
259 |
+
"col_colors": predict_colors,
|
260 |
+
"linewidths": 0,
|
261 |
+
"xticklabels": False,
|
262 |
+
"yticklabels": False,
|
263 |
+
"center": 0,
|
264 |
+
"cmap": "vlag",
|
265 |
+
}
|
266 |
+
|
267 |
+
if kwargs_dict is not None:
|
268 |
+
default_kwargs_dict.update(kwargs_dict)
|
269 |
+
g = sns.clustermap(predictions_df, **default_kwargs_dict)
|
270 |
+
|
271 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
272 |
+
|
273 |
+
for label_color in list(label_color_dict.keys()):
|
274 |
+
g.ax_col_dendrogram.bar(
|
275 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
276 |
+
)
|
277 |
+
|
278 |
+
g.ax_col_dendrogram.legend(
|
279 |
+
title=f"{title}",
|
280 |
+
loc="lower center",
|
281 |
+
ncol=4,
|
282 |
+
bbox_to_anchor=(0.5, 1),
|
283 |
+
facecolor="white",
|
284 |
+
)
|
285 |
+
|
286 |
+
output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
|
287 |
+
plt.savefig(output_file, bbox_inches="tight")
|
geneformer/in_silico_perturber.py
ADDED
@@ -0,0 +1,1579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer in silico perturber.
|
3 |
+
|
4 |
+
**Usage:**
|
5 |
+
|
6 |
+
.. code-block :: python
|
7 |
+
|
8 |
+
>>> from geneformer import InSilicoPerturber
|
9 |
+
>>> isp = InSilicoPerturber(perturb_type="delete",
|
10 |
+
... perturb_rank_shift=None,
|
11 |
+
... genes_to_perturb="all",
|
12 |
+
... model_type="CellClassifier",
|
13 |
+
... num_classes=0,
|
14 |
+
... emb_mode="cell",
|
15 |
+
... filter_data={"cell_type":["cardiomyocyte"]},
|
16 |
+
... cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
|
17 |
+
... state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
|
18 |
+
... max_ncells=None,
|
19 |
+
... emb_layer=0,
|
20 |
+
... forward_batch_size=100,
|
21 |
+
... nproc=16)
|
22 |
+
>>> isp.perturb_data("path/to/model",
|
23 |
+
... "path/to/input_data",
|
24 |
+
... "path/to/output_directory",
|
25 |
+
... "output_prefix")
|
26 |
+
|
27 |
+
**Description:**
|
28 |
+
|
29 |
+
| Performs in silico perturbation (e.g. deletion or overexpression) of defined set of genes or all genes in sample of cells.
|
30 |
+
| Outputs impact of perturbation on cell or gene embeddings.
|
31 |
+
| Output files are analyzed with ``in_silico_perturber_stats``.
|
32 |
+
|
33 |
+
"""
|
34 |
+
|
35 |
+
import logging
|
36 |
+
|
37 |
+
# imports
|
38 |
+
import os
|
39 |
+
import pickle
|
40 |
+
from collections import defaultdict
|
41 |
+
|
42 |
+
import torch
|
43 |
+
from datasets import Dataset
|
44 |
+
from multiprocess import set_start_method
|
45 |
+
from tqdm.auto import trange
|
46 |
+
|
47 |
+
from . import TOKEN_DICTIONARY_FILE
|
48 |
+
from . import perturber_utils as pu
|
49 |
+
from .emb_extractor import get_embs
|
50 |
+
|
51 |
+
import datasets
|
52 |
+
datasets.logging.disable_progress_bar()
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
|
57 |
+
|
58 |
+
class InSilicoPerturber:
|
59 |
+
valid_option_dict = {
|
60 |
+
"perturb_type": {"delete", "overexpress", "inhibit", "activate"},
|
61 |
+
"perturb_rank_shift": {None, 1, 2, 3},
|
62 |
+
"genes_to_perturb": {"all", list},
|
63 |
+
"combos": {0, 1},
|
64 |
+
"anchor_gene": {None, str},
|
65 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
|
66 |
+
"num_classes": {int},
|
67 |
+
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
68 |
+
"cell_emb_style": {"mean_pool"},
|
69 |
+
"filter_data": {None, dict},
|
70 |
+
"cell_states_to_model": {None, dict},
|
71 |
+
"state_embs_dict": {None, dict},
|
72 |
+
"max_ncells": {None, int},
|
73 |
+
"cell_inds_to_perturb": {"all", dict},
|
74 |
+
"emb_layer": {-1, 0},
|
75 |
+
"token_dictionary_file": {None, str},
|
76 |
+
"forward_batch_size": {int},
|
77 |
+
"nproc": {int},
|
78 |
+
}
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
perturb_type="delete",
|
83 |
+
perturb_rank_shift=None,
|
84 |
+
genes_to_perturb="all",
|
85 |
+
combos=0,
|
86 |
+
anchor_gene=None,
|
87 |
+
model_type="Pretrained",
|
88 |
+
num_classes=0,
|
89 |
+
emb_mode="cls",
|
90 |
+
cell_emb_style="mean_pool",
|
91 |
+
filter_data=None,
|
92 |
+
cell_states_to_model=None,
|
93 |
+
state_embs_dict=None,
|
94 |
+
max_ncells=None,
|
95 |
+
cell_inds_to_perturb="all",
|
96 |
+
emb_layer=-1,
|
97 |
+
forward_batch_size=100,
|
98 |
+
nproc=4,
|
99 |
+
token_dictionary_file=None,
|
100 |
+
clear_mem_ncells=1000,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Initialize in silico perturber.
|
104 |
+
|
105 |
+
**Parameters:**
|
106 |
+
|
107 |
+
perturb_type : {"delete", "overexpress", "inhibit", "activate"}
|
108 |
+
| Type of perturbation.
|
109 |
+
| "delete": delete gene from rank value encoding
|
110 |
+
| "overexpress": move gene to front of rank value encoding
|
111 |
+
| *(TBA)* "inhibit": move gene to lower quartile of rank value encoding
|
112 |
+
| *(TBA)* "activate": move gene to higher quartile of rank value encoding
|
113 |
+
*(TBA)* perturb_rank_shift : None, {1,2,3}
|
114 |
+
| Number of quartiles by which to shift rank of gene.
|
115 |
+
| For example, if perturb_type="activate" and perturb_rank_shift=1:
|
116 |
+
| genes in 4th quartile will move to middle of 3rd quartile.
|
117 |
+
| genes in 3rd quartile will move to middle of 2nd quartile.
|
118 |
+
| genes in 2nd quartile will move to middle of 1st quartile.
|
119 |
+
| genes in 1st quartile will move to front of rank value encoding.
|
120 |
+
| For example, if perturb_type="inhibit" and perturb_rank_shift=2:
|
121 |
+
| genes in 1st quartile will move to middle of 3rd quartile.
|
122 |
+
| genes in 2nd quartile will move to middle of 4th quartile.
|
123 |
+
| genes in 3rd or 4th quartile will move to bottom of rank value encoding.
|
124 |
+
genes_to_perturb : "all", list
|
125 |
+
| Default is perturbing each gene detected in each cell in the dataset.
|
126 |
+
| Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
|
127 |
+
| If gene list is provided, then perturber will only test perturbing them all together
|
128 |
+
| (rather than testing each possible combination of the provided genes).
|
129 |
+
combos : {0,1}
|
130 |
+
| Whether to perturb genes individually (0) or in pairs (1).
|
131 |
+
anchor_gene : None, str
|
132 |
+
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
133 |
+
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
134 |
+
| anchor gene will be perturbed in combination with each other gene.
|
135 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
|
136 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
|
137 |
+
num_classes : int
|
138 |
+
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
139 |
+
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
140 |
+
emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
|
141 |
+
| Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
|
142 |
+
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
143 |
+
cell_emb_style : "mean_pool"
|
144 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
145 |
+
| Currently only option is mean pooling of gene embeddings for given cell.
|
146 |
+
filter_data : None, dict
|
147 |
+
| Default is to use all input data for in silico perturbation study.
|
148 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
149 |
+
cell_states_to_model : None, dict
|
150 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
151 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
152 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
153 |
+
| start_state: value in the state_key column that specifies the start state
|
154 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
155 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
156 |
+
| For example: {"state_key": "disease",
|
157 |
+
| "start_state": "dcm",
|
158 |
+
| "goal_state": "nf",
|
159 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
160 |
+
state_embs_dict : None, dict
|
161 |
+
| Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
|
162 |
+
| Dictionary with keys specifying each possible cell state to model.
|
163 |
+
| Values are target embedding positions as torch.tensor.
|
164 |
+
| For example: {"nf": emb_nf,
|
165 |
+
| "hcm": emb_hcm,
|
166 |
+
| "dcm": emb_dcm,
|
167 |
+
| "other1": emb_other1,
|
168 |
+
| "other2": emb_other2}
|
169 |
+
max_ncells : None, int
|
170 |
+
| Maximum number of cells to test.
|
171 |
+
| If None, will test all cells.
|
172 |
+
cell_inds_to_perturb : "all", list
|
173 |
+
| Default is perturbing each cell in the dataset.
|
174 |
+
| Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
175 |
+
| start_ind: the first index to perturb.
|
176 |
+
| end_ind: the last index to perturb (exclusive).
|
177 |
+
| Indices will be selected *after* the filter_data criteria and sorting.
|
178 |
+
| Useful for splitting extremely large datasets across separate GPUs.
|
179 |
+
emb_layer : {-1, 0}
|
180 |
+
| Embedding layer to use for quantification.
|
181 |
+
| 0: last layer (recommended for questions closely tied to model's training objective)
|
182 |
+
| -1: 2nd to last layer (recommended for questions requiring more general representations)
|
183 |
+
forward_batch_size : int
|
184 |
+
| Batch size for forward pass.
|
185 |
+
nproc : int
|
186 |
+
| Number of CPU processes to use.
|
187 |
+
token_dictionary_file : Path
|
188 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
189 |
+
clear_mem_ncells : int
|
190 |
+
| Clear memory every n cells.
|
191 |
+
"""
|
192 |
+
try:
|
193 |
+
set_start_method("spawn")
|
194 |
+
except RuntimeError:
|
195 |
+
pass
|
196 |
+
|
197 |
+
self.perturb_type = perturb_type
|
198 |
+
self.perturb_rank_shift = perturb_rank_shift
|
199 |
+
self.genes_to_perturb = genes_to_perturb
|
200 |
+
self.combos = combos
|
201 |
+
self.anchor_gene = anchor_gene
|
202 |
+
if self.genes_to_perturb == "all":
|
203 |
+
self.perturb_group = False
|
204 |
+
else:
|
205 |
+
self.perturb_group = True
|
206 |
+
if (self.anchor_gene is not None) or (self.combos != 0):
|
207 |
+
self.anchor_gene = None
|
208 |
+
self.combos = 0
|
209 |
+
logger.warning(
|
210 |
+
"anchor_gene set to None and combos set to 0. "
|
211 |
+
"If providing list of genes to perturb, "
|
212 |
+
"list of genes_to_perturb will be perturbed together, "
|
213 |
+
"without anchor gene or combinations."
|
214 |
+
)
|
215 |
+
self.model_type = model_type
|
216 |
+
self.num_classes = num_classes
|
217 |
+
self.emb_mode = emb_mode
|
218 |
+
self.cell_emb_style = cell_emb_style
|
219 |
+
self.filter_data = filter_data
|
220 |
+
self.cell_states_to_model = cell_states_to_model
|
221 |
+
self.state_embs_dict = state_embs_dict
|
222 |
+
self.max_ncells = max_ncells
|
223 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
224 |
+
self.emb_layer = emb_layer
|
225 |
+
self.forward_batch_size = forward_batch_size
|
226 |
+
self.nproc = nproc
|
227 |
+
self.token_dictionary_file = token_dictionary_file
|
228 |
+
self.clear_mem_ncells = clear_mem_ncells
|
229 |
+
|
230 |
+
self.validate_options()
|
231 |
+
|
232 |
+
# load token dictionary (Ensembl IDs:token)
|
233 |
+
if self.token_dictionary_file is None:
|
234 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
235 |
+
with open(token_dictionary_file, "rb") as f:
|
236 |
+
self.gene_token_dict = pickle.load(f)
|
237 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
238 |
+
|
239 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
240 |
+
self.cls_token_id = self.gene_token_dict.get("<cls>")
|
241 |
+
self.eos_token_id = self.gene_token_dict.get("<eos>")
|
242 |
+
|
243 |
+
# Identify if special token is present in the token dictionary
|
244 |
+
if (self.cls_token_id is not None) and (self.eos_token_id is not None):
|
245 |
+
self.special_token = True
|
246 |
+
else:
|
247 |
+
if "cls" in self.emb_mode:
|
248 |
+
logger.error(
|
249 |
+
f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
|
250 |
+
)
|
251 |
+
raise
|
252 |
+
self.special_token = False
|
253 |
+
|
254 |
+
if self.anchor_gene is None:
|
255 |
+
self.anchor_token = None
|
256 |
+
else:
|
257 |
+
try:
|
258 |
+
self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
|
259 |
+
except KeyError:
|
260 |
+
logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
|
261 |
+
raise
|
262 |
+
|
263 |
+
if self.genes_to_perturb == "all":
|
264 |
+
self.tokens_to_perturb = "all"
|
265 |
+
else:
|
266 |
+
missing_genes = [
|
267 |
+
gene
|
268 |
+
for gene in self.genes_to_perturb
|
269 |
+
if gene not in self.gene_token_dict.keys()
|
270 |
+
]
|
271 |
+
if len(missing_genes) == len(self.genes_to_perturb):
|
272 |
+
logger.error(
|
273 |
+
"None of the provided genes to perturb are in token dictionary."
|
274 |
+
)
|
275 |
+
raise
|
276 |
+
elif len(missing_genes) > 0:
|
277 |
+
logger.warning(
|
278 |
+
f"Genes to perturb {missing_genes} are not in token dictionary."
|
279 |
+
)
|
280 |
+
self.tokens_to_perturb = [
|
281 |
+
self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
|
282 |
+
]
|
283 |
+
|
284 |
+
def validate_options(self):
|
285 |
+
# first disallow options under development
|
286 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
287 |
+
logger.error(
|
288 |
+
"In silico inhibition and activation currently under development. "
|
289 |
+
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
290 |
+
)
|
291 |
+
raise
|
292 |
+
if (self.combos > 0) and (self.anchor_gene is None):
|
293 |
+
logger.error(
|
294 |
+
"Combination perturbation without anchor gene is currently under development. "
|
295 |
+
"Currently, must provide anchor gene for combination perturbation."
|
296 |
+
)
|
297 |
+
raise
|
298 |
+
|
299 |
+
# confirm arguments are within valid options and compatible with each other
|
300 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
301 |
+
attr_value = self.__dict__[attr_name]
|
302 |
+
if type(attr_value) not in {list, dict}:
|
303 |
+
if attr_value in valid_options:
|
304 |
+
continue
|
305 |
+
if attr_name in ["anchor_gene"]:
|
306 |
+
if type(attr_name) in {str}:
|
307 |
+
continue
|
308 |
+
valid_type = False
|
309 |
+
for option in valid_options:
|
310 |
+
if (option in [bool, int, list, dict, str]) and isinstance(
|
311 |
+
attr_value, option
|
312 |
+
):
|
313 |
+
valid_type = True
|
314 |
+
break
|
315 |
+
if valid_type:
|
316 |
+
continue
|
317 |
+
logger.error(
|
318 |
+
f"Invalid option for {attr_name}. "
|
319 |
+
f"Valid options for {attr_name}: {valid_options}"
|
320 |
+
)
|
321 |
+
raise
|
322 |
+
|
323 |
+
if self.perturb_type in ["delete", "overexpress"]:
|
324 |
+
if self.perturb_rank_shift is not None:
|
325 |
+
if self.perturb_type == "delete":
|
326 |
+
logger.warning(
|
327 |
+
"perturb_rank_shift set to None. "
|
328 |
+
"If perturb type is delete then gene is deleted entirely "
|
329 |
+
"rather than shifted by quartile"
|
330 |
+
)
|
331 |
+
elif self.perturb_type == "overexpress":
|
332 |
+
logger.warning(
|
333 |
+
"perturb_rank_shift set to None. "
|
334 |
+
"If perturb type is overexpress then gene is moved to front "
|
335 |
+
"of rank value encoding rather than shifted by quartile"
|
336 |
+
)
|
337 |
+
self.perturb_rank_shift = None
|
338 |
+
|
339 |
+
if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
|
340 |
+
self.emb_mode = "cell"
|
341 |
+
logger.warning(
|
342 |
+
"emb_mode set to 'cell'. "
|
343 |
+
"Currently, analysis with anchor gene "
|
344 |
+
"only outputs effect on cell embeddings."
|
345 |
+
)
|
346 |
+
|
347 |
+
if self.cell_states_to_model is not None:
|
348 |
+
pu.validate_cell_states_to_model(self.cell_states_to_model)
|
349 |
+
|
350 |
+
if self.anchor_gene is not None:
|
351 |
+
self.anchor_gene = None
|
352 |
+
logger.warning(
|
353 |
+
"anchor_gene set to None. "
|
354 |
+
"Currently, anchor gene not available "
|
355 |
+
"when modeling multiple cell states."
|
356 |
+
)
|
357 |
+
|
358 |
+
if self.state_embs_dict is None:
|
359 |
+
logger.error(
|
360 |
+
"state_embs_dict must be provided for mode with cell_states_to_model. "
|
361 |
+
"Format is dictionary with keys specifying each possible cell state to model. "
|
362 |
+
"Values are target embedding positions as torch.tensor."
|
363 |
+
)
|
364 |
+
raise
|
365 |
+
|
366 |
+
for state_emb in self.state_embs_dict.values():
|
367 |
+
if not torch.is_tensor(state_emb):
|
368 |
+
logger.error(
|
369 |
+
"state_embs_dict must be dictionary with values being torch.tensor."
|
370 |
+
)
|
371 |
+
raise
|
372 |
+
|
373 |
+
keys_absent = []
|
374 |
+
for k, v in self.cell_states_to_model.items():
|
375 |
+
if (k == "start_state") or (k == "goal_state"):
|
376 |
+
if v not in self.state_embs_dict.keys():
|
377 |
+
keys_absent.append(v)
|
378 |
+
if k == "alt_states":
|
379 |
+
for state in v:
|
380 |
+
if state not in self.state_embs_dict.keys():
|
381 |
+
keys_absent.append(state)
|
382 |
+
if len(keys_absent) > 0:
|
383 |
+
logger.error(
|
384 |
+
"Each start_state, goal_state, and alt_states in cell_states_to_model "
|
385 |
+
"must be a key in state_embs_dict with the value being "
|
386 |
+
"the state's embedding position as torch.tensor. "
|
387 |
+
f"Missing keys: {keys_absent}"
|
388 |
+
)
|
389 |
+
raise
|
390 |
+
|
391 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
392 |
+
if self.perturb_rank_shift is None:
|
393 |
+
logger.error(
|
394 |
+
"If perturb_type is inhibit or activate then "
|
395 |
+
"quartile to shift by must be specified."
|
396 |
+
)
|
397 |
+
raise
|
398 |
+
|
399 |
+
if self.filter_data is not None:
|
400 |
+
for key, value in self.filter_data.items():
|
401 |
+
if not isinstance(value, list):
|
402 |
+
self.filter_data[key] = [value]
|
403 |
+
logger.warning(
|
404 |
+
"Values in filter_data dict must be lists. "
|
405 |
+
f"Changing {key} value to list ([{value}])."
|
406 |
+
)
|
407 |
+
|
408 |
+
if self.cell_inds_to_perturb != "all":
|
409 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
410 |
+
logger.error(
|
411 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
412 |
+
)
|
413 |
+
raise
|
414 |
+
if (
|
415 |
+
self.cell_inds_to_perturb["start"] < 0
|
416 |
+
or self.cell_inds_to_perturb["end"] < 0
|
417 |
+
):
|
418 |
+
logger.error("cell_inds_to_perturb must be positive.")
|
419 |
+
raise
|
420 |
+
|
421 |
+
def perturb_data(
|
422 |
+
self, model_directory, input_data_file, output_directory, output_prefix
|
423 |
+
):
|
424 |
+
"""
|
425 |
+
Perturb genes in input data and save as results in output_directory.
|
426 |
+
|
427 |
+
**Parameters:**
|
428 |
+
|
429 |
+
model_directory : Path
|
430 |
+
| Path to directory containing model
|
431 |
+
input_data_file : Path
|
432 |
+
| Path to directory containing .dataset inputs
|
433 |
+
output_directory : Path
|
434 |
+
| Path to directory where perturbation data will be saved as batched pickle files
|
435 |
+
output_prefix : str
|
436 |
+
| Prefix for output files
|
437 |
+
"""
|
438 |
+
|
439 |
+
### format output path ###
|
440 |
+
output_path_prefix = os.path.join(
|
441 |
+
output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
|
442 |
+
)
|
443 |
+
|
444 |
+
### load model and define parameters ###
|
445 |
+
model = pu.load_model(
|
446 |
+
self.model_type, self.num_classes, model_directory, mode="eval"
|
447 |
+
)
|
448 |
+
self.max_len = pu.get_model_input_size(model)
|
449 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
450 |
+
|
451 |
+
### filter input data ###
|
452 |
+
# general filtering of input data based on filter_data argument
|
453 |
+
filtered_input_data = pu.load_and_filter(
|
454 |
+
self.filter_data, self.nproc, input_data_file
|
455 |
+
)
|
456 |
+
|
457 |
+
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
458 |
+
if self.special_token:
|
459 |
+
if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
|
460 |
+
"cls" not in self.emb_mode
|
461 |
+
):
|
462 |
+
logger.error(
|
463 |
+
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
464 |
+
)
|
465 |
+
raise
|
466 |
+
if "cls" in self.emb_mode:
|
467 |
+
if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
|
468 |
+
filtered_input_data["input_ids"][0][-1] != self.eos_token_id
|
469 |
+
):
|
470 |
+
logger.error(
|
471 |
+
"Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
|
472 |
+
)
|
473 |
+
raise
|
474 |
+
|
475 |
+
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
476 |
+
|
477 |
+
if self.perturb_group is True:
|
478 |
+
if (self.special_token) and ("cls" in self.emb_mode):
|
479 |
+
self.isp_perturb_set_special(
|
480 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
self.isp_perturb_set(
|
484 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
if (self.special_token) and ("cls" in self.emb_mode):
|
488 |
+
self.isp_perturb_all_special(
|
489 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
490 |
+
)
|
491 |
+
else:
|
492 |
+
self.isp_perturb_all(
|
493 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
494 |
+
)
|
495 |
+
|
496 |
+
def apply_additional_filters(self, filtered_input_data):
|
497 |
+
# additional filtering of input data dependent on isp mode
|
498 |
+
if self.cell_states_to_model is not None:
|
499 |
+
# filter for cells with start_state and log result
|
500 |
+
filtered_input_data = pu.filter_data_by_start_state(
|
501 |
+
filtered_input_data, self.cell_states_to_model, self.nproc
|
502 |
+
)
|
503 |
+
|
504 |
+
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
505 |
+
# filter for cells with tokens_to_perturb and log result
|
506 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
507 |
+
filtered_input_data,
|
508 |
+
self.tokens_to_perturb,
|
509 |
+
self.nproc,
|
510 |
+
"genes_to_perturb",
|
511 |
+
)
|
512 |
+
|
513 |
+
if self.anchor_token is not None:
|
514 |
+
# filter for cells with anchor gene and log result
|
515 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
516 |
+
filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
|
517 |
+
)
|
518 |
+
|
519 |
+
# downsample and sort largest to smallest to encounter memory constraints earlier
|
520 |
+
filtered_input_data = pu.downsample_and_sort(
|
521 |
+
filtered_input_data, self.max_ncells
|
522 |
+
)
|
523 |
+
|
524 |
+
# slice dataset if cells_inds_to_perturb is not "all"
|
525 |
+
if self.cell_inds_to_perturb != "all":
|
526 |
+
filtered_input_data = pu.slice_by_inds_to_perturb(
|
527 |
+
filtered_input_data, self.cell_inds_to_perturb
|
528 |
+
)
|
529 |
+
|
530 |
+
return filtered_input_data
|
531 |
+
|
532 |
+
def isp_perturb_set(
|
533 |
+
self,
|
534 |
+
model,
|
535 |
+
filtered_input_data: Dataset,
|
536 |
+
layer_to_quant: int,
|
537 |
+
output_path_prefix: str,
|
538 |
+
):
|
539 |
+
def make_group_perturbation_batch(example):
|
540 |
+
example_input_ids = example["input_ids"]
|
541 |
+
example["tokens_to_perturb"] = self.tokens_to_perturb
|
542 |
+
indices_to_perturb = [
|
543 |
+
example_input_ids.index(token) if token in example_input_ids else None
|
544 |
+
for token in self.tokens_to_perturb
|
545 |
+
]
|
546 |
+
indices_to_perturb = [
|
547 |
+
item for item in indices_to_perturb if item is not None
|
548 |
+
]
|
549 |
+
if len(indices_to_perturb) > 0:
|
550 |
+
example["perturb_index"] = indices_to_perturb
|
551 |
+
else:
|
552 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
553 |
+
example["perturb_index"] = [-100]
|
554 |
+
if self.perturb_type == "delete":
|
555 |
+
example = pu.delete_indices(example)
|
556 |
+
elif self.perturb_type == "overexpress":
|
557 |
+
example = pu.overexpress_tokens(
|
558 |
+
example, self.max_len, self.special_token
|
559 |
+
)
|
560 |
+
example["n_overflow"] = pu.calc_n_overflow(
|
561 |
+
self.max_len,
|
562 |
+
example["length"],
|
563 |
+
self.tokens_to_perturb,
|
564 |
+
indices_to_perturb,
|
565 |
+
)
|
566 |
+
return example
|
567 |
+
|
568 |
+
total_batch_length = len(filtered_input_data)
|
569 |
+
if self.cell_states_to_model is None:
|
570 |
+
cos_sims_dict = defaultdict(list)
|
571 |
+
else:
|
572 |
+
cos_sims_dict = {
|
573 |
+
state: defaultdict(list)
|
574 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
575 |
+
}
|
576 |
+
|
577 |
+
perturbed_data = filtered_input_data.map(
|
578 |
+
make_group_perturbation_batch, num_proc=self.nproc
|
579 |
+
)
|
580 |
+
|
581 |
+
if self.perturb_type == "overexpress":
|
582 |
+
filtered_input_data = filtered_input_data.add_column(
|
583 |
+
"n_overflow", perturbed_data["n_overflow"]
|
584 |
+
)
|
585 |
+
# remove overflow genes from original data so that embeddings are comparable
|
586 |
+
# i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
|
587 |
+
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
588 |
+
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
589 |
+
# rather than only adding 2048)
|
590 |
+
filtered_input_data = filtered_input_data.map(
|
591 |
+
pu.truncate_by_n_overflow, num_proc=self.nproc
|
592 |
+
)
|
593 |
+
|
594 |
+
if self.emb_mode == "cell_and_gene":
|
595 |
+
stored_gene_embs_dict = defaultdict(list)
|
596 |
+
|
597 |
+
# iterate through batches
|
598 |
+
for i in trange(0, total_batch_length, self.forward_batch_size):
|
599 |
+
max_range = min(i + self.forward_batch_size, total_batch_length)
|
600 |
+
inds_select = [i for i in range(i, max_range)]
|
601 |
+
|
602 |
+
minibatch = filtered_input_data.select(inds_select)
|
603 |
+
perturbation_batch = perturbed_data.select(inds_select)
|
604 |
+
|
605 |
+
if self.cell_emb_style == "mean_pool":
|
606 |
+
full_original_emb = get_embs(
|
607 |
+
model,
|
608 |
+
minibatch,
|
609 |
+
"gene",
|
610 |
+
layer_to_quant,
|
611 |
+
self.pad_token_id,
|
612 |
+
self.forward_batch_size,
|
613 |
+
token_gene_dict=self.token_gene_dict,
|
614 |
+
summary_stat=None,
|
615 |
+
silent=True,
|
616 |
+
)
|
617 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
618 |
+
# remove indices that were perturbed
|
619 |
+
original_emb = pu.remove_perturbed_indices_set(
|
620 |
+
full_original_emb,
|
621 |
+
self.perturb_type,
|
622 |
+
indices_to_perturb,
|
623 |
+
self.tokens_to_perturb,
|
624 |
+
minibatch["length"],
|
625 |
+
)
|
626 |
+
full_perturbation_emb = get_embs(
|
627 |
+
model,
|
628 |
+
perturbation_batch,
|
629 |
+
"gene",
|
630 |
+
layer_to_quant,
|
631 |
+
self.pad_token_id,
|
632 |
+
self.forward_batch_size,
|
633 |
+
token_gene_dict=self.token_gene_dict,
|
634 |
+
summary_stat=None,
|
635 |
+
silent=True,
|
636 |
+
)
|
637 |
+
|
638 |
+
# remove overexpressed genes
|
639 |
+
if self.perturb_type == "overexpress":
|
640 |
+
perturbation_emb = full_perturbation_emb[
|
641 |
+
:, len(self.tokens_to_perturb) :, :
|
642 |
+
]
|
643 |
+
|
644 |
+
elif self.perturb_type == "delete":
|
645 |
+
perturbation_emb = full_perturbation_emb[
|
646 |
+
:, : max(perturbation_batch["length"]), :
|
647 |
+
]
|
648 |
+
|
649 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
650 |
+
|
651 |
+
# if no goal states, the cosine similarties are the mean of gene cosine similarities
|
652 |
+
if (
|
653 |
+
self.cell_states_to_model is None
|
654 |
+
or self.emb_mode == "cell_and_gene"
|
655 |
+
):
|
656 |
+
gene_cos_sims = pu.quant_cos_sims(
|
657 |
+
perturbation_emb,
|
658 |
+
original_emb,
|
659 |
+
self.cell_states_to_model,
|
660 |
+
self.state_embs_dict,
|
661 |
+
emb_mode="gene",
|
662 |
+
)
|
663 |
+
|
664 |
+
# if there are goal states, the cosine similarities are the cell cosine similarities
|
665 |
+
if self.cell_states_to_model is not None:
|
666 |
+
original_cell_emb = pu.mean_nonpadding_embs(
|
667 |
+
full_original_emb,
|
668 |
+
torch.tensor(minibatch["length"], device="cuda"),
|
669 |
+
dim=1,
|
670 |
+
)
|
671 |
+
perturbation_cell_emb = pu.mean_nonpadding_embs(
|
672 |
+
full_perturbation_emb,
|
673 |
+
torch.tensor(perturbation_batch["length"], device="cuda"),
|
674 |
+
dim=1,
|
675 |
+
)
|
676 |
+
cell_cos_sims = pu.quant_cos_sims(
|
677 |
+
perturbation_cell_emb,
|
678 |
+
original_cell_emb,
|
679 |
+
self.cell_states_to_model,
|
680 |
+
self.state_embs_dict,
|
681 |
+
emb_mode="cell",
|
682 |
+
)
|
683 |
+
|
684 |
+
# get cosine similarities in gene embeddings
|
685 |
+
# if getting gene embeddings, need gene names
|
686 |
+
if self.emb_mode == "cell_and_gene":
|
687 |
+
gene_list = minibatch["input_ids"]
|
688 |
+
# need to truncate gene_list
|
689 |
+
gene_list = [
|
690 |
+
[g for g in genes if g not in self.tokens_to_perturb][
|
691 |
+
:n_perturbation_genes
|
692 |
+
]
|
693 |
+
for genes in gene_list
|
694 |
+
]
|
695 |
+
|
696 |
+
for cell_i, genes in enumerate(gene_list):
|
697 |
+
for gene_j, affected_gene in enumerate(genes):
|
698 |
+
if len(self.genes_to_perturb) > 1:
|
699 |
+
tokens_to_perturb = tuple(self.tokens_to_perturb)
|
700 |
+
else:
|
701 |
+
tokens_to_perturb = self.tokens_to_perturb[0]
|
702 |
+
|
703 |
+
# fill in the gene cosine similarities
|
704 |
+
try:
|
705 |
+
stored_gene_embs_dict[
|
706 |
+
(tokens_to_perturb, affected_gene)
|
707 |
+
].append(gene_cos_sims[cell_i, gene_j].item())
|
708 |
+
except KeyError:
|
709 |
+
stored_gene_embs_dict[
|
710 |
+
(tokens_to_perturb, affected_gene)
|
711 |
+
] = gene_cos_sims[cell_i, gene_j].item()
|
712 |
+
else:
|
713 |
+
gene_list = None
|
714 |
+
|
715 |
+
if self.cell_states_to_model is None:
|
716 |
+
# calculate the mean of the gene cosine similarities for cell shift
|
717 |
+
# tensor of nonpadding lengths for each cell
|
718 |
+
if self.perturb_type == "overexpress":
|
719 |
+
# subtract number of genes that were overexpressed
|
720 |
+
# since they are removed before getting cos sims
|
721 |
+
n_overexpressed = len(self.tokens_to_perturb)
|
722 |
+
nonpadding_lens = [
|
723 |
+
x - n_overexpressed for x in perturbation_batch["length"]
|
724 |
+
]
|
725 |
+
else:
|
726 |
+
nonpadding_lens = perturbation_batch["length"]
|
727 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
728 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
729 |
+
)
|
730 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
731 |
+
cos_sims_dict,
|
732 |
+
cos_sims_data,
|
733 |
+
gene_list,
|
734 |
+
)
|
735 |
+
else:
|
736 |
+
cos_sims_data = cell_cos_sims
|
737 |
+
for state in cos_sims_dict.keys():
|
738 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
739 |
+
cos_sims_dict[state],
|
740 |
+
cos_sims_data[state],
|
741 |
+
gene_list,
|
742 |
+
)
|
743 |
+
del minibatch
|
744 |
+
del perturbation_batch
|
745 |
+
del original_emb
|
746 |
+
del perturbation_emb
|
747 |
+
del cos_sims_data
|
748 |
+
|
749 |
+
torch.cuda.empty_cache()
|
750 |
+
|
751 |
+
pu.write_perturbation_dictionary(
|
752 |
+
cos_sims_dict,
|
753 |
+
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
754 |
+
)
|
755 |
+
|
756 |
+
if self.emb_mode == "cell_and_gene":
|
757 |
+
pu.write_perturbation_dictionary(
|
758 |
+
stored_gene_embs_dict,
|
759 |
+
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
760 |
+
)
|
761 |
+
|
762 |
+
def isp_perturb_set_special(
|
763 |
+
self,
|
764 |
+
model,
|
765 |
+
filtered_input_data: Dataset,
|
766 |
+
layer_to_quant: int,
|
767 |
+
output_path_prefix: str,
|
768 |
+
):
|
769 |
+
def make_group_perturbation_batch(example):
|
770 |
+
example_input_ids = example["input_ids"]
|
771 |
+
example["tokens_to_perturb"] = self.tokens_to_perturb
|
772 |
+
indices_to_perturb = [
|
773 |
+
example_input_ids.index(token) if token in example_input_ids else None
|
774 |
+
for token in self.tokens_to_perturb
|
775 |
+
]
|
776 |
+
indices_to_perturb = [
|
777 |
+
item for item in indices_to_perturb if item is not None
|
778 |
+
]
|
779 |
+
if len(indices_to_perturb) > 0:
|
780 |
+
example["perturb_index"] = indices_to_perturb
|
781 |
+
else:
|
782 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
783 |
+
example["perturb_index"] = [-100]
|
784 |
+
if self.perturb_type == "delete":
|
785 |
+
example = pu.delete_indices(example)
|
786 |
+
elif self.perturb_type == "overexpress":
|
787 |
+
example = pu.overexpress_tokens(
|
788 |
+
example, self.max_len, self.special_token
|
789 |
+
)
|
790 |
+
example["n_overflow"] = pu.calc_n_overflow(
|
791 |
+
self.max_len,
|
792 |
+
example["length"],
|
793 |
+
self.tokens_to_perturb,
|
794 |
+
indices_to_perturb,
|
795 |
+
)
|
796 |
+
return example
|
797 |
+
|
798 |
+
total_batch_length = len(filtered_input_data)
|
799 |
+
|
800 |
+
|
801 |
+
if self.cell_states_to_model is None:
|
802 |
+
cos_sims_dict = defaultdict(list)
|
803 |
+
else:
|
804 |
+
cos_sims_dict = {
|
805 |
+
state: defaultdict(list)
|
806 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
807 |
+
}
|
808 |
+
|
809 |
+
perturbed_data = filtered_input_data.map(
|
810 |
+
make_group_perturbation_batch, num_proc=self.nproc
|
811 |
+
)
|
812 |
+
|
813 |
+
if self.perturb_type == "overexpress":
|
814 |
+
filtered_input_data = filtered_input_data.add_column(
|
815 |
+
"n_overflow", perturbed_data["n_overflow"]
|
816 |
+
)
|
817 |
+
filtered_input_data = filtered_input_data.map(
|
818 |
+
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
819 |
+
)
|
820 |
+
|
821 |
+
if self.emb_mode == "cls_and_gene":
|
822 |
+
stored_gene_embs_dict = defaultdict(list)
|
823 |
+
|
824 |
+
# iterate through batches
|
825 |
+
for i in trange(0, total_batch_length, self.forward_batch_size):
|
826 |
+
max_range = min(i + self.forward_batch_size, total_batch_length)
|
827 |
+
inds_select = [i for i in range(i, max_range)]
|
828 |
+
|
829 |
+
minibatch = filtered_input_data.select(inds_select)
|
830 |
+
perturbation_batch = perturbed_data.select(inds_select)
|
831 |
+
|
832 |
+
##### CLS Embedding Mode #####
|
833 |
+
if self.emb_mode == "cls":
|
834 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
835 |
+
|
836 |
+
original_cls_emb = get_embs(
|
837 |
+
model,
|
838 |
+
minibatch,
|
839 |
+
"cls",
|
840 |
+
layer_to_quant,
|
841 |
+
self.pad_token_id,
|
842 |
+
self.forward_batch_size,
|
843 |
+
token_gene_dict=self.token_gene_dict,
|
844 |
+
summary_stat=None,
|
845 |
+
silent=True,
|
846 |
+
)
|
847 |
+
|
848 |
+
perturbation_cls_emb = get_embs(
|
849 |
+
model,
|
850 |
+
perturbation_batch,
|
851 |
+
"cls",
|
852 |
+
layer_to_quant,
|
853 |
+
self.pad_token_id,
|
854 |
+
self.forward_batch_size,
|
855 |
+
token_gene_dict=self.token_gene_dict,
|
856 |
+
summary_stat=None,
|
857 |
+
silent=True,
|
858 |
+
)
|
859 |
+
|
860 |
+
# Calculate the cosine similarities
|
861 |
+
cls_cos_sims = pu.quant_cos_sims(
|
862 |
+
perturbation_cls_emb,
|
863 |
+
original_cls_emb,
|
864 |
+
self.cell_states_to_model,
|
865 |
+
self.state_embs_dict,
|
866 |
+
emb_mode="cell",
|
867 |
+
)
|
868 |
+
|
869 |
+
# Update perturbation dictionary
|
870 |
+
if self.cell_states_to_model is None:
|
871 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
872 |
+
cos_sims_dict,
|
873 |
+
cls_cos_sims,
|
874 |
+
gene_list=None,
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
for state in cos_sims_dict.keys():
|
878 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
879 |
+
cos_sims_dict[state],
|
880 |
+
cls_cos_sims[state],
|
881 |
+
gene_list=None,
|
882 |
+
)
|
883 |
+
|
884 |
+
##### CLS and Gene Embedding Mode #####
|
885 |
+
elif self.emb_mode == "cls_and_gene":
|
886 |
+
full_original_emb = get_embs(
|
887 |
+
model,
|
888 |
+
minibatch,
|
889 |
+
"gene",
|
890 |
+
layer_to_quant,
|
891 |
+
self.pad_token_id,
|
892 |
+
self.forward_batch_size,
|
893 |
+
self.token_gene_dict,
|
894 |
+
summary_stat=None,
|
895 |
+
silent=True,
|
896 |
+
)
|
897 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
898 |
+
|
899 |
+
# remove indices that were perturbed
|
900 |
+
original_emb = pu.remove_perturbed_indices_set(
|
901 |
+
full_original_emb,
|
902 |
+
self.perturb_type,
|
903 |
+
indices_to_perturb,
|
904 |
+
self.tokens_to_perturb,
|
905 |
+
minibatch["length"],
|
906 |
+
)
|
907 |
+
|
908 |
+
full_perturbation_emb = get_embs(
|
909 |
+
model,
|
910 |
+
perturbation_batch,
|
911 |
+
"gene",
|
912 |
+
layer_to_quant,
|
913 |
+
self.pad_token_id,
|
914 |
+
self.forward_batch_size,
|
915 |
+
self.token_gene_dict,
|
916 |
+
summary_stat=None,
|
917 |
+
silent=True,
|
918 |
+
)
|
919 |
+
|
920 |
+
# remove special tokens and padding
|
921 |
+
original_emb = original_emb[:, 1:-1, :]
|
922 |
+
if self.perturb_type == "overexpress":
|
923 |
+
perturbation_emb = full_perturbation_emb[
|
924 |
+
:, 1 + len(self.tokens_to_perturb) : -1, :
|
925 |
+
]
|
926 |
+
elif self.perturb_type == "delete":
|
927 |
+
perturbation_emb = full_perturbation_emb[
|
928 |
+
:, 1 : max(perturbation_batch["length"]) - 1, :
|
929 |
+
]
|
930 |
+
|
931 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
932 |
+
|
933 |
+
# truncate the original embedding as necessary
|
934 |
+
if self.perturb_type == "overexpress":
|
935 |
+
def calc_perturbation_length(ids):
|
936 |
+
if ids == [-100]:
|
937 |
+
return 0
|
938 |
+
else:
|
939 |
+
return len(ids)
|
940 |
+
|
941 |
+
max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
|
942 |
+
|
943 |
+
max_n_overflow = max(minibatch["n_overflow"])
|
944 |
+
if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
|
945 |
+
original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
|
946 |
+
elif perturbation_emb.size()[1] < original_emb.size()[1]:
|
947 |
+
original_emb = original_emb[:, 0:max_tensor_size, :]
|
948 |
+
|
949 |
+
gene_cos_sims = pu.quant_cos_sims(
|
950 |
+
perturbation_emb,
|
951 |
+
original_emb,
|
952 |
+
self.cell_states_to_model,
|
953 |
+
self.state_embs_dict,
|
954 |
+
emb_mode="gene",
|
955 |
+
)
|
956 |
+
|
957 |
+
# get cls emb
|
958 |
+
original_cls_emb = full_original_emb[:, 0, :]
|
959 |
+
perturbation_cls_emb = full_perturbation_emb[:, 0, :]
|
960 |
+
|
961 |
+
cls_cos_sims = pu.quant_cos_sims(
|
962 |
+
perturbation_cls_emb,
|
963 |
+
original_cls_emb,
|
964 |
+
self.cell_states_to_model,
|
965 |
+
self.state_embs_dict,
|
966 |
+
emb_mode="cell",
|
967 |
+
)
|
968 |
+
|
969 |
+
# get cosine similarities in gene embeddings
|
970 |
+
# since getting gene embeddings, need gene names
|
971 |
+
|
972 |
+
gene_list = minibatch["input_ids"]
|
973 |
+
# need to truncate gene_list
|
974 |
+
genes_to_exclude = self.tokens_to_perturb + [
|
975 |
+
self.cls_token_id,
|
976 |
+
self.eos_token_id,
|
977 |
+
]
|
978 |
+
gene_list = [
|
979 |
+
[g for g in genes if g not in genes_to_exclude][
|
980 |
+
:n_perturbation_genes
|
981 |
+
]
|
982 |
+
for genes in gene_list
|
983 |
+
]
|
984 |
+
|
985 |
+
for cell_i, genes in enumerate(gene_list):
|
986 |
+
for gene_j, affected_gene in enumerate(genes):
|
987 |
+
if len(self.genes_to_perturb) > 1:
|
988 |
+
tokens_to_perturb = tuple(self.tokens_to_perturb)
|
989 |
+
else:
|
990 |
+
tokens_to_perturb = self.tokens_to_perturb[0]
|
991 |
+
|
992 |
+
# fill in the gene cosine similarities
|
993 |
+
try:
|
994 |
+
stored_gene_embs_dict[
|
995 |
+
(tokens_to_perturb, affected_gene)
|
996 |
+
].append(gene_cos_sims[cell_i, gene_j].item())
|
997 |
+
except KeyError:
|
998 |
+
stored_gene_embs_dict[
|
999 |
+
(tokens_to_perturb, affected_gene)
|
1000 |
+
] = gene_cos_sims[cell_i, gene_j].item()
|
1001 |
+
|
1002 |
+
if self.cell_states_to_model is None:
|
1003 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
1004 |
+
cos_sims_dict,
|
1005 |
+
cls_cos_sims,
|
1006 |
+
gene_list=None,
|
1007 |
+
)
|
1008 |
+
else:
|
1009 |
+
for state in cos_sims_dict.keys():
|
1010 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1011 |
+
cos_sims_dict[state],
|
1012 |
+
cls_cos_sims[state],
|
1013 |
+
gene_list=None,
|
1014 |
+
)
|
1015 |
+
del full_original_emb
|
1016 |
+
del original_emb
|
1017 |
+
del full_perturbation_emb
|
1018 |
+
del perturbation_emb
|
1019 |
+
del gene_cos_sims
|
1020 |
+
|
1021 |
+
del original_cls_emb
|
1022 |
+
del perturbation_cls_emb
|
1023 |
+
del cls_cos_sims
|
1024 |
+
del minibatch
|
1025 |
+
del perturbation_batch
|
1026 |
+
|
1027 |
+
torch.cuda.empty_cache()
|
1028 |
+
|
1029 |
+
pu.write_perturbation_dictionary(
|
1030 |
+
cos_sims_dict,
|
1031 |
+
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
if self.emb_mode == "cls_and_gene":
|
1035 |
+
pu.write_perturbation_dictionary(
|
1036 |
+
stored_gene_embs_dict,
|
1037 |
+
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
def isp_perturb_all(
|
1041 |
+
self,
|
1042 |
+
model,
|
1043 |
+
filtered_input_data: Dataset,
|
1044 |
+
layer_to_quant: int,
|
1045 |
+
output_path_prefix: str,
|
1046 |
+
):
|
1047 |
+
pickle_batch = -1
|
1048 |
+
if self.cell_states_to_model is None:
|
1049 |
+
cos_sims_dict = defaultdict(list)
|
1050 |
+
else:
|
1051 |
+
cos_sims_dict = {
|
1052 |
+
state: defaultdict(list)
|
1053 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
1054 |
+
}
|
1055 |
+
|
1056 |
+
if self.emb_mode == "cell_and_gene":
|
1057 |
+
stored_gene_embs_dict = defaultdict(list)
|
1058 |
+
|
1059 |
+
num_inds_perturbed = 1 + self.combos
|
1060 |
+
for h in trange(len(filtered_input_data)):
|
1061 |
+
example_cell = filtered_input_data.select([h])
|
1062 |
+
full_original_emb = get_embs(
|
1063 |
+
model,
|
1064 |
+
example_cell,
|
1065 |
+
"gene",
|
1066 |
+
layer_to_quant,
|
1067 |
+
self.pad_token_id,
|
1068 |
+
self.forward_batch_size,
|
1069 |
+
self.token_gene_dict,
|
1070 |
+
summary_stat=None,
|
1071 |
+
silent=True,
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
if self.cell_states_to_model is not None:
|
1075 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1076 |
+
full_original_emb, "mean_pool"
|
1077 |
+
)
|
1078 |
+
|
1079 |
+
# gene_list is used to assign cos sims back to genes
|
1080 |
+
gene_list = example_cell["input_ids"][0][:]
|
1081 |
+
# need to remove the anchor gene
|
1082 |
+
if self.anchor_token is not None:
|
1083 |
+
for token in self.anchor_token:
|
1084 |
+
gene_list.remove(token)
|
1085 |
+
# index 0 is not overexpressed so remove
|
1086 |
+
if self.perturb_type == "overexpress":
|
1087 |
+
gene_list = gene_list[num_inds_perturbed:]
|
1088 |
+
# remove perturbed index for gene list dict
|
1089 |
+
perturbed_gene_dict = {
|
1090 |
+
gene: gene_list[:i] + gene_list[i + 1 :]
|
1091 |
+
for i, gene in enumerate(gene_list)
|
1092 |
+
}
|
1093 |
+
|
1094 |
+
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
1095 |
+
example_cell,
|
1096 |
+
self.perturb_type,
|
1097 |
+
self.tokens_to_perturb,
|
1098 |
+
self.anchor_token,
|
1099 |
+
self.combos,
|
1100 |
+
self.nproc,
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
ispall_total_batch_length = len(perturbation_batch)
|
1104 |
+
for i in trange(
|
1105 |
+
0, ispall_total_batch_length, self.forward_batch_size, leave=False
|
1106 |
+
):
|
1107 |
+
ispall_max_range = min(
|
1108 |
+
i + self.forward_batch_size, ispall_total_batch_length
|
1109 |
+
)
|
1110 |
+
perturbation_minibatch = perturbation_batch.select(
|
1111 |
+
[i for i in range(i, ispall_max_range)]
|
1112 |
+
)
|
1113 |
+
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1114 |
+
gene_list_mini = gene_list[
|
1115 |
+
i:ispall_max_range
|
1116 |
+
] # only perturbed genes from this minibatch
|
1117 |
+
|
1118 |
+
full_perturbation_emb = get_embs(
|
1119 |
+
model,
|
1120 |
+
perturbation_minibatch,
|
1121 |
+
"gene",
|
1122 |
+
layer_to_quant,
|
1123 |
+
self.pad_token_id,
|
1124 |
+
self.forward_batch_size,
|
1125 |
+
self.token_gene_dict,
|
1126 |
+
summary_stat=None,
|
1127 |
+
silent=True,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
del perturbation_minibatch
|
1131 |
+
|
1132 |
+
# need to remove overexpressed gene to quantify cosine shifts
|
1133 |
+
if self.perturb_type == "overexpress":
|
1134 |
+
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
1135 |
+
|
1136 |
+
elif self.perturb_type == "delete":
|
1137 |
+
perturbation_emb = full_perturbation_emb
|
1138 |
+
|
1139 |
+
if (
|
1140 |
+
self.cell_states_to_model is None
|
1141 |
+
or self.emb_mode == "cell_and_gene"
|
1142 |
+
):
|
1143 |
+
original_emb_minibatch = pu.make_comparison_batch(
|
1144 |
+
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1145 |
+
)
|
1146 |
+
gene_cos_sims = pu.quant_cos_sims(
|
1147 |
+
perturbation_emb,
|
1148 |
+
original_emb_minibatch,
|
1149 |
+
self.cell_states_to_model,
|
1150 |
+
self.state_embs_dict,
|
1151 |
+
emb_mode="gene",
|
1152 |
+
)
|
1153 |
+
del original_emb_minibatch
|
1154 |
+
|
1155 |
+
if self.cell_states_to_model is not None:
|
1156 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
1157 |
+
full_perturbation_emb, "mean_pool"
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
cell_cos_sims = pu.quant_cos_sims(
|
1161 |
+
perturbation_cell_emb,
|
1162 |
+
original_cell_emb,
|
1163 |
+
self.cell_states_to_model,
|
1164 |
+
self.state_embs_dict,
|
1165 |
+
emb_mode="cell",
|
1166 |
+
)
|
1167 |
+
del perturbation_cell_emb
|
1168 |
+
|
1169 |
+
if self.emb_mode == "cell_and_gene":
|
1170 |
+
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1171 |
+
for gene_j, affected_gene in enumerate(
|
1172 |
+
perturbed_gene_dict[perturbed_gene]
|
1173 |
+
):
|
1174 |
+
try:
|
1175 |
+
stored_gene_embs_dict[
|
1176 |
+
(perturbed_gene, affected_gene)
|
1177 |
+
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
1178 |
+
except KeyError:
|
1179 |
+
stored_gene_embs_dict[
|
1180 |
+
(perturbed_gene, affected_gene)
|
1181 |
+
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1182 |
+
|
1183 |
+
del full_perturbation_emb
|
1184 |
+
|
1185 |
+
if self.cell_states_to_model is None:
|
1186 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
1187 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
1188 |
+
cos_sims_dict,
|
1189 |
+
cos_sims_data,
|
1190 |
+
gene_list_mini,
|
1191 |
+
)
|
1192 |
+
else:
|
1193 |
+
cos_sims_data = cell_cos_sims
|
1194 |
+
for state in cos_sims_dict.keys():
|
1195 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1196 |
+
cos_sims_dict[state],
|
1197 |
+
cos_sims_data[state],
|
1198 |
+
gene_list_mini,
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1202 |
+
if i % self.clear_mem_ncells / 10 == 0:
|
1203 |
+
pu.write_perturbation_dictionary(
|
1204 |
+
cos_sims_dict,
|
1205 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1206 |
+
)
|
1207 |
+
if self.emb_mode == "cell_and_gene":
|
1208 |
+
pu.write_perturbation_dictionary(
|
1209 |
+
stored_gene_embs_dict,
|
1210 |
+
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1214 |
+
if i % self.clear_mem_ncells == 0:
|
1215 |
+
pickle_batch += 1
|
1216 |
+
if self.cell_states_to_model is None:
|
1217 |
+
cos_sims_dict = defaultdict(list)
|
1218 |
+
else:
|
1219 |
+
cos_sims_dict = {
|
1220 |
+
state: defaultdict(list)
|
1221 |
+
for state in pu.get_possible_states(
|
1222 |
+
self.cell_states_to_model
|
1223 |
+
)
|
1224 |
+
}
|
1225 |
+
|
1226 |
+
if self.emb_mode == "cell_and_gene":
|
1227 |
+
stored_gene_embs_dict = defaultdict(list)
|
1228 |
+
|
1229 |
+
torch.cuda.empty_cache()
|
1230 |
+
|
1231 |
+
pu.write_perturbation_dictionary(
|
1232 |
+
cos_sims_dict,
|
1233 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1234 |
+
)
|
1235 |
+
|
1236 |
+
if self.emb_mode == "cell_and_gene":
|
1237 |
+
pu.write_perturbation_dictionary(
|
1238 |
+
stored_gene_embs_dict,
|
1239 |
+
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
pickle_batch = -1
|
1243 |
+
if self.cell_states_to_model is None:
|
1244 |
+
cos_sims_dict = defaultdict(list)
|
1245 |
+
else:
|
1246 |
+
cos_sims_dict = {
|
1247 |
+
state: defaultdict(list)
|
1248 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
1249 |
+
}
|
1250 |
+
|
1251 |
+
if self.emb_mode == "cell_and_gene":
|
1252 |
+
stored_gene_embs_dict = defaultdict(list)
|
1253 |
+
|
1254 |
+
# clear memory between cells
|
1255 |
+
del perturbation_batch
|
1256 |
+
del full_original_emb
|
1257 |
+
if self.cell_states_to_model is not None:
|
1258 |
+
del original_cell_emb
|
1259 |
+
torch.cuda.empty_cache()
|
1260 |
+
|
1261 |
+
def isp_perturb_all_special(
|
1262 |
+
self,
|
1263 |
+
model,
|
1264 |
+
filtered_input_data: Dataset,
|
1265 |
+
layer_to_quant: int,
|
1266 |
+
output_path_prefix: str,
|
1267 |
+
):
|
1268 |
+
pickle_batch = -1
|
1269 |
+
if self.cell_states_to_model is None:
|
1270 |
+
cos_sims_dict = defaultdict(list)
|
1271 |
+
else:
|
1272 |
+
cos_sims_dict = {
|
1273 |
+
state: defaultdict(list)
|
1274 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
1275 |
+
}
|
1276 |
+
|
1277 |
+
if self.emb_mode == "cls_and_gene":
|
1278 |
+
stored_gene_embs_dict = defaultdict(list)
|
1279 |
+
|
1280 |
+
num_inds_perturbed = 1 + self.combos
|
1281 |
+
for h in trange(len(filtered_input_data)):
|
1282 |
+
example_cell = filtered_input_data.select([h])
|
1283 |
+
|
1284 |
+
# get original example cell cls and/or gene embs for comparison
|
1285 |
+
if self.emb_mode == "cls":
|
1286 |
+
original_cls_emb = get_embs(
|
1287 |
+
model,
|
1288 |
+
example_cell,
|
1289 |
+
"cls",
|
1290 |
+
layer_to_quant,
|
1291 |
+
self.pad_token_id,
|
1292 |
+
self.forward_batch_size,
|
1293 |
+
self.token_gene_dict,
|
1294 |
+
summary_stat=None,
|
1295 |
+
silent=True,
|
1296 |
+
)
|
1297 |
+
elif self.emb_mode == "cls_and_gene":
|
1298 |
+
full_original_emb = get_embs(
|
1299 |
+
model,
|
1300 |
+
example_cell,
|
1301 |
+
"gene",
|
1302 |
+
layer_to_quant,
|
1303 |
+
self.pad_token_id,
|
1304 |
+
self.forward_batch_size,
|
1305 |
+
self.token_gene_dict,
|
1306 |
+
summary_stat=None,
|
1307 |
+
silent=True,
|
1308 |
+
)
|
1309 |
+
original_cls_emb = full_original_emb[:, 0, :].clone().detach()
|
1310 |
+
|
1311 |
+
# gene_list is used to assign cos sims back to genes
|
1312 |
+
gene_list = example_cell["input_ids"][0][:]
|
1313 |
+
|
1314 |
+
# need to remove special tokens
|
1315 |
+
for token in [self.cls_token_id, self.eos_token_id]:
|
1316 |
+
gene_list.remove(token)
|
1317 |
+
# need to remove the anchor gene
|
1318 |
+
if self.anchor_token is not None:
|
1319 |
+
for token in self.anchor_token:
|
1320 |
+
gene_list.remove(token)
|
1321 |
+
# index 0 is not overexpressed so remove
|
1322 |
+
if self.perturb_type == "overexpress":
|
1323 |
+
gene_list = gene_list[num_inds_perturbed:]
|
1324 |
+
# remove perturbed index for gene list dict
|
1325 |
+
perturbed_gene_dict = {
|
1326 |
+
gene: gene_list[:i] + gene_list[i + 1 :]
|
1327 |
+
for i, gene in enumerate(gene_list)
|
1328 |
+
}
|
1329 |
+
|
1330 |
+
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
|
1331 |
+
example_cell,
|
1332 |
+
self.perturb_type,
|
1333 |
+
self.tokens_to_perturb,
|
1334 |
+
self.anchor_token,
|
1335 |
+
self.combos,
|
1336 |
+
self.nproc,
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
ispall_total_batch_length = len(perturbation_batch)
|
1340 |
+
for i in trange(
|
1341 |
+
0, ispall_total_batch_length, self.forward_batch_size, leave=False
|
1342 |
+
):
|
1343 |
+
ispall_max_range = min(
|
1344 |
+
i + self.forward_batch_size, ispall_total_batch_length
|
1345 |
+
)
|
1346 |
+
perturbation_minibatch = perturbation_batch.select(
|
1347 |
+
[i for i in range(i, ispall_max_range)]
|
1348 |
+
)
|
1349 |
+
indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
|
1350 |
+
gene_list_mini = gene_list[
|
1351 |
+
i:ispall_max_range
|
1352 |
+
] # only perturbed genes from this minibatch
|
1353 |
+
|
1354 |
+
##### CLS Embedding Mode #####
|
1355 |
+
if self.emb_mode == "cls":
|
1356 |
+
# Extract cls embeddings from perturbed cells
|
1357 |
+
perturbation_cls_emb = get_embs(
|
1358 |
+
model,
|
1359 |
+
perturbation_minibatch,
|
1360 |
+
"cls",
|
1361 |
+
layer_to_quant,
|
1362 |
+
self.pad_token_id,
|
1363 |
+
self.forward_batch_size,
|
1364 |
+
self.token_gene_dict,
|
1365 |
+
summary_stat=None,
|
1366 |
+
silent=True,
|
1367 |
+
)
|
1368 |
+
|
1369 |
+
# Calculate cosine similarities
|
1370 |
+
cls_cos_sims = pu.quant_cos_sims(
|
1371 |
+
perturbation_cls_emb,
|
1372 |
+
original_cls_emb,
|
1373 |
+
self.cell_states_to_model,
|
1374 |
+
self.state_embs_dict,
|
1375 |
+
emb_mode="cell",
|
1376 |
+
)
|
1377 |
+
|
1378 |
+
if self.cell_states_to_model is None:
|
1379 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
1380 |
+
cos_sims_dict,
|
1381 |
+
cls_cos_sims,
|
1382 |
+
gene_list_mini,
|
1383 |
+
)
|
1384 |
+
else:
|
1385 |
+
for state in cos_sims_dict.keys():
|
1386 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1387 |
+
cos_sims_dict[state],
|
1388 |
+
cls_cos_sims[state],
|
1389 |
+
gene_list_mini,
|
1390 |
+
)
|
1391 |
+
|
1392 |
+
del perturbation_minibatch
|
1393 |
+
del perturbation_cls_emb
|
1394 |
+
del cls_cos_sims
|
1395 |
+
|
1396 |
+
##### CLS and Gene Embedding Mode #####
|
1397 |
+
elif self.emb_mode == "cls_and_gene":
|
1398 |
+
full_perturbation_emb = get_embs(
|
1399 |
+
model,
|
1400 |
+
perturbation_minibatch,
|
1401 |
+
"gene",
|
1402 |
+
layer_to_quant,
|
1403 |
+
self.pad_token_id,
|
1404 |
+
self.forward_batch_size,
|
1405 |
+
self.token_gene_dict,
|
1406 |
+
summary_stat=None,
|
1407 |
+
silent=True,
|
1408 |
+
)
|
1409 |
+
|
1410 |
+
# need to remove overexpressed gene and cls/eos to quantify cosine shifts
|
1411 |
+
if self.perturb_type == "overexpress":
|
1412 |
+
perturbation_emb = (
|
1413 |
+
full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
|
1414 |
+
.clone()
|
1415 |
+
.detach()
|
1416 |
+
)
|
1417 |
+
elif self.perturb_type == "delete":
|
1418 |
+
perturbation_emb = (
|
1419 |
+
full_perturbation_emb[:, 1:-1, :].clone().detach()
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
original_emb_minibatch = pu.make_comparison_batch(
|
1423 |
+
full_original_emb, indices_to_perturb_mini, perturb_group=False
|
1424 |
+
)
|
1425 |
+
|
1426 |
+
original_emb_minibatch = (
|
1427 |
+
original_emb_minibatch[:, 1:-1, :].clone().detach()
|
1428 |
+
)
|
1429 |
+
gene_cos_sims = pu.quant_cos_sims(
|
1430 |
+
perturbation_emb,
|
1431 |
+
original_emb_minibatch,
|
1432 |
+
self.cell_states_to_model,
|
1433 |
+
self.state_embs_dict,
|
1434 |
+
emb_mode="gene",
|
1435 |
+
)
|
1436 |
+
|
1437 |
+
for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
|
1438 |
+
for gene_j, affected_gene in enumerate(
|
1439 |
+
perturbed_gene_dict[perturbed_gene]
|
1440 |
+
):
|
1441 |
+
try:
|
1442 |
+
stored_gene_embs_dict[
|
1443 |
+
(perturbed_gene, affected_gene)
|
1444 |
+
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
1445 |
+
except KeyError:
|
1446 |
+
stored_gene_embs_dict[
|
1447 |
+
(perturbed_gene, affected_gene)
|
1448 |
+
] = gene_cos_sims[perturbation_i, gene_j].item()
|
1449 |
+
|
1450 |
+
# get cls emb
|
1451 |
+
perturbation_cls_emb = (
|
1452 |
+
full_perturbation_emb[:, 0, :].clone().detach()
|
1453 |
+
)
|
1454 |
+
|
1455 |
+
cls_cos_sims = pu.quant_cos_sims(
|
1456 |
+
perturbation_cls_emb,
|
1457 |
+
original_cls_emb,
|
1458 |
+
self.cell_states_to_model,
|
1459 |
+
self.state_embs_dict,
|
1460 |
+
emb_mode="cell",
|
1461 |
+
)
|
1462 |
+
|
1463 |
+
if self.cell_states_to_model is None:
|
1464 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
1465 |
+
cos_sims_dict,
|
1466 |
+
cls_cos_sims,
|
1467 |
+
gene_list_mini,
|
1468 |
+
)
|
1469 |
+
else:
|
1470 |
+
for state in cos_sims_dict.keys():
|
1471 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1472 |
+
cos_sims_dict[state],
|
1473 |
+
cls_cos_sims[state],
|
1474 |
+
gene_list_mini,
|
1475 |
+
)
|
1476 |
+
|
1477 |
+
del perturbation_minibatch
|
1478 |
+
del original_emb_minibatch
|
1479 |
+
del full_perturbation_emb
|
1480 |
+
del perturbation_emb
|
1481 |
+
del perturbation_cls_emb
|
1482 |
+
del cls_cos_sims
|
1483 |
+
del gene_cos_sims
|
1484 |
+
|
1485 |
+
# save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
|
1486 |
+
if i % max(1, self.clear_mem_ncells / 10) == 0:
|
1487 |
+
pu.write_perturbation_dictionary(
|
1488 |
+
cos_sims_dict,
|
1489 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1490 |
+
)
|
1491 |
+
if self.emb_mode == "cls_and_gene":
|
1492 |
+
pu.write_perturbation_dictionary(
|
1493 |
+
stored_gene_embs_dict,
|
1494 |
+
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1495 |
+
)
|
1496 |
+
|
1497 |
+
# reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
|
1498 |
+
if i % self.clear_mem_ncells == 0:
|
1499 |
+
pickle_batch += 1
|
1500 |
+
if self.cell_states_to_model is None:
|
1501 |
+
cos_sims_dict = defaultdict(list)
|
1502 |
+
else:
|
1503 |
+
cos_sims_dict = {
|
1504 |
+
state: defaultdict(list)
|
1505 |
+
for state in pu.get_possible_states(
|
1506 |
+
self.cell_states_to_model
|
1507 |
+
)
|
1508 |
+
}
|
1509 |
+
|
1510 |
+
if self.emb_mode == "cls_and_gene":
|
1511 |
+
stored_gene_embs_dict = defaultdict(list)
|
1512 |
+
|
1513 |
+
torch.cuda.empty_cache()
|
1514 |
+
|
1515 |
+
pu.write_perturbation_dictionary(
|
1516 |
+
cos_sims_dict,
|
1517 |
+
f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
|
1518 |
+
)
|
1519 |
+
|
1520 |
+
if self.emb_mode == "cls_and_gene":
|
1521 |
+
pu.write_perturbation_dictionary(
|
1522 |
+
stored_gene_embs_dict,
|
1523 |
+
f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
|
1524 |
+
)
|
1525 |
+
|
1526 |
+
pickle_batch = -1
|
1527 |
+
if self.cell_states_to_model is None:
|
1528 |
+
cos_sims_dict = defaultdict(list)
|
1529 |
+
else:
|
1530 |
+
cos_sims_dict = {
|
1531 |
+
state: defaultdict(list)
|
1532 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
1533 |
+
}
|
1534 |
+
|
1535 |
+
if self.emb_mode == "cls_and_gene":
|
1536 |
+
stored_gene_embs_dict = defaultdict(list)
|
1537 |
+
|
1538 |
+
# clear memory between cells
|
1539 |
+
del perturbation_batch
|
1540 |
+
del original_cls_emb
|
1541 |
+
if self.emb_mode == "cls_and_gene":
|
1542 |
+
del full_original_emb
|
1543 |
+
torch.cuda.empty_cache()
|
1544 |
+
|
1545 |
+
def update_perturbation_dictionary(
|
1546 |
+
self,
|
1547 |
+
cos_sims_dict: defaultdict,
|
1548 |
+
cos_sims_data: torch.Tensor,
|
1549 |
+
gene_list=None,
|
1550 |
+
):
|
1551 |
+
if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
|
1552 |
+
logger.error(
|
1553 |
+
f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
|
1554 |
+
{cos_sims_data.shape[0]=}.\n \
|
1555 |
+
{len(gene_list)=}."
|
1556 |
+
)
|
1557 |
+
raise
|
1558 |
+
|
1559 |
+
if self.perturb_group is True:
|
1560 |
+
if len(self.tokens_to_perturb) > 1:
|
1561 |
+
perturbed_genes = tuple(self.tokens_to_perturb)
|
1562 |
+
else:
|
1563 |
+
perturbed_genes = self.tokens_to_perturb[0]
|
1564 |
+
|
1565 |
+
# if cell embeddings, can just append
|
1566 |
+
# shape will be (batch size, 1)
|
1567 |
+
cos_sims_data = torch.squeeze(cos_sims_data).tolist()
|
1568 |
+
|
1569 |
+
# handle case of single cell left
|
1570 |
+
if not isinstance(cos_sims_data, list):
|
1571 |
+
cos_sims_data = [cos_sims_data]
|
1572 |
+
|
1573 |
+
cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
|
1574 |
+
|
1575 |
+
else:
|
1576 |
+
for i, cos in enumerate(cos_sims_data.tolist()):
|
1577 |
+
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
1578 |
+
|
1579 |
+
return cos_sims_dict
|
geneformer/in_silico_perturber_stats.py
ADDED
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer in silico perturber stats generator.
|
3 |
+
|
4 |
+
**Usage:**
|
5 |
+
|
6 |
+
.. code-block :: python
|
7 |
+
|
8 |
+
>>> from geneformer import InSilicoPerturberStats
|
9 |
+
>>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
|
10 |
+
... cell_states_to_model={"state_key": "disease",
|
11 |
+
... "start_state": "dcm",
|
12 |
+
... "goal_state": "nf",
|
13 |
+
... "alt_states": ["hcm", "other1", "other2"]})
|
14 |
+
>>> ispstats.get_stats("path/to/input_data",
|
15 |
+
... None,
|
16 |
+
... "path/to/output_directory",
|
17 |
+
... "output_prefix")
|
18 |
+
|
19 |
+
**Description:**
|
20 |
+
|
21 |
+
| Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
|
22 |
+
| Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
import logging
|
28 |
+
import os
|
29 |
+
import pickle
|
30 |
+
import random
|
31 |
+
from pathlib import Path
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import pandas as pd
|
35 |
+
import statsmodels.stats.multitest as smt
|
36 |
+
from scipy.stats import ranksums
|
37 |
+
from sklearn.mixture import GaussianMixture
|
38 |
+
from tqdm.auto import tqdm, trange
|
39 |
+
|
40 |
+
from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
|
41 |
+
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
# invert dictionary keys/values
|
47 |
+
def invert_dict(dictionary):
|
48 |
+
return {v: k for k, v in dictionary.items()}
|
49 |
+
|
50 |
+
|
51 |
+
def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
|
52 |
+
if cell_or_gene_emb == "cell":
|
53 |
+
cell_emb_dict = {
|
54 |
+
k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
|
55 |
+
}
|
56 |
+
return [cell_emb_dict]
|
57 |
+
elif cell_or_gene_emb == "gene":
|
58 |
+
if anchor_token is None:
|
59 |
+
gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
|
60 |
+
else:
|
61 |
+
gene_emb_dict = {
|
62 |
+
k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
|
63 |
+
}
|
64 |
+
return [gene_emb_dict]
|
65 |
+
|
66 |
+
|
67 |
+
# read raw dictionary files
|
68 |
+
def read_dictionaries(
|
69 |
+
input_data_directory,
|
70 |
+
cell_or_gene_emb,
|
71 |
+
anchor_token,
|
72 |
+
cell_states_to_model,
|
73 |
+
pickle_suffix,
|
74 |
+
):
|
75 |
+
file_found = False
|
76 |
+
file_path_list = []
|
77 |
+
if cell_states_to_model is None:
|
78 |
+
dict_list = []
|
79 |
+
else:
|
80 |
+
validate_cell_states_to_model(cell_states_to_model)
|
81 |
+
cell_states_to_model_valid = {
|
82 |
+
state: value
|
83 |
+
for state, value in cell_states_to_model.items()
|
84 |
+
if state != "state_key"
|
85 |
+
and cell_states_to_model[state] is not None
|
86 |
+
and cell_states_to_model[state] != []
|
87 |
+
}
|
88 |
+
cell_states_list = []
|
89 |
+
# flatten all state values into list
|
90 |
+
for state in cell_states_to_model_valid:
|
91 |
+
value = cell_states_to_model_valid[state]
|
92 |
+
if isinstance(value, list):
|
93 |
+
cell_states_list += value
|
94 |
+
else:
|
95 |
+
cell_states_list.append(value)
|
96 |
+
state_dict = {state_value: dict() for state_value in cell_states_list}
|
97 |
+
for file in os.listdir(input_data_directory):
|
98 |
+
# process only files with given suffix (e.g. "_raw.pickle")
|
99 |
+
if file.endswith(pickle_suffix):
|
100 |
+
file_found = True
|
101 |
+
file_path_list += [f"{input_data_directory}/{file}"]
|
102 |
+
for file_path in tqdm(file_path_list):
|
103 |
+
with open(file_path, "rb") as fp:
|
104 |
+
cos_sims_dict = pickle.load(fp)
|
105 |
+
if cell_states_to_model is None:
|
106 |
+
dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
|
107 |
+
else:
|
108 |
+
for state_value in cell_states_list:
|
109 |
+
new_dict = read_dict(
|
110 |
+
cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
|
111 |
+
)[0]
|
112 |
+
for key in new_dict:
|
113 |
+
try:
|
114 |
+
state_dict[state_value][key] += new_dict[key]
|
115 |
+
except KeyError:
|
116 |
+
state_dict[state_value][key] = new_dict[key]
|
117 |
+
|
118 |
+
if not file_found:
|
119 |
+
logger.error(
|
120 |
+
"No raw data for processing found within provided directory. "
|
121 |
+
"Please ensure data files end with '{pickle_suffix}'."
|
122 |
+
)
|
123 |
+
raise
|
124 |
+
if cell_states_to_model is None:
|
125 |
+
return dict_list
|
126 |
+
else:
|
127 |
+
return state_dict
|
128 |
+
|
129 |
+
|
130 |
+
# get complete gene list
|
131 |
+
def get_gene_list(dict_list, mode):
|
132 |
+
if mode == "cell":
|
133 |
+
position = 0
|
134 |
+
elif mode == "gene":
|
135 |
+
position = 1
|
136 |
+
gene_set = set()
|
137 |
+
if isinstance(dict_list, list):
|
138 |
+
for dict_i in dict_list:
|
139 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
140 |
+
elif isinstance(dict_list, dict):
|
141 |
+
for state, dict_i in dict_list.items():
|
142 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
143 |
+
else:
|
144 |
+
logger.error(
|
145 |
+
"dict_list should be a list, or if modeling shift to goal states, a dict. "
|
146 |
+
f"{type(dict_list)} is not the correct format."
|
147 |
+
)
|
148 |
+
raise
|
149 |
+
gene_list = list(gene_set)
|
150 |
+
if mode == "gene":
|
151 |
+
gene_list.remove("cell_emb")
|
152 |
+
gene_list.sort()
|
153 |
+
return gene_list
|
154 |
+
|
155 |
+
|
156 |
+
def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
|
157 |
+
try:
|
158 |
+
return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
|
159 |
+
except TypeError:
|
160 |
+
return gene_token_id_dict.get(token_tuple, np.nan)
|
161 |
+
|
162 |
+
|
163 |
+
def n_detections(token, dict_list, mode, anchor_token):
|
164 |
+
cos_sim_megalist = []
|
165 |
+
for dict_i in dict_list:
|
166 |
+
if mode == "cell":
|
167 |
+
cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
|
168 |
+
elif mode == "gene":
|
169 |
+
cos_sim_megalist += dict_i.get((anchor_token, token), [])
|
170 |
+
return len(cos_sim_megalist)
|
171 |
+
|
172 |
+
|
173 |
+
def get_fdr(pvalues):
|
174 |
+
return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
|
175 |
+
|
176 |
+
|
177 |
+
def get_impact_component(test_value, gaussian_mixture_model):
|
178 |
+
impact_border = gaussian_mixture_model.means_[0][0]
|
179 |
+
nonimpact_border = gaussian_mixture_model.means_[1][0]
|
180 |
+
if test_value > nonimpact_border:
|
181 |
+
impact_component = 0
|
182 |
+
elif test_value < impact_border:
|
183 |
+
impact_component = 1
|
184 |
+
else:
|
185 |
+
impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
|
186 |
+
if impact_component_raw == 1:
|
187 |
+
impact_component = 0
|
188 |
+
elif impact_component_raw == 0:
|
189 |
+
impact_component = 1
|
190 |
+
return impact_component
|
191 |
+
|
192 |
+
|
193 |
+
# aggregate data for single perturbation in multiple cells
|
194 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
195 |
+
names = ["Cosine_sim", "Gene"]
|
196 |
+
cos_sims_full_dfs = []
|
197 |
+
if isinstance(genes_perturbed, list):
|
198 |
+
if len(genes_perturbed) > 1:
|
199 |
+
gene_ids_df = cos_sims_df.loc[
|
200 |
+
np.isin(
|
201 |
+
[set(idx) for idx in cos_sims_df["Ensembl_ID"]],
|
202 |
+
set(genes_perturbed),
|
203 |
+
),
|
204 |
+
:,
|
205 |
+
]
|
206 |
+
else:
|
207 |
+
gene_ids_df = cos_sims_df.loc[
|
208 |
+
np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
|
209 |
+
]
|
210 |
+
else:
|
211 |
+
logger.error(
|
212 |
+
"aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
|
213 |
+
)
|
214 |
+
raise
|
215 |
+
|
216 |
+
if gene_ids_df.empty:
|
217 |
+
logger.error("genes_to_perturb not found in data.")
|
218 |
+
raise
|
219 |
+
|
220 |
+
tokens = gene_ids_df["Gene"]
|
221 |
+
symbols = gene_ids_df["Gene_name"]
|
222 |
+
|
223 |
+
for token, symbol in zip(tokens, symbols):
|
224 |
+
cos_shift_data = []
|
225 |
+
for dict_i in dict_list:
|
226 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
227 |
+
|
228 |
+
df = pd.DataFrame(columns=names)
|
229 |
+
df["Cosine_sim"] = cos_shift_data
|
230 |
+
df["Gene"] = symbol
|
231 |
+
cos_sims_full_dfs.append(df)
|
232 |
+
|
233 |
+
return pd.concat(cos_sims_full_dfs)
|
234 |
+
|
235 |
+
|
236 |
+
def find(variable, x):
|
237 |
+
try:
|
238 |
+
if x in variable: # Test if variable is iterable and contains x
|
239 |
+
return True
|
240 |
+
elif x == variable:
|
241 |
+
return True
|
242 |
+
except (ValueError, TypeError):
|
243 |
+
return x == variable # Test if variable is x if non-iterable
|
244 |
+
|
245 |
+
|
246 |
+
def isp_aggregate_gene_shifts(
|
247 |
+
cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
|
248 |
+
):
|
249 |
+
cos_shift_data = dict()
|
250 |
+
for i in trange(cos_sims_df.shape[0]):
|
251 |
+
token = cos_sims_df["Gene"][i]
|
252 |
+
for dict_i in dict_list:
|
253 |
+
if token_dtype == "nontuple":
|
254 |
+
affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
|
255 |
+
else:
|
256 |
+
affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
|
257 |
+
for key in affected_pairs:
|
258 |
+
if key in cos_shift_data.keys():
|
259 |
+
cos_shift_data[key] += dict_i.get(key, [])
|
260 |
+
else:
|
261 |
+
cos_shift_data[key] = dict_i.get(key, [])
|
262 |
+
|
263 |
+
cos_data_mean = {
|
264 |
+
k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
|
265 |
+
}
|
266 |
+
cos_sims_full_df = pd.DataFrame()
|
267 |
+
cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
|
268 |
+
cos_sims_full_df["Gene_name"] = [
|
269 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
|
270 |
+
for k, v in cos_data_mean.items()
|
271 |
+
]
|
272 |
+
cos_sims_full_df["Ensembl_ID"] = [
|
273 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
|
274 |
+
for k, v in cos_data_mean.items()
|
275 |
+
]
|
276 |
+
|
277 |
+
cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
|
278 |
+
cos_sims_full_df["Affected_gene_name"] = [
|
279 |
+
gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
|
280 |
+
for token in cos_sims_full_df["Affected"]
|
281 |
+
]
|
282 |
+
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
283 |
+
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
284 |
+
]
|
285 |
+
cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
|
286 |
+
cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
|
287 |
+
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
288 |
+
|
289 |
+
specific_val = "cell_emb"
|
290 |
+
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
291 |
+
# reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
|
292 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
293 |
+
by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
|
294 |
+
).drop("temp", axis=1)
|
295 |
+
|
296 |
+
return cos_sims_full_df
|
297 |
+
|
298 |
+
|
299 |
+
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
300 |
+
def isp_stats_to_goal_state(
|
301 |
+
cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
|
302 |
+
):
|
303 |
+
if (
|
304 |
+
("alt_states" not in cell_states_to_model.keys())
|
305 |
+
or (len(cell_states_to_model["alt_states"]) == 0)
|
306 |
+
or (cell_states_to_model["alt_states"] == [None])
|
307 |
+
):
|
308 |
+
alt_end_state_exists = False
|
309 |
+
elif (len(cell_states_to_model["alt_states"]) > 0) and (
|
310 |
+
cell_states_to_model["alt_states"] != [None]
|
311 |
+
):
|
312 |
+
alt_end_state_exists = True
|
313 |
+
|
314 |
+
# for single perturbation in multiple cells, there are no random perturbations to compare to
|
315 |
+
if genes_perturbed != "all":
|
316 |
+
cos_sims_full_df = pd.DataFrame()
|
317 |
+
|
318 |
+
cos_shift_data_end = []
|
319 |
+
token = cos_sims_df["Gene"][0]
|
320 |
+
cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
|
321 |
+
(token, "cell_emb"), []
|
322 |
+
)
|
323 |
+
cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
|
324 |
+
if alt_end_state_exists is True:
|
325 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
326 |
+
cos_shift_data_alt_state = []
|
327 |
+
cos_shift_data_alt_state += result_dict.get(alt_state).get(
|
328 |
+
(token, "cell_emb"), []
|
329 |
+
)
|
330 |
+
cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
|
331 |
+
np.mean(cos_shift_data_alt_state)
|
332 |
+
]
|
333 |
+
|
334 |
+
# sort by shift to desired state
|
335 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
336 |
+
by=["Shift_to_goal_end"], ascending=[False]
|
337 |
+
)
|
338 |
+
return cos_sims_full_df
|
339 |
+
|
340 |
+
elif genes_perturbed == "all":
|
341 |
+
goal_end_random_megalist = []
|
342 |
+
if alt_end_state_exists is True:
|
343 |
+
alt_end_state_random_dict = {
|
344 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
345 |
+
}
|
346 |
+
for i in trange(cos_sims_df.shape[0]):
|
347 |
+
token = cos_sims_df["Gene"][i]
|
348 |
+
goal_end_random_megalist += result_dict[
|
349 |
+
cell_states_to_model["goal_state"]
|
350 |
+
].get((token, "cell_emb"), [])
|
351 |
+
if alt_end_state_exists is True:
|
352 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
353 |
+
alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
|
354 |
+
(token, "cell_emb"), []
|
355 |
+
)
|
356 |
+
|
357 |
+
# downsample to improve speed of ranksums
|
358 |
+
if len(goal_end_random_megalist) > 100_000:
|
359 |
+
random.seed(42)
|
360 |
+
goal_end_random_megalist = random.sample(
|
361 |
+
goal_end_random_megalist, k=100_000
|
362 |
+
)
|
363 |
+
if alt_end_state_exists is True:
|
364 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
365 |
+
if len(alt_end_state_random_dict[alt_state]) > 100_000:
|
366 |
+
random.seed(42)
|
367 |
+
alt_end_state_random_dict[alt_state] = random.sample(
|
368 |
+
alt_end_state_random_dict[alt_state], k=100_000
|
369 |
+
)
|
370 |
+
|
371 |
+
names = [
|
372 |
+
"Gene",
|
373 |
+
"Gene_name",
|
374 |
+
"Ensembl_ID",
|
375 |
+
"Shift_to_goal_end",
|
376 |
+
"Goal_end_vs_random_pval",
|
377 |
+
]
|
378 |
+
if alt_end_state_exists is True:
|
379 |
+
[
|
380 |
+
names.append(f"Shift_to_alt_end_{alt_state}")
|
381 |
+
for alt_state in cell_states_to_model["alt_states"]
|
382 |
+
]
|
383 |
+
names.append(names.pop(names.index("Goal_end_vs_random_pval")))
|
384 |
+
[
|
385 |
+
names.append(f"Alt_end_vs_random_pval_{alt_state}")
|
386 |
+
for alt_state in cell_states_to_model["alt_states"]
|
387 |
+
]
|
388 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
389 |
+
|
390 |
+
n_detections_dict = dict()
|
391 |
+
for i in trange(cos_sims_df.shape[0]):
|
392 |
+
token = cos_sims_df["Gene"][i]
|
393 |
+
name = cos_sims_df["Gene_name"][i]
|
394 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
395 |
+
goal_end_cos_sim_megalist = result_dict[
|
396 |
+
cell_states_to_model["goal_state"]
|
397 |
+
].get((token, "cell_emb"), [])
|
398 |
+
n_detections_dict[token] = len(goal_end_cos_sim_megalist)
|
399 |
+
mean_goal_end = np.mean(goal_end_cos_sim_megalist)
|
400 |
+
pval_goal_end = ranksums(
|
401 |
+
goal_end_random_megalist, goal_end_cos_sim_megalist
|
402 |
+
).pvalue
|
403 |
+
|
404 |
+
if alt_end_state_exists is True:
|
405 |
+
alt_end_state_dict = {
|
406 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
407 |
+
}
|
408 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
409 |
+
alt_end_state_dict[alt_state] = result_dict[alt_state].get(
|
410 |
+
(token, "cell_emb"), []
|
411 |
+
)
|
412 |
+
alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
|
413 |
+
alt_end_state_dict[alt_state]
|
414 |
+
)
|
415 |
+
alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
|
416 |
+
alt_end_state_random_dict[alt_state],
|
417 |
+
alt_end_state_dict[alt_state],
|
418 |
+
).pvalue
|
419 |
+
|
420 |
+
results_dict = dict()
|
421 |
+
results_dict["Gene"] = token
|
422 |
+
results_dict["Gene_name"] = name
|
423 |
+
results_dict["Ensembl_ID"] = ensembl_id
|
424 |
+
results_dict["Shift_to_goal_end"] = mean_goal_end
|
425 |
+
results_dict["Goal_end_vs_random_pval"] = pval_goal_end
|
426 |
+
if alt_end_state_exists is True:
|
427 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
428 |
+
results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
|
429 |
+
f"{alt_state}_mean"
|
430 |
+
]
|
431 |
+
results_dict[
|
432 |
+
f"Alt_end_vs_random_pval_{alt_state}"
|
433 |
+
] = alt_end_state_dict[f"{alt_state}_pval"]
|
434 |
+
|
435 |
+
cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
|
436 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
437 |
+
|
438 |
+
cos_sims_full_df["Goal_end_FDR"] = get_fdr(
|
439 |
+
list(cos_sims_full_df["Goal_end_vs_random_pval"])
|
440 |
+
)
|
441 |
+
if alt_end_state_exists is True:
|
442 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
443 |
+
cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
|
444 |
+
list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
|
445 |
+
)
|
446 |
+
|
447 |
+
# quantify number of detections of each gene
|
448 |
+
cos_sims_full_df["N_Detections"] = [
|
449 |
+
n_detections_dict[token] for token in cos_sims_full_df["Gene"]
|
450 |
+
]
|
451 |
+
|
452 |
+
# sort by shift to desired state
|
453 |
+
cos_sims_full_df["Sig"] = [
|
454 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
|
455 |
+
]
|
456 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
457 |
+
by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
|
458 |
+
ascending=[False, False, True],
|
459 |
+
)
|
460 |
+
|
461 |
+
return cos_sims_full_df
|
462 |
+
|
463 |
+
|
464 |
+
# stats comparing cos sim shifts of test perturbations vs null distribution
|
465 |
+
def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
|
466 |
+
cos_sims_full_df = cos_sims_df.copy()
|
467 |
+
|
468 |
+
cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
469 |
+
cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
470 |
+
cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
|
471 |
+
cos_sims_df.shape[0], dtype=float
|
472 |
+
)
|
473 |
+
cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
474 |
+
cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
475 |
+
cos_sims_full_df["N_Detections_test"] = np.zeros(
|
476 |
+
cos_sims_df.shape[0], dtype="uint32"
|
477 |
+
)
|
478 |
+
cos_sims_full_df["N_Detections_null"] = np.zeros(
|
479 |
+
cos_sims_df.shape[0], dtype="uint32"
|
480 |
+
)
|
481 |
+
|
482 |
+
for i in trange(cos_sims_df.shape[0]):
|
483 |
+
token = cos_sims_df["Gene"][i]
|
484 |
+
test_shifts = []
|
485 |
+
null_shifts = []
|
486 |
+
|
487 |
+
for dict_i in dict_list:
|
488 |
+
test_shifts += dict_i.get((token, "cell_emb"), [])
|
489 |
+
|
490 |
+
for dict_i in null_dict_list:
|
491 |
+
null_shifts += dict_i.get((token, "cell_emb"), [])
|
492 |
+
|
493 |
+
cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
|
494 |
+
cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
|
495 |
+
cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
|
496 |
+
test_shifts
|
497 |
+
) - np.mean(null_shifts)
|
498 |
+
cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
|
499 |
+
test_shifts, null_shifts, nan_policy="omit"
|
500 |
+
).pvalue
|
501 |
+
# remove nan values
|
502 |
+
cos_sims_full_df.Test_vs_null_pval = np.where(
|
503 |
+
np.isnan(cos_sims_full_df.Test_vs_null_pval),
|
504 |
+
1,
|
505 |
+
cos_sims_full_df.Test_vs_null_pval,
|
506 |
+
)
|
507 |
+
cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
|
508 |
+
cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
|
509 |
+
|
510 |
+
cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
|
511 |
+
cos_sims_full_df["Test_vs_null_pval"]
|
512 |
+
)
|
513 |
+
|
514 |
+
cos_sims_full_df["Sig"] = [
|
515 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
|
516 |
+
]
|
517 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
518 |
+
by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
|
519 |
+
ascending=[False, False, True],
|
520 |
+
)
|
521 |
+
return cos_sims_full_df
|
522 |
+
|
523 |
+
|
524 |
+
# stats for identifying perturbations with largest effect within a given set of cells
|
525 |
+
# fits a mixture model to 2 components (impact vs. non-impact) and
|
526 |
+
# reports the most likely component for each test perturbation
|
527 |
+
# Note: because assumes given perturbation has a consistent effect in the cells tested,
|
528 |
+
# we recommend only using the mixture model strategy with uniform cell populations
|
529 |
+
def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
530 |
+
names = ["Gene", "Gene_name", "Ensembl_ID"]
|
531 |
+
|
532 |
+
if combos == 0:
|
533 |
+
names += ["Test_avg_shift"]
|
534 |
+
elif combos == 1:
|
535 |
+
names += [
|
536 |
+
"Anchor_shift",
|
537 |
+
"Test_token_shift",
|
538 |
+
"Sum_of_indiv_shifts",
|
539 |
+
"Combo_shift",
|
540 |
+
"Combo_minus_sum_shift",
|
541 |
+
]
|
542 |
+
|
543 |
+
names += ["Impact_component", "Impact_component_percent"]
|
544 |
+
|
545 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
546 |
+
avg_values = []
|
547 |
+
gene_names = []
|
548 |
+
|
549 |
+
for i in trange(cos_sims_df.shape[0]):
|
550 |
+
token = cos_sims_df["Gene"][i]
|
551 |
+
name = cos_sims_df["Gene_name"][i]
|
552 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
553 |
+
cos_shift_data = []
|
554 |
+
|
555 |
+
for dict_i in dict_list:
|
556 |
+
if (combos == 0) and (anchor_token is not None):
|
557 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
558 |
+
else:
|
559 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
560 |
+
|
561 |
+
# Extract values for current gene
|
562 |
+
if combos == 0:
|
563 |
+
test_values = cos_shift_data
|
564 |
+
elif combos == 1:
|
565 |
+
test_values = []
|
566 |
+
for tup in cos_shift_data:
|
567 |
+
test_values.append(tup[2])
|
568 |
+
|
569 |
+
if len(test_values) > 0:
|
570 |
+
avg_value = np.mean(test_values)
|
571 |
+
avg_values.append(avg_value)
|
572 |
+
gene_names.append(name)
|
573 |
+
|
574 |
+
# fit Gaussian mixture model to dataset of mean for each gene
|
575 |
+
avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
|
576 |
+
gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
|
577 |
+
|
578 |
+
for i in trange(cos_sims_df.shape[0]):
|
579 |
+
token = cos_sims_df["Gene"][i]
|
580 |
+
name = cos_sims_df["Gene_name"][i]
|
581 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
582 |
+
cos_shift_data = []
|
583 |
+
|
584 |
+
for dict_i in dict_list:
|
585 |
+
if (combos == 0) and (anchor_token is not None):
|
586 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
587 |
+
else:
|
588 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
589 |
+
|
590 |
+
if combos == 0:
|
591 |
+
mean_test = np.mean(cos_shift_data)
|
592 |
+
impact_components = [
|
593 |
+
get_impact_component(value, gm) for value in cos_shift_data
|
594 |
+
]
|
595 |
+
elif combos == 1:
|
596 |
+
anchor_cos_sim_megalist = [
|
597 |
+
anchor for anchor, token, combo in cos_shift_data
|
598 |
+
]
|
599 |
+
token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
|
600 |
+
anchor_plus_token_cos_sim_megalist = [
|
601 |
+
1 - ((1 - anchor) + (1 - token))
|
602 |
+
for anchor, token, combo in cos_shift_data
|
603 |
+
]
|
604 |
+
combo_anchor_token_cos_sim_megalist = [
|
605 |
+
combo for anchor, token, combo in cos_shift_data
|
606 |
+
]
|
607 |
+
combo_minus_sum_cos_sim_megalist = [
|
608 |
+
combo - (1 - ((1 - anchor) + (1 - token)))
|
609 |
+
for anchor, token, combo in cos_shift_data
|
610 |
+
]
|
611 |
+
|
612 |
+
mean_anchor = np.mean(anchor_cos_sim_megalist)
|
613 |
+
mean_token = np.mean(token_cos_sim_megalist)
|
614 |
+
mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
|
615 |
+
mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
|
616 |
+
mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
|
617 |
+
|
618 |
+
impact_components = [
|
619 |
+
get_impact_component(value, gm)
|
620 |
+
for value in combo_anchor_token_cos_sim_megalist
|
621 |
+
]
|
622 |
+
|
623 |
+
impact_component = get_impact_component(mean_test, gm)
|
624 |
+
impact_component_percent = np.mean(impact_components) * 100
|
625 |
+
|
626 |
+
data_i = [token, name, ensembl_id]
|
627 |
+
if combos == 0:
|
628 |
+
data_i += [mean_test]
|
629 |
+
elif combos == 1:
|
630 |
+
data_i += [
|
631 |
+
mean_anchor,
|
632 |
+
mean_token,
|
633 |
+
mean_sum,
|
634 |
+
mean_test,
|
635 |
+
mean_combo_minus_sum,
|
636 |
+
]
|
637 |
+
data_i += [impact_component, impact_component_percent]
|
638 |
+
|
639 |
+
cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
|
640 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
641 |
+
|
642 |
+
# quantify number of detections of each gene
|
643 |
+
if anchor_token is None:
|
644 |
+
cos_sims_full_df["N_Detections"] = [
|
645 |
+
n_detections(i, dict_list, "cell", anchor_token)
|
646 |
+
for i in cos_sims_full_df["Gene"]
|
647 |
+
]
|
648 |
+
else:
|
649 |
+
cos_sims_full_df["N_Detections"] = [
|
650 |
+
n_detections(i, dict_list, "gene", anchor_token)
|
651 |
+
for i in cos_sims_full_df["Gene"]
|
652 |
+
]
|
653 |
+
|
654 |
+
if combos == 0:
|
655 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
656 |
+
by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
|
657 |
+
)
|
658 |
+
elif combos == 1:
|
659 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
660 |
+
by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
|
661 |
+
)
|
662 |
+
return cos_sims_full_df
|
663 |
+
|
664 |
+
|
665 |
+
class InSilicoPerturberStats:
|
666 |
+
valid_option_dict = {
|
667 |
+
"mode": {
|
668 |
+
"goal_state_shift",
|
669 |
+
"vs_null",
|
670 |
+
"mixture_model",
|
671 |
+
"aggregate_data",
|
672 |
+
"aggregate_gene_shifts",
|
673 |
+
},
|
674 |
+
"genes_perturbed": {"all", list},
|
675 |
+
"combos": {0, 1},
|
676 |
+
"anchor_gene": {None, str},
|
677 |
+
"cell_states_to_model": {None, dict},
|
678 |
+
"pickle_suffix": {None, str},
|
679 |
+
}
|
680 |
+
|
681 |
+
def __init__(
|
682 |
+
self,
|
683 |
+
mode="mixture_model",
|
684 |
+
genes_perturbed="all",
|
685 |
+
combos=0,
|
686 |
+
anchor_gene=None,
|
687 |
+
cell_states_to_model=None,
|
688 |
+
pickle_suffix="_raw.pickle",
|
689 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
690 |
+
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
691 |
+
):
|
692 |
+
"""
|
693 |
+
Initialize in silico perturber stats generator.
|
694 |
+
|
695 |
+
**Parameters:**
|
696 |
+
|
697 |
+
mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
|
698 |
+
| Type of stats.
|
699 |
+
| "goal_state_shift": perturbation vs. random for desired cell state shift
|
700 |
+
| "vs_null": perturbation vs. null from provided null distribution dataset
|
701 |
+
| "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
|
702 |
+
| "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
|
703 |
+
| "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
|
704 |
+
genes_perturbed : "all", list
|
705 |
+
| Genes perturbed in isp experiment.
|
706 |
+
| Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
707 |
+
| Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
|
708 |
+
combos : {0,1,2}
|
709 |
+
| Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
|
710 |
+
anchor_gene : None, str
|
711 |
+
| ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
|
712 |
+
| For example, if combos=1 and anchor_gene="ENSG00000136574":
|
713 |
+
| analyzes data for anchor gene perturbed in combination with each other gene.
|
714 |
+
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
715 |
+
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
716 |
+
cell_states_to_model: None, dict
|
717 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
718 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
719 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
720 |
+
| start_state: value in the state_key column that specifies the start state
|
721 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
722 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
723 |
+
| For example: {"state_key": "disease",
|
724 |
+
| "start_state": "dcm",
|
725 |
+
| "goal_state": "nf",
|
726 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
727 |
+
token_dictionary_file : Path
|
728 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
729 |
+
gene_name_id_dictionary_file : Path
|
730 |
+
| Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
|
731 |
+
"""
|
732 |
+
|
733 |
+
self.mode = mode
|
734 |
+
self.genes_perturbed = genes_perturbed
|
735 |
+
self.combos = combos
|
736 |
+
self.anchor_gene = anchor_gene
|
737 |
+
self.cell_states_to_model = cell_states_to_model
|
738 |
+
self.pickle_suffix = pickle_suffix
|
739 |
+
|
740 |
+
self.validate_options()
|
741 |
+
|
742 |
+
# load token dictionary (Ensembl IDs:token)
|
743 |
+
with open(token_dictionary_file, "rb") as f:
|
744 |
+
self.gene_token_dict = pickle.load(f)
|
745 |
+
|
746 |
+
# load gene name dictionary (gene name:Ensembl ID)
|
747 |
+
with open(gene_name_id_dictionary_file, "rb") as f:
|
748 |
+
self.gene_name_id_dict = pickle.load(f)
|
749 |
+
|
750 |
+
if anchor_gene is None:
|
751 |
+
self.anchor_token = None
|
752 |
+
else:
|
753 |
+
self.anchor_token = self.gene_token_dict[self.anchor_gene]
|
754 |
+
|
755 |
+
def validate_options(self):
|
756 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
757 |
+
attr_value = self.__dict__[attr_name]
|
758 |
+
if type(attr_value) not in {list, dict}:
|
759 |
+
if attr_name in {"anchor_gene"}:
|
760 |
+
continue
|
761 |
+
elif attr_value in valid_options:
|
762 |
+
continue
|
763 |
+
valid_type = False
|
764 |
+
for option in valid_options:
|
765 |
+
if (option in [str, int, list, dict]) and isinstance(
|
766 |
+
attr_value, option
|
767 |
+
):
|
768 |
+
valid_type = True
|
769 |
+
break
|
770 |
+
if not valid_type:
|
771 |
+
logger.error(
|
772 |
+
f"Invalid option for {attr_name}. "
|
773 |
+
f"Valid options for {attr_name}: {valid_options}"
|
774 |
+
)
|
775 |
+
raise
|
776 |
+
|
777 |
+
if self.cell_states_to_model is not None:
|
778 |
+
if len(self.cell_states_to_model.items()) == 1:
|
779 |
+
logger.warning(
|
780 |
+
"The single value dictionary for cell_states_to_model will be "
|
781 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
782 |
+
"Please specify state_key, start_state, goal_state, and alt_states "
|
783 |
+
"in the cell_states_to_model dictionary for future use. "
|
784 |
+
"For example, cell_states_to_model={"
|
785 |
+
"'state_key': 'disease', "
|
786 |
+
"'start_state': 'dcm', "
|
787 |
+
"'goal_state': 'nf', "
|
788 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
789 |
+
)
|
790 |
+
for key, value in self.cell_states_to_model.items():
|
791 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
792 |
+
if (
|
793 |
+
isinstance(value[0], list)
|
794 |
+
and isinstance(value[1], list)
|
795 |
+
and isinstance(value[2], list)
|
796 |
+
):
|
797 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
798 |
+
all_values = value[0] + value[1] + value[2]
|
799 |
+
if len(all_values) == len(set(all_values)):
|
800 |
+
continue
|
801 |
+
# reformat to the new named key format
|
802 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
803 |
+
self.cell_states_to_model = {
|
804 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
805 |
+
"start_state": state_values[0][0],
|
806 |
+
"goal_state": state_values[1][0],
|
807 |
+
"alt_states": state_values[2:][0],
|
808 |
+
}
|
809 |
+
elif set(self.cell_states_to_model.keys()) == {
|
810 |
+
"state_key",
|
811 |
+
"start_state",
|
812 |
+
"goal_state",
|
813 |
+
"alt_states",
|
814 |
+
}:
|
815 |
+
if (
|
816 |
+
(self.cell_states_to_model["state_key"] is None)
|
817 |
+
or (self.cell_states_to_model["start_state"] is None)
|
818 |
+
or (self.cell_states_to_model["goal_state"] is None)
|
819 |
+
):
|
820 |
+
logger.error(
|
821 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
822 |
+
)
|
823 |
+
raise
|
824 |
+
|
825 |
+
if (
|
826 |
+
self.cell_states_to_model["start_state"]
|
827 |
+
== self.cell_states_to_model["goal_state"]
|
828 |
+
):
|
829 |
+
logger.error("All states must be unique.")
|
830 |
+
raise
|
831 |
+
|
832 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
833 |
+
if not isinstance(self.cell_states_to_model["alt_states"], list):
|
834 |
+
logger.error(
|
835 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
836 |
+
)
|
837 |
+
raise
|
838 |
+
if len(self.cell_states_to_model["alt_states"]) != len(
|
839 |
+
set(self.cell_states_to_model["alt_states"])
|
840 |
+
):
|
841 |
+
logger.error("All states must be unique.")
|
842 |
+
raise
|
843 |
+
|
844 |
+
elif set(self.cell_states_to_model.keys()) == {
|
845 |
+
"state_key",
|
846 |
+
"start_state",
|
847 |
+
"goal_state",
|
848 |
+
}:
|
849 |
+
self.cell_states_to_model["alt_states"] = []
|
850 |
+
else:
|
851 |
+
logger.error(
|
852 |
+
"cell_states_to_model must only have the following four keys: "
|
853 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
854 |
+
"For example, cell_states_to_model={"
|
855 |
+
"'state_key': 'disease', "
|
856 |
+
"'start_state': 'dcm', "
|
857 |
+
"'goal_state': 'nf', "
|
858 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
859 |
+
)
|
860 |
+
raise
|
861 |
+
|
862 |
+
if self.anchor_gene is not None:
|
863 |
+
self.anchor_gene = None
|
864 |
+
logger.warning(
|
865 |
+
"anchor_gene set to None. "
|
866 |
+
"Currently, anchor gene not available "
|
867 |
+
"when modeling multiple cell states."
|
868 |
+
)
|
869 |
+
|
870 |
+
if self.combos > 0:
|
871 |
+
if self.anchor_gene is None:
|
872 |
+
logger.error(
|
873 |
+
"Currently, stats are only supported for combination "
|
874 |
+
"in silico perturbation run with anchor gene. Please add "
|
875 |
+
"anchor gene when using with combos > 0. "
|
876 |
+
)
|
877 |
+
raise
|
878 |
+
|
879 |
+
if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
|
880 |
+
logger.error(
|
881 |
+
"Mixture model mode requires multiple gene perturbations to fit model "
|
882 |
+
"so is incompatible with a single grouped perturbation."
|
883 |
+
)
|
884 |
+
raise
|
885 |
+
if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
|
886 |
+
logger.error(
|
887 |
+
"Simple data aggregation mode is for single perturbation in multiple cells "
|
888 |
+
"so is incompatible with a genes_perturbed being 'all'."
|
889 |
+
)
|
890 |
+
raise
|
891 |
+
|
892 |
+
def get_stats(
|
893 |
+
self,
|
894 |
+
input_data_directory,
|
895 |
+
null_dist_data_directory,
|
896 |
+
output_directory,
|
897 |
+
output_prefix,
|
898 |
+
null_dict_list=None,
|
899 |
+
):
|
900 |
+
"""
|
901 |
+
Get stats for in silico perturbation data and save as results in output_directory.
|
902 |
+
|
903 |
+
**Parameters:**
|
904 |
+
|
905 |
+
input_data_directory : Path
|
906 |
+
| Path to directory containing cos_sim dictionary inputs
|
907 |
+
null_dist_data_directory : Path
|
908 |
+
| Path to directory containing null distribution cos_sim dictionary inputs
|
909 |
+
output_directory : Path
|
910 |
+
| Path to directory where perturbation data will be saved as .csv
|
911 |
+
output_prefix : str
|
912 |
+
| Prefix for output .csv
|
913 |
+
null_dict_list: list[dict]
|
914 |
+
| List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
|
915 |
+
|
916 |
+
**Outputs:**
|
917 |
+
|
918 |
+
Definition of possible columns in .csv output file.
|
919 |
+
|
920 |
+
| Of note, not all columns will be present in all output files.
|
921 |
+
| Some columns are specific to particular perturbation modes.
|
922 |
+
|
923 |
+
| "Gene": gene token
|
924 |
+
| "Gene_name": gene name
|
925 |
+
| "Ensembl_ID": gene Ensembl ID
|
926 |
+
| "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
|
927 |
+
| "Sig": 1 if FDR<0.05, otherwise 0
|
928 |
+
|
929 |
+
| "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
|
930 |
+
| "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
|
931 |
+
| "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
|
932 |
+
| pvalue compares shift caused by perturbing given gene compared to random genes
|
933 |
+
| "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
|
934 |
+
| pvalue compares shift caused by perturbing given gene compared to random genes
|
935 |
+
| "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
|
936 |
+
| "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
|
937 |
+
|
938 |
+
| "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
|
939 |
+
| "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
|
940 |
+
| "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
|
941 |
+
| (i.e. "Test_avg_shift" minus "Null_avg_shift")
|
942 |
+
| "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
|
943 |
+
| "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
|
944 |
+
| "N_Detections_test": "N_Detections" in cells from test distribution
|
945 |
+
| "N_Detections_null": "N_Detections" in cells from null distribution
|
946 |
+
|
947 |
+
| "Anchor_shift": cosine shift in response to given perturbation of anchor gene
|
948 |
+
| "Test_token_shift": cosine shift in response to given perturbation of test gene
|
949 |
+
| "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
|
950 |
+
| "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
|
951 |
+
| "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
|
952 |
+
| (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
|
953 |
+
| "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
|
954 |
+
| 1: within impact component; 0: not within impact component
|
955 |
+
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
956 |
+
|
957 |
+
| In case of aggregating data / gene shifts:
|
958 |
+
| "Perturbed": ID(s) of gene(s) being perturbed
|
959 |
+
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
960 |
+
| "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
|
961 |
+
| "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
|
962 |
+
"""
|
963 |
+
|
964 |
+
if self.mode not in [
|
965 |
+
"goal_state_shift",
|
966 |
+
"vs_null",
|
967 |
+
"mixture_model",
|
968 |
+
"aggregate_data",
|
969 |
+
"aggregate_gene_shifts",
|
970 |
+
]:
|
971 |
+
logger.error(
|
972 |
+
"Currently, only modes available are stats for goal_state_shift, "
|
973 |
+
"vs_null (comparing to null distribution), "
|
974 |
+
"mixture_model (fitting mixture model for perturbations with or without impact), "
|
975 |
+
"and aggregating data for single perturbations or for gene embedding shifts."
|
976 |
+
)
|
977 |
+
raise
|
978 |
+
|
979 |
+
self.gene_token_id_dict = invert_dict(self.gene_token_dict)
|
980 |
+
self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
|
981 |
+
|
982 |
+
# obtain total gene list
|
983 |
+
if (self.combos == 0) and (self.anchor_token is not None):
|
984 |
+
# cos sim data for effect of gene perturbation on the embedding of each other gene
|
985 |
+
dict_list = read_dictionaries(
|
986 |
+
input_data_directory,
|
987 |
+
"gene",
|
988 |
+
self.anchor_token,
|
989 |
+
self.cell_states_to_model,
|
990 |
+
self.pickle_suffix,
|
991 |
+
)
|
992 |
+
gene_list = get_gene_list(dict_list, "gene")
|
993 |
+
elif (
|
994 |
+
(self.combos == 0)
|
995 |
+
and (self.anchor_token is None)
|
996 |
+
and (self.mode == "aggregate_gene_shifts")
|
997 |
+
):
|
998 |
+
dict_list = read_dictionaries(
|
999 |
+
input_data_directory,
|
1000 |
+
"gene",
|
1001 |
+
self.anchor_token,
|
1002 |
+
self.cell_states_to_model,
|
1003 |
+
self.pickle_suffix,
|
1004 |
+
)
|
1005 |
+
gene_list = get_gene_list(dict_list, "cell")
|
1006 |
+
else:
|
1007 |
+
# cos sim data for effect of gene perturbation on the embedding of each cell
|
1008 |
+
dict_list = read_dictionaries(
|
1009 |
+
input_data_directory,
|
1010 |
+
"cell",
|
1011 |
+
self.anchor_token,
|
1012 |
+
self.cell_states_to_model,
|
1013 |
+
self.pickle_suffix,
|
1014 |
+
)
|
1015 |
+
gene_list = get_gene_list(dict_list, "cell")
|
1016 |
+
|
1017 |
+
# initiate results dataframe
|
1018 |
+
cos_sims_df_initial = pd.DataFrame(
|
1019 |
+
{
|
1020 |
+
"Gene": gene_list,
|
1021 |
+
"Gene_name": [self.token_to_gene_name(item) for item in gene_list],
|
1022 |
+
"Ensembl_ID": [
|
1023 |
+
token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
|
1024 |
+
if self.genes_perturbed != "all"
|
1025 |
+
else self.gene_token_id_dict[genes[1]]
|
1026 |
+
if isinstance(genes, tuple)
|
1027 |
+
else self.gene_token_id_dict[genes]
|
1028 |
+
for genes in gene_list
|
1029 |
+
],
|
1030 |
+
},
|
1031 |
+
index=[i for i in range(len(gene_list))],
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
if self.mode == "goal_state_shift":
|
1035 |
+
cos_sims_df = isp_stats_to_goal_state(
|
1036 |
+
cos_sims_df_initial,
|
1037 |
+
dict_list,
|
1038 |
+
self.cell_states_to_model,
|
1039 |
+
self.genes_perturbed,
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
elif self.mode == "vs_null":
|
1043 |
+
if null_dict_list is None:
|
1044 |
+
null_dict_list = read_dictionaries(
|
1045 |
+
null_dist_data_directory,
|
1046 |
+
"cell",
|
1047 |
+
self.anchor_token,
|
1048 |
+
self.cell_states_to_model,
|
1049 |
+
self.pickle_suffix,
|
1050 |
+
)
|
1051 |
+
cos_sims_df = isp_stats_vs_null(
|
1052 |
+
cos_sims_df_initial, dict_list, null_dict_list
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
elif self.mode == "mixture_model":
|
1056 |
+
cos_sims_df = isp_stats_mixture_model(
|
1057 |
+
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
elif self.mode == "aggregate_data":
|
1061 |
+
cos_sims_df = isp_aggregate_grouped_perturb(
|
1062 |
+
cos_sims_df_initial, dict_list, self.genes_perturbed
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
elif self.mode == "aggregate_gene_shifts":
|
1066 |
+
if (self.genes_perturbed == "all") and (self.combos == 0):
|
1067 |
+
tuple_types = [
|
1068 |
+
True if isinstance(genes, tuple) else False for genes in gene_list
|
1069 |
+
]
|
1070 |
+
if all(tuple_types):
|
1071 |
+
token_dtype = "tuple"
|
1072 |
+
elif not any(tuple_types):
|
1073 |
+
token_dtype = "nontuple"
|
1074 |
+
else:
|
1075 |
+
token_dtype = "mix"
|
1076 |
+
else:
|
1077 |
+
token_dtype = "mix"
|
1078 |
+
|
1079 |
+
cos_sims_df = isp_aggregate_gene_shifts(
|
1080 |
+
cos_sims_df_initial,
|
1081 |
+
dict_list,
|
1082 |
+
self.gene_token_id_dict,
|
1083 |
+
self.gene_id_name_dict,
|
1084 |
+
token_dtype,
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
# save perturbation stats to output_path
|
1088 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
1089 |
+
cos_sims_df.to_csv(output_path)
|
1090 |
+
|
1091 |
+
def token_to_gene_name(self, item):
|
1092 |
+
if np.issubdtype(type(item), np.integer):
|
1093 |
+
return self.gene_id_name_dict.get(
|
1094 |
+
self.gene_token_id_dict.get(item, np.nan), np.nan
|
1095 |
+
)
|
1096 |
+
if isinstance(item, tuple):
|
1097 |
+
return tuple(
|
1098 |
+
[
|
1099 |
+
self.gene_id_name_dict.get(
|
1100 |
+
self.gene_token_id_dict.get(i, np.nan), np.nan
|
1101 |
+
)
|
1102 |
+
for i in item
|
1103 |
+
]
|
1104 |
+
)
|
geneformer/mtl/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# ruff: noqa: F401
|
geneformer/mtl/collators.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# imports
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
from ..collator_for_classification import DataCollatorForGeneClassification
|
5 |
+
from .. import TOKEN_DICTIONARY_FILE
|
6 |
+
|
7 |
+
"""Geneformer collator for multi-task cell classification."""
|
8 |
+
|
9 |
+
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
10 |
+
class_type = "cell"
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def load_token_dictionary():
|
14 |
+
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
15 |
+
return pickle.load(f)
|
16 |
+
|
17 |
+
def __init__(self, *args, **kwargs) -> None:
|
18 |
+
# Load the token dictionary
|
19 |
+
token_dictionary = self.load_token_dictionary()
|
20 |
+
# Use the loaded token dictionary
|
21 |
+
super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
|
22 |
+
|
23 |
+
def _prepare_batch(self, features):
|
24 |
+
# Process inputs as usual
|
25 |
+
batch = self.tokenizer.pad(
|
26 |
+
features,
|
27 |
+
class_type=self.class_type,
|
28 |
+
padding=self.padding,
|
29 |
+
max_length=self.max_length,
|
30 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
31 |
+
return_tensors="pt",
|
32 |
+
)
|
33 |
+
|
34 |
+
# Check if labels are present
|
35 |
+
if "label" in features[0]:
|
36 |
+
# Initialize labels dictionary for all tasks
|
37 |
+
labels = {task: [] for task in features[0]["label"].keys()}
|
38 |
+
# Populate labels for each task
|
39 |
+
for feature in features:
|
40 |
+
for task, label in feature["label"].items():
|
41 |
+
labels[task].append(label)
|
42 |
+
|
43 |
+
# Convert label lists to tensors, handling dictionaries appropriately
|
44 |
+
for task in labels:
|
45 |
+
if isinstance(labels[task][0], (list, torch.Tensor)):
|
46 |
+
dtype = torch.long
|
47 |
+
labels[task] = torch.tensor(labels[task], dtype=dtype)
|
48 |
+
elif isinstance(labels[task][0], dict):
|
49 |
+
# Handle dict specifically if needed
|
50 |
+
pass # Resolve nested data structure
|
51 |
+
|
52 |
+
# Update the batch to include task-specific labels
|
53 |
+
batch["labels"] = labels
|
54 |
+
else:
|
55 |
+
# If no labels are present, create empty labels for all tasks
|
56 |
+
batch["labels"] = {
|
57 |
+
task: torch.tensor([], dtype=torch.long)
|
58 |
+
for task in features[0]["input_ids"].keys()
|
59 |
+
}
|
60 |
+
|
61 |
+
return batch
|
62 |
+
|
63 |
+
def __call__(self, features):
|
64 |
+
batch = self._prepare_batch(features)
|
65 |
+
for k, v in batch.items():
|
66 |
+
if torch.is_tensor(v):
|
67 |
+
batch[k] = v.clone().detach()
|
68 |
+
elif isinstance(v, dict):
|
69 |
+
# Assuming nested structure needs conversion
|
70 |
+
batch[k] = {
|
71 |
+
task: torch.tensor(labels, dtype=torch.int64)
|
72 |
+
for task, labels in v.items()
|
73 |
+
}
|
74 |
+
else:
|
75 |
+
batch[k] = torch.tensor(v, dtype=torch.int64)
|
76 |
+
return batch
|
geneformer/mtl/data.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .collators import DataCollatorForMultitaskCellClassification
|
3 |
+
from .imports import *
|
4 |
+
|
5 |
+
def validate_columns(dataset, required_columns, dataset_type):
|
6 |
+
"""Ensures required columns are present in the dataset."""
|
7 |
+
missing_columns = [col for col in required_columns if col not in dataset.column_names]
|
8 |
+
if missing_columns:
|
9 |
+
raise KeyError(
|
10 |
+
f"Missing columns in {dataset_type} dataset: {missing_columns}. "
|
11 |
+
f"Available columns: {dataset.column_names}"
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def create_label_mappings(dataset, task_to_column):
|
16 |
+
"""Creates label mappings for the dataset."""
|
17 |
+
task_label_mappings = {}
|
18 |
+
num_labels_list = []
|
19 |
+
for task, column in task_to_column.items():
|
20 |
+
unique_values = sorted(set(dataset[column]))
|
21 |
+
mapping = {label: idx for idx, label in enumerate(unique_values)}
|
22 |
+
task_label_mappings[task] = mapping
|
23 |
+
num_labels_list.append(len(unique_values))
|
24 |
+
return task_label_mappings, num_labels_list
|
25 |
+
|
26 |
+
|
27 |
+
def save_label_mappings(mappings, path):
|
28 |
+
"""Saves label mappings to a pickle file."""
|
29 |
+
with open(path, "wb") as f:
|
30 |
+
pickle.dump(mappings, f)
|
31 |
+
|
32 |
+
|
33 |
+
def load_label_mappings(path):
|
34 |
+
"""Loads label mappings from a pickle file."""
|
35 |
+
with open(path, "rb") as f:
|
36 |
+
return pickle.load(f)
|
37 |
+
|
38 |
+
|
39 |
+
def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
|
40 |
+
"""Transforms the dataset to the required format."""
|
41 |
+
transformed_dataset = []
|
42 |
+
cell_id_mapping = {}
|
43 |
+
|
44 |
+
for idx, record in enumerate(dataset):
|
45 |
+
transformed_record = {
|
46 |
+
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
|
47 |
+
"cell_id": idx, # Index-based cell ID
|
48 |
+
}
|
49 |
+
|
50 |
+
if not is_test:
|
51 |
+
label_dict = {
|
52 |
+
task: task_label_mappings[task][record[column]]
|
53 |
+
for task, column in task_to_column.items()
|
54 |
+
}
|
55 |
+
else:
|
56 |
+
label_dict = {task: -1 for task in config["task_names"]}
|
57 |
+
|
58 |
+
transformed_record["label"] = label_dict
|
59 |
+
transformed_dataset.append(transformed_record)
|
60 |
+
cell_id_mapping[idx] = record.get("unique_cell_id", idx)
|
61 |
+
|
62 |
+
return transformed_dataset, cell_id_mapping
|
63 |
+
|
64 |
+
|
65 |
+
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
66 |
+
"""Main function to load and preprocess data."""
|
67 |
+
try:
|
68 |
+
dataset = load_from_disk(dataset_path)
|
69 |
+
|
70 |
+
# Setup task and column mappings
|
71 |
+
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
72 |
+
task_to_column = dict(zip(task_names, config["task_columns"]))
|
73 |
+
config["task_names"] = task_names
|
74 |
+
|
75 |
+
label_mappings_path = os.path.join(
|
76 |
+
config["results_dir"],
|
77 |
+
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
|
78 |
+
)
|
79 |
+
|
80 |
+
if not is_test:
|
81 |
+
validate_columns(dataset, task_to_column.values(), dataset_type)
|
82 |
+
|
83 |
+
# Create and save label mappings
|
84 |
+
task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
|
85 |
+
save_label_mappings(task_label_mappings, label_mappings_path)
|
86 |
+
else:
|
87 |
+
# Load existing mappings for test data
|
88 |
+
task_label_mappings = load_label_mappings(label_mappings_path)
|
89 |
+
num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
|
90 |
+
|
91 |
+
# Transform dataset
|
92 |
+
transformed_dataset, cell_id_mapping = transform_dataset(
|
93 |
+
dataset, task_to_column, task_label_mappings, config, is_test
|
94 |
+
)
|
95 |
+
|
96 |
+
return transformed_dataset, cell_id_mapping, num_labels_list
|
97 |
+
|
98 |
+
except KeyError as e:
|
99 |
+
raise ValueError(f"Configuration error or dataset key missing: {e}")
|
100 |
+
except Exception as e:
|
101 |
+
raise RuntimeError(f"Error during data loading or preprocessing: {e}")
|
102 |
+
|
103 |
+
|
104 |
+
def preload_and_process_data(config):
|
105 |
+
"""Preloads and preprocesses train and validation datasets."""
|
106 |
+
# Process train data and save mappings
|
107 |
+
train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
|
108 |
+
|
109 |
+
# Process validation data and save mappings
|
110 |
+
val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
|
111 |
+
|
112 |
+
# Validate that the mappings match
|
113 |
+
validate_label_mappings(config)
|
114 |
+
|
115 |
+
return (*train_data, *val_data[:2]) # Return train and val data along with mappings
|
116 |
+
|
117 |
+
|
118 |
+
def validate_label_mappings(config):
|
119 |
+
"""Ensures train and validation label mappings are consistent."""
|
120 |
+
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
|
121 |
+
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
|
122 |
+
train_mappings = load_label_mappings(train_mappings_path)
|
123 |
+
val_mappings = load_label_mappings(val_mappings_path)
|
124 |
+
|
125 |
+
for task_name in config["task_names"]:
|
126 |
+
if train_mappings[task_name] != val_mappings[task_name]:
|
127 |
+
raise ValueError(
|
128 |
+
f"Mismatch in label mappings for task '{task_name}'.\n"
|
129 |
+
f"Train Mapping: {train_mappings[task_name]}\n"
|
130 |
+
f"Validation Mapping: {val_mappings[task_name]}"
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
def get_data_loader(preprocessed_dataset, batch_size):
|
135 |
+
"""Creates a DataLoader with optimal settings."""
|
136 |
+
return DataLoader(
|
137 |
+
preprocessed_dataset,
|
138 |
+
batch_size=batch_size,
|
139 |
+
shuffle=True,
|
140 |
+
collate_fn=DataCollatorForMultitaskCellClassification(),
|
141 |
+
num_workers=os.cpu_count(),
|
142 |
+
pin_memory=True,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def preload_data(config):
|
147 |
+
"""Preprocesses train and validation data for trials."""
|
148 |
+
train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
|
149 |
+
val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
|
150 |
+
return train_loader, val_loader
|
151 |
+
|
152 |
+
|
153 |
+
def load_and_preprocess_test_data(config):
|
154 |
+
"""Loads and preprocesses test data."""
|
155 |
+
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
156 |
+
|
157 |
+
|
158 |
+
def prepare_test_loader(config):
|
159 |
+
"""Prepares DataLoader for test data."""
|
160 |
+
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
|
161 |
+
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
162 |
+
return test_loader, cell_id_mapping, num_labels_list
|
geneformer/mtl/eval_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
from .imports import * # noqa # isort:skip
|
4 |
+
from .data import prepare_test_loader # noqa # isort:skip
|
5 |
+
from .model import GeneformerMultiTask
|
6 |
+
|
7 |
+
|
8 |
+
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
9 |
+
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
10 |
+
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
11 |
+
cell_ids = []
|
12 |
+
|
13 |
+
# # Load task label mappings from pickle file
|
14 |
+
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
15 |
+
# task_label_mappings = pickle.load(f)
|
16 |
+
|
17 |
+
model.eval()
|
18 |
+
with torch.no_grad():
|
19 |
+
for batch in test_loader:
|
20 |
+
input_ids = batch["input_ids"].to(device)
|
21 |
+
attention_mask = batch["attention_mask"].to(device)
|
22 |
+
_, logits, _ = model(input_ids, attention_mask)
|
23 |
+
for sample_idx in range(len(batch["input_ids"])):
|
24 |
+
cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
|
25 |
+
cell_ids.append(cell_id)
|
26 |
+
for i, task_name in enumerate(config["task_names"]):
|
27 |
+
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
28 |
+
pred_prob = (
|
29 |
+
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
30 |
+
)
|
31 |
+
task_pred_labels[task_name].append(pred_label)
|
32 |
+
task_pred_probs[task_name].append(pred_prob)
|
33 |
+
|
34 |
+
# Save test predictions with cell IDs and probabilities to CSV
|
35 |
+
test_results_dir = config["results_dir"]
|
36 |
+
os.makedirs(test_results_dir, exist_ok=True)
|
37 |
+
test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
|
38 |
+
|
39 |
+
rows = []
|
40 |
+
for sample_idx in range(len(cell_ids)):
|
41 |
+
row = {"Cell ID": cell_ids[sample_idx]}
|
42 |
+
for task_name in config["task_names"]:
|
43 |
+
row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
|
44 |
+
row[f"{task_name} Probabilities"] = ",".join(
|
45 |
+
map(str, task_pred_probs[task_name][sample_idx])
|
46 |
+
)
|
47 |
+
rows.append(row)
|
48 |
+
|
49 |
+
df = pd.DataFrame(rows)
|
50 |
+
df.to_csv(test_preds_file, index=False)
|
51 |
+
print(f"Test predictions saved to {test_preds_file}")
|
52 |
+
|
53 |
+
|
54 |
+
def load_and_evaluate_test_model(config):
|
55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
|
57 |
+
model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
58 |
+
hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
|
59 |
+
|
60 |
+
# Load the saved best hyperparameters
|
61 |
+
with open(hyperparams_path, "r") as f:
|
62 |
+
best_hyperparams = json.load(f)
|
63 |
+
|
64 |
+
# Extract the task weights if present, otherwise set to None
|
65 |
+
task_weights = best_hyperparams.get("task_weights", None)
|
66 |
+
normalized_task_weights = task_weights if task_weights else []
|
67 |
+
|
68 |
+
# Print the loaded hyperparameters
|
69 |
+
print("Loaded hyperparameters:")
|
70 |
+
for param, value in best_hyperparams.items():
|
71 |
+
if param == "task_weights":
|
72 |
+
print(f"normalized_task_weights: {value}")
|
73 |
+
else:
|
74 |
+
print(f"{param}: {value}")
|
75 |
+
|
76 |
+
best_model_path = os.path.join(model_directory, "pytorch_model.bin")
|
77 |
+
best_model = GeneformerMultiTask(
|
78 |
+
config["pretrained_path"],
|
79 |
+
num_labels_list,
|
80 |
+
dropout_rate=best_hyperparams["dropout_rate"],
|
81 |
+
use_task_weights=config["use_task_weights"],
|
82 |
+
task_weights=normalized_task_weights,
|
83 |
+
)
|
84 |
+
best_model.load_state_dict(torch.load(best_model_path))
|
85 |
+
best_model.to(device)
|
86 |
+
|
87 |
+
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
88 |
+
print("Evaluation completed.")
|
geneformer/mtl/imports.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import gc
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import sys
|
7 |
+
import warnings
|
8 |
+
from enum import Enum
|
9 |
+
from itertools import chain
|
10 |
+
from typing import Dict, List, Optional, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import optuna
|
14 |
+
import pandas as pd
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.optim as optim
|
19 |
+
from datasets import load_from_disk
|
20 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
|
21 |
+
from sklearn.model_selection import train_test_split
|
22 |
+
from sklearn.preprocessing import LabelEncoder
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
from transformers import (
|
25 |
+
AdamW,
|
26 |
+
BatchEncoding,
|
27 |
+
BertConfig,
|
28 |
+
BertModel,
|
29 |
+
DataCollatorForTokenClassification,
|
30 |
+
SpecialTokensMixin,
|
31 |
+
get_cosine_schedule_with_warmup,
|
32 |
+
get_linear_schedule_with_warmup,
|
33 |
+
get_scheduler,
|
34 |
+
)
|
35 |
+
from transformers.utils import logging, to_py_obj
|
36 |
+
|
37 |
+
from .collators import DataCollatorForMultitaskCellClassification
|
38 |
+
|
39 |
+
# local modules
|
40 |
+
from .data import get_data_loader, preload_and_process_data
|
41 |
+
from .model import GeneformerMultiTask
|
42 |
+
from .optuna_utils import create_optuna_study
|
43 |
+
from .utils import save_model
|
geneformer/mtl/model.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import BertConfig, BertModel
|
4 |
+
|
5 |
+
|
6 |
+
class AttentionPool(nn.Module):
|
7 |
+
"""Attention-based pooling layer."""
|
8 |
+
|
9 |
+
def __init__(self, hidden_size):
|
10 |
+
super(AttentionPool, self).__init__()
|
11 |
+
self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
|
12 |
+
nn.init.xavier_uniform_(
|
13 |
+
self.attention_weights
|
14 |
+
) # https://pytorch.org/docs/stable/nn.init.html
|
15 |
+
|
16 |
+
def forward(self, hidden_states):
|
17 |
+
attention_scores = torch.matmul(hidden_states, self.attention_weights)
|
18 |
+
attention_scores = torch.softmax(attention_scores, dim=1)
|
19 |
+
pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
|
20 |
+
return pooled_output
|
21 |
+
|
22 |
+
|
23 |
+
class GeneformerMultiTask(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
pretrained_path,
|
27 |
+
num_labels_list,
|
28 |
+
dropout_rate=0.1,
|
29 |
+
use_task_weights=False,
|
30 |
+
task_weights=None,
|
31 |
+
max_layers_to_freeze=0,
|
32 |
+
use_attention_pooling=False,
|
33 |
+
):
|
34 |
+
super(GeneformerMultiTask, self).__init__()
|
35 |
+
self.config = BertConfig.from_pretrained(pretrained_path)
|
36 |
+
self.bert = BertModel(self.config)
|
37 |
+
self.num_labels_list = num_labels_list
|
38 |
+
self.use_task_weights = use_task_weights
|
39 |
+
self.dropout = nn.Dropout(dropout_rate)
|
40 |
+
self.use_attention_pooling = use_attention_pooling
|
41 |
+
|
42 |
+
if use_task_weights and (
|
43 |
+
task_weights is None or len(task_weights) != len(num_labels_list)
|
44 |
+
):
|
45 |
+
raise ValueError(
|
46 |
+
"Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
|
47 |
+
)
|
48 |
+
self.task_weights = (
|
49 |
+
task_weights if use_task_weights else [1.0] * len(num_labels_list)
|
50 |
+
)
|
51 |
+
|
52 |
+
# Freeze the specified initial layers
|
53 |
+
for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
|
54 |
+
for param in layer.parameters():
|
55 |
+
param.requires_grad = False
|
56 |
+
|
57 |
+
self.attention_pool = (
|
58 |
+
AttentionPool(self.config.hidden_size) if use_attention_pooling else None
|
59 |
+
)
|
60 |
+
|
61 |
+
self.classification_heads = nn.ModuleList(
|
62 |
+
[
|
63 |
+
nn.Linear(self.config.hidden_size, num_labels)
|
64 |
+
for num_labels in num_labels_list
|
65 |
+
]
|
66 |
+
)
|
67 |
+
# initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
|
68 |
+
for head in self.classification_heads:
|
69 |
+
nn.init.xavier_uniform_(head.weight)
|
70 |
+
nn.init.zeros_(head.bias)
|
71 |
+
|
72 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
73 |
+
try:
|
74 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
75 |
+
except Exception as e:
|
76 |
+
raise RuntimeError(f"Error during BERT forward pass: {e}")
|
77 |
+
|
78 |
+
sequence_output = outputs.last_hidden_state
|
79 |
+
|
80 |
+
try:
|
81 |
+
pooled_output = (
|
82 |
+
self.attention_pool(sequence_output)
|
83 |
+
if self.use_attention_pooling
|
84 |
+
else sequence_output[:, 0, :]
|
85 |
+
)
|
86 |
+
pooled_output = self.dropout(pooled_output)
|
87 |
+
except Exception as e:
|
88 |
+
raise RuntimeError(f"Error during pooling and dropout: {e}")
|
89 |
+
|
90 |
+
total_loss = 0
|
91 |
+
logits = []
|
92 |
+
losses = []
|
93 |
+
|
94 |
+
for task_id, (head, num_labels) in enumerate(
|
95 |
+
zip(self.classification_heads, self.num_labels_list)
|
96 |
+
):
|
97 |
+
try:
|
98 |
+
task_logits = head(pooled_output)
|
99 |
+
except Exception as e:
|
100 |
+
raise RuntimeError(
|
101 |
+
f"Error during forward pass of classification head {task_id}: {e}"
|
102 |
+
)
|
103 |
+
|
104 |
+
logits.append(task_logits)
|
105 |
+
|
106 |
+
if labels is not None:
|
107 |
+
try:
|
108 |
+
loss_fct = nn.CrossEntropyLoss()
|
109 |
+
task_loss = loss_fct(
|
110 |
+
task_logits.view(-1, num_labels), labels[task_id].view(-1)
|
111 |
+
)
|
112 |
+
if self.use_task_weights:
|
113 |
+
task_loss *= self.task_weights[task_id]
|
114 |
+
total_loss += task_loss
|
115 |
+
losses.append(task_loss.item())
|
116 |
+
except Exception as e:
|
117 |
+
raise RuntimeError(
|
118 |
+
f"Error during loss computation for task {task_id}: {e}"
|
119 |
+
)
|
120 |
+
|
121 |
+
return total_loss, logits, losses if labels is not None else logits
|
geneformer/mtl/optuna_utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import optuna
|
2 |
+
from optuna.integration import TensorBoardCallback
|
3 |
+
|
4 |
+
|
5 |
+
def save_trial_callback(study, trial, trials_result_path):
|
6 |
+
with open(trials_result_path, "a") as f:
|
7 |
+
f.write(
|
8 |
+
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
13 |
+
study = optuna.create_study(direction="maximize")
|
14 |
+
|
15 |
+
# init TensorBoard callback
|
16 |
+
tensorboard_callback = TensorBoardCallback(
|
17 |
+
dirname=tensorboard_log_dir, metric_name="F1 Macro"
|
18 |
+
)
|
19 |
+
|
20 |
+
# callback and TensorBoard callback
|
21 |
+
callbacks = [
|
22 |
+
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
23 |
+
tensorboard_callback,
|
24 |
+
]
|
25 |
+
|
26 |
+
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
27 |
+
return study
|
geneformer/mtl/train.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .imports import *
|
11 |
+
from .model import GeneformerMultiTask
|
12 |
+
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
|
13 |
+
|
14 |
+
|
15 |
+
def set_seed(seed):
|
16 |
+
random.seed(seed)
|
17 |
+
np.random.seed(seed)
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
torch.backends.cudnn.deterministic = True
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
|
23 |
+
|
24 |
+
def initialize_wandb(config):
|
25 |
+
if config.get("use_wandb", False):
|
26 |
+
import wandb
|
27 |
+
|
28 |
+
wandb.init(project=config["wandb_project"], config=config)
|
29 |
+
print("Weights & Biases (wandb) initialized and will be used for logging.")
|
30 |
+
else:
|
31 |
+
print(
|
32 |
+
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def create_model(config, num_labels_list, device):
|
37 |
+
model = GeneformerMultiTask(
|
38 |
+
config["pretrained_path"],
|
39 |
+
num_labels_list,
|
40 |
+
dropout_rate=config["dropout_rate"],
|
41 |
+
use_task_weights=config["use_task_weights"],
|
42 |
+
task_weights=config["task_weights"],
|
43 |
+
max_layers_to_freeze=config["max_layers_to_freeze"],
|
44 |
+
use_attention_pooling=config["use_attention_pooling"],
|
45 |
+
)
|
46 |
+
if config["use_data_parallel"]:
|
47 |
+
model = nn.DataParallel(model)
|
48 |
+
return model.to(device)
|
49 |
+
|
50 |
+
|
51 |
+
def setup_optimizer_and_scheduler(model, config, total_steps):
|
52 |
+
optimizer = AdamW(
|
53 |
+
model.parameters(),
|
54 |
+
lr=config["learning_rate"],
|
55 |
+
weight_decay=config["weight_decay"],
|
56 |
+
)
|
57 |
+
warmup_steps = int(config["warmup_ratio"] * total_steps)
|
58 |
+
|
59 |
+
if config["lr_scheduler_type"] == "linear":
|
60 |
+
scheduler = get_linear_schedule_with_warmup(
|
61 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
|
62 |
+
)
|
63 |
+
elif config["lr_scheduler_type"] == "cosine":
|
64 |
+
scheduler = get_cosine_schedule_with_warmup(
|
65 |
+
optimizer,
|
66 |
+
num_warmup_steps=warmup_steps,
|
67 |
+
num_training_steps=total_steps,
|
68 |
+
num_cycles=0.5,
|
69 |
+
)
|
70 |
+
|
71 |
+
return optimizer, scheduler
|
72 |
+
|
73 |
+
|
74 |
+
def train_epoch(
|
75 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
76 |
+
):
|
77 |
+
model.train()
|
78 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
79 |
+
for batch_idx, batch in enumerate(progress_bar):
|
80 |
+
optimizer.zero_grad()
|
81 |
+
input_ids = batch["input_ids"].to(device)
|
82 |
+
attention_mask = batch["attention_mask"].to(device)
|
83 |
+
labels = [
|
84 |
+
batch["labels"][task_name].to(device) for task_name in config["task_names"]
|
85 |
+
]
|
86 |
+
|
87 |
+
loss, _, _ = model(input_ids, attention_mask, labels)
|
88 |
+
loss.backward()
|
89 |
+
|
90 |
+
if config["gradient_clipping"]:
|
91 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
|
92 |
+
|
93 |
+
optimizer.step()
|
94 |
+
scheduler.step()
|
95 |
+
|
96 |
+
writer.add_scalar(
|
97 |
+
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
98 |
+
)
|
99 |
+
if config.get("use_wandb", False):
|
100 |
+
import wandb
|
101 |
+
|
102 |
+
wandb.log({"Training Loss": loss.item()})
|
103 |
+
|
104 |
+
# Update progress bar
|
105 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
106 |
+
|
107 |
+
return loss.item() # Return the last batch loss
|
108 |
+
|
109 |
+
|
110 |
+
def validate_model(model, val_loader, device, config):
|
111 |
+
model.eval()
|
112 |
+
val_loss = 0.0
|
113 |
+
task_true_labels = {task_name: [] for task_name in config["task_names"]}
|
114 |
+
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
115 |
+
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
for batch in val_loader:
|
119 |
+
input_ids = batch["input_ids"].to(device)
|
120 |
+
attention_mask = batch["attention_mask"].to(device)
|
121 |
+
labels = [
|
122 |
+
batch["labels"][task_name].to(device)
|
123 |
+
for task_name in config["task_names"]
|
124 |
+
]
|
125 |
+
loss, logits, _ = model(input_ids, attention_mask, labels)
|
126 |
+
val_loss += loss.item()
|
127 |
+
|
128 |
+
for sample_idx in range(len(batch["input_ids"])):
|
129 |
+
for i, task_name in enumerate(config["task_names"]):
|
130 |
+
true_label = batch["labels"][task_name][sample_idx].item()
|
131 |
+
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
132 |
+
pred_prob = (
|
133 |
+
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
134 |
+
)
|
135 |
+
task_true_labels[task_name].append(true_label)
|
136 |
+
task_pred_labels[task_name].append(pred_label)
|
137 |
+
task_pred_probs[task_name].append(pred_prob)
|
138 |
+
|
139 |
+
val_loss /= len(val_loader)
|
140 |
+
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
141 |
+
|
142 |
+
|
143 |
+
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
144 |
+
for task_name, metrics in task_metrics.items():
|
145 |
+
print(
|
146 |
+
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
147 |
+
)
|
148 |
+
if config.get("use_wandb", False):
|
149 |
+
import wandb
|
150 |
+
|
151 |
+
wandb.log(
|
152 |
+
{
|
153 |
+
f"{task_name} Validation F1 Macro": metrics["f1"],
|
154 |
+
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
155 |
+
}
|
156 |
+
)
|
157 |
+
|
158 |
+
writer.add_scalar("Validation Loss", val_loss, epochs)
|
159 |
+
for task_name, metrics in task_metrics.items():
|
160 |
+
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
|
161 |
+
writer.add_scalar(
|
162 |
+
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
def save_validation_predictions(
|
167 |
+
val_cell_id_mapping,
|
168 |
+
task_true_labels,
|
169 |
+
task_pred_labels,
|
170 |
+
task_pred_probs,
|
171 |
+
config,
|
172 |
+
trial_number=None,
|
173 |
+
):
|
174 |
+
if trial_number is not None:
|
175 |
+
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
|
176 |
+
os.makedirs(trial_results_dir, exist_ok=True)
|
177 |
+
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
|
178 |
+
else:
|
179 |
+
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
180 |
+
|
181 |
+
rows = []
|
182 |
+
for sample_idx in range(len(val_cell_id_mapping)):
|
183 |
+
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
|
184 |
+
for task_name in config["task_names"]:
|
185 |
+
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
|
186 |
+
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
|
187 |
+
row[f"{task_name} Probabilities"] = ",".join(
|
188 |
+
map(str, task_pred_probs[task_name][sample_idx])
|
189 |
+
)
|
190 |
+
rows.append(row)
|
191 |
+
|
192 |
+
df = pd.DataFrame(rows)
|
193 |
+
df.to_csv(val_preds_file, index=False)
|
194 |
+
print(f"Validation predictions saved to {val_preds_file}")
|
195 |
+
|
196 |
+
|
197 |
+
def train_model(
|
198 |
+
config,
|
199 |
+
device,
|
200 |
+
train_loader,
|
201 |
+
val_loader,
|
202 |
+
train_cell_id_mapping,
|
203 |
+
val_cell_id_mapping,
|
204 |
+
num_labels_list,
|
205 |
+
):
|
206 |
+
set_seed(config["seed"])
|
207 |
+
initialize_wandb(config)
|
208 |
+
|
209 |
+
model = create_model(config, num_labels_list, device)
|
210 |
+
total_steps = len(train_loader) * config["epochs"]
|
211 |
+
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
212 |
+
|
213 |
+
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
|
214 |
+
writer = SummaryWriter(log_dir=log_dir)
|
215 |
+
|
216 |
+
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
|
217 |
+
for epoch in epoch_progress:
|
218 |
+
last_loss = train_epoch(
|
219 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
220 |
+
)
|
221 |
+
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
|
222 |
+
|
223 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
224 |
+
model, val_loader, device, config
|
225 |
+
)
|
226 |
+
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
227 |
+
|
228 |
+
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
229 |
+
writer.close()
|
230 |
+
|
231 |
+
save_validation_predictions(
|
232 |
+
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
|
233 |
+
)
|
234 |
+
|
235 |
+
if config.get("use_wandb", False):
|
236 |
+
import wandb
|
237 |
+
|
238 |
+
wandb.finish()
|
239 |
+
|
240 |
+
print(f"\nFinal Validation Loss: {val_loss:.4f}")
|
241 |
+
return val_loss, model # Return both the validation loss and the trained model
|
242 |
+
|
243 |
+
|
244 |
+
def objective(
|
245 |
+
trial,
|
246 |
+
train_loader,
|
247 |
+
val_loader,
|
248 |
+
train_cell_id_mapping,
|
249 |
+
val_cell_id_mapping,
|
250 |
+
num_labels_list,
|
251 |
+
config,
|
252 |
+
device,
|
253 |
+
):
|
254 |
+
set_seed(config["seed"]) # Set the seed before each trial
|
255 |
+
initialize_wandb(config)
|
256 |
+
|
257 |
+
# Hyperparameters
|
258 |
+
config["learning_rate"] = trial.suggest_float(
|
259 |
+
"learning_rate",
|
260 |
+
config["hyperparameters"]["learning_rate"]["low"],
|
261 |
+
config["hyperparameters"]["learning_rate"]["high"],
|
262 |
+
log=config["hyperparameters"]["learning_rate"]["log"],
|
263 |
+
)
|
264 |
+
config["warmup_ratio"] = trial.suggest_float(
|
265 |
+
"warmup_ratio",
|
266 |
+
config["hyperparameters"]["warmup_ratio"]["low"],
|
267 |
+
config["hyperparameters"]["warmup_ratio"]["high"],
|
268 |
+
)
|
269 |
+
config["weight_decay"] = trial.suggest_float(
|
270 |
+
"weight_decay",
|
271 |
+
config["hyperparameters"]["weight_decay"]["low"],
|
272 |
+
config["hyperparameters"]["weight_decay"]["high"],
|
273 |
+
)
|
274 |
+
config["dropout_rate"] = trial.suggest_float(
|
275 |
+
"dropout_rate",
|
276 |
+
config["hyperparameters"]["dropout_rate"]["low"],
|
277 |
+
config["hyperparameters"]["dropout_rate"]["high"],
|
278 |
+
)
|
279 |
+
config["lr_scheduler_type"] = trial.suggest_categorical(
|
280 |
+
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
+
)
|
282 |
+
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
+
"use_attention_pooling", [False]
|
284 |
+
)
|
285 |
+
|
286 |
+
if config["use_task_weights"]:
|
287 |
+
config["task_weights"] = [
|
288 |
+
trial.suggest_float(
|
289 |
+
f"task_weight_{i}",
|
290 |
+
config["hyperparameters"]["task_weights"]["low"],
|
291 |
+
config["hyperparameters"]["task_weights"]["high"],
|
292 |
+
)
|
293 |
+
for i in range(len(num_labels_list))
|
294 |
+
]
|
295 |
+
weight_sum = sum(config["task_weights"])
|
296 |
+
config["task_weights"] = [
|
297 |
+
weight / weight_sum for weight in config["task_weights"]
|
298 |
+
]
|
299 |
+
else:
|
300 |
+
config["task_weights"] = None
|
301 |
+
|
302 |
+
# Dynamic range for max_layers_to_freeze
|
303 |
+
freeze_range = get_layer_freeze_range(config["pretrained_path"])
|
304 |
+
config["max_layers_to_freeze"] = trial.suggest_int(
|
305 |
+
"max_layers_to_freeze",
|
306 |
+
freeze_range["min"],
|
307 |
+
freeze_range["max"]
|
308 |
+
)
|
309 |
+
|
310 |
+
model = create_model(config, num_labels_list, device)
|
311 |
+
total_steps = len(train_loader) * config["epochs"]
|
312 |
+
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
313 |
+
|
314 |
+
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
|
315 |
+
writer = SummaryWriter(log_dir=log_dir)
|
316 |
+
|
317 |
+
for epoch in range(config["epochs"]):
|
318 |
+
train_epoch(
|
319 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
320 |
+
)
|
321 |
+
|
322 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
323 |
+
model, val_loader, device, config
|
324 |
+
)
|
325 |
+
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
326 |
+
|
327 |
+
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
328 |
+
writer.close()
|
329 |
+
|
330 |
+
save_validation_predictions(
|
331 |
+
val_cell_id_mapping,
|
332 |
+
task_true_labels,
|
333 |
+
task_pred_labels,
|
334 |
+
task_pred_probs,
|
335 |
+
config,
|
336 |
+
trial.number,
|
337 |
+
)
|
338 |
+
|
339 |
+
trial.set_user_attr("model_state_dict", model.state_dict())
|
340 |
+
trial.set_user_attr("task_weights", config["task_weights"])
|
341 |
+
|
342 |
+
trial.report(val_loss, config["epochs"])
|
343 |
+
|
344 |
+
if trial.should_prune():
|
345 |
+
raise optuna.TrialPruned()
|
346 |
+
|
347 |
+
if config.get("use_wandb", False):
|
348 |
+
import wandb
|
349 |
+
|
350 |
+
wandb.log(
|
351 |
+
{
|
352 |
+
"trial_number": trial.number,
|
353 |
+
"val_loss": val_loss,
|
354 |
+
**{
|
355 |
+
f"{task_name}_f1": metrics["f1"]
|
356 |
+
for task_name, metrics in task_metrics.items()
|
357 |
+
},
|
358 |
+
**{
|
359 |
+
f"{task_name}_accuracy": metrics["accuracy"]
|
360 |
+
for task_name, metrics in task_metrics.items()
|
361 |
+
},
|
362 |
+
**{
|
363 |
+
k: v
|
364 |
+
for k, v in config.items()
|
365 |
+
if k
|
366 |
+
in [
|
367 |
+
"learning_rate",
|
368 |
+
"warmup_ratio",
|
369 |
+
"weight_decay",
|
370 |
+
"dropout_rate",
|
371 |
+
"lr_scheduler_type",
|
372 |
+
"use_attention_pooling",
|
373 |
+
"max_layers_to_freeze",
|
374 |
+
]
|
375 |
+
},
|
376 |
+
}
|
377 |
+
)
|
378 |
+
wandb.finish()
|
379 |
+
|
380 |
+
return val_loss
|