diff --git a/Essay_classifier/.gitignore b/Essay_classifier/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..379c808db0c8f8e1e1fa7c0fdcc59785044aa788
--- /dev/null
+++ b/Essay_classifier/.gitignore
@@ -0,0 +1,168 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+cache_dir/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wandb/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+*.pyc
+
+# S5 specific stuff
+wandb/
+cache_dir/
+raw_datasets/
diff --git a/Essay_classifier/.idea/.gitignore b/Essay_classifier/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..26d33521af10bcc7fd8cea344038eaaeb78d0ef5
--- /dev/null
+++ b/Essay_classifier/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/Essay_classifier/.idea/Essay_classifier.iml b/Essay_classifier/.idea/Essay_classifier.iml
new file mode 100644
index 0000000000000000000000000000000000000000..8e5446ac9594d6e198c2a2923123566d13b94bf9
--- /dev/null
+++ b/Essay_classifier/.idea/Essay_classifier.iml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+Figure 1: S5 uses a single multi-input, multi-output linear state-space model, coupled with non-linearities, to define a non-linear sequence-to-sequence transformation. Parallel scans are used for efficient offline processing. +
+ + +The S5 layer builds on the prior S4 work ([paper](https://arxiv.org/abs/2111.00396)). While it has departed considerably, this repository originally started off with much of the JAX implementation of S4 from the +Annotated S4 blog by Rush and Karamcheti (available [here](https://github.com/srush/annotated-s4)). + + +## Requirements & Installation +To run the code on your own machine, run either `pip install -r requirements_cpu.txt` or `pip install -r requirements_gpu.txt`. The GPU installation of JAX can be tricky, and so we include requirements that should work for most people, although further instructions are available [here](https://github.com/google/jax#installation). + +Run from within the root directory `pip install -e .` to install the package. + + +## Data Download +Downloading the raw data is done differently for each dataset. The following datasets require no action: +- Text (IMDb) +- Image (Cifar black & white) +- sMNIST +- psMNIST +- Cifar (Color) + +The remaining datasets need to be manually downloaded. To download _everything_, run `./bin/download_all.sh`. This will download quite a lot of data and will take some time. + +Below is a summary of the steps for each dataset: +- ListOps: run `./bin/download_lra.sh` to download the full LRA dataset. +- Retrieval (AAN): run `./bin/download_aan.sh` +- Pathfinder: run `./bin/download_lra.sh` to download the full LRA dataset. +- Path-X: run `./bin/download_lra.sh` to download the full LRA dataset. +- Speech commands 35: run `./bin/download_sc35.sh` to download the speech commands data. + +*With the exception of SC35.* When the dataset is used for the first time, a cache is created in `./cache_dir`. Converting the data (e.g. tokenizing) can be quite slow, and so this cache contains the processed dataset. The cache can be moved and specified with the `--dir_name` argument (i.e. the default is `--dir_name=./cache_dir`) to avoid applying this preprocessing every time the code is run somewhere new. + +SC35 is slightly different. SC35 doesn't use `--dir_name`, and instead requires that the following path exists: `./raw_datasets/speech_commands/0.0.2/SpeechCommands` (i.e. the directory `./raw_datasets/speech_commands/0.0.2/SpeechCommands/zero` must exist). The cache is then stored in `./raw_datasets/speech_commands/0.0.2/SpeechCommands/processed_data`. This directory can then be copied (preserving the directory path) to move the preprocessed dataset to a new location. + + +## Repository Structure +Directories and files that ship with GitHub repo: +``` +s5/ Source code for models, datasets, etc. + dataloading.py Dataloading functions. + layers.py Defines the S5 layer which wraps the S5 SSM with nonlinearity, norms, dropout, etc. + seq_model.py Defines deep sequence models that consist of stacks of S5 layers. + ssm.py S5 SSM implementation. + ssm_init.py Helper functions for initializing the S5 SSM . + train.py Training loop code. + train_helpers.py Functions for optimization, training and evaluation steps. + dataloaders/ Code mainly derived from S4 processing each dataset. + utils/ Range of utility functions. +bin/ Shell scripts for downloading data and running example experiments. +requirements_cpu.txt Requirements for running in CPU mode (not advised). +requirements_gpu.txt Requirements for running in GPU mode (installation can be highly system-dependent). +run_train.py Training loop entrypoint. +``` + +Directories that may be created on-the-fly: +``` +raw_datasets/ Raw data as downloaded. +cache_dir/ Precompiled caches of data. Can be copied to new locations to avoid preprocessing. +wandb/ Local WandB log files. +``` + +## Experiments + +The configurations to run the LRA and 35-way Speech Commands experiments from the paper are located in `bin/run_experiments`. For example, +to run the LRA text (character level IMDB) experiment, run `./bin/run_experiments/run_lra_imdb.sh`. +To log with W&B, adjust the default `USE_WANDB, wandb_entity, wandb_project` arguments. +Note: the pendulum +regression dataloading and experiments will be added soon. + +## Citation +Please use the following when citing our work: +``` +@misc{smith2022s5, + doi = {10.48550/ARXIV.2208.04933}, + url = {https://arxiv.org/abs/2208.04933}, + author = {Smith, Jimmy T. H. and Warrington, Andrew and Linderman, Scott W.}, + keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Simplified State Space Layers for Sequence Modeling}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution 4.0 International} +} +``` + +Please reach out if you have any questions. + +-- The S5 authors. diff --git a/Essay_classifier/S5.egg-info/PKG-INFO b/Essay_classifier/S5.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..062426a7b301d3abc5ad8151418d82e0e2bf1550 --- /dev/null +++ b/Essay_classifier/S5.egg-info/PKG-INFO @@ -0,0 +1,13 @@ +Metadata-Version: 2.1 +Name: S5 +Version: 0.1 +Summary: Simplified State Space Models for Sequence Modeling. +Home-page: UNKNOWN +Author: J.T.H. Smith, A. Warrington, S. Linderman. +Author-email: jsmith14@stanford.edu +License: UNKNOWN +Platform: UNKNOWN +License-File: LICENSE + +UNKNOWN + diff --git a/Essay_classifier/S5.egg-info/SOURCES.txt b/Essay_classifier/S5.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..6f54e68780e1a1648602ad2a7a4d7013047c72b0 --- /dev/null +++ b/Essay_classifier/S5.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +LICENSE +README.md +setup.py +S5.egg-info/PKG-INFO +S5.egg-info/SOURCES.txt +S5.egg-info/dependency_links.txt +S5.egg-info/top_level.txt \ No newline at end of file diff --git a/Essay_classifier/S5.egg-info/dependency_links.txt b/Essay_classifier/S5.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/Essay_classifier/S5.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/Essay_classifier/S5.egg-info/top_level.txt b/Essay_classifier/S5.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/Essay_classifier/S5.egg-info/top_level.txt @@ -0,0 +1 @@ + diff --git a/Essay_classifier/bin/download_aan.sh b/Essay_classifier/bin/download_aan.sh new file mode 100644 index 0000000000000000000000000000000000000000..840acf9e59d69c7a47a69a41abd63f0440fd20d1 --- /dev/null +++ b/Essay_classifier/bin/download_aan.sh @@ -0,0 +1,4 @@ +mkdir raw_datasets + +# Download the raw AAN data from the TutorialBank Corpus. +wget -v https://github.com/Yale-LILY/TutorialBank/blob/master/resources-v2022-clean.tsv -P ./raw_datasets diff --git a/Essay_classifier/bin/download_all.sh b/Essay_classifier/bin/download_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..6224e06011f425112971d714e028aca802bd1fe1 --- /dev/null +++ b/Essay_classifier/bin/download_all.sh @@ -0,0 +1,8 @@ +# Make a directory to dump the raw data into. +rm -rf ./raw_datasets +mkdir ./raw_datasets + +./bin/download_lra.sh +./bin/download_aan.sh +./bin/download_sc35.sh + diff --git a/Essay_classifier/bin/download_lra.sh b/Essay_classifier/bin/download_lra.sh new file mode 100644 index 0000000000000000000000000000000000000000..e3691c514b09818d11d4967775ee570af6327af7 --- /dev/null +++ b/Essay_classifier/bin/download_lra.sh @@ -0,0 +1,9 @@ +mkdir raw_datasets + +# Clone and unpack the LRA object. +# This can take a long time, so get comfortable. +rm -rf ./raw_datasets/lra_release.gz ./raw_datasets/lra_release # Clean out any old datasets. +wget -v https://storage.googleapis.com/long-range-arena/lra_release.gz -P ./raw_datasets + +# Add a progress bar because this can be slow. +pv ./raw_datasets/lra_release.gz | tar -zx -C ./raw_datasets/ diff --git a/Essay_classifier/bin/download_sc35.sh b/Essay_classifier/bin/download_sc35.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3135a9659e0f84ed89dd6550bfc1e5302805ed0 --- /dev/null +++ b/Essay_classifier/bin/download_sc35.sh @@ -0,0 +1,4 @@ +mkdir raw_datasets + +# Use tfds to download the speech commands dataset. +python ./bin/python_scripts/download_sc.py diff --git a/Essay_classifier/bin/python_scripts/download_sc.py b/Essay_classifier/bin/python_scripts/download_sc.py new file mode 100644 index 0000000000000000000000000000000000000000..71695a1623a3e30e3512c85159066860a4115833 --- /dev/null +++ b/Essay_classifier/bin/python_scripts/download_sc.py @@ -0,0 +1,4 @@ +import tensorflow_datasets as tfds +import os +cfg = tfds.download.DownloadConfig(extract_dir=os.getcwd() + '/raw_datasets/') +tfds.load('speech_commands', data_dir='./raw_datasets', download=True, download_and_prepare_kwargs={'download_dir': os.getcwd() + '/raw_datasets/', 'download_config': cfg}) diff --git a/Essay_classifier/bin/run_experiments/run_gpt_classifier.sh b/Essay_classifier/bin/run_experiments/run_gpt_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..766e2a76e74fa91f2469922a7cf8c58994611f2a --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_gpt_classifier.sh @@ -0,0 +1,7 @@ +python run_train.py --C_init=lecun_normal --activation_fn=half_glu2 \ + --batchnorm=True --bidirectional=True --blocks=12 --bsz=8 \ + --d_model=64 --dataset=imdb-classification \ + --dt_global=True --epochs=35 --jax_seed=8825365 --lr_factor=4 \ + --n_layers=6 --opt_config=standard --p_dropout=0.1 --ssm_lr_base=0.001 \ + --ssm_size_base=192 --warmup_end=0 --weight_decay=0.07 \ + --USE_WANDB True --wandb_project awsome_0 --wandb_entity Vodolay diff --git a/Essay_classifier/bin/run_experiments/run_lra_aan.sh b/Essay_classifier/bin/run_experiments/run_lra_aan.sh new file mode 100644 index 0000000000000000000000000000000000000000..a27a7bf758085de00c9baed9fb8fbed0865bdcae --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_aan.sh @@ -0,0 +1,5 @@ +python run_train.py --C_init=trunc_standard_normal --batchnorm=True --bidirectional=True \ + --blocks=16 --bsz=32 --d_model=128 --dataset=aan-classification \ + --dt_global=True --epochs=20 --jax_seed=5464368 --lr_factor=2 --n_layers=6 \ + --opt_config=standard --p_dropout=0.0 --ssm_lr_base=0.001 --ssm_size_base=256 \ + --warmup_end=1 --weight_decay=0.05 \ No newline at end of file diff --git a/Essay_classifier/bin/run_experiments/run_lra_cifar.sh b/Essay_classifier/bin/run_experiments/run_lra_cifar.sh new file mode 100644 index 0000000000000000000000000000000000000000..6db2d63872b74337aaa40c251f23a3ff21f21682 --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_cifar.sh @@ -0,0 +1,4 @@ +python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ + --blocks=3 --bsz=50 --clip_eigs=True --d_model=512 --dataset=lra-cifar-classification \ + --epochs=250 --jax_seed=16416 --lr_factor=4.5 --n_layers=6 --opt_config=BfastandCdecay \ + --p_dropout=0.1 --ssm_lr_base=0.001 --ssm_size_base=384 --warmup_end=1 --weight_decay=0.07 \ No newline at end of file diff --git a/Essay_classifier/bin/run_experiments/run_lra_imdb.sh b/Essay_classifier/bin/run_experiments/run_lra_imdb.sh new file mode 100644 index 0000000000000000000000000000000000000000..766e2a76e74fa91f2469922a7cf8c58994611f2a --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_imdb.sh @@ -0,0 +1,7 @@ +python run_train.py --C_init=lecun_normal --activation_fn=half_glu2 \ + --batchnorm=True --bidirectional=True --blocks=12 --bsz=8 \ + --d_model=64 --dataset=imdb-classification \ + --dt_global=True --epochs=35 --jax_seed=8825365 --lr_factor=4 \ + --n_layers=6 --opt_config=standard --p_dropout=0.1 --ssm_lr_base=0.001 \ + --ssm_size_base=192 --warmup_end=0 --weight_decay=0.07 \ + --USE_WANDB True --wandb_project awsome_0 --wandb_entity Vodolay diff --git a/Essay_classifier/bin/run_experiments/run_lra_listops.sh b/Essay_classifier/bin/run_experiments/run_lra_listops.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c01a0099192901ee18111ecba7a1044295aa939 --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_listops.sh @@ -0,0 +1,4 @@ +python run_train.py --C_init=lecun_normal --activation_fn=half_glu2 --batchnorm=True \ + --bidirectional=True --blocks=8 --bsz=50 --d_model=128 --dataset=listops-classification \ + --epochs=40 --jax_seed=6554595 --lr_factor=3 --n_layers=8 --opt_config=BfastandCdecay \ + --p_dropout=0 --ssm_lr_base=0.001 --ssm_size_base=16 --warmup_end=1 --weight_decay=0.04 \ No newline at end of file diff --git a/Essay_classifier/bin/run_experiments/run_lra_pathfinder.sh b/Essay_classifier/bin/run_experiments/run_lra_pathfinder.sh new file mode 100644 index 0000000000000000000000000000000000000000..ddf1bf21fdbb0b6f510b514deb6b8b5f638a82ee --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_pathfinder.sh @@ -0,0 +1,5 @@ +python run_train.py --C_init=trunc_standard_normal --batchnorm=True --bidirectional=True \ + --blocks=8 --bn_momentum=0.9 --bsz=64 --d_model=192 \ + --dataset=pathfinder-classification --epochs=200 --jax_seed=8180844 --lr_factor=5 \ + --n_layers=6 --opt_config=standard --p_dropout=0.05 --ssm_lr_base=0.0009 \ + --ssm_size_base=256 --warmup_end=1 --weight_decay=0.03 \ No newline at end of file diff --git a/Essay_classifier/bin/run_experiments/run_lra_pathx.sh b/Essay_classifier/bin/run_experiments/run_lra_pathx.sh new file mode 100644 index 0000000000000000000000000000000000000000..970333cdeac2f2caeaf6a3137132698ab2c35164 --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_lra_pathx.sh @@ -0,0 +1,5 @@ +python run_train.py --C_init=complex_normal --batchnorm=True --bidirectional=True \ + --blocks=16 --bn_momentum=0.9 --bsz=32 --d_model=128 --dataset=pathx-classification \ + --dt_min=0.0001 --epochs=75 --jax_seed=6429262 --lr_factor=3 --n_layers=6 \ + --opt_config=BandCdecay --p_dropout=0.0 --ssm_lr_base=0.0006 --ssm_size_base=256 \ + --warmup_end=1 --weight_decay=0.06 \ No newline at end of file diff --git a/Essay_classifier/bin/run_experiments/run_speech35.sh b/Essay_classifier/bin/run_experiments/run_speech35.sh new file mode 100644 index 0000000000000000000000000000000000000000..8b9c5a5b31ba5e102cbee26ebe069ac9a523c535 --- /dev/null +++ b/Essay_classifier/bin/run_experiments/run_speech35.sh @@ -0,0 +1,4 @@ +python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ + --blocks=16 --bsz=16 --d_model=96 --dataset=speech35-classification \ + --epochs=40 --jax_seed=4062966 --lr_factor=4 --n_layers=6 --opt_config=noBCdecay \ + --p_dropout=0.1 --ssm_lr_base=0.002 --ssm_size_base=128 --warmup_end=1 --weight_decay=0.04 \ No newline at end of file diff --git a/Essay_classifier/docs/figures/pdfs/s3-block-diagram-2.pdf b/Essay_classifier/docs/figures/pdfs/s3-block-diagram-2.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d85f6a9a1161afb2e1ec4f9389f2293d3dec9348 Binary files /dev/null and b/Essay_classifier/docs/figures/pdfs/s3-block-diagram-2.pdf differ diff --git a/Essay_classifier/docs/figures/pdfs/s4-matrix-blocks.pdf b/Essay_classifier/docs/figures/pdfs/s4-matrix-blocks.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d785fa49f14b67c2132de5b7c43174b73a26d3f5 Binary files /dev/null and b/Essay_classifier/docs/figures/pdfs/s4-matrix-blocks.pdf differ diff --git a/Essay_classifier/docs/figures/pdfs/s4-s3-block-diagram-2.pdf b/Essay_classifier/docs/figures/pdfs/s4-s3-block-diagram-2.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b51e5d1555c3b65824cd5af5a6eab63160fe60e8 Binary files /dev/null and b/Essay_classifier/docs/figures/pdfs/s4-s3-block-diagram-2.pdf differ diff --git a/Essay_classifier/docs/figures/pdfs/s5-matrix-blocks.pdf b/Essay_classifier/docs/figures/pdfs/s5-matrix-blocks.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a7c0d3e72dbd059504de529539cb6af06d9f5c5a Binary files /dev/null and b/Essay_classifier/docs/figures/pdfs/s5-matrix-blocks.pdf differ diff --git a/Essay_classifier/docs/figures/pngs/pendulum.png b/Essay_classifier/docs/figures/pngs/pendulum.png new file mode 100644 index 0000000000000000000000000000000000000000..ba7d987923781bb074771d81ce16818e7d4df2d7 Binary files /dev/null and b/Essay_classifier/docs/figures/pngs/pendulum.png differ diff --git a/Essay_classifier/docs/figures/pngs/s3-block-diagram-2.png b/Essay_classifier/docs/figures/pngs/s3-block-diagram-2.png new file mode 100644 index 0000000000000000000000000000000000000000..8e415ab59ae9b6150ec6be87c1438efee7832a10 Binary files /dev/null and b/Essay_classifier/docs/figures/pngs/s3-block-diagram-2.png differ diff --git a/Essay_classifier/docs/figures/pngs/s4-matrix-blocks.png b/Essay_classifier/docs/figures/pngs/s4-matrix-blocks.png new file mode 100644 index 0000000000000000000000000000000000000000..a35d926efcbf002cf18e7eadea926e4a7faa6b65 Binary files /dev/null and b/Essay_classifier/docs/figures/pngs/s4-matrix-blocks.png differ diff --git a/Essay_classifier/docs/figures/pngs/s4-s3-block-diagram-2.png b/Essay_classifier/docs/figures/pngs/s4-s3-block-diagram-2.png new file mode 100644 index 0000000000000000000000000000000000000000..0b594e2419752a0e8b4c3533f00edc589e58ad00 Binary files /dev/null and b/Essay_classifier/docs/figures/pngs/s4-s3-block-diagram-2.png differ diff --git a/Essay_classifier/docs/figures/pngs/s5-matrix-blocks.png b/Essay_classifier/docs/figures/pngs/s5-matrix-blocks.png new file mode 100644 index 0000000000000000000000000000000000000000..d9fb746b499d664bd2326d87ea04066e7a514c64 Binary files /dev/null and b/Essay_classifier/docs/figures/pngs/s5-matrix-blocks.png differ diff --git a/Essay_classifier/docs/s5_blog.md b/Essay_classifier/docs/s5_blog.md new file mode 100644 index 0000000000000000000000000000000000000000..6a5458c43eb2f71af665b6ba3b2b0093ea4f3854 --- /dev/null +++ b/Essay_classifier/docs/s5_blog.md @@ -0,0 +1,77 @@ + + + + +# S5: Simplified State Space Layers for Sequence Modeling + +_By [Jimmy Smith](https://icme.stanford.edu/people/jimmy-smith), [Andrew Warrington](https://github.com/andrewwarrington) & [Scott Linderman](https://web.stanford.edu/~swl1/)._ + +_This post accompanies the preprint Smith et al [2022], available [here](https://arxiv.org/pdf/2208.04933.pdf). Code for the paper is available [here](https://github.com/lindermanlab/S5)_. + + + +## TL;DR. +In our preprint we demonstrate that we can build a state-of-the-art deep sequence-to-sequence model using by stacking many dense, multi-input, multi-output (MIMO) state space models (SSMs) as a layer. This replaces the many single-input, single-output (SISO) SSMs used by the _structured state space sequence_ (S4) model [Gu et al, 2021]. This allows us to make use of efficient parallel scan to achieve the same computational effiency of S4, without the need to use frequency domain and convolutional methods. We show that S5 achieves the same, if not better, performance than S4 on a range of long-range sequence modeling tasks. + +![](./figures/pngs/s5-matrix-blocks.png) +_Figure 1: Our S5 layer uses a single, dense, multi-input, multi-output state space model as a layer in a deep sequence-to-sequence model._ + + + +## S4 is Epically Good. So... Why? + + +![](./figures/pngs/s4-s3-block-diagram-2.png) +_Figure 2: A schematic of the computations required by S4. \\(H\\) SISO SSMs are applied in the frequency domain, passed through a non-linearity, and then mixed to provide the input to the next layer. Deriving the "Frequency domain convolution kernel generation" (and the required parameterization, indicated in blue) is the primary focus of Gu et al [2021]._ + +The performance of S4 is unarguable. Transformer-based methods were clawing for single percentage point gains on the long range arena benchmark dataset [Tay et al, 2021]. S4 beat many SotA transformer methods by as much as twenty percentage points. AND, to top it off, could process sequences with complexity linear in the sequence length, and sublinear in parallel time (with a reasonable number of processors). + +However, the original S4 is a very involved method. It required specific matrix parameterizations, decompositions, mathematical identities, Fourier transforms, and more, as illustrated in [Figure 2](#fig_s4_stack). As a research group, we spent several weeks trying to understand all the intricacies of the method. This left us asking: is there a different way of using the same core concepts, retaining performance and complexity, but, maybe, making it (subjectively, we admit!) simpler? + +Enter S5. + + +## From SISO to MIMO. From Convolution to Parallel Recurrence. + + + +![](./figures/pngs/s4-matrix-blocks.png) +_Figure 3: Our S5 layer uses a single, dense, multi-input, multi-output state space model as a layer in a deep sequence-to-sequence model._ + +--- + +todo + +--- + +## S4 and Its Variants. +Since publishing the original S4 model, the original authors have released three further papers studying the S4 model. Most significant of those papers are S4D [Gu, 2022] and DSS [Gupta, 2022]. These paper explores using diagonal state spaces, similar to what we use. S4D provided a proof as to why the (diagonalizable) normal matrix, from the normal-plus-low-rank factorization of the HiPPO-LegS matrix, provides such a good initialization for SISO systems. We show (although its really not that difficult!) that using this initialization in the MIMO case enjoys similar characteristics. We note, however, that S4D and DSS provide computationally simpler implementations of S4; but, doe not perform quite as strongly. Most importantly, though, S5 isn't the only simplification to S4. + + + +## Other Resources. +- Much of our understanding and early code was based on the _excellent_ blog post, _The Annotated S4_, by [Rush and Karamcheti \[2021\]](https://srush.github.io/annotated-s4/). +- Full code for the original S4 implementation, and many of its forerunners and derivatives, is available [here](https://github.com/HazyResearch/state-spaces). +- Instructions for obtaining the LRA dataset are [here](https://openreview.net/pdf?id=qVyeW-grC2k). + + + +## Awesome Other Work. +There are obviously many other great researchers working on adapting, extending, and understanding S4. We outline some very recent work here: + +- Mega, by Ma et al [2022], combines linear state space layers with transformer heads for sequence modeling. The main Mega method has \\(O(L^2)\\) complexity. A second method, Mega-chunk, is presented that has \\(O(L)\\), but does not achieve the same performance as Mega. Combining SSMs with transformer heads is a great avenue for future work. +- Liquid-S4, by Hasani et al [2022], extends S4 by adding a dependence on the input signal into the state matrix. When expanded, this is equivilant to adding cross-terms between the \\(k^{th}\\) input and all previous inputs. Evaluating all previous terms is intractable, and so this sequence is often truncated. Extending the linear SSM, such that it is conditionally linear, is a really exciting opportunity for making the more model of linear state space layers more expressive. +- ADD "what makes conv great" once it is de-anonymysed. + + + +## Bibliography +- Smith, Jimmy TH, Andrew Warrington, and Scott W. Linderman. "Simplified State Space Layers for Sequence Modeling." arXiv preprint arXiv:2208.04933 (2022). [Link](https://arxiv.org/pdf/2208.04933.pdf). +- Gu, Albert, Karan Goel, and Christopher Re. "Efficiently Modeling Long Sequences with Structured State Spaces." International Conference on Learning Representations (2021). [Link](https://openreview.net/pdf?id=uYLFoz1vlAC). +- Rush, Sasha, and Sidd Karamcheti. "The Annotated S4." Blog Track at ICLR 2022 (2022). [Link](https://srush.github.io/annotated-s4/). +- Yi Tay, et al. "Long Range Arena : A Benchmark for Efficient Transformers ." International Conference on Learning Representations (2021). [Link](https://openreview.net/pdf?id=qVyeW-grC2k). +- Ma, Xuezhe, et al. "Mega: Moving Average Equipped Gated Attention." arXiv preprint arXiv:2209.10655 (2022). [Link](https://arxiv.org/pdf/2209.10655). +- Hasani, Ramin, et al. "Liquid Structural State-Space Models." arXiv preprint arXiv:2209.12951 (2022). [Link](https://web10.arxiv.org/pdf/2209.12951.pdf). +- Gu S4d. + + diff --git a/Essay_classifier/essays/dataset_dict.json b/Essay_classifier/essays/dataset_dict.json new file mode 100644 index 0000000000000000000000000000000000000000..31145eaf92f08c428500ba08c6ee6e186851a3ac --- /dev/null +++ b/Essay_classifier/essays/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "test"]} \ No newline at end of file diff --git a/Essay_classifier/essays/test/data-00000-of-00001.arrow b/Essay_classifier/essays/test/data-00000-of-00001.arrow new file mode 100644 index 0000000000000000000000000000000000000000..cac3bebff5f0eacf56cbce81328bb7aed51c7f06 --- /dev/null +++ b/Essay_classifier/essays/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf303ddc5053fcc1d3d8ad2e8a885dbf12d31960c1e097a4979980362c4c94ea +size 470136 diff --git a/Essay_classifier/essays/test/dataset_info.json b/Essay_classifier/essays/test/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..9db8efda3c90ef76c70e46a23d428d89b7cfa75e --- /dev/null +++ b/Essay_classifier/essays/test/dataset_info.json @@ -0,0 +1,20 @@ +{ + "citation": "", + "description": "", + "features": { + "text": { + "dtype": "string", + "_type": "Value" + }, + "label": { + "dtype": "int64", + "_type": "Value" + }, + "__index_level_0__": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/Essay_classifier/essays/test/state.json b/Essay_classifier/essays/test/state.json new file mode 100644 index 0000000000000000000000000000000000000000..42ba7bd763f87bd157e0012614bc18ad8a19e279 --- /dev/null +++ b/Essay_classifier/essays/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "cf3f779c3519cf1d", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/Essay_classifier/essays/train/data-00000-of-00001.arrow b/Essay_classifier/essays/train/data-00000-of-00001.arrow new file mode 100644 index 0000000000000000000000000000000000000000..5083d817d768f6e5501a15cae48aae4c75775cd7 --- /dev/null +++ b/Essay_classifier/essays/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a92f80195891f57e75536c3660c1a67e36e044a6079678742eaa8f7f72711411 +size 1943280 diff --git a/Essay_classifier/essays/train/dataset_info.json b/Essay_classifier/essays/train/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..9db8efda3c90ef76c70e46a23d428d89b7cfa75e --- /dev/null +++ b/Essay_classifier/essays/train/dataset_info.json @@ -0,0 +1,20 @@ +{ + "citation": "", + "description": "", + "features": { + "text": { + "dtype": "string", + "_type": "Value" + }, + "label": { + "dtype": "int64", + "_type": "Value" + }, + "__index_level_0__": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/Essay_classifier/essays/train/state.json b/Essay_classifier/essays/train/state.json new file mode 100644 index 0000000000000000000000000000000000000000..3f764513d47664c8720aac9e93e6c32aff13a2b5 --- /dev/null +++ b/Essay_classifier/essays/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "c8e09f6301a80e82", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/Essay_classifier/requirements_cpu.txt b/Essay_classifier/requirements_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..acf288e15f74c1889385a57c72fb9aae3b84c7a3 --- /dev/null +++ b/Essay_classifier/requirements_cpu.txt @@ -0,0 +1,9 @@ +flax==0.5.2 +torch==1.11.0 +torchtext==0.12.0 +tensorflow-datasets==4.5.2 +pydub==0.25.1 +datasets==2.4.0 +tqdm==4.62.3 +jaxlib==0.3.5 +jax==0.3.5 diff --git a/Essay_classifier/requirements_gpu.txt b/Essay_classifier/requirements_gpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..51a5e98e7691fc45da6459a65b30d563b6a9b23b --- /dev/null +++ b/Essay_classifier/requirements_gpu.txt @@ -0,0 +1,9 @@ +flax +torch +torchtext +tensorflow-datasets==4.5.2 +pydub==0.25.1 +datasets +tqdm +--find-links https://storage.googleapis.com/jax-releases/jax_releases.html +jax[cuda]>=version diff --git a/Essay_classifier/run_train.py b/Essay_classifier/run_train.py new file mode 100644 index 0000000000000000000000000000000000000000..328e1ae479f0454d3ccc04c9dedc3d709972d26f --- /dev/null +++ b/Essay_classifier/run_train.py @@ -0,0 +1,101 @@ +import argparse +from s5.utils.util import str2bool +from s5.train import train +from s5.dataloading import Datasets + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument("--USE_WANDB", type=str2bool, default=False, + help="log with wandb?") + parser.add_argument("--wandb_project", type=str, default=None, + help="wandb project name") + parser.add_argument("--wandb_entity", type=str, default=None, + help="wandb entity name, e.g. username") + parser.add_argument("--dir_name", type=str, default='./cache_dir', + help="name of directory where data is cached") + parser.add_argument("--dataset", type=str, choices=Datasets.keys(), + default='mnist-classification', + help="dataset name") + + # Model Parameters + parser.add_argument("--n_layers", type=int, default=6, + help="Number of layers in the network") + parser.add_argument("--d_model", type=int, default=128, + help="Number of features, i.e. H, " + "dimension of layer inputs/outputs") + parser.add_argument("--ssm_size_base", type=int, default=256, + help="SSM Latent size, i.e. P") + parser.add_argument("--blocks", type=int, default=8, + help="How many blocks, J, to initialize with") + parser.add_argument("--C_init", type=str, default="trunc_standard_normal", + choices=["trunc_standard_normal", "lecun_normal", "complex_normal"], + help="Options for initialization of C: \\" + "trunc_standard_normal: sample from trunc. std. normal then multiply by V \\ " \ + "lecun_normal sample from lecun normal, then multiply by V\\ " \ + "complex_normal: sample directly from complex standard normal") + parser.add_argument("--discretization", type=str, default="zoh", choices=["zoh", "bilinear"]) + parser.add_argument("--mode", type=str, default="pool", choices=["pool", "last"], + help="options: (for classification tasks) \\" \ + " pool: mean pooling \\" \ + "last: take last element") + parser.add_argument("--activation_fn", default="half_glu1", type=str, + choices=["full_glu", "half_glu1", "half_glu2", "gelu"]) + parser.add_argument("--conj_sym", type=str2bool, default=True, + help="whether to enforce conjugate symmetry") + parser.add_argument("--clip_eigs", type=str2bool, default=False, + help="whether to enforce the left-half plane condition") + parser.add_argument("--bidirectional", type=str2bool, default=False, + help="whether to use bidirectional model") + parser.add_argument("--dt_min", type=float, default=0.001, + help="min value to sample initial timescale params from") + parser.add_argument("--dt_max", type=float, default=0.1, + help="max value to sample initial timescale params from") + + # Optimization Parameters + parser.add_argument("--prenorm", type=str2bool, default=True, + help="True: use prenorm, False: use postnorm") + parser.add_argument("--batchnorm", type=str2bool, default=True, + help="True: use batchnorm, False: use layernorm") + parser.add_argument("--bn_momentum", type=float, default=0.95, + help="batchnorm momentum") + parser.add_argument("--bsz", type=int, default=64, + help="batch size") + parser.add_argument("--epochs", type=int, default=100, + help="max number of epochs") + parser.add_argument("--early_stop_patience", type=int, default=1000, + help="number of epochs to continue training when val loss plateaus") + parser.add_argument("--ssm_lr_base", type=float, default=1e-3, + help="initial ssm learning rate") + parser.add_argument("--lr_factor", type=float, default=1, + help="global learning rate = lr_factor*ssm_lr_base") + parser.add_argument("--dt_global", type=str2bool, default=False, + help="Treat timescale parameter as global parameter or SSM parameter") + parser.add_argument("--lr_min", type=float, default=0, + help="minimum learning rate") + parser.add_argument("--cosine_anneal", type=str2bool, default=True, + help="whether to use cosine annealing schedule") + parser.add_argument("--warmup_end", type=int, default=1, + help="epoch to end linear warmup") + parser.add_argument("--lr_patience", type=int, default=1000000, + help="patience before decaying learning rate for lr_decay_on_val_plateau") + parser.add_argument("--reduce_factor", type=float, default=1.0, + help="factor to decay learning rate for lr_decay_on_val_plateau") + parser.add_argument("--p_dropout", type=float, default=0.0, + help="probability of dropout") + parser.add_argument("--weight_decay", type=float, default=0.05, + help="weight decay value") + parser.add_argument("--opt_config", type=str, default="standard", choices=['standard', + 'BandCdecay', + 'BfastandCdecay', + 'noBCdecay'], + help="Opt configurations: \\ " \ + "standard: no weight decay on B (ssm lr), weight decay on C (global lr) \\" \ + "BandCdecay: weight decay on B (ssm lr), weight decay on C (global lr) \\" \ + "BfastandCdecay: weight decay on B (global lr), weight decay on C (global lr) \\" \ + "noBCdecay: no weight decay on B (ssm lr), no weight decay on C (ssm lr) \\") + parser.add_argument("--jax_seed", type=int, default=1919, + help="seed randomness") + + train(parser.parse_args()) diff --git a/Essay_classifier/s5/__init__.py b/Essay_classifier/s5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Essay_classifier/s5/dataloaders/README.md b/Essay_classifier/s5/dataloaders/README.md new file mode 100644 index 0000000000000000000000000000000000000000..55545a61aa8675656f7c53227cee60f6557b0fe8 --- /dev/null +++ b/Essay_classifier/s5/dataloaders/README.md @@ -0,0 +1,8 @@ +# Data & Dataloaders +The scripts in this directory deal with downloading, preparing and caching datasets, as well as building dataloaders from (preferably) a cache +or downloading the data directly. The scripts in this directory are **HEAVILY** based on the scripts in the original S4 repository, +but have been modified to remove them from the PyTorch Lightning ecosystem. + +These files were originally distributed under the Apache 2.0 license, (c) Albert Gu. The original copyright therefore remains with the original +authors, but we modify and distribute under the permissions of the license. Warranty, trademarking and liability are also therefore not allowed. + diff --git a/Essay_classifier/s5/dataloaders/__init__.py b/Essay_classifier/s5/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5857ca8bdcd618404894e55db024b5ec4b02246 --- /dev/null +++ b/Essay_classifier/s5/dataloaders/__init__.py @@ -0,0 +1,2 @@ +from . import audio, basic +from .base import SequenceDataset \ No newline at end of file diff --git a/Essay_classifier/s5/dataloaders/audio.py b/Essay_classifier/s5/dataloaders/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..43395a120aa43c033155397f73564f32e69eb451 --- /dev/null +++ b/Essay_classifier/s5/dataloaders/audio.py @@ -0,0 +1,1005 @@ +"""Audio datasets and utilities.""" +import os +from os import listdir +from os.path import join + +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F + +from .base import default_data_path, SequenceDataset, deprecated + + +def minmax_scale(tensor, range_min=0, range_max=1): + """ + Min-max scaling to [0, 1]. + """ + min_val = torch.amin(tensor, dim=(1, 2), keepdim=True) + max_val = torch.amax(tensor, dim=(1, 2), keepdim=True) + return range_min + (range_max - range_min) * (tensor - min_val) / (max_val - min_val + 1e-6) + +def quantize(samples, bits=8, epsilon=0.01): + """ + Linearly quantize a signal in [0, 1] to a signal in [0, q_levels - 1]. + """ + q_levels = 1 << bits + samples *= q_levels - epsilon + samples += epsilon / 2 + return samples.long() + +def dequantize(samples, bits=8): + """ + Dequantize a signal in [0, q_levels - 1]. + """ + q_levels = 1 << bits + return samples.float() / (q_levels / 2) - 1 + +def mu_law_encode(audio, bits=8): + """ + Perform mu-law companding transformation. + """ + mu = torch.tensor((1 << bits) - 1) + + # Audio must be min-max scaled between -1 and 1 + audio = minmax_scale(audio, range_min=-1, range_max=1) + + # Perform mu-law companding transformation. + numerator = torch.log1p(mu * torch.abs(audio + 1e-8)) + denominator = torch.log1p(mu) + encoded = torch.sign(audio) * (numerator / denominator) + + # Shift signal to [0, 1] + encoded = (encoded + 1) / 2 + + # Quantize signal to the specified number of levels. + return quantize(encoded, bits=bits) + +def mu_law_decode(encoded, bits=8): + """ + Perform inverse mu-law transformation. + """ + mu = (1 << bits) - 1 + # Invert the quantization + x = dequantize(encoded, bits=bits) + + # Invert the mu-law transformation + x = torch.sign(x) * ((1 + mu)**(torch.abs(x)) - 1) / mu + + # Returned values in range [-1, 1] + return x + +def linear_encode(samples, bits=8): + """ + Perform scaling and linear quantization. + """ + samples = samples.clone() + samples = minmax_scale(samples) + return quantize(samples, bits=bits) + +def linear_decode(samples, bits=8): + """ + Invert the linear quantization. + """ + return dequantize(samples, bits=bits) + +def q_zero(bits=8): + """ + The quantized level of the 0.0 value. + """ + return 1 << (bits - 1) + + +class AbstractAudioDataset(torch.utils.data.Dataset): + + def __init__( + self, + bits=8, + sample_len=None, + quantization='linear', + return_type='autoregressive', + drop_last=True, + target_sr=None, + context_len=None, + pad_len=None, + **kwargs, + ) -> None: + super().__init__() + + self.bits = bits + self.sample_len = sample_len + self.quantization = quantization + self.return_type = return_type + self.drop_last = drop_last + self.target_sr = target_sr + self.zero = q_zero(bits) + self.context_len = context_len + self.pad_len = pad_len + + for key, value in kwargs.items(): + setattr(self, key, value) + + self.file_names = NotImplementedError("Must be assigned in setup().") + self.transforms = {} + + self.setup() + self.create_quantizer(self.quantization) + self.create_examples(self.sample_len) + + + def setup(self): + return NotImplementedError("Must assign a list of filepaths to self.file_names.") + + def __getitem__(self, index): + # Load signal + if self.sample_len is not None: + file_name, start_frame, num_frames = self.examples[index] + seq, sr = torchaudio.load(file_name, frame_offset=start_frame, num_frames=num_frames) + else: + seq, sr = torchaudio.load(self.examples[index]) + + # Average non-mono signals across channels + if seq.shape[0] > 1: + seq = seq.mean(dim=0, keepdim=True) + + # Resample signal if required + if self.target_sr is not None and sr != self.target_sr: + if sr not in self.transforms: + self.transforms[sr] = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr) + seq = self.transforms[sr](seq) + + # Transpose the signal to get (L, 1) + seq = seq.transpose(0, 1) + + # Unsqueeze to (1, L, 1) + seq = seq.unsqueeze(0) + + # Quantized signal + qseq = self.quantizer(seq, self.bits) + + # Squeeze back to (L, 1) + qseq = qseq.squeeze(0) + + # Return the signal + if self.return_type == 'autoregressive': + # Autoregressive training + # x is [0, qseq[0], qseq[1], ..., qseq[-2]] + # y is [qseq[0], qseq[1], ..., qseq[-1]] + y = qseq + x = torch.roll(qseq, 1, 0) # Roll the signal 1 step + x[0] = self.zero # Fill the first element with q_0 + x = x.squeeze(1) # Squeeze to (L, ) + if self.context_len is not None: + y = y[self.context_len:] # Trim the signal + if self.pad_len is not None: + x = torch.cat((torch.zeros(self.pad_len, dtype=self.qtype) + self.zero, x)) # Pad the signal + return x, y + elif self.return_type is None: + return qseq + else: + raise NotImplementedError(f'Invalid return type {self.return_type}') + + def __len__(self): + return len(self.examples) + + def create_examples(self, sample_len: int): + # Get metadata for all files + self.metadata = [ + torchaudio.info(file_name) for file_name in self.file_names + ] + + if sample_len is not None: + # Reorganize files into a flat list of (file_name, start_frame) pairs + # so that consecutive items are separated by sample_len + self.examples = [] + for file_name, metadata in zip(self.file_names, self.metadata): + # Update the sample_len if resampling to target_sr is required + # This is because the resampling will change the length of the signal + # so we need to adjust the sample_len accordingly (e.g. if downsampling + # the sample_len will need to be increased) + sample_len_i = sample_len + if self.target_sr is not None and metadata.sample_rate != self.target_sr: + sample_len_i = int(sample_len * metadata.sample_rate / self.target_sr) + + margin = metadata.num_frames % sample_len_i + for start_frame in range(0, metadata.num_frames - margin, sample_len_i): + self.examples.append((file_name, start_frame, sample_len_i)) + + if margin > 0 and not self.drop_last: + # Last (leftover) example is shorter than sample_len, and equal to the margin + # (must be padded in collate_fn) + self.examples.append((file_name, metadata.num_frames - margin, margin)) + else: + self.examples = self.file_names + + def create_quantizer(self, quantization: str): + if quantization == 'linear': + self.quantizer = linear_encode + self.dequantizer = linear_decode + self.qtype = torch.long + elif quantization == 'mu-law': + self.quantizer = mu_law_encode + self.dequantizer = mu_law_decode + self.qtype = torch.long + elif quantization is None: + self.quantizer = lambda x, bits: x + self.dequantizer = lambda x, bits: x + self.qtype = torch.float + else: + raise ValueError('Invalid quantization type') + +class QuantizedAudioDataset(AbstractAudioDataset): + """ + Adapted from https://github.com/deepsound-project/samplernn-pytorch/blob/master/dataset.py + """ + + def __init__( + self, + path, + bits=8, + ratio_min=0, + ratio_max=1, + sample_len=None, + quantization='linear', # [linear, mu-law] + return_type='autoregressive', # [autoregressive, None] + drop_last=False, + target_sr=None, + context_len=None, + pad_len=None, + **kwargs, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + drop_last=drop_last, + target_sr=target_sr, + path=path, + ratio_min=ratio_min, + ratio_max=ratio_max, + context_len=context_len, + pad_len=pad_len, + **kwargs, + ) + + def setup(self): + from natsort import natsorted + file_names = natsorted( + [join(self.path, file_name) for file_name in listdir(self.path)] + ) + self.file_names = file_names[ + int(self.ratio_min * len(file_names)) : int(self.ratio_max * len(file_names)) + ] + +class QuantizedAutoregressiveAudio(SequenceDataset): + _name_ = 'qautoaudio' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'path': None, + 'bits': 8, + 'sample_len': None, + 'train_percentage': 0.88, + 'quantization': 'linear', + 'drop_last': False, + 'context_len': None, + 'pad_len': None, + } + + def setup(self): + from src.dataloaders_2.audio import QuantizedAudioDataset + assert self.path is not None or self.data_dir is not None, "Pass a path to a folder of audio: either `data_dir` for full directory or `path` for relative path." + if self.data_dir is None: + self.data_dir = default_data_path / self.path + + self.dataset_train = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=0, + ratio_max=self.train_percentage, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + self.dataset_val = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage, + ratio_max=self.train_percentage + (1 - self.train_percentage) / 2, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + self.dataset_test = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage + (1 - self.train_percentage) / 2, + ratio_max=1, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + def collate_fn(batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + if self.pad_len is None: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + else: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len + self.pad_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + return x, y, {"lengths": lengths} + + if not self.drop_last: + self._collate_fn = collate_fn # TODO not tested + +class SpeechCommands09(AbstractAudioDataset): + + CLASSES = [ + "zero", + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + ] + + CLASS_TO_IDX = dict(zip(CLASSES, range(len(CLASSES)))) + + def __init__( + self, + path, + bits=8, + split='train', + sample_len=16000, + quantization='linear', # [linear, mu-law] + return_type='autoregressive', # [autoregressive, None] + drop_last=False, + target_sr=None, + dequantize=False, + pad_len=None, + **kwargs, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + split=split, + drop_last=drop_last, + target_sr=target_sr, + path=path, + dequantize=dequantize, + pad_len=pad_len, + **kwargs, + ) + + def setup(self): + with open(join(self.path, 'validation_list.txt')) as f: + validation_files = set([line.rstrip() for line in f.readlines()]) + + with open(join(self.path, 'testing_list.txt')) as f: + test_files = set([line.rstrip() for line in f.readlines()]) + + # Get all files in the paths named after CLASSES + self.file_names = [] + for class_name in self.CLASSES: + self.file_names += [ + (class_name, file_name) + for file_name in listdir(join(self.path, class_name)) + if file_name.endswith('.wav') + ] + + # Keep files based on the split + if self.split == 'train': + self.file_names = [ + join(self.path, class_name, file_name) + for class_name, file_name in self.file_names + if join(class_name, file_name) not in validation_files + and join(class_name, file_name) not in test_files + ] + elif self.split == 'validation': + self.file_names = [ + join(self.path, class_name, file_name) + for class_name, file_name in self.file_names + if join(class_name, file_name) in validation_files + ] + elif self.split == 'test': + self.file_names = [ + join(self.path, class_name, file_name) + for class_name, file_name in self.file_names + if join(class_name, file_name) in test_files + ] + + def __getitem__(self, index): + item = super().__getitem__(index) + x, y, *z = item + if self.dequantize: + x = self.dequantizer(x).unsqueeze(1) + return (x, y, *z) + +class SpeechCommands09Autoregressive(SequenceDataset): + _name_ = 'sc09' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'quantization': 'mu-law', + 'dequantize': False, + 'pad_len': None, + } + + def setup(self): + from src.dataloaders_2.audio import SpeechCommands09 + self.data_dir = self.data_dir or default_data_path / self._name_ + + self.dataset_train = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='train', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.dataset_val = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='validation', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.dataset_test = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='test', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.sample_len = self.dataset_train.sample_len + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + if self.pad_len is None: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + else: + pad_length = 0 # int(self.sample_len + self.pad_len - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero if not self.dequantize else 0., + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero if not self.dequantize else 0.) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + return x, y, {"lengths": lengths} + +class MaestroDataset(AbstractAudioDataset): + + YEARS = [2004, 2006, 2008, 2009, 2011, 2013, 2014, 2015, 2017, 2018] + SPLITS = ['train', 'validation', 'test'] + + def __init__( + self, + path, + bits=8, + split='train', + sample_len=None, + quantization='linear', + return_type='autoregressive', + drop_last=False, + target_sr=16000, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + split=split, + path=path, + drop_last=drop_last, + target_sr=target_sr, + ) + + def setup(self): + import pandas as pd + from natsort import natsorted + + self.path = str(self.path) + + # Pull out examples in the specified split + df = pd.read_csv(self.path + '/maestro-v3.0.0.csv') + df = df[df['split'] == self.split] + + file_names = [] + for filename in df['audio_filename'].values: + filepath = os.path.join(self.path, filename) + assert os.path.exists(filepath) + file_names.append(filepath) + self.file_names = natsorted(file_names) + +class MaestroAutoregressive(SequenceDataset): + _name_ = 'maestro' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'sample_len': None, + 'quantization': 'mu-law', + } + + def setup(self): + from src.dataloaders_2.audio import MaestroDataset + self.data_dir = self.data_dir or default_data_path / self._name_ / 'maestro-v3.0.0' + + self.dataset_train = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='train', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + self.dataset_val = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='validation', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + self.dataset_test = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='test', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(max(1024, 2**max_length.log2().ceil()), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + return x, y, {"lengths": lengths} + +class LJSpeech(QuantizedAudioDataset): + + def __init__( + self, + path, + bits=8, + ratio_min=0, + ratio_max=1, + sample_len=None, + quantization='linear', # [linear, mu-law] + return_type='autoregressive', # [autoregressive, None] + drop_last=False, + target_sr=None, + use_text=False, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + drop_last=drop_last, + target_sr=target_sr, + path=path, + ratio_min=ratio_min, + ratio_max=ratio_max, + use_text=use_text, + ) + + def setup(self): + import pandas as pd + from sklearn.preprocessing import LabelEncoder + super().setup() + + self.vocab_size = None + if self.use_text: + self.transcripts = {} + with open(str(self.path.parents[0] / 'metadata.csv'), 'r') as f: + for line in f: + index, raw_transcript, normalized_transcript = line.rstrip('\n').split("|") + self.transcripts[index] = normalized_transcript + # df = pd.read_csv(self.path.parents[0] / 'metadata.csv', sep="|", header=None) + # self.transcripts = dict(zip(df[0], df[2])) # use normalized transcripts + + self.tok_transcripts = {} + self.vocab = set() + for file_name in self.file_names: + # Very simple tokenization, character by character + # Capitalization is ignored for simplicity + file_name = file_name.split('/')[-1].split('.')[0] + self.tok_transcripts[file_name] = list(self.transcripts[file_name].lower()) + self.vocab.update(self.tok_transcripts[file_name]) + + # Fit a label encoder mapping characters to numbers + self.label_encoder = LabelEncoder() + self.label_encoder.fit(list(self.vocab)) + # add a token for padding, no additional token for UNK (our dev/test set contain no unseen characters) + self.vocab_size = len(self.vocab) + 1 + + # Finalize the tokenized transcripts + for file_name in self.file_names: + file_name = file_name.split('/')[-1].split('.')[0] + self.tok_transcripts[file_name] = torch.tensor(self.label_encoder.transform(self.tok_transcripts[file_name])) + + + def __getitem__(self, index): + item = super().__getitem__(index) + if self.use_text: + file_name, _, _ = self.examples[index] + tok_transcript = self.tok_transcripts[file_name.split('/')[-1].split('.')[0]] + return (*item, tok_transcript) + return item + +class LJSpeechAutoregressive(SequenceDataset): + _name_ = 'ljspeech' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'sample_len': None, + 'quantization': 'mu-law', + 'train_percentage': 0.88, + 'use_text': False, + } + + def setup(self): + from src.dataloaders_2.audio import LJSpeech + self.data_dir = self.data_dir or default_data_path / self._name_ / 'LJSpeech-1.1' / 'wavs' + + self.dataset_train = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=0, + ratio_max=self.train_percentage, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.dataset_val = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage, + ratio_max=self.train_percentage + (1 - self.train_percentage) / 2, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.dataset_test = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage + (1 - self.train_percentage) / 2, + ratio_max=1, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.vocab_size = self.dataset_train.vocab_size + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + + if self.use_text: + tokens = z[0] + text_lengths = torch.tensor([len(e) for e in tokens]) + tokens = nn.utils.rnn.pad_sequence( + tokens, + padding_value=self.vocab_size - 1, + batch_first=True, + ) + else: + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + if self.use_text: + return x, y, {"lengths": lengths, "tokens": tokens, "text_lengths": text_lengths} + else: + return x, y, {"lengths": lengths} + +class _SpeechCommands09Classification(SpeechCommands09): + + def __init__( + self, + path, + bits=8, + split='train', + sample_len=16000, + quantization='linear', # [linear, mu-law] + drop_last=False, + target_sr=None, + **kwargs, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=None, + split=split, + drop_last=drop_last, + target_sr=target_sr, + path=path, + **kwargs, + ) + + def __getitem__(self, index): + x = super().__getitem__(index) + x = mu_law_decode(x) + y = torch.tensor(self.CLASS_TO_IDX[self.file_names[index].split("/")[-2]]) + return x, y + +class SpeechCommands09Classification(SequenceDataset): + _name_ = 'sc09cls' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 10 + + @property + def l_output(self): + return 0 + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'quantization': 'mu-law', + } + + def setup(self): + from src.dataloaders_2.audio import _SpeechCommands09Classification + self.data_dir = self.data_dir or default_data_path / 'sc09' + + self.dataset_train = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='train', + quantization=self.quantization, + ) + + self.dataset_val = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='validation', + quantization=self.quantization, + ) + + self.dataset_test = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='test', + quantization=self.quantization, + ) + + self.sample_len = self.dataset_train.sample_len + + def collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=0.)#self.dataset_train.zero) + y = torch.tensor(y) + return x, y, {"lengths": lengths} + +@deprecated +class SpeechCommandsGeneration(SequenceDataset): + _name_ = "scg" + + init_defaults = { + "mfcc": False, + "dropped_rate": 0.0, + "length": 16000, + "all_classes": False, + "discrete_input": False, + } + + @property + def n_tokens(self): + return 256 if self.discrete_input else None + + def init(self): + if self.mfcc: + self.d_input = 20 + self.L = 161 + else: + self.d_input = 1 + self.L = self.length + + if self.dropped_rate > 0.0: + self.d_input += 1 + + self.d_output = 256 + self.l_output = self.length + + def setup(self): + from src.dataloaders_2.datasets.sc import _SpeechCommandsGeneration + + # TODO refactor with data_dir argument + self.dataset_train = _SpeechCommandsGeneration( + partition="train", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + self.dataset_val = _SpeechCommandsGeneration( + partition="val", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + self.dataset_test = _SpeechCommandsGeneration( + partition="test", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + x, y, *z = return_value + return (x, y.long(), *z) diff --git a/Essay_classifier/s5/dataloaders/base.py b/Essay_classifier/s5/dataloaders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a4579dafe261f61548914b520db020d8314a250c --- /dev/null +++ b/Essay_classifier/s5/dataloaders/base.py @@ -0,0 +1,350 @@ +""" Datasets for core experimental results """ +from functools import partial +from pathlib import Path +import torch +import torchaudio.functional as TF +import torchvision +from einops import rearrange + +from ..utils.util import is_list + + +def deprecated(cls_or_func): + def _deprecated(*args, **kwargs): + print(f"{cls_or_func} is deprecated") + return cls_or_func(*args, **kwargs) + return _deprecated + + +# Default data path is environment variable or hippo/data +default_data_path = Path(__file__).parent.parent.parent.absolute() +default_data_path = default_data_path / "raw_data" + + +class DefaultCollateMixin: + """Controls collating in the DataLoader + + The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. + Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a + _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the + rest of the arguments into the constructor. + """ + + @classmethod + def _collate_callback(cls, x, *args, **kwargs): + """ + Modify the behavior of the default _collate method. + """ + return x + + _collate_arg_names = [] + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + """ + Modify the return value of the collate_fn. + Assign a name to each element of the returned tuple beyond the (x, y) pairs + See InformerSequenceDataset for an example of this being used + """ + x, y, *z = return_value + assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" + return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} + + @classmethod + def _collate(cls, batch, *args, **kwargs): + # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py + elem = batch[0] + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + x = torch.stack(batch, dim=0, out=out) + + # Insert custom functionality into the collate_fn + x = cls._collate_callback(x, *args, **kwargs) + + return x + else: + return torch.tensor(batch) + + @classmethod + def _collate_fn(cls, batch, *args, **kwargs): + """ + Default collate function. + Generally accessed by the dataloader() methods to pass into torch DataLoader + + Arguments: + batch: list of (x, y) pairs + args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback + """ + x, y, *z = zip(*batch) + + x = cls._collate(x, *args, **kwargs) + y = cls._collate(y) + z = [cls._collate(z_) for z_ in z] + + return_value = (x, y, *z) + return cls._return_callback(return_value, *args, **kwargs) + + # List of loader arguments to pass into collate_fn + collate_args = [] + + def _dataloader(self, dataset, **loader_args): + collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} + loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} + loader_cls = loader_registry[loader_args.pop("_name_", None)] + return loader_cls( + dataset=dataset, + collate_fn=partial(self._collate_fn, **collate_args), + **loader_args, + ) + + +class SequenceResolutionCollateMixin(DefaultCollateMixin): + """self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" + + @classmethod + def _collate_callback(cls, x, resolution=None): + if resolution is None: + pass + elif is_list(resolution): # Resize to first resolution, then apply resampling technique + # Sample to first resolution + x = x.squeeze(-1) # (B, L) + L = x.size(1) + x = x[:, ::resolution[0]] # assume length is first axis after batch + _L = L // resolution[0] + for r in resolution[1:]: + x = TF.resample(x, _L, L//r) + _L = L // r + x = x.unsqueeze(-1) # (B, L, 1) + else: + # Assume x is (B, L_0, L_1, ..., L_k, C) for x.ndim > 2 and (B, L) for x.ndim = 2 + assert x.ndim >= 2 + n_resaxes = max(1, x.ndim - 2) # [AG 22/07/02] this line looks suspicious... are there cases with 2 axes? + # rearrange: b (l_0 res_0) (l_1 res_1) ... (l_k res_k) ... -> res_0 res_1 .. res_k b l_0 l_1 ... + lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." + rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." + x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) + x = x[tuple([0] * n_resaxes)] + + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None): + return (*return_value, {"rate": resolution}) + + collate_args = ['resolution'] + + +class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): + """self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" + + _interpolation = torchvision.transforms.InterpolationMode.BILINEAR + _antialias = True + + @classmethod + def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): + if x.ndim < 4: + return super()._collate_callback(x, resolution=resolution) + if img_size is None: + x = super()._collate_callback(x, resolution=resolution) + else: + x = rearrange(x, 'b ... c -> b c ...') if channels_last else x + _size = round(img_size/resolution) + x = torchvision.transforms.functional.resize( + x, + size=[_size, _size], + interpolation=cls._interpolation, + antialias=cls._antialias, + ) + x = rearrange(x, 'b c ... -> b ... c') if channels_last else x + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): + return (*return_value, {"rate": resolution}) + + collate_args = ['resolution', 'img_size', 'channels_last'] + + +class TBPTTDataLoader(torch.utils.data.DataLoader): + """ + Adapted from https://github.com/deepsound-project/samplernn-pytorch + """ + + def __init__( + self, + dataset, + batch_size, + chunk_len, + overlap_len, + *args, + **kwargs + ): + super().__init__(dataset, batch_size, *args, **kwargs) + assert chunk_len is not None and overlap_len is not None, "TBPTTDataLoader: chunk_len and overlap_len must be specified." + + # Zero padding value, given by the dataset + self.zero = dataset.zero if hasattr(dataset, "zero") else 0 + + # Size of the chunks to be fed into the model + self.chunk_len = chunk_len + + # Keep `overlap_len` from the previous chunk (e.g. SampleRNN requires this) + self.overlap_len = overlap_len + + def __iter__(self): + for batch in super().__iter__(): + x, y, z = batch # (B, L) (B, L, 1) {'lengths': (B,)} + + # Pad with self.overlap_len - 1 zeros + pad = lambda x, val: torch.cat([x.new_zeros((x.shape[0], self.overlap_len - 1, *x.shape[2:])) + val, x], dim=1) + x = pad(x, self.zero) + y = pad(y, 0) + z = { k: pad(v, 0) for k, v in z.items() if v.ndim > 1 } + _, seq_len, *_ = x.shape + + reset = True + + for seq_begin in list(range(self.overlap_len - 1, seq_len, self.chunk_len))[:-1]: + from_index = seq_begin - self.overlap_len + 1 + to_index = seq_begin + self.chunk_len + # TODO: check this + # Ensure divisible by overlap_len + if self.overlap_len > 0: + to_index = min(to_index, seq_len - ((seq_len - self.overlap_len + 1) % self.overlap_len)) + + x_chunk = x[:, from_index:to_index] + if len(y.shape) == 3: + y_chunk = y[:, seq_begin:to_index] + else: + y_chunk = y + z_chunk = {k: v[:, from_index:to_index] for k, v in z.items() if len(v.shape) > 1} + + yield (x_chunk, y_chunk, {**z_chunk, "reset": reset}) + + reset = False + + def __len__(self): + raise NotImplementedError() + + +# class SequenceDataset(LightningDataModule): +# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just +# provide our own class with the same core methods as LightningDataModule (e.g. setup) +class SequenceDataset(DefaultCollateMixin): + registry = {} + _name_ = NotImplementedError("Dataset must have shorthand name") + + # Since subclasses do not specify __init__ which is instead handled by this class + # Subclasses can provide a list of default arguments which are automatically registered as attributes + # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class + # such as the _name_ and d_input/d_output + @property + def init_defaults(self): + return {} + + # https://www.python.org/dev/peps/pep-0487/#subclass-registration + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.registry[cls._name_] = cls + + def __init__(self, _name_, data_dir=None, **dataset_cfg): + assert _name_ == self._name_ + self.data_dir = Path(data_dir).absolute() if data_dir is not None else None + + # Add all arguments to self + init_args = self.init_defaults.copy() + init_args.update(dataset_cfg) + for k, v in init_args.items(): + setattr(self, k, v) + + # The train, val, test datasets must be set by `setup()` + self.dataset_train = self.dataset_val = self.dataset_test = None + + self.init() + + def init(self): + """Hook called at end of __init__, override this instead of __init__""" + pass + + def setup(self): + """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" + raise NotImplementedError + + def split_train_val(self, val_split): + """ + Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. + """ + train_len = int(len(self.dataset_train) * (1.0 - val_split)) + self.dataset_train, self.dataset_val = torch.utils.data.random_split( + self.dataset_train, + (train_len, len(self.dataset_train) - train_len), + generator=torch.Generator().manual_seed( + getattr(self, "seed", 42) + ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us + ) + + def train_dataloader(self, **kwargs): + return self._train_dataloader(self.dataset_train, **kwargs) + + def _train_dataloader(self, dataset, **kwargs): + if dataset is None: return + kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler + return self._dataloader(dataset, **kwargs) + + def val_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_val, **kwargs) + + def test_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_test, **kwargs) + + def _eval_dataloader(self, dataset, **kwargs): + if dataset is None: return + # Note that shuffle=False by default + return self._dataloader(dataset, **kwargs) + + def __str__(self): + return self._name_ + + +class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): + + def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if train_resolution is None: train_resolution = [1] + if not is_list(train_resolution): train_resolution = [train_resolution] + assert len(train_resolution) == 1, "Only one train resolution supported for now." + return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) + + def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if dataset is None: return + if eval_resolutions is None: eval_resolutions = [1] + if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] + + dataloaders = [] + for resolution in eval_resolutions: + dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) + + return ( + { + None if res == 1 else str(res): dl + for res, dl in zip(eval_resolutions, dataloaders) + } + if dataloaders is not None else None + ) + + +class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): + pass + + +# Registry for dataloader class +loader_registry = { + "tbptt": TBPTTDataLoader, + None: torch.utils.data.DataLoader, # default case +} + diff --git a/Essay_classifier/s5/dataloaders/basic.py b/Essay_classifier/s5/dataloaders/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..66b12e9a755e2ff849a2c306a851a196124bfe5a --- /dev/null +++ b/Essay_classifier/s5/dataloaders/basic.py @@ -0,0 +1,272 @@ +"""Implementation of basic benchmark datasets used in S4 experiments: MNIST, CIFAR10 and Speech Commands.""" +import numpy as np +import torch +import torchvision +from einops.layers.torch import Rearrange + +from .base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset +from ..utils import permutations + + +class MNIST(SequenceDataset): + _name_ = "mnist" + d_input = 1 + d_output = 10 + l_output = 0 + L = 784 + + @property + def init_defaults(self): + return { + "permute": True, + "val_split": 0.1, + "seed": 42, # For train/val split + } + + def setup(self): + self.data_dir = self.data_dir or default_data_path / self._name_ + + transform_list = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: x.view(self.d_input, self.L).t()), + ] # (L, d_input) + if self.permute: + # below is another permutation that other works have used + # permute = np.random.RandomState(92916) + # permutation = torch.LongTensor(permute.permutation(784)) + permutation = permutations.bitreversal_permutation(self.L) + transform_list.append( + torchvision.transforms.Lambda(lambda x: x[permutation]) + ) + # TODO does MNIST need normalization? + # torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs + transform = torchvision.transforms.Compose(transform_list) + self.dataset_train = torchvision.datasets.MNIST( + self.data_dir, + train=True, + download=True, + transform=transform, + ) + self.dataset_test = torchvision.datasets.MNIST( + self.data_dir, + train=False, + transform=transform, + ) + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + + +class CIFAR10(ImageResolutionSequenceDataset): + _name_ = "cifar" + d_output = 10 + l_output = 0 + + @property + def init_defaults(self): + return { + "permute": None, + "grayscale": False, + "tokenize": False, # if grayscale, tokenize into discrete byte inputs + "augment": False, + "cutout": False, + "rescale": None, + "random_erasing": False, + "val_split": 0.1, + "seed": 42, # For validation split + } + + @property + def d_input(self): + if self.grayscale: + if self.tokenize: + return 256 + else: + return 1 + else: + assert not self.tokenize + return 3 + + def setup(self): + img_size = 32 + if self.rescale: + img_size //= self.rescale + + if self.grayscale: + preprocessors = [ + torchvision.transforms.Grayscale(), + torchvision.transforms.ToTensor(), + ] + permutations_list = [ + torchvision.transforms.Lambda( + lambda x: x.view(1, img_size * img_size).t() + ) # (L, d_input) + ] + + if self.tokenize: + preprocessors.append( + torchvision.transforms.Lambda(lambda x: (x * 255).long()) + ) + permutations_list.append(Rearrange("l 1 -> l")) + else: + preprocessors.append( + torchvision.transforms.Normalize( + mean=122.6 / 255.0, std=61.0 / 255.0 + ) + ) + else: + preprocessors = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + permutations_list = [ + torchvision.transforms.Lambda( + Rearrange("z h w -> (h w) z", z=3, h=img_size, w=img_size) + ) # (L, d_input) + ] + + # Permutations and reshaping + if self.permute == "br": + permutation = permutations.bitreversal_permutation(img_size * img_size) + print("bit reversal", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "snake": + permutation = permutations.snake_permutation(img_size, img_size) + print("snake", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "hilbert": + permutation = permutations.hilbert_permutation(img_size) + print("hilbert", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "transpose": + permutation = permutations.transpose_permutation(img_size, img_size) + transform = torchvision.transforms.Lambda( + lambda x: torch.cat([x, x[permutation]], dim=-1) + ) + permutations_list.append(transform) + elif self.permute == "2d": # h, w, c + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> h w c", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + elif self.permute == "2d_transpose": # c, h, w + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> c h w", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + + # Augmentation + if self.augment: + augmentations = [ + torchvision.transforms.RandomCrop( + img_size, padding=4, padding_mode="symmetric" + ), + torchvision.transforms.RandomHorizontalFlip(), + ] + + post_augmentations = [] + if self.cutout: + raise NotImplementedError("Cutout not currently supported.") + # post_augmentations.append(Cutout(1, img_size // 2)) + pass + if self.random_erasing: + # augmentations.append(RandomErasing()) + pass + else: + augmentations, post_augmentations = [], [] + transforms_train = ( + augmentations + preprocessors + post_augmentations + permutations_list + ) + transforms_eval = preprocessors + permutations_list + + transform_train = torchvision.transforms.Compose(transforms_train) + transform_eval = torchvision.transforms.Compose(transforms_eval) + self.dataset_train = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", + train=True, + download=True, + transform=transform_train, + ) + self.dataset_test = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", train=False, transform=transform_eval + ) + + if self.rescale: + print(f"Resizing all images to {img_size} x {img_size}.") + self.dataset_train.data = self.dataset_train.data.reshape((self.dataset_train.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + self.dataset_test.data = self.dataset_test.data.reshape((self.dataset_test.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + +class SpeechCommands(ResolutionSequenceDataset): + _name_ = "sc" + + @property + def init_defaults(self): + return { + "mfcc": False, + "dropped_rate": 0.0, + "length": 16000, + "all_classes": False, + } + + @property + def d_input(self): + _d_input = 20 if self.mfcc else 1 + _d_input += 1 if self.dropped_rate > 0.0 else 0 + return _d_input + + @property + def d_output(self): + return 10 if not self.all_classes else 35 + + @property + def l_output(self): + return 0 + + @property + def L(self): + return 161 if self.mfcc else self.length + + + def setup(self): + self.data_dir = self.data_dir or default_data_path # TODO make same logic as other classes + + from s5.dataloaders.sc import _SpeechCommands + + # TODO refactor with data_dir argument + self.dataset_train = _SpeechCommands( + partition="train", + length=self.L, + mfcc=self.mfcc, + sr=self.sr, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_val = _SpeechCommands( + partition="val", + length=self.L, + mfcc=self.mfcc, + sr=self.sr, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_test = _SpeechCommands( + partition="test", + length=self.L, + mfcc=self.mfcc, + sr=self.sr, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) diff --git a/Essay_classifier/s5/dataloaders/lra.py b/Essay_classifier/s5/dataloaders/lra.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd9b88f2ede65c3cfebb8bb7bc510e008f013fe --- /dev/null +++ b/Essay_classifier/s5/dataloaders/lra.py @@ -0,0 +1,736 @@ +"""Long Range Arena datasets""" +import io +import logging +import os +import pickle +from pathlib import Path +import torch +from torch import nn +import torch.nn.functional as F +import torchtext +import torchvision +from einops.layers.torch import Rearrange, Reduce +from PIL import Image # Only used for Pathfinder +from datasets import DatasetDict, Value, load_dataset, load_from_disk + +from .base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset + + +class IMDB(SequenceDataset): + _name_ = "imdb" + d_output = 2 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 4096, + "level": "char", + "min_freq": 15, + "seed": 42, + "val_split": 0.0, + "append_bos": False, + "append_eos": True, + # 'max_vocab': 135, + "n_workers": 4, # Only used for tokenizing dataset before caching + } + + @property + def n_tokens(self): + return len(self.vocab) + + def prepare_data(self): + if self.cache_dir is None: # Just download the dataset + load_dataset(self._name_, cache_dir=self.data_dir) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + """If cache_dir is not None, we'll cache the processed dataset there.""" + + # # NOTE - AW - we manually set these elsewhere. + # self.data_dir = self.data_dir or default_data_path / self._name_ + # self.cache_dir = self.data_dir / "cache" + + assert self.level in [ + "word", + "char", + ], f"level {self.level} not supported" + + if stage == "test" and hasattr(self, "dataset_test"): + return + dataset, self.tokenizer, self.vocab = self.process_dataset() + print( + f"IMDB {self.level} level | min_freq {self.min_freq} | vocab size {len(self.vocab)}" + ) + dataset.set_format(type="torch", columns=["input_ids", "label"]) + + # Create all splits + dataset_train, self.dataset_test = dataset["train"], dataset["test"] + if self.val_split == 0.0: + # Use test set as val set, as done in the LRA paper + self.dataset_train, self.dataset_val = dataset_train, None + else: + train_val = dataset_train.train_test_split( + test_size=self.val_split, seed=self.seed + ) + self.dataset_train, self.dataset_val = ( + train_val["train"], + train_val["test"], + ) + + def _collate_fn(self, batch): + xs, ys = zip(*[(data["input_ids"], data["label"]) for data in batch]) + lengths = torch.tensor([len(x) for x in xs]) + xs = nn.utils.rnn.pad_sequence( + xs, padding_value=self.vocab["optgroup
+ # element, or if there is no more content in the parent
+ # element.
+ if type == "StartTag":
+ return next["name"] in ('option', 'optgroup')
+ else:
+ return type == "EndTag" or type is None
+ elif tagname in ('rt', 'rp'):
+ # An rt element's end tag may be omitted if the rt element is
+ # immediately followed by an rt or rp element, or if there is
+ # no more content in the parent element.
+ # An rp element's end tag may be omitted if the rp element is
+ # immediately followed by an rt or rp element, or if there is
+ # no more content in the parent element.
+ if type == "StartTag":
+ return next["name"] in ('rt', 'rp')
+ else:
+ return type == "EndTag" or type is None
+ elif tagname == 'colgroup':
+ # A colgroup element's end tag may be omitted if the colgroup
+ # element is not immediately followed by a space character or
+ # a comment.
+ if type in ("Comment", "SpaceCharacters"):
+ return False
+ elif type == "StartTag":
+ # XXX: we also look for an immediately following colgroup
+ # element. See is_optional_start.
+ return next["name"] != 'colgroup'
+ else:
+ return True
+ elif tagname in ('thead', 'tbody'):
+ # A thead element's end tag may be omitted if the thead element
+ # is immediately followed by a tbody or tfoot element.
+ # A tbody element's end tag may be omitted if the tbody element
+ # is immediately followed by a tbody or tfoot element, or if
+ # there is no more content in the parent element.
+ # A tfoot element's end tag may be omitted if the tfoot element
+ # is immediately followed by a tbody element, or if there is no
+ # more content in the parent element.
+ # XXX: we never omit the end tag when the following element is
+ # a tbody. See is_optional_start.
+ if type == "StartTag":
+ return next["name"] in ['tbody', 'tfoot']
+ elif tagname == 'tbody':
+ return type == "EndTag" or type is None
+ else:
+ return False
+ elif tagname == 'tfoot':
+ # A tfoot element's end tag may be omitted if the tfoot element
+ # is immediately followed by a tbody element, or if there is no
+ # more content in the parent element.
+ # XXX: we never omit the end tag when the following element is
+ # a tbody. See is_optional_start.
+ if type == "StartTag":
+ return next["name"] == 'tbody'
+ else:
+ return type == "EndTag" or type is None
+ elif tagname in ('td', 'th'):
+ # A td element's end tag may be omitted if the td element is
+ # immediately followed by a td or th element, or if there is
+ # no more content in the parent element.
+ # A th element's end tag may be omitted if the th element is
+ # immediately followed by a td or th element, or if there is
+ # no more content in the parent element.
+ if type == "StartTag":
+ return next["name"] in ('td', 'th')
+ else:
+ return type == "EndTag" or type is None
+ return False
diff --git a/Essay_classifier/venv/Lib/site-packages/pip/_vendor/html5lib/filters/sanitizer.py b/Essay_classifier/venv/Lib/site-packages/pip/_vendor/html5lib/filters/sanitizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa7431d131213f85ab36cacc54b000e88898080b
--- /dev/null
+++ b/Essay_classifier/venv/Lib/site-packages/pip/_vendor/html5lib/filters/sanitizer.py
@@ -0,0 +1,916 @@
+"""Deprecated from html5lib 1.1.
+
+See `here This is a doc
') +This is a doc
') +,, and