File size: 5,401 Bytes
2ce7b1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# S5: Simplified State Space Layers for Sequence Modeling

This repository provides the implementation for the
paper: Simplified State Space Layers for Sequence Modeling.  The preprint is available [here](https://arxiv.org/abs/2208.04933).

![](./docs/figures/pngs/s5-matrix-blocks.png)
<p style="text-align: center;">
Figure 1:  S5 uses a single multi-input, multi-output linear state-space model, coupled with non-linearities, to define a non-linear sequence-to-sequence transformation. Parallel scans are used for efficient offline processing. 
</p>


The S5 layer builds on the prior S4 work ([paper](https://arxiv.org/abs/2111.00396)). While it has departed considerably, this repository originally started off with much of the JAX implementation of S4 from the
Annotated S4 blog by Rush and Karamcheti (available [here](https://github.com/srush/annotated-s4)).


## Requirements & Installation
To run the code on your own machine, run either `pip install -r requirements_cpu.txt` or `pip install -r requirements_gpu.txt`.  The GPU installation of JAX can be tricky, and so we include requirements that should work for most people, although further instructions are available [here](https://github.com/google/jax#installation).

Run from within the root directory `pip install -e .` to install the package. 


## Data Download
Downloading the raw data is done differently for each dataset.  The following datasets require no action:
- Text (IMDb)
- Image (Cifar black & white)
- sMNIST
- psMNIST
- Cifar (Color)

The remaining datasets need to be manually downloaded.  To download _everything_, run `./bin/download_all.sh`.  This will download quite a lot of data and will take some time.  

Below is a summary of the steps for each dataset:
- ListOps: run `./bin/download_lra.sh` to download the full LRA dataset.  
- Retrieval (AAN): run `./bin/download_aan.sh`
- Pathfinder: run `./bin/download_lra.sh` to download the full LRA dataset.
- Path-X: run `./bin/download_lra.sh` to download the full LRA dataset.
- Speech commands 35: run `./bin/download_sc35.sh` to download the speech commands data.

*With the exception of SC35.*  When the dataset is used for the first time, a cache is created in `./cache_dir`.  Converting the data (e.g. tokenizing) can be quite slow, and so this cache contains the processed dataset.  The cache can be moved and specified with the `--dir_name` argument (i.e. the default is `--dir_name=./cache_dir`) to avoid applying this preprocessing every time the code is run somewhere new.

SC35 is slightly different.  SC35 doesn't use `--dir_name`, and instead requires that the following path exists: `./raw_datasets/speech_commands/0.0.2/SpeechCommands` (i.e. the directory `./raw_datasets/speech_commands/0.0.2/SpeechCommands/zero` must exist).  The cache is then stored in `./raw_datasets/speech_commands/0.0.2/SpeechCommands/processed_data`.  This directory can then be copied (preserving the directory path) to move the preprocessed dataset to a new location.


## Repository Structure
Directories and files that ship with GitHub repo:
```
s5/                    Source code for models, datasets, etc.
    dataloading.py          Dataloading functions.
    layers.py               Defines the S5 layer which wraps the S5 SSM with nonlinearity, norms, dropout, etc.
    seq_model.py            Defines deep sequence models that consist of stacks of S5 layers.
    ssm.py                  S5 SSM implementation.
    ssm_init.py             Helper functions for initializing the S5 SSM .
    train.py                Training loop code.
    train_helpers.py        Functions for optimization, training and evaluation steps.
    dataloaders/            Code mainly derived from S4 processing each dataset.
    utils/                  Range of utility functions.
bin/                    Shell scripts for downloading data and running example experiments.
requirements_cpu.txt    Requirements for running in CPU mode (not advised).
requirements_gpu.txt    Requirements for running in GPU mode (installation can be highly system-dependent).
run_train.py            Training loop entrypoint.
```

Directories that may be created on-the-fly:
```
raw_datasets/       Raw data as downloaded.
cache_dir/          Precompiled caches of data.  Can be copied to new locations to avoid preprocessing.
wandb/              Local WandB log files.
```

## Experiments

The configurations to run the LRA and 35-way Speech Commands experiments from the paper are located in  `bin/run_experiments`. For example,
to run the LRA text (character level IMDB) experiment, run `./bin/run_experiments/run_lra_imdb.sh`. 
To log with W&B, adjust the default `USE_WANDB, wandb_entity, wandb_project` arguments. 
Note: the pendulum
regression dataloading and experiments will be added soon.

## Citation
Please use the following when citing our work:
```
@misc{smith2022s5,
  doi = {10.48550/ARXIV.2208.04933},
  url = {https://arxiv.org/abs/2208.04933},
  author = {Smith, Jimmy T. H. and Warrington, Andrew and Linderman, Scott W.},
  keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Simplified State Space Layers for Sequence Modeling},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}
```

Please reach out if you have any questions.

-- The S5 authors.