Teerth Patel commited on
Commit
199a42f
1 Parent(s): 772112c

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitignore +40 -0
  3. Dockerfile +18 -0
  4. app.py +77 -0
  5. benchmarks/CLRS/env/__init__.py +14 -0
  6. benchmarks/CLRS/env/baseline_model_description.txt +507 -0
  7. benchmarks/CLRS/env/baselines.py +794 -0
  8. benchmarks/CLRS/env/baselines_test.py +294 -0
  9. benchmarks/CLRS/env/data_description.txt +35 -0
  10. benchmarks/CLRS/env/dataset.py +326 -0
  11. benchmarks/CLRS/env/dataset_test.py +116 -0
  12. benchmarks/CLRS/env/decoders.py +381 -0
  13. benchmarks/CLRS/env/decoders_test.py +47 -0
  14. benchmarks/CLRS/env/encoders.py +139 -0
  15. benchmarks/CLRS/env/evaluation.py +202 -0
  16. benchmarks/CLRS/env/evaluation_test.py +55 -0
  17. benchmarks/CLRS/env/losses.py +209 -0
  18. benchmarks/CLRS/env/losses_test.py +166 -0
  19. benchmarks/CLRS/env/model.py +46 -0
  20. benchmarks/CLRS/env/nets.py +719 -0
  21. benchmarks/CLRS/env/probing.py +351 -0
  22. benchmarks/CLRS/env/probing_test.py +192 -0
  23. benchmarks/CLRS/env/processors.py +856 -0
  24. benchmarks/CLRS/env/processors_test.py +64 -0
  25. benchmarks/CLRS/env/samplers.py +882 -0
  26. benchmarks/CLRS/env/samplers_test.py +250 -0
  27. benchmarks/CLRS/env/specs.py +525 -0
  28. benchmarks/CLRS/env/train.py +560 -0
  29. benchmarks/CLRS/scripts/eval.py +454 -0
  30. benchmarks/CLRS/scripts/requirements.txt +13 -0
  31. benchmarks/CLRS/scripts/research_problem.txt +3 -0
  32. benchmarks/CLRS/scripts/source_code.txt +1 -0
  33. benchmarks/amp-parkinsons-disease-progression-prediction/env/data_description.txt +33 -0
  34. benchmarks/amp-parkinsons-disease-progression-prediction/env/evaluation_details.txt +12 -0
  35. benchmarks/amp-parkinsons-disease-progression-prediction/env/public_timeseries_testing_util.py +94 -0
  36. benchmarks/amp-parkinsons-disease-progression-prediction/env/train.py +141 -0
  37. benchmarks/amp-parkinsons-disease-progression-prediction/scripts/eval.py +21 -0
  38. benchmarks/amp-parkinsons-disease-progression-prediction/scripts/prepare.py +79 -0
  39. benchmarks/amp-parkinsons-disease-progression-prediction/scripts/read_only_files.txt +5 -0
  40. benchmarks/amp-parkinsons-disease-progression-prediction/scripts/research_problem.txt +3 -0
  41. benchmarks/amp-parkinsons-disease-progression-prediction/scripts/source_code.txt +2 -0
  42. benchmarks/babylm/env/babyLM_for_hf.py +104 -0
  43. benchmarks/babylm/env/train.py +641 -0
  44. benchmarks/babylm/scripts/eval.py +212 -0
  45. benchmarks/babylm/scripts/prepare.py +11 -0
  46. benchmarks/babylm/scripts/read_only_files.txt +2 -0
  47. benchmarks/babylm/scripts/research_problem.txt +7 -0
  48. benchmarks/bibtex-generation/env/arxiv_API_reference.txt +599 -0
  49. benchmarks/bibtex-generation/env/bibtex_generation.py +0 -0
  50. benchmarks/bibtex-generation/env/claude_example.py +11 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *_api_key.txt
2
+
3
+ __pycache__
4
+ build
5
+ dist
6
+ /.venv/
7
+ .env
8
+ .vscode
9
+ .mypy_cache
10
+ .pytest_cache
11
+ *.pyc
12
+ flagged
13
+ wandb
14
+ logs
15
+ temp
16
+ tests/logs
17
+ tests/wandb
18
+
19
+ # Outputs generated by the cli demo
20
+ cached_generated_dataset/
21
+ generated_dataset/
22
+ huggingface_data/huggingface_datasets/dataset_index.json
23
+ huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
24
+ huggingface_data/huggingface_datasets/reranking_dataset_index.json
25
+ huggingface_data/huggingface_models/
26
+ retrieved_dataset_dict/
27
+ result/
28
+ checkpoint/
29
+ status.yaml
30
+ # Outputs generated by the colab demo
31
+ trained_model/
32
+ trained_tokenizer/
33
+
34
+
35
+ # data
36
+ babylm_data
37
+ checkpoint-*/
38
+
39
+ # hf models info
40
+ huggingface_models/
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM anibali/pytorch:2.0.0-cuda11.8-ubuntu22.04
2
+
3
+ # Set up time zone.
4
+ ENV TZ=UTC
5
+ RUN sudo ln -snf /usr/share/zoneinfo/$TZ /etc/localtime
6
+
7
+ USER root
8
+ RUN apt update && apt install -y gcc-10 g++-10 && ln /usr/bin/gcc-10 /usr/bin/gcc && ln /usr/bin/g++-10 /usr/bin/g++ && apt install -y zlib1g-dev && rm -r /var/lib/apt/lists/*
9
+
10
+ # copy files
11
+ WORKDIR /app
12
+ COPY . .
13
+
14
+ # Install libraries
15
+ RUN python3 -m pip install -r requirements.txt
16
+
17
+ # start bash shell
18
+ CMD bash
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from reactagent.environment import Environment
4
+ from reactagent.agents.agent_research import ResearchAgent
5
+ from reactagent.runner import create_parser
6
+ from reactagent import llm
7
+ from reactagent.users.user import User
8
+
9
+ class SessionInfo:
10
+ def __init__(self):
11
+ self.coro_cache = {}
12
+ self.parser = create_parser()
13
+
14
+ def make_session(self, prompt, session_hash):
15
+ id = session_hash
16
+
17
+ llm_name='claude-3-5-sonnet-20240620'
18
+ fastllm_name='claude-3-haiku-20240307'
19
+ rawargs = [
20
+ '--research-problem', prompt,
21
+ '--log-dir', str(Path('logs', id)),
22
+ '--work-dir', str(Path('workspaces', id)),
23
+ '--llm-name', llm_name,
24
+ '--edit-script-llm-name', llm_name,
25
+ '--fast-llm-name', fastllm_name,
26
+ ]
27
+
28
+ args = self.parser.parse_args(rawargs)
29
+ llm.FAST_MODEL = args.fast_llm_name
30
+ env = Environment(args)
31
+ agent = ResearchAgent(args, env)
32
+ coro = agent.run(env)
33
+
34
+ self.coro_cache[id] = coro
35
+ return id
36
+
37
+ def get_response(self, human_input, session_hash):
38
+ coro_input = human_input
39
+ if session_hash not in self.coro_cache:
40
+ self.make_session(human_input, session_hash)
41
+ coro_input = None
42
+
43
+ try:
44
+ output = self.coro_cache[session_hash].send(coro_input)
45
+ except StopIteration:
46
+ output = None
47
+ del self.coro_cache[session_hash]
48
+
49
+ return output
50
+
51
+ session_info = SessionInfo()
52
+
53
+ def info_to_message(info):
54
+ msg = ""
55
+ for k, v in info.items():
56
+ if isinstance(v, dict):
57
+ tempv = v
58
+ v = ""
59
+ for k2, v2 in tempv.items():
60
+ v += f"{k2}:\n {v2}\n"
61
+ v = User.indent_text(v, 2)
62
+ msg += '-' * 64
63
+ msg += '\n'
64
+ msg += f"{k}:\n{v}\n"
65
+
66
+ msg += "Please provide feedback based on the history, response entries, and observation, and questions: "
67
+ return msg
68
+
69
+ def predict(message, history, request: gr.Request):
70
+ response = session_info.get_response(message, request.session_hash)
71
+ if response is None:
72
+ response = "Agent is finished. Enter a new instruction."
73
+ return response
74
+
75
+ if __name__ == "__main__":
76
+ demo = gr.ChatInterface(predict)
77
+ demo.launch()
benchmarks/CLRS/env/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
benchmarks/CLRS/env/baseline_model_description.txt ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model:
2
+
3
+ Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In
4
+ a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of
5
+ the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes
6
+ in the GNN that will execute the algorithm.
7
+ A sample of every algorithm is represented as a graph, with each input, output and hint located in
8
+ either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension,
9
+ and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of
10
+ the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar,
11
+ categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and
12
+ loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and
13
+ optimised using mean squared error.
14
+
15
+ Base Model
16
+
17
+ Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS
18
+ benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder
19
+ fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as
20
+ high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the
21
+ same dimension and are added together; the same happens with hints and inputs located in edges,
22
+ and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph
23
+ 3
24
+
25
+ A Generalist Neural Algorithmic Learner
26
+
27
+ embeddings. Thus, at the
28
+ step for a time-step t of the algorithm, we have a
29
+ n end of the encoding
30
+ o
31
+ (t) (t)
32
+ (t)
33
+ single set of embeddings xi , eij , g
34
+ , shapes n × h, n × n × h, and h, in the nodes, edges and
35
+ graph, respectively. Note that this is independent of the number and type of the inputs and hints of
36
+ the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS.
37
+ Further, note that at each step, the input encoding is fed directly to these embeddings—this recall
38
+ mechanism significantly improves the model’s robustness over long trajectories [34].
39
+ Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed
40
+ (t)
41
+ node embeddings, hi . Additionally, the processor uses the processed node embeddings from the
42
+ (t−1)
43
+ previous step, hi
44
+ , as inputs. Importantly, the same processor model can operate on graphs of any
45
+ size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and
46
+ passing messages over a fully-connected graph, as our base model. The MPNN computes processed
47
+ embeddings as follows:
48
+ 
49
+ 
50
+ 
51
+ 
52
+ (t)
53
+ (t−1)
54
+ (t)
55
+ (t) (t) (t)
56
+ (t)
57
+ (t)
58
+ (t)
59
+ z(t) = xi khi
60
+ mi = max fm zi , zj , eij , g(t)
61
+ hi = fr zi , mi
62
+ (1)
63
+ 1≤j≤n
64
+
65
+ starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the
66
+ message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh →
67
+ Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max
68
+ aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the
69
+ neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations
70
+ (t)
71
+ where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before
72
+ using them further. Further details on the MPNN processor may be found in Veličković et al. [5].
73
+ Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict
74
+ the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder
75
+ relies mainly on a linear decoder for each hint and output, along with a mechanism to compute
76
+ pairwise node similarities when appropriate. Specifically, the pointer type decoder computes
77
+ a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the
78
+ argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used).
79
+ Loss. The decoded hints and outputs are used to compute the loss during training, according to their
80
+ type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time,
81
+ and the output loss is averaged across outputs (most algorithms have a single output, though some
82
+ have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at
83
+ each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing
84
+ is used (see Section 3.2.1).
85
+ We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution
86
+ samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution
87
+ evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only
88
+ these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can
89
+ be found in Appendix A.
90
+ 3.2
91
+
92
+ Model improvements
93
+
94
+ As previously discussed, single-task improvements, especially in terms of learning stability, will
95
+ empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner,
96
+ all the changes made to the model, which have lead to an absolute improvement of over 20% on
97
+ average across all 30 tasks in CLRS.
98
+ 3.2.1
99
+
100
+ Dataset and training
101
+
102
+ Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints
103
+ in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes
104
+ advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint
105
+ values instead of the network’s own predictions. In the prior model [5], ground-truth hints were
106
+ 4
107
+
108
+ A Generalist Neural Algorithmic Learner
109
+
110
+ provided during training with probability 0.5, as, without teacher forcing, losses tended to grow
111
+ unbounded along a trajectory when scalar hints were present, destabilising the training. In this
112
+ work we incorporate several significant stabilising changes (described in future paragraphs), which
113
+ allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the
114
+ network becoming overconfident in always expecting correct hint predictions. With teacher forcing,
115
+ performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String
116
+ Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9).
117
+ Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed
118
+ CLRS training dataset [5], we augmented the training data in three key ways, without breaking
119
+ the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new
120
+ training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly,
121
+ we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate
122
+ for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph
123
+ algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi
124
+ model [38]); and for string matching algorithms, we varied the length of the pattern to be matched.
125
+ These both serve to expose the model to different trajectory lengths; for example, in many graph
126
+ algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and
127
+ varying the connection probability in the graph generation allows for varying the expected diameter.
128
+ These changes considerably increase training data variability, compared to the original dataset in
129
+ Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process
130
+ in Appendix A.
131
+ Soft hint propagation. When predicted hints are fed back as inputs during training, gradients
132
+ may or may not be allowed to flow through them. In previous work, only hints of the scalar type
133
+ allowed gradients through, as all categoricals were post-processed from logits into the ground-truth
134
+ format via argmax or thresholding before being fed back. Instead, in this work we use softmax
135
+ for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without
136
+ these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing),
137
+ as well as in Naïve String Matcher (Appendix A, Figs. 7-9).
138
+ Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common
139
+ to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of
140
+ this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the
141
+ model can overfit to the fixed training values. We therefore turned this fixed hint into an input for
142
+ these 11 algorithms, eliminating the need for explicitly predicting it.
143
+ Improving training stability with encoder initialisation and gradient clipping. The scalar
144
+ hints have unbounded values, in principle, and are optimised using mean-squared error, hence their
145
+ gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then
146
+ get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to
147
+ exploding signals (and consequently gradients), even before any training takes place.
148
+ To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for
149
+ scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun
150
+ initialisation [46] elsewhere. This combination of initialisations proved important for the initial
151
+ learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw
152
+ drastic improvements in learning stability, as well as significant increases in validation performance,
153
+ with gradient clipping [47], which we subsequently employed in all experiments.
154
+ 3.2.2
155
+
156
+ Encoders and decoders
157
+
158
+ Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar
159
+ input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node
160
+ index. To avoid overfitting to these linearly spaced values during training, we replaced them with
161
+ random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly
162
+ spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to
163
+ 3
164
+
165
+ Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity
166
+ Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44].
167
+ 5
168
+
169
+ A Generalist Neural Algorithmic Learner
170
+
171
+ these positions, such as string matching. Namely, the model could learn to base all of its computations
172
+ on the assumption that it will always be finding a m-character pattern inside an n-character string,
173
+ even though at test time, m and n will increase fourfold.
174
+ Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble
175
+ Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS
176
+ benchmark, this permutation is encoded as a pointer where each node points to its predecessor in
177
+ the sorted order (the first node points to itself); this is represented as a n × n matrix P where each
178
+ row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of
179
+ pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained
180
+ decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does
181
+ not explicitly take advantage of the fact that the pointers encode a permutation, which the model
182
+ has to learn instead. Our early experiments showed that the model was often failing to predict valid
183
+ permutations OOD.
184
+ Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as
185
+ follows. First, we modify the output representation by rewiring the first node to point to the last one,
186
+ turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We
187
+ also augment the representation with a one-hot vector of size n that specifies the first node, so we do
188
+ not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the
189
+ permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax
190
+ with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly
191
+ stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating
192
+ and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by:
193
+ S 0 (Y) = exp(Y)
194
+
195
+ S l (Y) = Tc (Tr (S l−1 (Y)))
196
+
197
+ S(Y) = lim S l (Y),
198
+ l→∞
199
+
200
+ (2)
201
+
202
+ where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively.
203
+ Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix,
204
+ we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking
205
+ P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a
206
+ permutation matrix [52, Theorem 1].
207
+ In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a
208
+ smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and
209
+ lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good
210
+ balance between speed of convergence and tie-breaking. We also encode the fact that no node points
211
+ to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to
212
+ −∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to
213
+ applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P,
214
+ and mask_one pointing to the first element, into the original pointer representation used by CLRS.
215
+ 3.2.3
216
+
217
+ Processor networks
218
+
219
+ Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping
220
+ the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it
221
+ updates all hidden states in each step. Although it is theoretically possible for the network to keep the
222
+ states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness
223
+ in NDRs [54], we augment the network with an update gate, biased to be closed by default. We
224
+ found that the gate stabilizes learning on many of the tasks, and increases the mean performance
225
+ over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be
226
+ advantageous in the multi-task case.
227
+ To add gating to the MPNN model we produce a per-node gating vector from the same inputs that
228
+ process the embeddings in Equation 1:
229
+ 
230
+ 
231
+ (t)
232
+ (t)
233
+ (t)
234
+ gi = fg zi , mi
235
+ (3)
236
+ where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with
237
+ ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the
238
+ final layer bias of fg is initialized to a value of −3, which biases the network for not updating its
239
+ 6
240
+
241
+ A Generalist Neural Algorithmic Learner
242
+
243
+ Our model
244
+ Previous SOTA [5]
245
+
246
+ 80
247
+ 60
248
+ 40
249
+
250
+ Quickselect
251
+
252
+ Heapsort
253
+
254
+ Knuth-Morris-Pratt
255
+
256
+ Strongly Conn. Comps.
257
+
258
+ DFS
259
+
260
+ Floyd-Warshall
261
+
262
+ Quicksort
263
+
264
+ Bubble Sort
265
+
266
+ Optimal BST
267
+
268
+ Find Max. Subarray
269
+
270
+ Insertion Sort
271
+
272
+ Binary Search
273
+
274
+ LCS Length
275
+
276
+ Naïve String Matcher
277
+
278
+ MST Prim
279
+
280
+ Topological Sort
281
+
282
+ Task Scheduling
283
+
284
+ MST Kruskal
285
+
286
+ Articulation Points
287
+
288
+ Jarvis' March
289
+
290
+ Matrix Chain Order
291
+
292
+ Bridges
293
+
294
+ Graham Scan
295
+
296
+ Dijkstra
297
+
298
+ Activity Selector
299
+
300
+ Bellman-Ford
301
+
302
+ DAG Shortest Paths
303
+
304
+ Segments Intersect
305
+
306
+ 0
307
+
308
+ BFS
309
+
310
+ 20
311
+ Minimum
312
+
313
+ Average score [%]
314
+
315
+ 100
316
+
317
+ Figure 2: The OOD performance in single-task experiments before and after the improvements
318
+ presented in this paper, sorted in descending order of current performance. Error bars represent
319
+ standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current).
320
+ The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2).
321
+ b (t) , are computed as follows:
322
+ representations, unless necessary. The processed gated embeddings, h
323
+ i
324
+ b (t) = g(t)
325
+ h
326
+ i
327
+ i
328
+ and are used instead of
329
+
330
+ (t)
331
+ hi
332
+
333
+ (t)
334
+
335
+ (t)
336
+
337
+ hi + (1 − gi )
338
+
339
+ in the subsequent steps, replacing z
340
+
341
+ (t−1)
342
+
343
+ hi
344
+ (t)
345
+
346
+ (4)
347
+
348
+ in Eq. 1 by z
349
+
350
+ (t)
351
+
352
+ =
353
+
354
+ (t) b (t−1)
355
+ xi kh
356
+ .
357
+ i
358
+
359
+ Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning—
360
+ where edges store values, and update them based on other edges’ values. An example of this is the
361
+ Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The
362
+ update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which
363
+ roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then
364
+ from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic
365
+ programming. Even though there are no node representations in the above update, all our processors
366
+ are centered on passing messages between node representations hi .
367
+ To rectify this situation, we augment our processor to perform message passing towards edges.
368
+ Referring again to the update for dij , we note that the edge representations are updated by choosing
369
+ an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations
370
+ over triplets of nodes, then reducing over one node to obtain edge latents:
371
+ tijk = ψt (hi , hj , hk , eij , eik , ekj , g)
372
+ hij = φt (max tijk )
373
+ (5)
374
+ k
375
+
376
+ Here, ψt is a triplet message function, mapping all relevant representations to a single vector for
377
+ each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets
378
+ for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the
379
+ max aggregation to obtain edge representations. The computed hij vectors can then be used in any
380
+ edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks
381
+ where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree
382
+ algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily
383
+ sort the edges by weight, as it selects how to augment the spanning forest at each step.
384
+ In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only
385
+ 8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions,
386
+ to make them compatible with the rest of the architecture. Our initial experimentation demonstrated
387
+ that the output dimensionality of ψt did not significantly affect downstream performance. Note that
388
+ computing triplet representations has been a useful approach in general GNN design [57]—however,
389
+ it has predominantly been studied in the context of GNNs over constant input features. Our study is
390
+ among the first to verify their utility over reasoning tasks with well-specified initial features.
391
+ 3.3
392
+
393
+ Results
394
+
395
+ By incorporating the changes described in the previous sections we arrived at a single model type,
396
+ with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance
397
+ 7
398
+
399
+ A Generalist Neural Algorithmic Learner
400
+
401
+ Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our
402
+ best model Triplet-GMPNN with all our improvements, after 10,000 training steps.
403
+ Alg. Type
404
+
405
+ Memnet [5]
406
+
407
+ MPNN [5]
408
+
409
+ PGN [5]
410
+
411
+ Triplet-GMPNN (ours)
412
+
413
+ Div. & C.
414
+ DP
415
+ Geometry
416
+ Graphs
417
+ Greedy
418
+ Search
419
+ Sorting
420
+ Strings
421
+
422
+ 13.05% ± 0.14
423
+ 67.94% ± 8.20
424
+ 45.14% ± 11.95
425
+ 24.12% ± 5.30
426
+ 53.42% ± 20.82
427
+ 34.35% ± 21.67
428
+ 71.53% ± 1.41
429
+ 1.51% ± 0.46
430
+
431
+ 20.30% ± 0.85
432
+ 65.10% ± 6.44
433
+ 73.11% ± 17.19
434
+ 62.79% ± 8.75
435
+ 82.39% ± 3.01
436
+ 41.20% ± 19.87
437
+ 11.83% ± 2.78
438
+ 3.21% ± 0.94
439
+
440
+ 65.23% ± 4.44
441
+ 70.58% ± 6.48
442
+ 61.19% ± 7.01
443
+ 60.25% ± 8.42
444
+ 75.84% ± 6.59
445
+ 56.11% ± 21.56
446
+ 15.45% ± 8.46
447
+ 2.04% ± 0.20
448
+
449
+ 76.36% ± 1.34
450
+ 81.99% ± 4.98
451
+ 94.09% ± 2.30
452
+ 81.41% ± 6.21
453
+ 91.21% ± 2.95
454
+ 58.61% ± 24.34
455
+ 60.37% ± 12.16
456
+ 49.09% ± 23.49
457
+
458
+ 38.88%
459
+
460
+ 44.99%
461
+
462
+ 50.84%
463
+
464
+ 74.14%
465
+
466
+ 0/30
467
+ 3/30
468
+ 10/30
469
+
470
+ 6/30
471
+ 9/30
472
+ 14/30
473
+
474
+ 3/30
475
+ 7/30
476
+ 15/30
477
+
478
+ 11/30
479
+ 17/30
480
+ 24/30
481
+
482
+ Overall avg.
483
+ > 90%
484
+ > 80%
485
+ > 60%
486
+
487
+ on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as
488
+ Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test
489
+ set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines
490
+ include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5].
491
+ Figure 2 displays the comparison between the improved model and the best model from Veličković
492
+ et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher
493
+ (in absolute terms) compared to the next best model (see Table 1), and to a significant performance
494
+ improvement in all but one algorithm family, compared to every other model. Further, our stabilising
495
+ changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient
496
+ updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We
497
+ finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to
498
+ the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%.
499
+ There are two notable examples of algorithm families with significant OOD performance improvement.
500
+ The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now
501
+ solved at approximately 94% OOD, compared to the previous best of about 73%; the second being
502
+ string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds
503
+ 49% compared to the previous best of approximately 3%.
504
+ The significant overall performance boost is reflected in the increased number of algorithms we can
505
+ now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically,
506
+ we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9
507
+ previously) and 90% for 11 algorithms (6 previously).
benchmarks/CLRS/env/baselines.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """JAX implementation of CLRS baseline models."""
17
+
18
+ import functools
19
+ import os
20
+ import pickle
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import chex
24
+
25
+ from clrs._src import decoders
26
+ from clrs._src import losses
27
+ from clrs._src import model
28
+ from clrs._src import nets
29
+ from clrs._src import probing
30
+ from clrs._src import processors
31
+ from clrs._src import samplers
32
+ from clrs._src import specs
33
+
34
+ import haiku as hk
35
+ import jax
36
+ import jax.numpy as jnp
37
+ import numpy as np
38
+ import optax
39
+
40
+
41
+ _Array = chex.Array
42
+ _DataPoint = probing.DataPoint
43
+ _Features = samplers.Features
44
+ _FeaturesChunked = samplers.FeaturesChunked
45
+ _Feedback = samplers.Feedback
46
+ _Location = specs.Location
47
+ _Seed = jnp.integer
48
+ _Spec = specs.Spec
49
+ _Stage = specs.Stage
50
+ _Trajectory = samplers.Trajectory
51
+ _Type = specs.Type
52
+ _OutputClass = specs.OutputClass
53
+
54
+ # pytype: disable=signature-mismatch
55
+
56
+
57
+ def _maybe_pick_first_pmapped(tree):
58
+ if jax.local_device_count() == 1:
59
+ return tree
60
+ return jax.tree_util.tree_map(lambda x: x[0], tree)
61
+
62
+
63
+ @jax.jit
64
+ def _restack_from_pmap(tree):
65
+ """Stack the results of a pmapped computation across the first two axes."""
66
+ restack_array = lambda x: jnp.reshape(x, (-1,) + x.shape[2:])
67
+ return jax.tree_util.tree_map(restack_array, tree)
68
+
69
+
70
+ def _maybe_restack_from_pmap(tree):
71
+ if jax.local_device_count() == 1:
72
+ return tree
73
+ return _restack_from_pmap(tree)
74
+
75
+
76
+ @functools.partial(jax.jit, static_argnums=[1, 2])
77
+ def _pmap_reshape(x, n_devices, split_axis=0):
78
+ """Splits a pytree over n_devices on axis split_axis for pmapping."""
79
+ def _reshape(arr):
80
+ new_shape = (arr.shape[:split_axis] +
81
+ (n_devices, arr.shape[split_axis] // n_devices) +
82
+ arr.shape[split_axis + 1:])
83
+ return jnp.moveaxis(jnp.reshape(arr, new_shape), split_axis, 0)
84
+ return jax.tree_util.tree_map(_reshape, x)
85
+
86
+
87
+ def _maybe_pmap_reshape(x, split_axis=0):
88
+ n_devices = jax.local_device_count()
89
+ if n_devices == 1:
90
+ return x
91
+ return _pmap_reshape(x, n_devices, split_axis)
92
+
93
+
94
+ @functools.partial(jax.jit, static_argnums=1)
95
+ def _pmap_data(data: Union[_Feedback, _Features], n_devices: int):
96
+ """Replicate/split feedback or features for pmapping."""
97
+ if isinstance(data, _Feedback):
98
+ features = data.features
99
+ else:
100
+ features = data
101
+ pmap_data = features._replace(
102
+ inputs=_pmap_reshape(features.inputs, n_devices),
103
+ hints=_pmap_reshape(features.hints, n_devices, split_axis=1),
104
+ lengths=_pmap_reshape(features.lengths, n_devices),
105
+ )
106
+ if isinstance(data, _Feedback):
107
+ pmap_data = data._replace(
108
+ features=pmap_data,
109
+ outputs=_pmap_reshape(data.outputs, n_devices)
110
+ )
111
+ return pmap_data
112
+
113
+
114
+ def _maybe_pmap_data(data: Union[_Feedback, _Features]):
115
+ n_devices = jax.local_device_count()
116
+ if n_devices == 1:
117
+ return data
118
+ return _pmap_data(data, n_devices)
119
+
120
+
121
+ def _maybe_put_replicated(tree):
122
+ if jax.local_device_count() == 1:
123
+ return jax.device_put(tree)
124
+ else:
125
+ return jax.device_put_replicated(tree, jax.local_devices())
126
+
127
+
128
+ def _maybe_pmap_rng_key(rng_key: _Array):
129
+ n_devices = jax.local_device_count()
130
+ if n_devices == 1:
131
+ return rng_key
132
+ pmap_rng_keys = jax.random.split(rng_key, n_devices)
133
+ return jax.device_put_sharded(list(pmap_rng_keys), jax.local_devices())
134
+
135
+
136
+ class BaselineModel(model.Model):
137
+ """Model implementation with selectable message passing algorithm."""
138
+
139
+ def __init__(
140
+ self,
141
+ spec: Union[_Spec, List[_Spec]],
142
+ dummy_trajectory: Union[List[_Feedback], _Feedback],
143
+ processor_factory: processors.ProcessorFactory,
144
+ hidden_dim: int = 32,
145
+ encode_hints: bool = False,
146
+ decode_hints: bool = True,
147
+ encoder_init: str = 'default',
148
+ use_lstm: bool = False,
149
+ learning_rate: float = 0.005,
150
+ grad_clip_max_norm: float = 0.0,
151
+ checkpoint_path: str = '/tmp/clrs3',
152
+ freeze_processor: bool = False,
153
+ dropout_prob: float = 0.0,
154
+ hint_teacher_forcing: float = 0.0,
155
+ hint_repred_mode: str = 'soft',
156
+ name: str = 'base_model',
157
+ nb_msg_passing_steps: int = 1,
158
+ ):
159
+ """Constructor for BaselineModel.
160
+
161
+ The model consists of encoders, processor and decoders. It can train
162
+ and evaluate either a single algorithm or a set of algorithms; in the
163
+ latter case, a single processor is shared among all the algorithms, while
164
+ the encoders and decoders are separate for each algorithm.
165
+
166
+ Args:
167
+ spec: Either a single spec for one algorithm, or a list of specs for
168
+ multiple algorithms to be trained and evaluated.
169
+ dummy_trajectory: Either a single feedback batch, in the single-algorithm
170
+ case, or a list of feedback batches, in the multi-algorithm case, that
171
+ comply with the `spec` (or list of specs), to initialize network size.
172
+ processor_factory: A callable that takes an `out_size` parameter
173
+ and returns a processor (see `processors.py`).
174
+ hidden_dim: Size of the hidden state of the model, i.e., size of the
175
+ message-passing vectors.
176
+ encode_hints: Whether to provide hints as model inputs.
177
+ decode_hints: Whether to provide hints as model outputs.
178
+ encoder_init: The initialiser type to use for the encoders.
179
+ use_lstm: Whether to insert an LSTM after message passing.
180
+ learning_rate: Learning rate for training.
181
+ grad_clip_max_norm: if greater than 0, the maximum norm of the gradients.
182
+ checkpoint_path: Path for loading/saving checkpoints.
183
+ freeze_processor: If True, the processor weights will be frozen and
184
+ only encoders and decoders (and, if used, the lstm) will be trained.
185
+ dropout_prob: Dropout rate in the message-passing stage.
186
+ hint_teacher_forcing: Probability of using ground-truth hints instead
187
+ of predicted hints as inputs during training (only relevant if
188
+ `encode_hints`=True)
189
+ hint_repred_mode: How to process predicted hints when fed back as inputs.
190
+ Only meaningful when `encode_hints` and `decode_hints` are True.
191
+ Options are:
192
+ - 'soft', where we use softmaxes for categoricals, pointers
193
+ and mask_one, and sigmoids for masks. This will allow gradients
194
+ to flow through hints during training.
195
+ - 'hard', where we use argmax instead of softmax, and hard
196
+ thresholding of masks. No gradients will go through the hints
197
+ during training; even for scalar hints, which don't have any
198
+ kind of post-processing, gradients will be stopped.
199
+ - 'hard_on_eval', which is soft for training and hard for evaluation.
200
+ name: Model name.
201
+ nb_msg_passing_steps: Number of message passing steps per hint.
202
+
203
+ Raises:
204
+ ValueError: if `encode_hints=True` and `decode_hints=False`.
205
+ """
206
+ super(BaselineModel, self).__init__(spec=spec)
207
+
208
+ if encode_hints and not decode_hints:
209
+ raise ValueError('`encode_hints=True`, `decode_hints=False` is invalid.')
210
+
211
+ assert hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
212
+
213
+ self.decode_hints = decode_hints
214
+ self.checkpoint_path = checkpoint_path
215
+ self.name = name
216
+ self._freeze_processor = freeze_processor
217
+ if grad_clip_max_norm != 0.0:
218
+ optax_chain = [optax.clip_by_global_norm(grad_clip_max_norm),
219
+ optax.scale_by_adam(),
220
+ optax.scale(-learning_rate)]
221
+ self.opt = optax.chain(*optax_chain)
222
+ else:
223
+ self.opt = optax.adam(learning_rate)
224
+
225
+ self.nb_msg_passing_steps = nb_msg_passing_steps
226
+
227
+ self.nb_dims = []
228
+ if isinstance(dummy_trajectory, _Feedback):
229
+ assert len(self._spec) == 1
230
+ dummy_trajectory = [dummy_trajectory]
231
+ for traj in dummy_trajectory:
232
+ nb_dims = {}
233
+ for inp in traj.features.inputs:
234
+ nb_dims[inp.name] = inp.data.shape[-1]
235
+ for hint in traj.features.hints:
236
+ nb_dims[hint.name] = hint.data.shape[-1]
237
+ for outp in traj.outputs:
238
+ nb_dims[outp.name] = outp.data.shape[-1]
239
+ self.nb_dims.append(nb_dims)
240
+
241
+ self._create_net_fns(hidden_dim, encode_hints, processor_factory, use_lstm,
242
+ encoder_init, dropout_prob, hint_teacher_forcing,
243
+ hint_repred_mode)
244
+ self._device_params = None
245
+ self._device_opt_state = None
246
+ self.opt_state_skeleton = None
247
+
248
+ def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
249
+ use_lstm, encoder_init, dropout_prob,
250
+ hint_teacher_forcing, hint_repred_mode):
251
+ def _use_net(*args, **kwargs):
252
+ return nets.Net(self._spec, hidden_dim, encode_hints, self.decode_hints,
253
+ processor_factory, use_lstm, encoder_init,
254
+ dropout_prob, hint_teacher_forcing,
255
+ hint_repred_mode,
256
+ self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
257
+
258
+ self.net_fn = hk.transform(_use_net)
259
+ pmap_args = dict(axis_name='batch', devices=jax.local_devices())
260
+ n_devices = jax.local_device_count()
261
+ func, static_arg, extra_args = (
262
+ (jax.jit, 'static_argnums', {}) if n_devices == 1 else
263
+ (jax.pmap, 'static_broadcasted_argnums', pmap_args))
264
+ pmean = functools.partial(jax.lax.pmean, axis_name='batch')
265
+ self._maybe_pmean = pmean if n_devices > 1 else lambda x: x
266
+ extra_args[static_arg] = 3
267
+ self.jitted_grad = func(self._compute_grad, **extra_args)
268
+ extra_args[static_arg] = 4
269
+ self.jitted_feedback = func(self._feedback, donate_argnums=[0, 3],
270
+ **extra_args)
271
+ extra_args[static_arg] = [3, 4, 5]
272
+ self.jitted_predict = func(self._predict, **extra_args)
273
+ extra_args[static_arg] = [3, 4]
274
+ self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2],
275
+ **extra_args)
276
+
277
+ def init(self, features: Union[_Features, List[_Features]], seed: _Seed):
278
+ if not isinstance(features, list):
279
+ assert len(self._spec) == 1
280
+ features = [features]
281
+ self.params = self.net_fn.init(jax.random.PRNGKey(seed), features, True, # pytype: disable=wrong-arg-types # jax-ndarray
282
+ algorithm_index=-1,
283
+ return_hints=False,
284
+ return_all_outputs=False)
285
+ self.opt_state = self.opt.init(self.params)
286
+ # We will use the optimizer state skeleton for traversal when we
287
+ # want to avoid updating the state of params of untrained algorithms.
288
+ self.opt_state_skeleton = self.opt.init(jnp.zeros(1))
289
+
290
+ @property
291
+ def params(self):
292
+ if self._device_params is None:
293
+ return None
294
+ return jax.device_get(_maybe_pick_first_pmapped(self._device_params))
295
+
296
+ @params.setter
297
+ def params(self, params):
298
+ self._device_params = _maybe_put_replicated(params)
299
+
300
+ @property
301
+ def opt_state(self):
302
+ if self._device_opt_state is None:
303
+ return None
304
+ return jax.device_get(_maybe_pick_first_pmapped(self._device_opt_state))
305
+
306
+ @opt_state.setter
307
+ def opt_state(self, opt_state):
308
+ self._device_opt_state = _maybe_put_replicated(opt_state)
309
+
310
+ def _compute_grad(self, params, rng_key, feedback, algorithm_index):
311
+ lss, grads = jax.value_and_grad(self._loss)(
312
+ params, rng_key, feedback, algorithm_index)
313
+ return self._maybe_pmean(lss), self._maybe_pmean(grads)
314
+
315
+ def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index):
316
+ lss, grads = jax.value_and_grad(self._loss)(
317
+ params, rng_key, feedback, algorithm_index)
318
+ grads = self._maybe_pmean(grads)
319
+ params, opt_state = self._update_params(params, grads, opt_state,
320
+ algorithm_index)
321
+ lss = self._maybe_pmean(lss)
322
+ return lss, params, opt_state
323
+
324
+ def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features,
325
+ algorithm_index: int, return_hints: bool,
326
+ return_all_outputs: bool):
327
+ outs, hint_preds = self.net_fn.apply(
328
+ params, rng_key, [features],
329
+ repred=True, algorithm_index=algorithm_index,
330
+ return_hints=return_hints,
331
+ return_all_outputs=return_all_outputs)
332
+ outs = decoders.postprocess(self._spec[algorithm_index],
333
+ outs,
334
+ sinkhorn_temperature=0.1,
335
+ sinkhorn_steps=50,
336
+ hard=True,
337
+ )
338
+ return outs, hint_preds
339
+
340
+ def compute_grad(
341
+ self,
342
+ rng_key: hk.PRNGSequence,
343
+ feedback: _Feedback,
344
+ algorithm_index: Optional[int] = None,
345
+ ) -> Tuple[float, _Array]:
346
+ """Compute gradients."""
347
+
348
+ if algorithm_index is None:
349
+ assert len(self._spec) == 1
350
+ algorithm_index = 0
351
+ assert algorithm_index >= 0
352
+
353
+ # Calculate gradients.
354
+ rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
355
+ feedback = _maybe_pmap_data(feedback)
356
+ loss, grads = self.jitted_grad(
357
+ self._device_params, rng_keys, feedback, algorithm_index)
358
+ loss = _maybe_pick_first_pmapped(loss)
359
+ grads = _maybe_pick_first_pmapped(grads)
360
+
361
+ return loss, grads
362
+
363
+ def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback,
364
+ algorithm_index=None) -> float:
365
+ if algorithm_index is None:
366
+ assert len(self._spec) == 1
367
+ algorithm_index = 0
368
+ # Calculate and apply gradients.
369
+ rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
370
+ feedback = _maybe_pmap_data(feedback)
371
+ loss, self._device_params, self._device_opt_state = self.jitted_feedback(
372
+ self._device_params, rng_keys, feedback,
373
+ self._device_opt_state, algorithm_index)
374
+ loss = _maybe_pick_first_pmapped(loss)
375
+ return loss
376
+
377
+ def predict(self, rng_key: hk.PRNGSequence, features: _Features,
378
+ algorithm_index: Optional[int] = None,
379
+ return_hints: bool = False,
380
+ return_all_outputs: bool = False):
381
+ """Model inference step."""
382
+ if algorithm_index is None:
383
+ assert len(self._spec) == 1
384
+ algorithm_index = 0
385
+
386
+ rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
387
+ features = _maybe_pmap_data(features)
388
+ return _maybe_restack_from_pmap(
389
+ self.jitted_predict(
390
+ self._device_params, rng_keys, features,
391
+ algorithm_index,
392
+ return_hints,
393
+ return_all_outputs))
394
+
395
+ def _loss(self, params, rng_key, feedback, algorithm_index):
396
+ """Calculates model loss f(feedback; params)."""
397
+ output_preds, hint_preds = self.net_fn.apply(
398
+ params, rng_key, [feedback.features],
399
+ repred=False,
400
+ algorithm_index=algorithm_index,
401
+ return_hints=True,
402
+ return_all_outputs=False)
403
+
404
+ nb_nodes = _nb_nodes(feedback, is_chunked=False)
405
+ lengths = feedback.features.lengths
406
+ total_loss = 0.0
407
+
408
+ # Calculate output loss.
409
+ for truth in feedback.outputs:
410
+ total_loss += losses.output_loss(
411
+ truth=truth,
412
+ pred=output_preds[truth.name],
413
+ nb_nodes=nb_nodes,
414
+ )
415
+
416
+ # Optionally accumulate hint losses.
417
+ if self.decode_hints:
418
+ for truth in feedback.features.hints:
419
+ total_loss += losses.hint_loss(
420
+ truth=truth,
421
+ preds=[x[truth.name] for x in hint_preds],
422
+ lengths=lengths,
423
+ nb_nodes=nb_nodes,
424
+ )
425
+
426
+ return total_loss
427
+
428
+ def _update_params(self, params, grads, opt_state, algorithm_index):
429
+ updates, opt_state = filter_null_grads(
430
+ grads, self.opt, opt_state, self.opt_state_skeleton, algorithm_index)
431
+ if self._freeze_processor:
432
+ params_subset = _filter_out_processor(params)
433
+ updates_subset = _filter_out_processor(updates)
434
+ assert len(params) > len(params_subset)
435
+ assert params_subset
436
+ new_params = optax.apply_updates(params_subset, updates_subset)
437
+ new_params = hk.data_structures.merge(params, new_params)
438
+ else:
439
+ new_params = optax.apply_updates(params, updates)
440
+
441
+ return new_params, opt_state
442
+
443
+ def update_model_params_accum(self, grads) -> None:
444
+ grads = _maybe_put_replicated(grads)
445
+ self._device_params, self._device_opt_state = self.jitted_accum_opt_update(
446
+ self._device_params, grads, self._device_opt_state, self.opt,
447
+ self._freeze_processor)
448
+
449
+ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:
450
+ """Gets verbose loss information."""
451
+ hint_preds = extra_info
452
+
453
+ nb_nodes = _nb_nodes(feedback, is_chunked=False)
454
+ lengths = feedback.features.lengths
455
+ losses_ = {}
456
+
457
+ # Optionally accumulate hint losses.
458
+ if self.decode_hints:
459
+ for truth in feedback.features.hints:
460
+ losses_.update(
461
+ losses.hint_loss(
462
+ truth=truth,
463
+ preds=[x[truth.name] for x in hint_preds],
464
+ lengths=lengths,
465
+ nb_nodes=nb_nodes,
466
+ verbose=True,
467
+ ))
468
+
469
+ return losses_
470
+
471
+ def restore_model(self, file_name: str, only_load_processor: bool = False):
472
+ """Restore model from `file_name`."""
473
+ path = os.path.join(self.checkpoint_path, file_name)
474
+ with open(path, 'rb') as f:
475
+ restored_state = pickle.load(f)
476
+ if only_load_processor:
477
+ restored_params = _filter_in_processor(restored_state['params'])
478
+ else:
479
+ restored_params = restored_state['params']
480
+ self.params = hk.data_structures.merge(self.params, restored_params)
481
+ self.opt_state = restored_state['opt_state']
482
+
483
+ def save_model(self, file_name: str):
484
+ """Save model (processor weights only) to `file_name`."""
485
+ os.makedirs(self.checkpoint_path, exist_ok=True)
486
+ to_save = {'params': self.params, 'opt_state': self.opt_state}
487
+ path = os.path.join(self.checkpoint_path, file_name)
488
+ with open(path, 'wb') as f:
489
+ pickle.dump(to_save, f)
490
+
491
+
492
+ class BaselineModelChunked(BaselineModel):
493
+ """Model that processes time-chunked data.
494
+
495
+ Unlike `BaselineModel`, which processes full samples, `BaselineModelChunked`
496
+ processes fixed-timelength chunks of data. Each tensor of inputs and hints
497
+ has dimensions chunk_length x batch_size x ... The beginning of a new
498
+ sample withing the chunk is signalled by a tensor called `is_first` of
499
+ dimensions chunk_length x batch_size.
500
+
501
+ The chunked model is intended for training. For validation and test, use
502
+ `BaselineModel`.
503
+ """
504
+
505
+ mp_states: List[List[nets.MessagePassingStateChunked]]
506
+ init_mp_states: List[List[nets.MessagePassingStateChunked]]
507
+
508
+ def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
509
+ use_lstm, encoder_init, dropout_prob,
510
+ hint_teacher_forcing, hint_repred_mode):
511
+ def _use_net(*args, **kwargs):
512
+ return nets.NetChunked(
513
+ self._spec, hidden_dim, encode_hints, self.decode_hints,
514
+ processor_factory, use_lstm, encoder_init, dropout_prob,
515
+ hint_teacher_forcing, hint_repred_mode,
516
+ self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
517
+
518
+ self.net_fn = hk.transform(_use_net)
519
+ pmap_args = dict(axis_name='batch', devices=jax.local_devices())
520
+ n_devices = jax.local_device_count()
521
+ func, static_arg, extra_args = (
522
+ (jax.jit, 'static_argnums', {}) if n_devices == 1 else
523
+ (jax.pmap, 'static_broadcasted_argnums', pmap_args))
524
+ pmean = functools.partial(jax.lax.pmean, axis_name='batch')
525
+ self._maybe_pmean = pmean if n_devices > 1 else lambda x: x
526
+ extra_args[static_arg] = 4
527
+ self.jitted_grad = func(self._compute_grad, **extra_args)
528
+ extra_args[static_arg] = 5
529
+ self.jitted_feedback = func(self._feedback, donate_argnums=[0, 4],
530
+ **extra_args)
531
+ extra_args[static_arg] = [3, 4]
532
+ self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2],
533
+ **extra_args)
534
+
535
+ def _init_mp_state(self, features_list: List[List[_FeaturesChunked]],
536
+ rng_key: _Array):
537
+ def _empty_mp_state():
538
+ return nets.MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars
539
+ inputs=None, hints=None, is_first=None,
540
+ hint_preds=None, hiddens=None, lstm_state=None)
541
+ empty_mp_states = [[_empty_mp_state() for _ in f] for f in features_list]
542
+ dummy_params = [self.net_fn.init(rng_key, f, e, False,
543
+ init_mp_state=True, algorithm_index=-1)
544
+ for (f, e) in zip(features_list, empty_mp_states)]
545
+ mp_states = [
546
+ self.net_fn.apply(d, rng_key, f, e, False,
547
+ init_mp_state=True, algorithm_index=-1)[1]
548
+ for (d, f, e) in zip(dummy_params, features_list, empty_mp_states)]
549
+ return mp_states
550
+
551
+ def init(self,
552
+ features: List[List[_FeaturesChunked]],
553
+ seed: _Seed):
554
+ self.mp_states = self._init_mp_state(features,
555
+ jax.random.PRNGKey(seed)) # pytype: disable=wrong-arg-types # jax-ndarray
556
+ self.init_mp_states = [list(x) for x in self.mp_states]
557
+ self.params = self.net_fn.init(
558
+ jax.random.PRNGKey(seed), features[0], self.mp_states[0], # pytype: disable=wrong-arg-types # jax-ndarray
559
+ True, init_mp_state=False, algorithm_index=-1)
560
+ self.opt_state = self.opt.init(self.params)
561
+ # We will use the optimizer state skeleton for traversal when we
562
+ # want to avoid updating the state of params of untrained algorithms.
563
+ self.opt_state_skeleton = self.opt.init(jnp.zeros(1))
564
+
565
+ def predict(self, rng_key: hk.PRNGSequence, features: _FeaturesChunked,
566
+ algorithm_index: Optional[int] = None):
567
+ """Inference not implemented. Chunked model intended for training only."""
568
+ raise NotImplementedError
569
+
570
+ def _loss(self, params, rng_key, feedback, mp_state, algorithm_index):
571
+ (output_preds, hint_preds), mp_state = self.net_fn.apply(
572
+ params, rng_key, [feedback.features],
573
+ [mp_state],
574
+ repred=False,
575
+ init_mp_state=False,
576
+ algorithm_index=algorithm_index)
577
+
578
+ nb_nodes = _nb_nodes(feedback, is_chunked=True)
579
+
580
+ total_loss = 0.0
581
+ is_first = feedback.features.is_first
582
+ is_last = feedback.features.is_last
583
+
584
+ # Calculate output loss.
585
+ for truth in feedback.outputs:
586
+ total_loss += losses.output_loss_chunked(
587
+ truth=truth,
588
+ pred=output_preds[truth.name],
589
+ is_last=is_last,
590
+ nb_nodes=nb_nodes,
591
+ )
592
+
593
+ # Optionally accumulate hint losses.
594
+ if self.decode_hints:
595
+ for truth in feedback.features.hints:
596
+ loss = losses.hint_loss_chunked(
597
+ truth=truth,
598
+ pred=hint_preds[truth.name],
599
+ is_first=is_first,
600
+ nb_nodes=nb_nodes,
601
+ )
602
+ total_loss += loss
603
+
604
+ return total_loss, (mp_state,)
605
+
606
+ def _compute_grad(self, params, rng_key, feedback, mp_state, algorithm_index):
607
+ (lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)(
608
+ params, rng_key, feedback, mp_state, algorithm_index)
609
+ return self._maybe_pmean(lss), mp_state, self._maybe_pmean(grads)
610
+
611
+ def _feedback(self, params, rng_key, feedback, mp_state, opt_state,
612
+ algorithm_index):
613
+ (lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)(
614
+ params, rng_key, feedback, mp_state, algorithm_index)
615
+ grads = self._maybe_pmean(grads)
616
+ params, opt_state = self._update_params(params, grads, opt_state,
617
+ algorithm_index)
618
+ lss = self._maybe_pmean(lss)
619
+ return lss, params, opt_state, mp_state
620
+
621
+ def compute_grad(
622
+ self,
623
+ rng_key: hk.PRNGSequence,
624
+ feedback: _Feedback,
625
+ algorithm_index: Optional[Tuple[int, int]] = None,
626
+ ) -> Tuple[float, _Array]:
627
+ """Compute gradients."""
628
+
629
+ if algorithm_index is None:
630
+ assert len(self._spec) == 1
631
+ algorithm_index = (0, 0)
632
+ length_index, algorithm_index = algorithm_index
633
+ # Reusing init_mp_state improves performance.
634
+ # The next, commented out line, should be used for proper state keeping.
635
+ # mp_state = self.mp_states[length_index][algorithm_index]
636
+ mp_state = self.init_mp_states[length_index][algorithm_index]
637
+ rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
638
+ feedback = _maybe_pmap_reshape(feedback, split_axis=1)
639
+ mp_state = _maybe_pmap_reshape(mp_state, split_axis=0)
640
+
641
+ loss, mp_state, grads = self.jitted_grad(
642
+ self._device_params, rng_keys, feedback, mp_state, algorithm_index)
643
+ loss = _maybe_pick_first_pmapped(loss)
644
+ grads = _maybe_pick_first_pmapped(grads)
645
+ mp_state = _maybe_restack_from_pmap(mp_state)
646
+ self.mp_states[length_index][algorithm_index] = mp_state
647
+ return loss, grads
648
+
649
+ def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback,
650
+ algorithm_index=None) -> float:
651
+ if algorithm_index is None:
652
+ assert len(self._spec) == 1
653
+ algorithm_index = (0, 0)
654
+ length_index, algorithm_index = algorithm_index
655
+ # Reusing init_mp_state improves performance.
656
+ # The next, commented out line, should be used for proper state keeping.
657
+ # mp_state = self.mp_states[length_index][algorithm_index]
658
+ mp_state = self.init_mp_states[length_index][algorithm_index]
659
+ rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
660
+ feedback = _maybe_pmap_reshape(feedback, split_axis=1)
661
+ mp_state = _maybe_pmap_reshape(mp_state, split_axis=0)
662
+ loss, self._device_params, self._device_opt_state, mp_state = (
663
+ self.jitted_feedback(
664
+ self._device_params, rng_keys, feedback,
665
+ mp_state, self._device_opt_state, algorithm_index))
666
+ loss = _maybe_pick_first_pmapped(loss)
667
+ mp_state = _maybe_restack_from_pmap(mp_state)
668
+ self.mp_states[length_index][algorithm_index] = mp_state
669
+ return loss
670
+
671
+ def verbose_loss(self, *args, **kwargs):
672
+ raise NotImplementedError
673
+
674
+
675
+ def _nb_nodes(feedback: _Feedback, is_chunked) -> int:
676
+ for inp in feedback.features.inputs:
677
+ if inp.location in [_Location.NODE, _Location.EDGE]:
678
+ if is_chunked:
679
+ return inp.data.shape[2] # inputs are time x batch x nodes x ...
680
+ else:
681
+ return inp.data.shape[1] # inputs are batch x nodes x ...
682
+ assert False
683
+
684
+
685
+ def _param_in_processor(module_name):
686
+ return processors.PROCESSOR_TAG in module_name
687
+
688
+
689
+ def _filter_out_processor(params: hk.Params) -> hk.Params:
690
+ return hk.data_structures.filter(
691
+ lambda module_name, n, v: not _param_in_processor(module_name), params)
692
+
693
+
694
+ def _filter_in_processor(params: hk.Params) -> hk.Params:
695
+ return hk.data_structures.filter(
696
+ lambda module_name, n, v: _param_in_processor(module_name), params)
697
+
698
+
699
+ def _is_not_done_broadcast(lengths, i, tensor):
700
+ is_not_done = (lengths > i + 1) * 1.0
701
+ while len(is_not_done.shape) < len(tensor.shape):
702
+ is_not_done = jnp.expand_dims(is_not_done, -1)
703
+ return is_not_done
704
+
705
+
706
+ def accum_opt_update(params, grads, opt_state, opt, freeze_processor):
707
+ """Update params from gradients collected from several algorithms."""
708
+ # Average the gradients over all algos
709
+ grads = jax.tree_util.tree_map(
710
+ lambda *x: sum(x) / (sum([jnp.any(k) for k in x]) + 1e-12), *grads)
711
+ updates, opt_state = opt.update(grads, opt_state)
712
+ if freeze_processor:
713
+ params_subset = _filter_out_processor(params)
714
+ assert len(params) > len(params_subset)
715
+ assert params_subset
716
+ updates_subset = _filter_out_processor(updates)
717
+ new_params = optax.apply_updates(params_subset, updates_subset)
718
+ new_params = hk.data_structures.merge(params, new_params)
719
+ else:
720
+ new_params = optax.apply_updates(params, updates)
721
+
722
+ return new_params, opt_state
723
+
724
+
725
+ @functools.partial(jax.jit, static_argnames=['opt'])
726
+ def opt_update(opt, flat_grads, flat_opt_state):
727
+ return opt.update(flat_grads, flat_opt_state)
728
+
729
+
730
+ def filter_null_grads(grads, opt, opt_state, opt_state_skeleton, algo_idx):
731
+ """Compute updates ignoring params that have no gradients.
732
+
733
+ This prevents untrained params (e.g., encoders/decoders for algorithms
734
+ that are not being trained) to accumulate, e.g., momentum from spurious
735
+ zero gradients.
736
+
737
+ Note: this works as intended for "per-parameter" optimizer state, such as
738
+ momentum. However, when the optimizer has some global state (such as the
739
+ step counts in Adam), the global state will be updated every time,
740
+ affecting also future updates of parameters that had null gradients in the
741
+ current step.
742
+
743
+ Args:
744
+ grads: Gradients for all parameters.
745
+ opt: Optax optimizer.
746
+ opt_state: Optimizer state.
747
+ opt_state_skeleton: A "skeleton" of optimizer state that has been
748
+ initialized with scalar parameters. This serves to traverse each parameter
749
+ of the otpimizer state during the opt state update.
750
+ algo_idx: Index of algorithm, to filter out unused encoders/decoders.
751
+ If None, no filtering happens.
752
+ Returns:
753
+ Updates and new optimizer state, where the parameters with null gradient
754
+ have not been taken into account.
755
+ """
756
+ def _keep_in_algo(k, v):
757
+ """Ignore params of encoders/decoders irrelevant for this algo."""
758
+ # Note: in shared pointer decoder modes, we should exclude shared params
759
+ # for algos that do not have pointer outputs.
760
+ if ((processors.PROCESSOR_TAG in k) or
761
+ (f'algo_{algo_idx}_' in k)):
762
+ return v
763
+ return jax.tree_util.tree_map(lambda x: None, v)
764
+
765
+ if algo_idx is None:
766
+ masked_grads = grads
767
+ else:
768
+ masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()}
769
+ flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads)
770
+ flat_opt_state = jax.tree_util.tree_map(
771
+ lambda _, x: x # pylint:disable=g-long-lambda
772
+ if isinstance(x, (np.ndarray, jax.Array))
773
+ else treedef.flatten_up_to(x),
774
+ opt_state_skeleton,
775
+ opt_state,
776
+ )
777
+
778
+ # Compute updates only for the params with gradient.
779
+ flat_updates, flat_opt_state = opt_update(opt, flat_grads, flat_opt_state)
780
+
781
+ def unflatten(flat, original):
782
+ """Restore tree structure, filling missing (None) leaves with original."""
783
+ if isinstance(flat, (np.ndarray, jax.Array)):
784
+ return flat
785
+ return jax.tree_util.tree_map(lambda x, y: x if y is None else y, original,
786
+ treedef.unflatten(flat))
787
+
788
+ # Restore the state and updates tree structure.
789
+ new_opt_state = jax.tree_util.tree_map(lambda _, x, y: unflatten(x, y),
790
+ opt_state_skeleton, flat_opt_state,
791
+ opt_state)
792
+ updates = unflatten(flat_updates,
793
+ jax.tree_util.tree_map(lambda x: 0., grads))
794
+ return updates, new_opt_state
benchmarks/CLRS/env/baselines_test.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `baselines.py`."""
17
+
18
+ import copy
19
+ import functools
20
+ from typing import Generator
21
+
22
+ from absl.testing import absltest
23
+ from absl.testing import parameterized
24
+ import chex
25
+
26
+ from clrs._src import baselines
27
+ from clrs._src import dataset
28
+ from clrs._src import probing
29
+ from clrs._src import processors
30
+ from clrs._src import samplers
31
+ from clrs._src import specs
32
+
33
+ import haiku as hk
34
+ import jax
35
+ import numpy as np
36
+
37
+ _Array = np.ndarray
38
+
39
+
40
+ def _error(x, y):
41
+ return np.sum(np.abs(x-y))
42
+
43
+
44
+ def _make_sampler(algo: str, length: int) -> samplers.Sampler:
45
+ sampler, _ = samplers.build_sampler(
46
+ algo,
47
+ seed=samplers.CLRS30['val']['seed'],
48
+ num_samples=samplers.CLRS30['val']['num_samples'],
49
+ length=length,
50
+ )
51
+ return sampler
52
+
53
+
54
+ def _without_permutation(feedback):
55
+ """Replace should-be permutations with pointers."""
56
+ outputs = []
57
+ for x in feedback.outputs:
58
+ if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
59
+ outputs.append(x)
60
+ continue
61
+ assert x.location == specs.Location.NODE
62
+ outputs.append(probing.DataPoint(name=x.name, location=x.location,
63
+ type_=specs.Type.POINTER, data=x.data))
64
+ return feedback._replace(outputs=outputs)
65
+
66
+
67
+ def _make_iterable_sampler(
68
+ algo: str, batch_size: int,
69
+ length: int) -> Generator[samplers.Feedback, None, None]:
70
+ sampler = _make_sampler(algo, length)
71
+ while True:
72
+ yield _without_permutation(sampler.next(batch_size))
73
+
74
+
75
+ def _remove_permutation_from_spec(spec):
76
+ """Modify spec to turn permutation type to pointer."""
77
+ new_spec = {}
78
+ for k in spec:
79
+ if (spec[k][1] == specs.Location.NODE and
80
+ spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
81
+ new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
82
+ else:
83
+ new_spec[k] = spec[k]
84
+ return new_spec
85
+
86
+
87
+ class BaselinesTest(parameterized.TestCase):
88
+
89
+ def test_full_vs_chunked(self):
90
+ """Test that chunking does not affect gradients."""
91
+
92
+ batch_size = 4
93
+ length = 8
94
+ algo = 'insertion_sort'
95
+ spec = _remove_permutation_from_spec(specs.SPECS[algo])
96
+ rng_key = jax.random.PRNGKey(42)
97
+
98
+ full_ds = _make_iterable_sampler(algo, batch_size, length)
99
+ chunked_ds = dataset.chunkify(
100
+ _make_iterable_sampler(algo, batch_size, length),
101
+ length)
102
+ double_chunked_ds = dataset.chunkify(
103
+ _make_iterable_sampler(algo, batch_size, length),
104
+ length * 2)
105
+
106
+ full_batches = [next(full_ds) for _ in range(2)]
107
+ chunked_batches = [next(chunked_ds) for _ in range(2)]
108
+ double_chunk_batch = next(double_chunked_ds)
109
+
110
+ with chex.fake_jit(): # jitting makes test longer
111
+
112
+ processor_factory = processors.get_processor_factory(
113
+ 'mpnn', use_ln=False, nb_triplet_fts=0)
114
+ common_args = dict(processor_factory=processor_factory, hidden_dim=8,
115
+ learning_rate=0.01,
116
+ decode_hints=True, encode_hints=True)
117
+
118
+ b_full = baselines.BaselineModel(
119
+ spec, dummy_trajectory=full_batches[0], **common_args)
120
+ b_full.init(full_batches[0].features, seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
121
+ full_params = b_full.params
122
+ full_loss_0 = b_full.feedback(rng_key, full_batches[0])
123
+ b_full.params = full_params
124
+ full_loss_1 = b_full.feedback(rng_key, full_batches[1])
125
+ new_full_params = b_full.params
126
+
127
+ b_chunked = baselines.BaselineModelChunked(
128
+ spec, dummy_trajectory=chunked_batches[0], **common_args)
129
+ b_chunked.init([[chunked_batches[0].features]], seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
130
+ chunked_params = b_chunked.params
131
+ jax.tree_util.tree_map(np.testing.assert_array_equal, full_params,
132
+ chunked_params)
133
+ chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0])
134
+ b_chunked.params = chunked_params
135
+ chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1])
136
+ new_chunked_params = b_chunked.params
137
+
138
+ b_chunked.params = chunked_params
139
+ double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch)
140
+
141
+ # Test that losses match
142
+ np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4)
143
+ np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4)
144
+ np.testing.assert_allclose(full_loss_0 + full_loss_1,
145
+ 2 * double_chunked_loss,
146
+ rtol=1e-4)
147
+
148
+ # Test that gradients are the same (parameters changed equally).
149
+ # First check that gradients were not zero, i.e., parameters have changed.
150
+ param_change, _ = jax.tree_util.tree_flatten(
151
+ jax.tree_util.tree_map(_error, full_params, new_full_params))
152
+ self.assertGreater(np.mean(param_change), 0.1)
153
+ # Now check that full and chunked gradients are the same.
154
+ jax.tree_util.tree_map(
155
+ functools.partial(np.testing.assert_allclose, rtol=1e-4),
156
+ new_full_params, new_chunked_params)
157
+
158
+ def test_multi_vs_single(self):
159
+ """Test that multi = single when we only train one of the algorithms."""
160
+
161
+ batch_size = 4
162
+ length = 16
163
+ algos = ['insertion_sort', 'activity_selector', 'bfs']
164
+ spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
165
+ rng_key = jax.random.PRNGKey(42)
166
+
167
+ full_ds = [_make_iterable_sampler(algo, batch_size, length)
168
+ for algo in algos]
169
+ full_batches = [next(ds) for ds in full_ds]
170
+ full_batches_2 = [next(ds) for ds in full_ds]
171
+
172
+ with chex.fake_jit(): # jitting makes test longer
173
+
174
+ processor_factory = processors.get_processor_factory(
175
+ 'mpnn', use_ln=False, nb_triplet_fts=0)
176
+ common_args = dict(processor_factory=processor_factory, hidden_dim=8,
177
+ learning_rate=0.01,
178
+ decode_hints=True, encode_hints=True)
179
+
180
+ b_single = baselines.BaselineModel(
181
+ spec[0], dummy_trajectory=full_batches[0], **common_args)
182
+ b_multi = baselines.BaselineModel(
183
+ spec, dummy_trajectory=full_batches, **common_args)
184
+ b_single.init(full_batches[0].features, seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
185
+ b_multi.init([f.features for f in full_batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
186
+
187
+ single_params = []
188
+ single_losses = []
189
+ multi_params = []
190
+ multi_losses = []
191
+
192
+ single_params.append(copy.deepcopy(b_single.params))
193
+ single_losses.append(b_single.feedback(rng_key, full_batches[0]))
194
+ single_params.append(copy.deepcopy(b_single.params))
195
+ single_losses.append(b_single.feedback(rng_key, full_batches_2[0]))
196
+ single_params.append(copy.deepcopy(b_single.params))
197
+
198
+ multi_params.append(copy.deepcopy(b_multi.params))
199
+ multi_losses.append(b_multi.feedback(rng_key, full_batches[0],
200
+ algorithm_index=0))
201
+ multi_params.append(copy.deepcopy(b_multi.params))
202
+ multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0],
203
+ algorithm_index=0))
204
+ multi_params.append(copy.deepcopy(b_multi.params))
205
+
206
+ # Test that losses match
207
+ np.testing.assert_array_equal(single_losses, multi_losses)
208
+ # Test that loss decreased
209
+ assert single_losses[1] < single_losses[0]
210
+
211
+ # Test that param changes were the same in single and multi-algorithm
212
+ for single, multi in zip(single_params, multi_params):
213
+ assert hk.data_structures.is_subset(subset=single, superset=multi)
214
+ for module_name, params in single.items():
215
+ jax.tree_util.tree_map(np.testing.assert_array_equal, params,
216
+ multi[module_name])
217
+
218
+ # Test that params change for the trained algorithm, but not the others
219
+ for module_name, params in multi_params[0].items():
220
+ param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)),
221
+ params,
222
+ multi_params[1][module_name])
223
+ param_change = sum(param_changes.values())
224
+ if module_name in single_params[0]: # params of trained algorithm
225
+ assert param_change > 1e-3
226
+ else: # params of non-trained algorithms
227
+ assert param_change == 0.0
228
+
229
+ @parameterized.parameters(True, False)
230
+ def test_multi_algorithm_idx(self, is_chunked):
231
+ """Test that algorithm selection works as intended."""
232
+
233
+ batch_size = 4
234
+ length = 8
235
+ algos = ['insertion_sort', 'activity_selector', 'bfs']
236
+ spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
237
+ rng_key = jax.random.PRNGKey(42)
238
+
239
+ if is_chunked:
240
+ ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length),
241
+ 2 * length) for algo in algos]
242
+ else:
243
+ ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos]
244
+ batches = [next(d) for d in ds]
245
+
246
+ processor_factory = processors.get_processor_factory(
247
+ 'mpnn', use_ln=False, nb_triplet_fts=0)
248
+ common_args = dict(processor_factory=processor_factory, hidden_dim=8,
249
+ learning_rate=0.01,
250
+ decode_hints=True, encode_hints=True)
251
+ if is_chunked:
252
+ baseline = baselines.BaselineModelChunked(
253
+ spec, dummy_trajectory=batches, **common_args)
254
+ baseline.init([[f.features for f in batches]], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
255
+ else:
256
+ baseline = baselines.BaselineModel(
257
+ spec, dummy_trajectory=batches, **common_args)
258
+ baseline.init([f.features for f in batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
259
+
260
+ # Find out what parameters change when we train each algorithm
261
+ def _change(x, y):
262
+ changes = {}
263
+ for module_name, params in x.items():
264
+ changes[module_name] = sum(
265
+ jax.tree_util.tree_map(
266
+ lambda a, b: np.sum(np.abs(a-b)), params, y[module_name]
267
+ ).values())
268
+ return changes
269
+
270
+ param_changes = []
271
+ for algo_idx in range(len(algos)):
272
+ init_params = copy.deepcopy(baseline.params)
273
+ _ = baseline.feedback(
274
+ rng_key,
275
+ batches[algo_idx],
276
+ algorithm_index=(0, algo_idx) if is_chunked else algo_idx)
277
+ param_changes.append(_change(init_params, baseline.params))
278
+
279
+ # Test that non-changing parameters correspond to encoders/decoders
280
+ # associated with the non-trained algorithms
281
+ unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes]
282
+
283
+ def _get_other_algos(algo_idx, modules):
284
+ return set([k for k in modules if '_construct_encoders_decoders' in k
285
+ and f'algo_{algo_idx}' not in k])
286
+
287
+ for algo_idx in range(len(algos)):
288
+ expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys())
289
+ self.assertNotEmpty(expected_unchanged)
290
+ self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx]))
291
+
292
+
293
+ if __name__ == '__main__':
294
+ absltest.main()
benchmarks/CLRS/env/data_description.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The CLRS Algorithmic Reasoning Benchmark
2
+
3
+ Learning representations of algorithms is an emerging area of machine learning, seeking to bridge concepts from neural networks with classical algorithms. The CLRS Algorithmic Reasoning Benchmark (CLRS) consolidates and extends previous work toward evaluation algorithmic reasoning by providing a suite of implementations of classical algorithms. These algorithms have been selected from the third edition of the standard Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein.
4
+
5
+ Algorithms as graphs
6
+ CLRS implements the selected algorithms in an idiomatic way, which aligns as closely as possible to the original CLRS 3ed pseudocode. By controlling the input data distribution to conform to the preconditions we are able to automatically generate input/output pairs. We additionally provide trajectories of "hints" that expose the internal state of each algorithm, to both optionally simplify the learning challenge and to distinguish between different algorithms that solve the same overall task (e.g. sorting).
7
+
8
+ In the most generic sense, algorithms can be seen as manipulating sets of objects, along with any relations between them (which can themselves be decomposed into binary relations). Accordingly, we study all of the algorithms in this benchmark using a graph representation. In the event that objects obey a more strict ordered structure (e.g. arrays or rooted trees), we impose this ordering through inclusion of predecessor links.
9
+
10
+ How it works
11
+ For each algorithm, we provide a canonical set of train, eval and test trajectories for benchmarking out-of-distribution generalization.
12
+
13
+ Trajectories Problem Size
14
+ Train 1000 16
15
+ Eval 32 x multiplier 16
16
+ Test 32 x multiplier 64
17
+ Here, "problem size" refers to e.g. the length of an array or number of nodes in a graph, depending on the algorithm. "multiplier" is an algorithm-specific factor that increases the number of available eval and test trajectories to compensate for paucity of evaluation signals. "multiplier" is 1 for all algorithms except:
18
+
19
+ Maximum subarray (Kadane), for which "multiplier" is 32.
20
+ Quick select, minimum, binary search, string matchers (both naive and KMP), and segment intersection, for which "multiplier" is 64.
21
+ The trajectories can be used like so:
22
+
23
+ train_ds, num_samples, spec = clrs.create_dataset(
24
+ folder='/tmp/CLRS30', algorithm='bfs',
25
+ split='train', batch_size=32)
26
+
27
+ for i, feedback in enumerate(train_ds.as_numpy_iterator()):
28
+ if i == 0:
29
+ model.init(feedback.features, initial_seed)
30
+ loss = model.feedback(rng_key, feedback)
31
+ Here, feedback is a namedtuple with the following structure:
32
+
33
+ Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
34
+ Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
35
+ where the content of Features can be used for training and outputs is reserved for evaluation. Each field of the tuple is an ndarray with a leading batch dimension. Because hints are provided for the full algorithm trajectory, these contain an additional time dimension padded up to the maximum length max(T) of any trajectory within the dataset. The lengths field specifies the true length t <= max(T) for each trajectory, which can be used e.g. for loss masking.
benchmarks/CLRS/env/dataset.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """CLRS dataset."""
16
+
17
+ import dataclasses
18
+
19
+ import functools
20
+ from typing import Iterator
21
+
22
+ from clrs._src import probing
23
+ from clrs._src import samplers
24
+ from clrs._src import specs
25
+
26
+ import jax
27
+ import numpy as np
28
+ import tensorflow as tf
29
+ import tensorflow_datasets as tfds
30
+
31
+
32
+ def _correct_axis_filtering(tensor, index, name):
33
+ if 'hint_' in name:
34
+ return tensor[:, index]
35
+ else:
36
+ return tensor[index]
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class CLRSConfig(tfds.core.BuilderConfig):
41
+ """Specify the split in the variant because they have different shapes."""
42
+ split: str = ''
43
+
44
+
45
+ DEFAULT_BUILDER_CONFIGS = []
46
+
47
+
48
+ def _build_default_builder_configs():
49
+ for split in ['train', 'val', 'test']:
50
+ for alg in specs.CLRS_30_ALGS:
51
+ DEFAULT_BUILDER_CONFIGS.append(
52
+ CLRSConfig(name=f'{alg}_{split}', split=split))
53
+
54
+
55
+ _build_default_builder_configs()
56
+
57
+
58
+ class CLRSDataset(tfds.core.GeneratorBasedBuilder):
59
+ """DatasetBuilder for my_dataset dataset."""
60
+
61
+ VERSION = tfds.core.Version('1.0.0')
62
+ RELEASE_NOTES = {
63
+ '1.0.0': 'Initial release.',
64
+ }
65
+ BUILDER_CONFIGS = DEFAULT_BUILDER_CONFIGS
66
+
67
+ _instantiated_dataset = None
68
+ _instantiated_dataset_name = ''
69
+ _instantiated_dataset_split = ''
70
+
71
+ def _num_samples(self, algorithm_name):
72
+ num_samples = samplers.CLRS30[self._builder_config.split]['num_samples'] # pytype: disable=attribute-error # always-use-return-annotations
73
+ if self._builder_config.split != 'train': # pytype: disable=attribute-error # always-use-return-annotations
74
+ # Generate more samples for those algorithms in which the number of
75
+ # signals is small.
76
+ num_samples *= specs.CLRS_30_ALGS_SETTINGS[algorithm_name][
77
+ 'num_samples_multiplier']
78
+ return num_samples
79
+
80
+ def _create_data(self, single_sample):
81
+ algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
82
+ num_samples = self._num_samples(algorithm_name)
83
+ sampler, _ = samplers.build_sampler(
84
+ algorithm_name,
85
+ seed=samplers.CLRS30[self._builder_config.split]['seed'], # pytype: disable=attribute-error # always-use-return-annotations
86
+ num_samples=num_samples,
87
+ length=samplers.CLRS30[self._builder_config.split]['length'], # pytype: disable=attribute-error # always-use-return-annotations
88
+ )
89
+ sampled_dataset = sampler.next(batch_size=1 if single_sample else None)
90
+ data = {'input_' + t.name: t.data for t in sampled_dataset.features.inputs}
91
+ # All other data points have input_, hint_, and output_ prefixes, so we
92
+ # guarantee that this key is unused.
93
+ data['lengths'] = sampled_dataset.features.lengths
94
+ data.update({'output_' + t.name: t.data for t in sampled_dataset.outputs})
95
+ data.update({
96
+ 'hint_' + t.name: t.data for t in sampled_dataset.features.hints})
97
+ self._instantiated_dataset = data
98
+
99
+ def _info(self) -> tfds.core.DatasetInfo:
100
+ if tf.io.gfile.exists(self.data_dir):
101
+ info = tfds.core.DatasetInfo(builder=self)
102
+ info.read_from_directory(self.data_dir)
103
+ return info
104
+
105
+ if (self._instantiated_dataset_name != self._builder_config.name
106
+ or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
107
+ self._create_data(single_sample=True)
108
+
109
+ data = {k: _correct_axis_filtering(v, 0, k)
110
+ for k, v in self._instantiated_dataset.items()}
111
+ data_info = {
112
+ k: tfds.features.Tensor(shape=v.shape, dtype=tf.dtypes.as_dtype(
113
+ v.dtype)) for k, v in data.items()}
114
+ return tfds.core.DatasetInfo(
115
+ builder=self,
116
+ features=tfds.features.FeaturesDict(data_info),
117
+ )
118
+
119
+ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
120
+ """Download the data and define splits."""
121
+ if (self._instantiated_dataset_name != self._builder_config.name
122
+ or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
123
+ self._create_data(single_sample=False)
124
+ self._instantiated_dataset_name = self._builder_config.name
125
+ self._instantiated_dataset_split = self._builder_config.split # pytype: disable=attribute-error # always-use-return-annotations
126
+ return {self._builder_config.split: self._generate_examples()} # pytype: disable=attribute-error # always-use-return-annotations
127
+
128
+ def _generate_examples(self):
129
+ """Generator of examples for each split."""
130
+ algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
131
+ for i in range(self._num_samples(algorithm_name)):
132
+ data = {k: _correct_axis_filtering(v, i, k)
133
+ for k, v in self._instantiated_dataset.items()}
134
+ yield str(i), data
135
+
136
+
137
+ def _get_clrs_file_name():
138
+ return f'CLRS30_v{CLRSDataset.VERSION}.tar.gz'
139
+
140
+
141
+ def get_dataset_gcp_url():
142
+ return f'https://storage.googleapis.com/dm-clrs/{_get_clrs_file_name()}'
143
+
144
+
145
+ def get_clrs_folder():
146
+ return f'CLRS30_v{CLRSDataset.VERSION}'
147
+
148
+
149
+ def _preprocess(data_point, algorithm=None):
150
+ """Convert sampled inputs into DataPoints."""
151
+ inputs = []
152
+ outputs = []
153
+ hints = []
154
+ lengths = None
155
+
156
+ for name, data in data_point.items():
157
+ if name == 'lengths':
158
+ lengths = data
159
+ continue
160
+ data_point_name = name.split('_')
161
+ name = '_'.join(data_point_name[1:])
162
+ (stage, location, dp_type) = specs.SPECS[algorithm][name]
163
+ assert stage == data_point_name[0]
164
+ if stage == specs.Stage.HINT:
165
+ data = tf.experimental.numpy.swapaxes(data, 0, 1)
166
+ dp = probing.DataPoint(name, location, dp_type, data)
167
+ if stage == specs.Stage.INPUT:
168
+ inputs.append(dp)
169
+ elif stage == specs.Stage.OUTPUT:
170
+ outputs.append(dp)
171
+ else:
172
+ hints.append(dp)
173
+ return samplers.Feedback(
174
+ samplers.Features(tuple(inputs), tuple(hints), lengths), tuple(outputs))
175
+
176
+
177
+ def create_dataset(folder, algorithm, split, batch_size):
178
+ dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
179
+ data_dir=folder, split=split)
180
+ num_samples = len(dataset) # Must be done here for correct size
181
+ dataset = dataset.repeat()
182
+ dataset = dataset.batch(batch_size)
183
+ return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)),
184
+ num_samples,
185
+ specs.SPECS[algorithm])
186
+
187
+
188
+ def _copy_hint(source, dest, i, start_source, start_dest, to_add):
189
+ """Copy from full-sample hint to a hint chunk."""
190
+ assert np.all(dest[start_dest:, i:] == 0)
191
+ assert start_dest < dest.shape[0]
192
+ assert start_dest + to_add <= dest.shape[0]
193
+ assert start_source < source.shape[0]
194
+ assert start_source + to_add <= source.shape[0]
195
+ dest[start_dest:start_dest+to_add, i] = source[
196
+ start_source:start_source+to_add, i]
197
+ return dest
198
+
199
+
200
+ def _copy_io(source, dest, i, start_dest, to_add):
201
+ """Copy from an input or output to an input or output chunk."""
202
+ assert np.all(dest[start_dest:, i:] == 0)
203
+ dest[start_dest:start_dest+to_add, i] = source[i]
204
+ return dest
205
+
206
+
207
+ def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int):
208
+ """Generator of fixed-length chunks from full-trajectory samples.
209
+
210
+ Args:
211
+ dataset: full-sample dataset as numpy iterator.
212
+ chunk_length: time length of chunks.
213
+ Yields:
214
+ Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs
215
+ has dimensions chunk_length x batch_size x ... Samples are not time-padded,
216
+ after the end of one sample immediately comes the next. Since different
217
+ samples can have different time lengths, the beginnings and ends of samples
218
+ within a batch do not need to coincide. For this reason, the chunked
219
+ dataset features include two chunk_length x batch_size int tensors,
220
+ `is_first` and `is_last`, that mark the beginning and end of each sample.
221
+ For example, if `chunk_legnth`==6 and `batch_size`==2 and the first
222
+ full-sample batch had one sample of length 3 and one of length 5,
223
+ we would have a first chunked batch with the following `is_first` and
224
+ `is_last` tensors:
225
+
226
+ is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1]
227
+ [0, 0] [0, 0] [0 1]
228
+ [0, 0] [1, 0] [0 1]
229
+ [1, 0] [0, 0] [2 1]
230
+ [0, 0] [0, 1] [2 1]
231
+ [0, 1]] [0, 0]] [2 3]] )
232
+
233
+ while the data in the inputs, outputs and hints tensors would correspond
234
+ to samples as identified by the sample_id indicated above for reference.
235
+ Notice that, while in the full-sample dataset inputs and outputs have
236
+ no time dimension, here they do; the input and output tensors are simply
237
+ repeated along each sample's time length.
238
+ """
239
+ def _get_batch():
240
+ d = next(dataset)
241
+ return (d.features.inputs, d.features.hints, d.outputs,
242
+ d.features.lengths.astype(int))
243
+
244
+ inputs, hints, outputs, lengths = _get_batch()
245
+ for inp in inputs:
246
+ if inp.location in [specs.Location.NODE, specs.Location.EDGE]:
247
+ batch_size = inp.data.shape[0]
248
+ break
249
+
250
+ io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype)
251
+ chunk_inputs = jax.tree_util.tree_map(io_chunk, inputs)
252
+ chunk_outputs = jax.tree_util.tree_map(io_chunk, outputs)
253
+
254
+ hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype)
255
+ chunk_hints = jax.tree_util.tree_map(hint_chunk, hints)
256
+
257
+ inputs = [inputs]
258
+ hints = [hints]
259
+ outputs = [outputs]
260
+ left = [lengths.copy()]
261
+ lengths = [lengths.copy()]
262
+
263
+ while True:
264
+ # Create a new empty chunk
265
+ chunk_inputs = jax.tree_util.tree_map(np.zeros_like, chunk_inputs)
266
+ chunk_hints = jax.tree_util.tree_map(np.zeros_like, chunk_hints)
267
+ chunk_outputs = jax.tree_util.tree_map(np.zeros_like, chunk_outputs)
268
+ start_mark = np.zeros((chunk_length, batch_size), dtype=int)
269
+ end_mark = np.zeros((chunk_length, batch_size), dtype=int)
270
+
271
+ # Get enough data batches to fill the new chunk
272
+ while np.any(np.sum(left, axis=0) < chunk_length):
273
+ inp, hh, out, ll = _get_batch()
274
+ inputs.append(inp)
275
+ hints.append(hh)
276
+ outputs.append(out)
277
+ left.append(ll.copy())
278
+ lengths.append(ll.copy())
279
+
280
+ # Fill the chunk, one batch element at a time
281
+ for i in range(batch_size):
282
+ total, idx = 0, 0
283
+ while total < chunk_length:
284
+ to_add = min(left[idx][i], chunk_length - total)
285
+ if to_add:
286
+ start = lengths[idx][i] - left[idx][i]
287
+ assert start >= 0
288
+ f_io = functools.partial(_copy_io, i=i, start_dest=total,
289
+ to_add=to_add)
290
+ chunk_inputs = jax.tree_util.tree_map(f_io, inputs[idx], chunk_inputs)
291
+ chunk_outputs = jax.tree_util.tree_map(f_io, outputs[idx],
292
+ chunk_outputs)
293
+ f_hint = functools.partial(_copy_hint, i=i, start_source=start,
294
+ start_dest=total, to_add=to_add)
295
+ chunk_hints = jax.tree_util.tree_map(f_hint, hints[idx], chunk_hints)
296
+ if start == 0:
297
+ start_mark[total, i] = 1
298
+ total += to_add
299
+ left[idx][i] -= to_add
300
+ assert left[idx][i] >= 0
301
+ if left[idx][i] == 0:
302
+ end_mark[total - 1, i] = 1
303
+ idx += 1
304
+ assert total == chunk_length
305
+
306
+ while left and np.all(left[0] == 0):
307
+ inputs.pop(0)
308
+ hints.pop(0)
309
+ outputs.pop(0)
310
+ left.pop(0)
311
+ lengths.pop(0)
312
+
313
+ yield samplers.Feedback(
314
+ samplers.FeaturesChunked(chunk_inputs, chunk_hints,
315
+ start_mark, end_mark),
316
+ chunk_outputs)
317
+
318
+
319
+ def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length):
320
+ dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
321
+ data_dir=folder, split=split)
322
+ dataset = dataset.repeat()
323
+ dataset = dataset.batch(batch_size)
324
+ dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm))
325
+ dataset = dataset.as_numpy_iterator()
326
+ return chunkify(dataset, chunk_length), specs.SPECS[algorithm]
benchmarks/CLRS/env/dataset_test.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `dataset.py`."""
17
+
18
+ from typing import Generator, List
19
+
20
+ from absl.testing import absltest
21
+ from absl.testing import parameterized
22
+
23
+ from clrs._src import dataset
24
+ from clrs._src import samplers
25
+ from clrs._src import specs
26
+ import numpy as np
27
+
28
+ _Array = np.ndarray
29
+
30
+
31
+ def _stack_to_shortest(x: List[_Array]) -> _Array:
32
+ min_len = min(map(len, x))
33
+ return np.array([a[:min_len] for a in x])
34
+
35
+
36
+ def _make_sampler(algo: str) -> samplers.Sampler:
37
+ sampler, _ = samplers.build_sampler(
38
+ algo,
39
+ seed=samplers.CLRS30['val']['seed'],
40
+ num_samples=samplers.CLRS30['val']['num_samples'],
41
+ length=samplers.CLRS30['val']['length'],
42
+ )
43
+ return sampler
44
+
45
+
46
+ def _make_iterable_sampler(
47
+ algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]:
48
+ sampler = _make_sampler(algo)
49
+ while True:
50
+ yield sampler.next(batch_size)
51
+
52
+
53
+ class DatasetTest(parameterized.TestCase):
54
+
55
+ @parameterized.product(
56
+ name=specs.CLRS_30_ALGS[:5],
57
+ chunk_length=[20, 50])
58
+ def test_chunkify(self, name: str, chunk_length: int):
59
+ """Test that samples are concatenated and split in chunks correctly."""
60
+ batch_size = 8
61
+
62
+ ds = _make_iterable_sampler(name, batch_size)
63
+ chunked_ds = dataset.chunkify(
64
+ _make_iterable_sampler(name, batch_size),
65
+ chunk_length)
66
+
67
+ samples = [next(ds) for _ in range(20)]
68
+ cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0)
69
+ n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1
70
+ chunks = [next(chunked_ds) for _ in range(n_chunks)]
71
+
72
+ # Check correctness of `is_first` and `is_last` markers
73
+ start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
74
+ [c.features.is_first for c in chunks]).T]).T
75
+ end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
76
+ [c.features.is_last for c in chunks]).T]).T
77
+ assert len(start_idx) >= len(cum_lengths)
78
+ start_idx = start_idx[:len(cum_lengths)]
79
+ assert len(end_idx) >= len(cum_lengths)
80
+ end_idx = end_idx[:len(cum_lengths)]
81
+
82
+ np.testing.assert_equal(start_idx[0], 0)
83
+ np.testing.assert_array_equal(cum_lengths - 1, end_idx)
84
+ np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:])
85
+
86
+ # Check that inputs, outputs and hints have been copied correctly
87
+ all_input = np.concatenate([c.features.inputs[0].data for c in chunks])
88
+ all_output = np.concatenate([c.outputs[0].data for c in chunks])
89
+ all_hint = np.concatenate([c.features.hints[0].data for c in chunks])
90
+ for i in range(batch_size):
91
+ length0 = int(samples[0].features.lengths[i])
92
+ length1 = int(samples[1].features.lengths[i])
93
+ # Check first sample
94
+ np.testing.assert_array_equal(
95
+ all_input[:length0, i],
96
+ np.tile(samples[0].features.inputs[0].data[i], [length0, 1]))
97
+ np.testing.assert_array_equal(
98
+ all_output[:length0, i],
99
+ np.tile(samples[0].outputs[0].data[i], [length0, 1]))
100
+ np.testing.assert_array_equal(
101
+ all_hint[:length0, i],
102
+ samples[0].features.hints[0].data[:length0, i])
103
+ # Check second sample
104
+ np.testing.assert_array_equal(
105
+ all_input[length0:length0 + length1, i],
106
+ np.tile(samples[1].features.inputs[0].data[i], [length1, 1]))
107
+ np.testing.assert_array_equal(
108
+ all_output[length0:length0 + length1, i],
109
+ np.tile(samples[1].outputs[0].data[i], [length1, 1]))
110
+ np.testing.assert_array_equal(
111
+ all_hint[length0:length0 + length1, i],
112
+ samples[1].features.hints[0].data[:length1, i])
113
+
114
+
115
+ if __name__ == '__main__':
116
+ absltest.main()
benchmarks/CLRS/env/decoders.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """decoders utilities."""
16
+
17
+ import functools
18
+ from typing import Dict, Optional
19
+
20
+ import chex
21
+ from clrs._src import probing
22
+ from clrs._src import specs
23
+ import haiku as hk
24
+ import jax
25
+ import jax.numpy as jnp
26
+
27
+ _Array = chex.Array
28
+ _DataPoint = probing.DataPoint
29
+ _Location = specs.Location
30
+ _Spec = specs.Spec
31
+ _Stage = specs.Stage
32
+ _Type = specs.Type
33
+
34
+
35
+ def log_sinkhorn(x: _Array, steps: int, temperature: float, zero_diagonal: bool,
36
+ noise_rng_key: Optional[_Array]) -> _Array:
37
+ """Sinkhorn operator in log space, to postprocess permutation pointer logits.
38
+
39
+ Args:
40
+ x: input of shape [..., n, n], a batch of square matrices.
41
+ steps: number of iterations.
42
+ temperature: temperature parameter (as temperature approaches zero, the
43
+ output approaches a permutation matrix).
44
+ zero_diagonal: whether to force the diagonal logits towards -inf.
45
+ noise_rng_key: key to add Gumbel noise.
46
+
47
+ Returns:
48
+ Elementwise logarithm of a doubly-stochastic matrix (a matrix with
49
+ non-negative elements whose rows and columns sum to 1).
50
+ """
51
+ assert x.ndim >= 2
52
+ assert x.shape[-1] == x.shape[-2]
53
+ if noise_rng_key is not None:
54
+ # Add standard Gumbel noise (see https://arxiv.org/abs/1802.08665)
55
+ noise = -jnp.log(-jnp.log(jax.random.uniform(noise_rng_key,
56
+ x.shape) + 1e-12) + 1e-12)
57
+ x = x + noise
58
+ x /= temperature
59
+ if zero_diagonal:
60
+ x = x - 1e6 * jnp.eye(x.shape[-1])
61
+ for _ in range(steps):
62
+ x = jax.nn.log_softmax(x, axis=-1)
63
+ x = jax.nn.log_softmax(x, axis=-2)
64
+ return x
65
+
66
+
67
+ def construct_decoders(loc: str, t: str, hidden_dim: int, nb_dims: int,
68
+ name: str):
69
+ """Constructs decoders."""
70
+ linear = functools.partial(hk.Linear, name=f"{name}_dec_linear")
71
+ if loc == _Location.NODE:
72
+ # Node decoders.
73
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
74
+ decoders = (linear(1),)
75
+ elif t == _Type.CATEGORICAL:
76
+ decoders = (linear(nb_dims),)
77
+ elif t in [_Type.POINTER, _Type.PERMUTATION_POINTER]:
78
+ decoders = (linear(hidden_dim), linear(hidden_dim), linear(hidden_dim),
79
+ linear(1))
80
+ else:
81
+ raise ValueError(f"Invalid Type {t}")
82
+
83
+ elif loc == _Location.EDGE:
84
+ # Edge decoders.
85
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
86
+ decoders = (linear(1), linear(1), linear(1))
87
+ elif t == _Type.CATEGORICAL:
88
+ decoders = (linear(nb_dims), linear(nb_dims), linear(nb_dims))
89
+ elif t == _Type.POINTER:
90
+ decoders = (linear(hidden_dim), linear(hidden_dim),
91
+ linear(hidden_dim), linear(hidden_dim), linear(1))
92
+ else:
93
+ raise ValueError(f"Invalid Type {t}")
94
+
95
+ elif loc == _Location.GRAPH:
96
+ # Graph decoders.
97
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
98
+ decoders = (linear(1), linear(1))
99
+ elif t == _Type.CATEGORICAL:
100
+ decoders = (linear(nb_dims), linear(nb_dims))
101
+ elif t == _Type.POINTER:
102
+ decoders = (linear(1), linear(1),
103
+ linear(1))
104
+ else:
105
+ raise ValueError(f"Invalid Type {t}")
106
+
107
+ else:
108
+ raise ValueError(f"Invalid Location {loc}")
109
+
110
+ return decoders
111
+
112
+
113
+ def construct_diff_decoders(name: str):
114
+ """Constructs diff decoders."""
115
+ linear = functools.partial(hk.Linear, name=f"{name}_diffdec_linear")
116
+ decoders = {}
117
+ decoders[_Location.NODE] = linear(1)
118
+ decoders[_Location.EDGE] = (linear(1), linear(1), linear(1))
119
+ decoders[_Location.GRAPH] = (linear(1), linear(1))
120
+
121
+ return decoders
122
+
123
+
124
+ def postprocess(spec: _Spec, preds: Dict[str, _Array],
125
+ sinkhorn_temperature: float,
126
+ sinkhorn_steps: int,
127
+ hard: bool) -> Dict[str, _DataPoint]:
128
+ """Postprocesses decoder output.
129
+
130
+ This is done on outputs in order to score performance, and on hints in
131
+ order to score them but also in order to feed them back to the model.
132
+ At scoring time, the postprocessing mode is "hard", logits will be
133
+ arg-maxed and masks will be thresholded. However, for the case of the hints
134
+ that are fed back in the model, the postprocessing can be hard or soft,
135
+ depending on whether we want to let gradients flow through them or not.
136
+
137
+ Args:
138
+ spec: The spec of the algorithm whose outputs/hints we are postprocessing.
139
+ preds: Output and/or hint predictions, as produced by decoders.
140
+ sinkhorn_temperature: Parameter for the sinkhorn operator on permutation
141
+ pointers.
142
+ sinkhorn_steps: Parameter for the sinkhorn operator on permutation
143
+ pointers.
144
+ hard: whether to do hard postprocessing, which involves argmax for
145
+ MASK_ONE, CATEGORICAL and POINTERS, thresholding for MASK, and stop
146
+ gradient through for SCALAR. If False, soft postprocessing will be used,
147
+ with softmax, sigmoid and gradients allowed.
148
+ Returns:
149
+ The postprocessed `preds`. In "soft" post-processing, POINTER types will
150
+ change to SOFT_POINTER, so encoders know they do not need to be
151
+ pre-processed before feeding them back in.
152
+ """
153
+ result = {}
154
+ for name in preds.keys():
155
+ _, loc, t = spec[name]
156
+ new_t = t
157
+ data = preds[name]
158
+ if t == _Type.SCALAR:
159
+ if hard:
160
+ data = jax.lax.stop_gradient(data)
161
+ elif t == _Type.MASK:
162
+ if hard:
163
+ data = (data > 0.0) * 1.0
164
+ else:
165
+ data = jax.nn.sigmoid(data)
166
+ elif t in [_Type.MASK_ONE, _Type.CATEGORICAL]:
167
+ cat_size = data.shape[-1]
168
+ if hard:
169
+ best = jnp.argmax(data, -1)
170
+ data = hk.one_hot(best, cat_size)
171
+ else:
172
+ data = jax.nn.softmax(data, axis=-1)
173
+ elif t == _Type.POINTER:
174
+ if hard:
175
+ data = jnp.argmax(data, -1).astype(float)
176
+ else:
177
+ data = jax.nn.softmax(data, -1)
178
+ new_t = _Type.SOFT_POINTER
179
+ elif t == _Type.PERMUTATION_POINTER:
180
+ # Convert the matrix of logits to a doubly stochastic matrix.
181
+ data = log_sinkhorn(
182
+ x=data,
183
+ steps=sinkhorn_steps,
184
+ temperature=sinkhorn_temperature,
185
+ zero_diagonal=True,
186
+ noise_rng_key=None)
187
+ data = jnp.exp(data)
188
+ if hard:
189
+ data = jax.nn.one_hot(jnp.argmax(data, axis=-1), data.shape[-1])
190
+ else:
191
+ raise ValueError("Invalid type")
192
+ result[name] = probing.DataPoint(
193
+ name=name, location=loc, type_=new_t, data=data)
194
+
195
+ return result
196
+
197
+
198
+ def decode_fts(
199
+ decoders,
200
+ spec: _Spec,
201
+ h_t: _Array,
202
+ adj_mat: _Array,
203
+ edge_fts: _Array,
204
+ graph_fts: _Array,
205
+ inf_bias: bool,
206
+ inf_bias_edge: bool,
207
+ repred: bool,
208
+ ):
209
+ """Decodes node, edge and graph features."""
210
+ output_preds = {}
211
+ hint_preds = {}
212
+
213
+ for name in decoders:
214
+ decoder = decoders[name]
215
+ stage, loc, t = spec[name]
216
+
217
+ if loc == _Location.NODE:
218
+ preds = _decode_node_fts(decoder, t, h_t, edge_fts, adj_mat,
219
+ inf_bias, repred)
220
+ elif loc == _Location.EDGE:
221
+ preds = _decode_edge_fts(decoder, t, h_t, edge_fts, adj_mat,
222
+ inf_bias_edge)
223
+ elif loc == _Location.GRAPH:
224
+ preds = _decode_graph_fts(decoder, t, h_t, graph_fts)
225
+ else:
226
+ raise ValueError("Invalid output type")
227
+
228
+ if stage == _Stage.OUTPUT:
229
+ output_preds[name] = preds
230
+ elif stage == _Stage.HINT:
231
+ hint_preds[name] = preds
232
+ else:
233
+ raise ValueError(f"Found unexpected decoder {name}")
234
+
235
+ return hint_preds, output_preds
236
+
237
+
238
+ def _decode_node_fts(decoders, t: str, h_t: _Array, edge_fts: _Array,
239
+ adj_mat: _Array, inf_bias: bool, repred: bool) -> _Array:
240
+ """Decodes node features."""
241
+
242
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
243
+ preds = jnp.squeeze(decoders[0](h_t), -1)
244
+ elif t == _Type.CATEGORICAL:
245
+ preds = decoders[0](h_t)
246
+ elif t in [_Type.POINTER, _Type.PERMUTATION_POINTER]:
247
+ p_1 = decoders[0](h_t)
248
+ p_2 = decoders[1](h_t)
249
+ p_3 = decoders[2](edge_fts)
250
+
251
+ p_e = jnp.expand_dims(p_2, -2) + p_3
252
+ p_m = jnp.maximum(jnp.expand_dims(p_1, -2),
253
+ jnp.transpose(p_e, (0, 2, 1, 3)))
254
+
255
+ preds = jnp.squeeze(decoders[3](p_m), -1)
256
+
257
+ if inf_bias:
258
+ per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True)
259
+ preds = jnp.where(adj_mat > 0.5,
260
+ preds,
261
+ jnp.minimum(-1.0, per_batch_min - 1.0))
262
+ if t == _Type.PERMUTATION_POINTER:
263
+ if repred: # testing or validation, no Gumbel noise
264
+ preds = log_sinkhorn(
265
+ x=preds, steps=10, temperature=0.1,
266
+ zero_diagonal=True, noise_rng_key=None)
267
+ else: # training, add Gumbel noise
268
+ preds = log_sinkhorn(
269
+ x=preds, steps=10, temperature=0.1,
270
+ zero_diagonal=True, noise_rng_key=hk.next_rng_key())
271
+ else:
272
+ raise ValueError("Invalid output type")
273
+
274
+ return preds
275
+
276
+
277
+ def _decode_edge_fts(decoders, t: str, h_t: _Array, edge_fts: _Array,
278
+ adj_mat: _Array, inf_bias_edge: bool) -> _Array:
279
+ """Decodes edge features."""
280
+
281
+ pred_1 = decoders[0](h_t)
282
+ pred_2 = decoders[1](h_t)
283
+ pred_e = decoders[2](edge_fts)
284
+ pred = (jnp.expand_dims(pred_1, -2) + jnp.expand_dims(pred_2, -3) + pred_e)
285
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
286
+ preds = jnp.squeeze(pred, -1)
287
+ elif t == _Type.CATEGORICAL:
288
+ preds = pred
289
+ elif t == _Type.POINTER:
290
+ pred_2 = decoders[3](h_t)
291
+
292
+ p_m = jnp.maximum(jnp.expand_dims(pred, -2),
293
+ jnp.expand_dims(
294
+ jnp.expand_dims(pred_2, -3), -3))
295
+
296
+ preds = jnp.squeeze(decoders[4](p_m), -1)
297
+ else:
298
+ raise ValueError("Invalid output type")
299
+ if inf_bias_edge and t in [_Type.MASK, _Type.MASK_ONE]:
300
+ per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True)
301
+ preds = jnp.where(adj_mat > 0.5,
302
+ preds,
303
+ jnp.minimum(-1.0, per_batch_min - 1.0))
304
+
305
+ return preds
306
+
307
+
308
+ def _decode_graph_fts(decoders, t: str, h_t: _Array,
309
+ graph_fts: _Array) -> _Array:
310
+ """Decodes graph features."""
311
+
312
+ gr_emb = jnp.max(h_t, axis=-2)
313
+ pred_n = decoders[0](gr_emb)
314
+ pred_g = decoders[1](graph_fts)
315
+ pred = pred_n + pred_g
316
+ if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
317
+ preds = jnp.squeeze(pred, -1)
318
+ elif t == _Type.CATEGORICAL:
319
+ preds = pred
320
+ elif t == _Type.POINTER:
321
+ pred_2 = decoders[2](h_t)
322
+ ptr_p = jnp.expand_dims(pred, 1) + jnp.transpose(pred_2, (0, 2, 1))
323
+ preds = jnp.squeeze(ptr_p, 1)
324
+ else:
325
+ raise ValueError("Invalid output type")
326
+
327
+ return preds
328
+
329
+
330
+ def maybe_decode_diffs(
331
+ diff_decoders,
332
+ h_t: _Array,
333
+ edge_fts: _Array,
334
+ graph_fts: _Array,
335
+ decode_diffs: bool,
336
+ ) -> Optional[Dict[str, _Array]]:
337
+ """Optionally decodes node, edge and graph diffs."""
338
+
339
+ if decode_diffs:
340
+ preds = {}
341
+ node = _Location.NODE
342
+ edge = _Location.EDGE
343
+ graph = _Location.GRAPH
344
+ preds[node] = _decode_node_diffs(diff_decoders[node], h_t)
345
+ preds[edge] = _decode_edge_diffs(diff_decoders[edge], h_t, edge_fts)
346
+ preds[graph] = _decode_graph_diffs(diff_decoders[graph], h_t, graph_fts)
347
+
348
+ else:
349
+ preds = None
350
+
351
+ return preds
352
+
353
+
354
+ def _decode_node_diffs(decoders, h_t: _Array) -> _Array:
355
+ """Decodes node diffs."""
356
+ return jnp.squeeze(decoders(h_t), -1)
357
+
358
+
359
+ def _decode_edge_diffs(decoders, h_t: _Array, edge_fts: _Array) -> _Array:
360
+ """Decodes edge diffs."""
361
+
362
+ e_pred_1 = decoders[0](h_t)
363
+ e_pred_2 = decoders[1](h_t)
364
+ e_pred_e = decoders[2](edge_fts)
365
+ preds = jnp.squeeze(
366
+ jnp.expand_dims(e_pred_1, -1) + jnp.expand_dims(e_pred_2, -2) + e_pred_e,
367
+ -1,
368
+ )
369
+
370
+ return preds
371
+
372
+
373
+ def _decode_graph_diffs(decoders, h_t: _Array, graph_fts: _Array) -> _Array:
374
+ """Decodes graph diffs."""
375
+
376
+ gr_emb = jnp.max(h_t, axis=-2)
377
+ g_pred_n = decoders[0](gr_emb)
378
+ g_pred_g = decoders[1](graph_fts)
379
+ preds = jnp.squeeze(g_pred_n + g_pred_g, -1)
380
+
381
+ return preds
benchmarks/CLRS/env/decoders_test.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `decoders.py`."""
17
+
18
+ from absl.testing import absltest
19
+
20
+ import chex
21
+ from clrs._src import decoders
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+
26
+ class DecodersTest(absltest.TestCase):
27
+
28
+ def test_log_sinkhorn(self):
29
+ x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
30
+ y = jnp.exp(decoders.log_sinkhorn(x, steps=10, temperature=1.0,
31
+ zero_diagonal=False,
32
+ noise_rng_key=None))
33
+ chex.assert_trees_all_close(jnp.sum(y, axis=-1), 1., atol=1e-4)
34
+ chex.assert_trees_all_close(jnp.sum(y, axis=-2), 1., atol=1e-4)
35
+
36
+ def test_log_sinkhorn_zero_diagonal(self):
37
+ x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
38
+ y = jnp.exp(decoders.log_sinkhorn(x, steps=10, temperature=1.0,
39
+ zero_diagonal=True,
40
+ noise_rng_key=None))
41
+ chex.assert_trees_all_close(jnp.sum(y, axis=-1), 1., atol=1e-4)
42
+ chex.assert_trees_all_close(jnp.sum(y, axis=-2), 1., atol=1e-4)
43
+ chex.assert_trees_all_close(jnp.sum(y.diagonal()), 0., atol=1e-4)
44
+
45
+
46
+ if __name__ == '__main__':
47
+ absltest.main()
benchmarks/CLRS/env/encoders.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Encoder utilities."""
16
+
17
+ import functools
18
+ import chex
19
+ from clrs._src import probing
20
+ from clrs._src import specs
21
+ import haiku as hk
22
+ import jax.numpy as jnp
23
+
24
+ _Array = chex.Array
25
+ _DataPoint = probing.DataPoint
26
+ _Location = specs.Location
27
+ _Spec = specs.Spec
28
+ _Stage = specs.Stage
29
+ _Type = specs.Type
30
+
31
+
32
+ def construct_encoders(stage: str, loc: str, t: str,
33
+ hidden_dim: int, init: str, name: str):
34
+ """Constructs encoders."""
35
+ if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR:
36
+ initialiser = hk.initializers.TruncatedNormal(
37
+ stddev=1.0 / jnp.sqrt(hidden_dim))
38
+ elif init in ['default', 'xavier_on_scalars']:
39
+ initialiser = None
40
+ else:
41
+ raise ValueError(f'Encoder initialiser {init} not supported.')
42
+ linear = functools.partial(
43
+ hk.Linear,
44
+ w_init=initialiser,
45
+ name=f'{name}_enc_linear')
46
+ encoders = [linear(hidden_dim)]
47
+ if loc == _Location.EDGE and t == _Type.POINTER:
48
+ # Edge pointers need two-way encoders.
49
+ encoders.append(linear(hidden_dim))
50
+
51
+ return encoders
52
+
53
+
54
+ def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint:
55
+ """Pre-process data point.
56
+
57
+ Make sure that the data is ready to be encoded into features.
58
+ If the data is of POINTER type, we expand the compressed index representation
59
+ to a full one-hot. But if the data is a SOFT_POINTER, the representation
60
+ is already expanded and we just overwrite the type as POINTER so that
61
+ it is treated as such for encoding.
62
+
63
+ Args:
64
+ dp: A DataPoint to prepare for encoding.
65
+ nb_nodes: Number of nodes in the graph, necessary to expand pointers to
66
+ the right dimension.
67
+ Returns:
68
+ The datapoint, with data and possibly type modified.
69
+ """
70
+ new_type = dp.type_
71
+ if dp.type_ == _Type.POINTER:
72
+ data = hk.one_hot(dp.data, nb_nodes)
73
+ else:
74
+ data = dp.data.astype(jnp.float32)
75
+ if dp.type_ == _Type.SOFT_POINTER:
76
+ new_type = _Type.POINTER
77
+ dp = probing.DataPoint(
78
+ name=dp.name, location=dp.location, type_=new_type, data=data)
79
+
80
+ return dp
81
+
82
+
83
+ def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array:
84
+ """Accumulates adjacency matrix."""
85
+ if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
86
+ _Type.PERMUTATION_POINTER]:
87
+ adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5)
88
+ elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
89
+ adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0)
90
+
91
+ return (adj_mat > 0.).astype('float32') # pytype: disable=attribute-error # numpy-scalars
92
+
93
+
94
+ def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array:
95
+ """Encodes and accumulates edge features."""
96
+ if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
97
+ _Type.PERMUTATION_POINTER]:
98
+ encoding = _encode_inputs(encoders, dp)
99
+ edge_fts += encoding
100
+
101
+ elif dp.location == _Location.EDGE:
102
+ encoding = _encode_inputs(encoders, dp)
103
+ if dp.type_ == _Type.POINTER:
104
+ # Aggregate pointer contributions across sender and receiver nodes.
105
+ encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1))
106
+ edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
107
+ else:
108
+ edge_fts += encoding
109
+
110
+ return edge_fts
111
+
112
+
113
+ def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array:
114
+ """Encodes and accumulates node features."""
115
+ is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER])
116
+ if ((dp.location == _Location.NODE and not is_pointer) or
117
+ (dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
118
+ encoding = _encode_inputs(encoders, dp)
119
+ node_fts += encoding
120
+
121
+ return node_fts
122
+
123
+
124
+ def accum_graph_fts(encoders, dp: _DataPoint,
125
+ graph_fts: _Array) -> _Array:
126
+ """Encodes and accumulates graph features."""
127
+ if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
128
+ encoding = _encode_inputs(encoders, dp)
129
+ graph_fts += encoding
130
+
131
+ return graph_fts
132
+
133
+
134
+ def _encode_inputs(encoders, dp: _DataPoint) -> _Array:
135
+ if dp.type_ == _Type.CATEGORICAL:
136
+ encoding = encoders[0](dp.data)
137
+ else:
138
+ encoding = encoders[0](jnp.expand_dims(dp.data, -1))
139
+ return encoding
benchmarks/CLRS/env/evaluation.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Model base classes and utilities."""
17
+
18
+ from typing import Dict, List, Tuple
19
+ import chex
20
+ from clrs._src import probing
21
+ from clrs._src import specs
22
+ import numpy as np
23
+
24
+
25
+ _Array = chex.Array
26
+ Result = Dict[str, probing.DataPoint]
27
+
28
+
29
+ def fuse_perm_and_mask(perm: probing.DataPoint,
30
+ mask: probing.DataPoint) -> probing.DataPoint:
31
+ """Replace permutation pointers active in the mask with self-pointers.
32
+
33
+ Args:
34
+ perm: a node permutation_pointer; data shape is expected to be
35
+ [..., N, N], and ideally one-hot over the last two dimensions, although
36
+ this method does not check for one-hotness.
37
+ mask: a mask_one over nodes; data shape is expected to be
38
+ [..., N], and ideally one-hot over the last dimension, although
39
+ this method does not check for one-hotness.
40
+ Returns:
41
+ A node pointer with shape [..., N].
42
+ """
43
+ assert perm.type_ == specs.Type.PERMUTATION_POINTER
44
+ assert perm.location == specs.Location.NODE
45
+ assert mask.name == perm.name + '_mask'
46
+ assert mask.type_ == specs.Type.MASK_ONE
47
+ assert mask.location == specs.Location.NODE
48
+ assert perm.data.shape[-1] == perm.data.shape[-2]
49
+ assert perm.data.shape[:-1] == mask.data.shape
50
+ data = np.where(mask.data > 0.5,
51
+ np.arange(perm.data.shape[-1]), # self-pointers
52
+ np.argmax(perm.data, axis=-1)) # original pointers
53
+ return probing.DataPoint(name=perm.name,
54
+ type_=specs.Type.POINTER,
55
+ location=perm.location,
56
+ data=data)
57
+
58
+
59
+ def _reduce_permutations_tuple(
60
+ targets: Tuple[probing.DataPoint, ...]) -> Tuple[probing.DataPoint, ...]:
61
+ """Reduce node pointer + mask_one permutation to just node pointer."""
62
+ out_targets = []
63
+ n_perms = 0
64
+ i = 0
65
+ while i < len(targets):
66
+ truth = targets[i]
67
+ if truth.type_ != specs.Type.PERMUTATION_POINTER:
68
+ out_targets.append(truth)
69
+ i += 1
70
+ continue
71
+ truth_mask = targets[i + 1]
72
+ out_targets.append(fuse_perm_and_mask(truth, truth_mask))
73
+ i += 2
74
+ n_perms += 1
75
+
76
+ assert len(out_targets) == len(targets) - n_perms
77
+ return tuple(out_targets)
78
+
79
+
80
+ def _reduce_permutations_dict(predictions: Result) -> Result:
81
+ """Reduce node pointer + mask_one permutation to just node pointer."""
82
+ out_preds = {}
83
+ n_perms = 0
84
+ for k, pred in predictions.items():
85
+ if (k.endswith('_mask') and k[:-5] in predictions and
86
+ predictions[k[:-5]].type_ == specs.Type.PERMUTATION_POINTER):
87
+ # This mask will be processed with its associated permutation datapoint
88
+ continue
89
+ if pred.type_ != specs.Type.PERMUTATION_POINTER:
90
+ out_preds[k] = pred
91
+ continue
92
+ pred_mask = predictions[k + '_mask']
93
+ out_preds[k] = fuse_perm_and_mask(pred, pred_mask)
94
+ n_perms += 1
95
+
96
+ assert len(out_preds) == len(predictions) - n_perms
97
+ return out_preds
98
+
99
+
100
+ def evaluate_hints(
101
+ hints: Tuple[probing.DataPoint, ...],
102
+ lengths: _Array,
103
+ hint_preds: List[Result],
104
+ ) -> Dict[str, _Array]:
105
+ """Evaluate hint predictions."""
106
+ evals = {}
107
+ hints = _reduce_permutations_tuple(hints)
108
+ hint_preds = [_reduce_permutations_dict(h) for h in hint_preds]
109
+ for truth in hints:
110
+ assert truth.name in hint_preds[0]
111
+ eval_along_time = [_evaluate(truth, p[truth.name],
112
+ idx=i+1, lengths=lengths)
113
+ for (i, p) in enumerate(hint_preds)]
114
+ evals[truth.name] = np.sum(
115
+ [x * np.sum(i+1 < lengths)
116
+ for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1)
117
+ evals[truth.name + '_along_time'] = np.array(eval_along_time)
118
+
119
+ # Unlike outputs, the hints sometimes include scalars, which don't have
120
+ # a meaningful eval score. So we don't compute a global 'hint score' as we
121
+ # do for outputs.
122
+ return evals
123
+
124
+
125
+ def evaluate(
126
+ outputs: Tuple[probing.DataPoint, ...],
127
+ predictions: Result,
128
+ ) -> Dict[str, float]:
129
+ """Evaluate output predictions."""
130
+ evals = {}
131
+ outputs = _reduce_permutations_tuple(outputs)
132
+ predictions = _reduce_permutations_dict(predictions)
133
+ for truth in outputs:
134
+ assert truth.name in predictions
135
+ pred = predictions[truth.name]
136
+ evals[truth.name] = _evaluate(truth, pred)
137
+ # Return a single scalar score that is the mean of all output scores.
138
+ evals['score'] = sum([v.item() for v in evals.values()]) / len(evals)
139
+ return evals
140
+
141
+
142
+ def _evaluate(truth, pred, idx=None, lengths=None):
143
+ """Evaluate single prediction of hint or output."""
144
+ assert pred.name == truth.name
145
+ assert pred.location == truth.location
146
+ assert pred.type_ == truth.type_
147
+
148
+ if truth.type_ not in _EVAL_FN:
149
+ raise ValueError('Invalid type')
150
+ truth_data = truth.data
151
+ pred_data = pred.data
152
+ if idx is not None:
153
+ if np.all(idx >= lengths):
154
+ return 0.
155
+ truth_data = truth_data[idx][idx < lengths]
156
+ pred_data = pred_data[idx < lengths]
157
+ return _EVAL_FN[truth.type_](pred_data, truth_data)
158
+
159
+
160
+ def _eval_one(pred, truth):
161
+ mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
162
+ return np.sum(
163
+ (np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)
164
+
165
+
166
+ def _mask_fn(pred, truth):
167
+ """Evaluate outputs of type MASK, and account for any class imbalance."""
168
+ mask = (truth != specs.OutputClass.MASKED).astype(np.float32)
169
+
170
+ # Use F1 score for the masked outputs to address any imbalance
171
+ tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
172
+ fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask)
173
+ fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask)
174
+
175
+ # Protect against division by zero
176
+ if tp + fp > 0:
177
+ precision = tp / (tp + fp)
178
+ else:
179
+ precision = np.float32(1.0)
180
+ if tp + fn > 0:
181
+ recall = tp / (tp + fn)
182
+ else:
183
+ recall = np.float32(1.0)
184
+
185
+ if precision + recall > 0.0:
186
+ f_1 = 2.0 * precision * recall / (precision + recall)
187
+ else:
188
+ f_1 = np.float32(0.0)
189
+
190
+ return f_1
191
+
192
+ _EVAL_FN = {
193
+ specs.Type.SCALAR:
194
+ lambda pred, truth: np.mean((pred - truth)**2),
195
+ specs.Type.MASK: _mask_fn,
196
+ specs.Type.MASK_ONE:
197
+ _eval_one,
198
+ specs.Type.CATEGORICAL:
199
+ _eval_one,
200
+ specs.Type.POINTER:
201
+ lambda pred, truth: np.mean((pred == truth) * 1.0),
202
+ }
benchmarks/CLRS/env/evaluation_test.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `evaluation.py`."""
17
+
18
+ from absl.testing import absltest
19
+ from clrs._src import evaluation
20
+ from clrs._src import probing
21
+ from clrs._src import specs
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ class EvaluationTest(absltest.TestCase):
29
+
30
+ def test_reduce_permutations(self):
31
+ b = 8
32
+ n = 16
33
+ pred = jnp.stack([jax.random.permutation(jax.random.PRNGKey(i), n)
34
+ for i in range(b)])
35
+ heads = jax.random.randint(jax.random.PRNGKey(42), (b,), 0, n)
36
+
37
+ perm = probing.DataPoint(name='test',
38
+ type_=specs.Type.PERMUTATION_POINTER,
39
+ location=specs.Location.NODE,
40
+ data=jax.nn.one_hot(pred, n))
41
+ mask = probing.DataPoint(name='test_mask',
42
+ type_=specs.Type.MASK_ONE,
43
+ location=specs.Location.NODE,
44
+ data=jax.nn.one_hot(heads, n))
45
+ output = evaluation.fuse_perm_and_mask(perm=perm, mask=mask)
46
+ expected_output = np.array(pred)
47
+ expected_output[np.arange(b), heads] = heads
48
+ self.assertEqual(output.name, 'test')
49
+ self.assertEqual(output.type_, specs.Type.POINTER)
50
+ self.assertEqual(output.location, specs.Location.NODE)
51
+ np.testing.assert_allclose(output.data, expected_output)
52
+
53
+
54
+ if __name__ == '__main__':
55
+ absltest.main()
benchmarks/CLRS/env/losses.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for calculating losses."""
16
+
17
+ from typing import Dict, List, Tuple
18
+ import chex
19
+ from clrs._src import probing
20
+ from clrs._src import specs
21
+
22
+ import haiku as hk
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ _Array = chex.Array
27
+ _DataPoint = probing.DataPoint
28
+ _Location = specs.Location
29
+ _OutputClass = specs.OutputClass
30
+ _PredTrajectory = Dict[str, _Array]
31
+ _PredTrajectories = List[_PredTrajectory]
32
+ _Type = specs.Type
33
+
34
+ EPS = 1e-12
35
+
36
+
37
+ def _expand_to(x: _Array, y: _Array) -> _Array:
38
+ while len(y.shape) > len(x.shape):
39
+ x = jnp.expand_dims(x, -1)
40
+ return x
41
+
42
+
43
+ def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array:
44
+ return jnp.broadcast_to(_expand_to(x, y), y.shape)
45
+
46
+
47
+ def output_loss_chunked(truth: _DataPoint, pred: _Array,
48
+ is_last: _Array, nb_nodes: int) -> float:
49
+ """Output loss for time-chunked training."""
50
+
51
+ mask = None
52
+
53
+ if truth.type_ == _Type.SCALAR:
54
+ loss = (pred - truth.data)**2
55
+
56
+ elif truth.type_ == _Type.MASK:
57
+ loss = (
58
+ jnp.maximum(pred, 0) - pred * truth.data +
59
+ jnp.log1p(jnp.exp(-jnp.abs(pred))))
60
+ mask = (truth.data != _OutputClass.MASKED)
61
+
62
+ elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
63
+ mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1)
64
+ masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
65
+ jnp.float32)
66
+ loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1)
67
+
68
+ elif truth.type_ == _Type.POINTER:
69
+ loss = -jnp.sum(
70
+ hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1)
71
+
72
+ elif truth.type_ == _Type.PERMUTATION_POINTER:
73
+ # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
74
+ # Compute the cross entropy between doubly stochastic pred and truth_data
75
+ loss = -jnp.sum(truth.data * pred, axis=-1)
76
+
77
+ if mask is not None:
78
+ mask = mask * _expand_and_broadcast_to(is_last, loss)
79
+ else:
80
+ mask = _expand_and_broadcast_to(is_last, loss)
81
+ total_mask = jnp.maximum(jnp.sum(mask), EPS)
82
+ return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask
83
+
84
+
85
+ def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
86
+ """Output loss for full-sample training."""
87
+
88
+ if truth.type_ == _Type.SCALAR:
89
+ total_loss = jnp.mean((pred - truth.data)**2)
90
+
91
+ elif truth.type_ == _Type.MASK:
92
+ loss = (
93
+ jnp.maximum(pred, 0) - pred * truth.data +
94
+ jnp.log1p(jnp.exp(-jnp.abs(pred))))
95
+ mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
96
+ total_loss = jnp.sum(loss * mask) / jnp.sum(mask)
97
+
98
+ elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
99
+ masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
100
+ jnp.float32)
101
+ total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) /
102
+ jnp.sum(truth.data == _OutputClass.POSITIVE))
103
+
104
+ elif truth.type_ == _Type.POINTER:
105
+ total_loss = (
106
+ jnp.mean(-jnp.sum(
107
+ hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred),
108
+ axis=-1)))
109
+
110
+ elif truth.type_ == _Type.PERMUTATION_POINTER:
111
+ # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
112
+ # Compute the cross entropy between doubly stochastic pred and truth_data
113
+ total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1))
114
+
115
+ return total_loss
116
+
117
+
118
+ def hint_loss_chunked(
119
+ truth: _DataPoint,
120
+ pred: _Array,
121
+ is_first: _Array,
122
+ nb_nodes: int,
123
+ ):
124
+ """Hint loss for time-chunked training."""
125
+ loss, mask = _hint_loss(
126
+ truth_data=truth.data,
127
+ truth_type=truth.type_,
128
+ pred=pred,
129
+ nb_nodes=nb_nodes,
130
+ )
131
+
132
+ mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32)
133
+ loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
134
+ return loss
135
+
136
+
137
+ def hint_loss(
138
+ truth: _DataPoint,
139
+ preds: List[_Array],
140
+ lengths: _Array,
141
+ nb_nodes: int,
142
+ verbose: bool = False,
143
+ ):
144
+ """Hint loss for full-sample training."""
145
+ total_loss = 0.
146
+ verbose_loss = {}
147
+ length = truth.data.shape[0] - 1
148
+
149
+ loss, mask = _hint_loss(
150
+ truth_data=truth.data[1:],
151
+ truth_type=truth.type_,
152
+ pred=jnp.stack(preds),
153
+ nb_nodes=nb_nodes,
154
+ )
155
+ mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss)
156
+ loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
157
+ if verbose:
158
+ verbose_loss['loss_' + truth.name] = loss
159
+ else:
160
+ total_loss += loss
161
+
162
+ return verbose_loss if verbose else total_loss
163
+
164
+
165
+ def _hint_loss(
166
+ truth_data: _Array,
167
+ truth_type: str,
168
+ pred: _Array,
169
+ nb_nodes: int,
170
+ ) -> Tuple[_Array, _Array]:
171
+ """Hint loss helper."""
172
+ mask = None
173
+ if truth_type == _Type.SCALAR:
174
+ loss = (pred - truth_data)**2
175
+
176
+ elif truth_type == _Type.MASK:
177
+ loss = (jnp.maximum(pred, 0) - pred * truth_data +
178
+ jnp.log1p(jnp.exp(-jnp.abs(pred))))
179
+ mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars
180
+
181
+ elif truth_type == _Type.MASK_ONE:
182
+ loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1,
183
+ keepdims=True)
184
+
185
+ elif truth_type == _Type.CATEGORICAL:
186
+ loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1)
187
+ mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype(
188
+ jnp.float32)
189
+
190
+ elif truth_type == _Type.POINTER:
191
+ loss = -jnp.sum(
192
+ hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred),
193
+ axis=-1)
194
+
195
+ elif truth_type == _Type.PERMUTATION_POINTER:
196
+ # Predictions are NxN logits aiming to represent a doubly stochastic matrix.
197
+ # Compute the cross entropy between doubly stochastic pred and truth_data
198
+ loss = -jnp.sum(truth_data * pred, axis=-1)
199
+
200
+ if mask is None:
201
+ mask = jnp.ones_like(loss)
202
+ return loss, mask
203
+
204
+
205
+ def _is_not_done_broadcast(lengths, i, tensor):
206
+ is_not_done = (lengths > i + 1) * 1.0
207
+ while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars
208
+ is_not_done = jnp.expand_dims(is_not_done, -1)
209
+ return is_not_done
benchmarks/CLRS/env/losses_test.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `losses.py`."""
17
+
18
+ from typing import Generator
19
+
20
+ from absl.testing import absltest
21
+ from absl.testing import parameterized
22
+
23
+ from clrs._src import dataset
24
+ from clrs._src import losses
25
+ from clrs._src import probing
26
+ from clrs._src import samplers
27
+ from clrs._src import specs
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+
32
+ _Array = np.ndarray
33
+ _Location = specs.Location
34
+
35
+
36
+ def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler:
37
+ sampler, _ = samplers.build_sampler(
38
+ algo,
39
+ seed=samplers.CLRS30['val']['seed'],
40
+ num_samples=samplers.CLRS30['val']['num_samples'],
41
+ length=nb_nodes,
42
+ )
43
+ return sampler
44
+
45
+
46
+ def _make_iterable_sampler(
47
+ algo: str, batch_size: int,
48
+ nb_nodes: int) -> Generator[samplers.Feedback, None, None]:
49
+ sampler = _make_sampler(algo, nb_nodes)
50
+ while True:
51
+ yield sampler.next(batch_size)
52
+
53
+
54
+ def _as_pred_data(x, nb_nodes, seed, batch_axis):
55
+ """Fake a prediction from a data point."""
56
+ # Permute along batch axis to make the prediction different.
57
+ key = jax.random.PRNGKey(seed)
58
+ data = jax.random.permutation(key, x.data, axis=batch_axis)
59
+ # Extend to one-hot for pointer types.
60
+ if x.type_ == specs.Type.POINTER:
61
+ return jax.nn.one_hot(data, nb_nodes)
62
+ return data
63
+
64
+
65
+ def _mask_datapoint(x, seed, t_axis=None):
66
+ """Add some masking to data."""
67
+ key = jax.random.PRNGKey(seed)
68
+ data = x.data
69
+ if x.type_ == specs.Type.MASK:
70
+ # mask some data at random
71
+ mask_shape = list(data.shape)
72
+ if t_axis is not None:
73
+ mask_shape[t_axis] = 1
74
+ mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
75
+ data = jnp.where(mask, specs.OutputClass.MASKED, data)
76
+ elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]:
77
+ # mask some data at random (all categories together)
78
+ mask_shape = list(data.shape)[:-1]
79
+ if t_axis is not None:
80
+ mask_shape[t_axis] = 1
81
+ mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
82
+ data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data)
83
+ return probing.DataPoint(name=x.name, location=x.location, type_=x.type_,
84
+ data=data)
85
+
86
+
87
+ def _rand_diff(seed, shape):
88
+ return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0
89
+
90
+
91
+ def _rand_mask(seed, shape, p=0.5):
92
+ return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float)
93
+
94
+
95
+ def invert(d):
96
+ """Dict of lists -> list of dicts."""
97
+ if d:
98
+ return [dict(zip(d, i)) for i in zip(*d.values())]
99
+
100
+
101
+ def _create_data(algo, nb_nodes):
102
+ batch_size = 8
103
+
104
+ ds = _make_iterable_sampler(algo, batch_size, nb_nodes)
105
+ full_sample = next(ds)
106
+
107
+ chunk_length = full_sample.features.lengths[0].astype(int)
108
+ chunked_ds = dataset.chunkify(
109
+ _make_iterable_sampler(algo, batch_size, nb_nodes),
110
+ chunk_length)
111
+ chunk_sample = next(chunked_ds)
112
+ return full_sample, chunk_sample
113
+
114
+
115
+ class FullVsChunkLossesTest(parameterized.TestCase):
116
+ """Test that the full and chunked versions of the losses match."""
117
+
118
+ # Test two algorithms with fixed-length, covering all data types
119
+ @parameterized.parameters('dfs', 'floyd_warshall')
120
+ def test_output_loss(self, algo):
121
+ nb_nodes = 16
122
+ full_sample, chunk_sample = _create_data(algo, nb_nodes)
123
+
124
+ # Calculate output loss.
125
+ for truth_full, truth_chunked in zip(full_sample.outputs,
126
+ chunk_sample.outputs):
127
+ chunk_output_loss = losses.output_loss_chunked(
128
+ truth=_mask_datapoint(truth_chunked, seed=0),
129
+ pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1),
130
+ is_last=chunk_sample.features.is_last,
131
+ nb_nodes=nb_nodes,
132
+ )
133
+ full_output_loss = losses.output_loss(
134
+ truth=_mask_datapoint(truth_full, seed=0),
135
+ pred=_as_pred_data(truth_full, nb_nodes, 0, 0),
136
+ nb_nodes=nb_nodes,
137
+ )
138
+ np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4)
139
+
140
+ @parameterized.parameters('dfs', 'floyd_warshall')
141
+ def test_hint_loss(self, algo):
142
+ nb_nodes = 16
143
+ full_sample, chunk_sample = _create_data(algo, nb_nodes)
144
+ for truth_full, truth_chunked in zip(full_sample.features.hints,
145
+ chunk_sample.features.hints):
146
+ np.testing.assert_array_equal(truth_full.data, truth_chunked.data)
147
+ pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1)
148
+ chunk_hint_loss = losses.hint_loss_chunked(
149
+ truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0),
150
+ pred=pred,
151
+ is_first=chunk_sample.features.is_first,
152
+ nb_nodes=nb_nodes,
153
+ )
154
+
155
+ full_preds = pred[1:]
156
+ full_hint_loss = losses.hint_loss(
157
+ truth=_mask_datapoint(truth_full, 1, t_axis=0),
158
+ preds=full_preds,
159
+ lengths=full_sample.features.lengths,
160
+ nb_nodes=nb_nodes,
161
+ )
162
+ np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4)
163
+
164
+
165
+ if __name__ == '__main__':
166
+ absltest.main()
benchmarks/CLRS/env/model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Model base classes and utilities."""
17
+
18
+ import abc
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ from clrs._src import probing
22
+ from clrs._src import samplers
23
+ from clrs._src import specs
24
+
25
+
26
+ Result = Dict[str, probing.DataPoint]
27
+
28
+
29
+ class Model(abc.ABC):
30
+ """Abstract base class for CLRS3-B models."""
31
+
32
+ def __init__(self, spec: Union[specs.Spec, List[specs.Spec]]):
33
+ """Set up the problem, prepare to predict on first task."""
34
+ if not isinstance(spec, list):
35
+ spec = [spec]
36
+ self._spec = spec
37
+
38
+ @abc.abstractmethod
39
+ def predict(self, features: samplers.Features) -> Result:
40
+ """Make predictions about the current task."""
41
+ pass
42
+
43
+ @abc.abstractmethod
44
+ def feedback(self, feedback: Optional[samplers.Feedback]):
45
+ """Advance to the next task, incorporating any available feedback."""
46
+ pass
benchmarks/CLRS/env/nets.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """JAX implementation of CLRS basic network."""
17
+
18
+ import functools
19
+
20
+ from typing import Dict, List, Optional, Tuple
21
+
22
+ import chex
23
+
24
+ from clrs._src import decoders
25
+ from clrs._src import encoders
26
+ from clrs._src import probing
27
+ from clrs._src import processors
28
+ from clrs._src import samplers
29
+ from clrs._src import specs
30
+
31
+ import haiku as hk
32
+ import jax
33
+ import jax.numpy as jnp
34
+
35
+
36
+ _Array = chex.Array
37
+ _DataPoint = probing.DataPoint
38
+ _Features = samplers.Features
39
+ _FeaturesChunked = samplers.FeaturesChunked
40
+ _Location = specs.Location
41
+ _Spec = specs.Spec
42
+ _Stage = specs.Stage
43
+ _Trajectory = samplers.Trajectory
44
+ _Type = specs.Type
45
+
46
+
47
+ @chex.dataclass
48
+ class _MessagePassingScanState:
49
+ hint_preds: chex.Array
50
+ output_preds: chex.Array
51
+ hiddens: chex.Array
52
+ lstm_state: Optional[hk.LSTMState]
53
+
54
+
55
+ @chex.dataclass
56
+ class _MessagePassingOutputChunked:
57
+ hint_preds: chex.Array
58
+ output_preds: chex.Array
59
+
60
+
61
+ @chex.dataclass
62
+ class MessagePassingStateChunked:
63
+ inputs: chex.Array
64
+ hints: chex.Array
65
+ is_first: chex.Array
66
+ hint_preds: chex.Array
67
+ hiddens: chex.Array
68
+ lstm_state: Optional[hk.LSTMState]
69
+
70
+
71
+ class Net(hk.Module):
72
+ """Building blocks (networks) used to encode and decode messages."""
73
+
74
+ def __init__(
75
+ self,
76
+ spec: List[_Spec],
77
+ hidden_dim: int,
78
+ encode_hints: bool,
79
+ decode_hints: bool,
80
+ processor_factory: processors.ProcessorFactory,
81
+ use_lstm: bool,
82
+ encoder_init: str,
83
+ dropout_prob: float,
84
+ hint_teacher_forcing: float,
85
+ hint_repred_mode='soft',
86
+ nb_dims=None,
87
+ nb_msg_passing_steps=1,
88
+ name: str = 'net',
89
+ ):
90
+ """Constructs a `Net`."""
91
+ super().__init__(name=name)
92
+
93
+ self._dropout_prob = dropout_prob
94
+ self._hint_teacher_forcing = hint_teacher_forcing
95
+ self._hint_repred_mode = hint_repred_mode
96
+ self.spec = spec
97
+ self.hidden_dim = hidden_dim
98
+ self.encode_hints = encode_hints
99
+ self.decode_hints = decode_hints
100
+ self.processor_factory = processor_factory
101
+ self.nb_dims = nb_dims
102
+ self.use_lstm = use_lstm
103
+ self.encoder_init = encoder_init
104
+ self.nb_msg_passing_steps = nb_msg_passing_steps
105
+
106
+ def _msg_passing_step(self,
107
+ mp_state: _MessagePassingScanState,
108
+ i: int,
109
+ hints: List[_DataPoint],
110
+ repred: bool,
111
+ lengths: chex.Array,
112
+ batch_size: int,
113
+ nb_nodes: int,
114
+ inputs: _Trajectory,
115
+ first_step: bool,
116
+ spec: _Spec,
117
+ encs: Dict[str, List[hk.Module]],
118
+ decs: Dict[str, Tuple[hk.Module]],
119
+ return_hints: bool,
120
+ return_all_outputs: bool
121
+ ):
122
+ if self.decode_hints and not first_step:
123
+ assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
124
+ hard_postprocess = (self._hint_repred_mode == 'hard' or
125
+ (self._hint_repred_mode == 'hard_on_eval' and repred))
126
+ decoded_hint = decoders.postprocess(spec,
127
+ mp_state.hint_preds,
128
+ sinkhorn_temperature=0.1,
129
+ sinkhorn_steps=25,
130
+ hard=hard_postprocess)
131
+ if repred and self.decode_hints and not first_step:
132
+ cur_hint = []
133
+ for hint in decoded_hint:
134
+ cur_hint.append(decoded_hint[hint])
135
+ else:
136
+ cur_hint = []
137
+ needs_noise = (self.decode_hints and not first_step and
138
+ self._hint_teacher_forcing < 1.0)
139
+ if needs_noise:
140
+ # For noisy teacher forcing, choose which examples in the batch to force
141
+ force_mask = jax.random.bernoulli(
142
+ hk.next_rng_key(), self._hint_teacher_forcing,
143
+ (batch_size,))
144
+ else:
145
+ force_mask = None
146
+ for hint in hints:
147
+ hint_data = jnp.asarray(hint.data)[i]
148
+ _, loc, typ = spec[hint.name]
149
+ if needs_noise:
150
+ if (typ == _Type.POINTER and
151
+ decoded_hint[hint.name].type_ == _Type.SOFT_POINTER):
152
+ # When using soft pointers, the decoded hints cannot be summarised
153
+ # as indices (as would happen in hard postprocessing), so we need
154
+ # to raise the ground-truth hint (potentially used for teacher
155
+ # forcing) to its one-hot version.
156
+ hint_data = hk.one_hot(hint_data, nb_nodes)
157
+ typ = _Type.SOFT_POINTER
158
+ hint_data = jnp.where(_expand_to(force_mask, hint_data),
159
+ hint_data,
160
+ decoded_hint[hint.name].data)
161
+ cur_hint.append(
162
+ probing.DataPoint(
163
+ name=hint.name, location=loc, type_=typ, data=hint_data))
164
+
165
+ hiddens, output_preds_cand, hint_preds, lstm_state = self._one_step_pred(
166
+ inputs, cur_hint, mp_state.hiddens,
167
+ batch_size, nb_nodes, mp_state.lstm_state,
168
+ spec, encs, decs, repred)
169
+
170
+ if first_step:
171
+ output_preds = output_preds_cand
172
+ else:
173
+ output_preds = {}
174
+ for outp in mp_state.output_preds:
175
+ is_not_done = _is_not_done_broadcast(lengths, i,
176
+ output_preds_cand[outp])
177
+ output_preds[outp] = is_not_done * output_preds_cand[outp] + (
178
+ 1.0 - is_not_done) * mp_state.output_preds[outp]
179
+
180
+ new_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
181
+ hint_preds=hint_preds,
182
+ output_preds=output_preds,
183
+ hiddens=hiddens,
184
+ lstm_state=lstm_state)
185
+ # Save memory by not stacking unnecessary fields
186
+ accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
187
+ hint_preds=hint_preds if return_hints else None,
188
+ output_preds=output_preds if return_all_outputs else None,
189
+ hiddens=None, lstm_state=None)
190
+
191
+ # Complying to jax.scan, the first returned value is the state we carry over
192
+ # the second value is the output that will be stacked over steps.
193
+ return new_mp_state, accum_mp_state
194
+
195
+ def __call__(self, features_list: List[_Features], repred: bool,
196
+ algorithm_index: int,
197
+ return_hints: bool,
198
+ return_all_outputs: bool):
199
+ """Process one batch of data.
200
+
201
+ Args:
202
+ features_list: A list of _Features objects, each with the inputs, hints
203
+ and lengths for a batch o data corresponding to one algorithm.
204
+ The list should have either length 1, at train/evaluation time,
205
+ or length equal to the number of algorithms this Net is meant to
206
+ process, at initialization.
207
+ repred: False during training, when we have access to ground-truth hints.
208
+ True in validation/test mode, when we have to use our own
209
+ hint predictions.
210
+ algorithm_index: Which algorithm is being processed. It can be -1 at
211
+ initialisation (either because we are initialising the parameters of
212
+ the module or because we are intialising the message-passing state),
213
+ meaning that all algorithms should be processed, in which case
214
+ `features_list` should have length equal to the number of specs of
215
+ the Net. Otherwise, `algorithm_index` should be
216
+ between 0 and `length(self.spec) - 1`, meaning only one of the
217
+ algorithms will be processed, and `features_list` should have length 1.
218
+ return_hints: Whether to accumulate and return the predicted hints,
219
+ when they are decoded.
220
+ return_all_outputs: Whether to return the full sequence of outputs, or
221
+ just the last step's output.
222
+
223
+ Returns:
224
+ A 2-tuple with (output predictions, hint predictions)
225
+ for the selected algorithm.
226
+ """
227
+ if algorithm_index == -1:
228
+ algorithm_indices = range(len(features_list))
229
+ else:
230
+ algorithm_indices = [algorithm_index]
231
+ assert len(algorithm_indices) == len(features_list)
232
+
233
+ self.encoders, self.decoders = self._construct_encoders_decoders()
234
+ self.processor = self.processor_factory(self.hidden_dim)
235
+
236
+ # Optionally construct LSTM.
237
+ if self.use_lstm:
238
+ self.lstm = hk.LSTM(
239
+ hidden_size=self.hidden_dim,
240
+ name='processor_lstm')
241
+ lstm_init = self.lstm.initial_state
242
+ else:
243
+ self.lstm = None
244
+ lstm_init = lambda x: 0
245
+
246
+ for algorithm_index, features in zip(algorithm_indices, features_list):
247
+ inputs = features.inputs
248
+ hints = features.hints
249
+ lengths = features.lengths
250
+
251
+ batch_size, nb_nodes = _data_dimensions(features)
252
+
253
+ nb_mp_steps = max(1, hints[0].data.shape[0] - 1)
254
+ hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
255
+
256
+ if self.use_lstm:
257
+ lstm_state = lstm_init(batch_size * nb_nodes)
258
+ lstm_state = jax.tree_util.tree_map(
259
+ lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]),
260
+ lstm_state)
261
+ else:
262
+ lstm_state = None
263
+
264
+ mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
265
+ hint_preds=None, output_preds=None,
266
+ hiddens=hiddens, lstm_state=lstm_state)
267
+
268
+ # Do the first step outside of the scan because it has a different
269
+ # computation graph.
270
+ common_args = dict(
271
+ hints=hints,
272
+ repred=repred,
273
+ inputs=inputs,
274
+ batch_size=batch_size,
275
+ nb_nodes=nb_nodes,
276
+ lengths=lengths,
277
+ spec=self.spec[algorithm_index],
278
+ encs=self.encoders[algorithm_index],
279
+ decs=self.decoders[algorithm_index],
280
+ return_hints=return_hints,
281
+ return_all_outputs=return_all_outputs,
282
+ )
283
+ mp_state, lean_mp_state = self._msg_passing_step(
284
+ mp_state,
285
+ i=0,
286
+ first_step=True,
287
+ **common_args)
288
+
289
+ # Then scan through the rest.
290
+ scan_fn = functools.partial(
291
+ self._msg_passing_step,
292
+ first_step=False,
293
+ **common_args)
294
+
295
+ output_mp_state, accum_mp_state = hk.scan(
296
+ scan_fn,
297
+ mp_state,
298
+ jnp.arange(nb_mp_steps - 1) + 1,
299
+ length=nb_mp_steps - 1)
300
+
301
+ # We only return the last algorithm's output. That's because
302
+ # the output only matters when a single algorithm is processed; the case
303
+ # `algorithm_index==-1` (meaning all algorithms should be processed)
304
+ # is used only to init parameters.
305
+ accum_mp_state = jax.tree_util.tree_map(
306
+ lambda init, tail: jnp.concatenate([init[None], tail], axis=0),
307
+ lean_mp_state, accum_mp_state)
308
+
309
+ def invert(d):
310
+ """Dict of lists -> list of dicts."""
311
+ if d:
312
+ return [dict(zip(d, i)) for i in zip(*d.values())]
313
+
314
+ if return_all_outputs:
315
+ output_preds = {k: jnp.stack(v)
316
+ for k, v in accum_mp_state.output_preds.items()}
317
+ else:
318
+ output_preds = output_mp_state.output_preds
319
+ hint_preds = invert(accum_mp_state.hint_preds)
320
+
321
+ return output_preds, hint_preds
322
+
323
+ def _construct_encoders_decoders(self):
324
+ """Constructs encoders and decoders, separate for each algorithm."""
325
+ encoders_ = []
326
+ decoders_ = []
327
+ enc_algo_idx = None
328
+ for (algo_idx, spec) in enumerate(self.spec):
329
+ enc = {}
330
+ dec = {}
331
+ for name, (stage, loc, t) in spec.items():
332
+ if stage == _Stage.INPUT or (
333
+ stage == _Stage.HINT and self.encode_hints):
334
+ # Build input encoders.
335
+ if name == specs.ALGO_IDX_INPUT_NAME:
336
+ if enc_algo_idx is None:
337
+ enc_algo_idx = [hk.Linear(self.hidden_dim,
338
+ name=f'{name}_enc_linear')]
339
+ enc[name] = enc_algo_idx
340
+ else:
341
+ enc[name] = encoders.construct_encoders(
342
+ stage, loc, t, hidden_dim=self.hidden_dim,
343
+ init=self.encoder_init,
344
+ name=f'algo_{algo_idx}_{name}')
345
+
346
+ if stage == _Stage.OUTPUT or (
347
+ stage == _Stage.HINT and self.decode_hints):
348
+ # Build output decoders.
349
+ dec[name] = decoders.construct_decoders(
350
+ loc, t, hidden_dim=self.hidden_dim,
351
+ nb_dims=self.nb_dims[algo_idx][name],
352
+ name=f'algo_{algo_idx}_{name}')
353
+ encoders_.append(enc)
354
+ decoders_.append(dec)
355
+
356
+ return encoders_, decoders_
357
+
358
+ def _one_step_pred(
359
+ self,
360
+ inputs: _Trajectory,
361
+ hints: _Trajectory,
362
+ hidden: _Array,
363
+ batch_size: int,
364
+ nb_nodes: int,
365
+ lstm_state: Optional[hk.LSTMState],
366
+ spec: _Spec,
367
+ encs: Dict[str, List[hk.Module]],
368
+ decs: Dict[str, Tuple[hk.Module]],
369
+ repred: bool,
370
+ ):
371
+ """Generates one-step predictions."""
372
+
373
+ # Initialise empty node/edge/graph features and adjacency matrix.
374
+ node_fts = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
375
+ edge_fts = jnp.zeros((batch_size, nb_nodes, nb_nodes, self.hidden_dim))
376
+ graph_fts = jnp.zeros((batch_size, self.hidden_dim))
377
+ adj_mat = jnp.repeat(
378
+ jnp.expand_dims(jnp.eye(nb_nodes), 0), batch_size, axis=0)
379
+
380
+ # ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
381
+ # Encode node/edge/graph features from inputs and (optionally) hints.
382
+ trajectories = [inputs]
383
+ if self.encode_hints:
384
+ trajectories.append(hints)
385
+
386
+ for trajectory in trajectories:
387
+ for dp in trajectory:
388
+ try:
389
+ dp = encoders.preprocess(dp, nb_nodes)
390
+ assert dp.type_ != _Type.SOFT_POINTER
391
+ adj_mat = encoders.accum_adj_mat(dp, adj_mat)
392
+ encoder = encs[dp.name]
393
+ edge_fts = encoders.accum_edge_fts(encoder, dp, edge_fts)
394
+ node_fts = encoders.accum_node_fts(encoder, dp, node_fts)
395
+ graph_fts = encoders.accum_graph_fts(encoder, dp, graph_fts)
396
+ except Exception as e:
397
+ raise Exception(f'Failed to process {dp}') from e
398
+
399
+ # PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
400
+ nxt_hidden = hidden
401
+ for _ in range(self.nb_msg_passing_steps):
402
+ nxt_hidden, nxt_edge = self.processor(
403
+ node_fts,
404
+ edge_fts,
405
+ graph_fts,
406
+ adj_mat,
407
+ nxt_hidden,
408
+ batch_size=batch_size,
409
+ nb_nodes=nb_nodes,
410
+ )
411
+
412
+ if not repred: # dropout only on training
413
+ nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden)
414
+
415
+ if self.use_lstm:
416
+ # lstm doesn't accept multiple batch dimensions (in our case, batch and
417
+ # nodes), so we vmap over the (first) batch dimension.
418
+ nxt_hidden, nxt_lstm_state = jax.vmap(self.lstm)(nxt_hidden, lstm_state)
419
+ else:
420
+ nxt_lstm_state = None
421
+
422
+ h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1)
423
+ if nxt_edge is not None:
424
+ e_t = jnp.concatenate([edge_fts, nxt_edge], axis=-1)
425
+ else:
426
+ e_t = edge_fts
427
+
428
+ # DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
429
+ # Decode features and (optionally) hints.
430
+ hint_preds, output_preds = decoders.decode_fts(
431
+ decoders=decs,
432
+ spec=spec,
433
+ h_t=h_t,
434
+ adj_mat=adj_mat,
435
+ edge_fts=e_t,
436
+ graph_fts=graph_fts,
437
+ inf_bias=self.processor.inf_bias,
438
+ inf_bias_edge=self.processor.inf_bias_edge,
439
+ repred=repred,
440
+ )
441
+
442
+ return nxt_hidden, output_preds, hint_preds, nxt_lstm_state
443
+
444
+
445
+ class NetChunked(Net):
446
+ """A Net that will process time-chunked data instead of full samples."""
447
+
448
+ def _msg_passing_step(self,
449
+ mp_state: MessagePassingStateChunked,
450
+ xs,
451
+ repred: bool,
452
+ init_mp_state: bool,
453
+ batch_size: int,
454
+ nb_nodes: int,
455
+ spec: _Spec,
456
+ encs: Dict[str, List[hk.Module]],
457
+ decs: Dict[str, Tuple[hk.Module]],
458
+ ):
459
+ """Perform one message passing step.
460
+
461
+ This function is unrolled along the time axis to process a data chunk.
462
+
463
+ Args:
464
+ mp_state: message-passing state. Includes the inputs, hints,
465
+ beginning-of-sample markers, hint predictions, hidden and lstm state
466
+ to be used for prediction in the current step.
467
+ xs: A 3-tuple of with the next timestep's inputs, hints, and
468
+ beginning-of-sample markers. These will replace the contents of
469
+ the `mp_state` at the output, in readiness for the next unroll step of
470
+ the chunk (or the first step of the next chunk). Besides, the next
471
+ timestep's hints are necessary to compute diffs when `decode_diffs`
472
+ is True.
473
+ repred: False during training, when we have access to ground-truth hints.
474
+ True in validation/test mode, when we have to use our own
475
+ hint predictions.
476
+ init_mp_state: Indicates if we are calling the method just to initialise
477
+ the message-passing state, before the beginning of training or
478
+ validation.
479
+ batch_size: Size of batch dimension.
480
+ nb_nodes: Number of nodes in graph.
481
+ spec: The spec of the algorithm being processed.
482
+ encs: encoders for the algorithm being processed.
483
+ decs: decoders for the algorithm being processed.
484
+ Returns:
485
+ A 2-tuple with the next mp_state and an output consisting of
486
+ hint predictions and output predictions.
487
+ """
488
+ def _as_prediction_data(hint):
489
+ if hint.type_ == _Type.POINTER:
490
+ return hk.one_hot(hint.data, nb_nodes)
491
+ return hint.data
492
+
493
+ nxt_inputs, nxt_hints, nxt_is_first = xs
494
+ inputs = mp_state.inputs
495
+ is_first = mp_state.is_first
496
+ hints = mp_state.hints
497
+ if init_mp_state:
498
+ prev_hint_preds = {h.name: _as_prediction_data(h) for h in hints}
499
+ hints_for_pred = hints
500
+ else:
501
+ prev_hint_preds = mp_state.hint_preds
502
+ if self.decode_hints:
503
+ if repred:
504
+ force_mask = jnp.zeros(batch_size, dtype=bool)
505
+ elif self._hint_teacher_forcing == 1.0:
506
+ force_mask = jnp.ones(batch_size, dtype=bool)
507
+ else:
508
+ force_mask = jax.random.bernoulli(
509
+ hk.next_rng_key(), self._hint_teacher_forcing,
510
+ (batch_size,))
511
+ assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
512
+ hard_postprocess = (
513
+ self._hint_repred_mode == 'hard' or
514
+ (self._hint_repred_mode == 'hard_on_eval' and repred))
515
+ decoded_hints = decoders.postprocess(spec,
516
+ prev_hint_preds,
517
+ sinkhorn_temperature=0.1,
518
+ sinkhorn_steps=25,
519
+ hard=hard_postprocess)
520
+ hints_for_pred = []
521
+ for h in hints:
522
+ typ = h.type_
523
+ hint_data = h.data
524
+ if (typ == _Type.POINTER and
525
+ decoded_hints[h.name].type_ == _Type.SOFT_POINTER):
526
+ hint_data = hk.one_hot(hint_data, nb_nodes)
527
+ typ = _Type.SOFT_POINTER
528
+ hints_for_pred.append(probing.DataPoint(
529
+ name=h.name, location=h.location, type_=typ,
530
+ data=jnp.where(_expand_to(is_first | force_mask, hint_data),
531
+ hint_data, decoded_hints[h.name].data)))
532
+ else:
533
+ hints_for_pred = hints
534
+
535
+ hiddens = jnp.where(is_first[..., None, None], 0.0, mp_state.hiddens)
536
+ if self.use_lstm:
537
+ lstm_state = jax.tree_util.tree_map(
538
+ lambda x: jnp.where(is_first[..., None, None], 0.0, x),
539
+ mp_state.lstm_state)
540
+ else:
541
+ lstm_state = None
542
+ hiddens, output_preds, hint_preds, lstm_state = self._one_step_pred(
543
+ inputs, hints_for_pred, hiddens,
544
+ batch_size, nb_nodes, lstm_state,
545
+ spec, encs, decs, repred)
546
+
547
+ new_mp_state = MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars
548
+ hiddens=hiddens, lstm_state=lstm_state, hint_preds=hint_preds,
549
+ inputs=nxt_inputs, hints=nxt_hints, is_first=nxt_is_first)
550
+ mp_output = _MessagePassingOutputChunked( # pytype: disable=wrong-arg-types # numpy-scalars
551
+ hint_preds=hint_preds,
552
+ output_preds=output_preds)
553
+ return new_mp_state, mp_output
554
+
555
+ def __call__(self, features_list: List[_FeaturesChunked],
556
+ mp_state_list: List[MessagePassingStateChunked],
557
+ repred: bool, init_mp_state: bool,
558
+ algorithm_index: int):
559
+ """Process one chunk of data.
560
+
561
+ Args:
562
+ features_list: A list of _FeaturesChunked objects, each with the
563
+ inputs, hints and beginning- and end-of-sample markers for
564
+ a chunk (i.e., fixed time length) of data corresponding to one
565
+ algorithm. All features are expected
566
+ to have dimensions chunk_length x batch_size x ...
567
+ The list should have either length 1, at train/evaluation time,
568
+ or length equal to the number of algorithms this Net is meant to
569
+ process, at initialization.
570
+ mp_state_list: list of message-passing states. Each message-passing state
571
+ includes the inputs, hints, beginning-of-sample markers,
572
+ hint prediction, hidden and lstm state from the end of the previous
573
+ chunk, for one algorithm. The length of the list should be the same
574
+ as the length of `features_list`.
575
+ repred: False during training, when we have access to ground-truth hints.
576
+ True in validation/test mode, when we have to use our own hint
577
+ predictions.
578
+ init_mp_state: Indicates if we are calling the network just to initialise
579
+ the message-passing state, before the beginning of training or
580
+ validation. If True, `algorithm_index` (see below) must be -1 in order
581
+ to initialize the message-passing state of all algorithms.
582
+ algorithm_index: Which algorithm is being processed. It can be -1 at
583
+ initialisation (either because we are initialising the parameters of
584
+ the module or because we are intialising the message-passing state),
585
+ meaning that all algorithms should be processed, in which case
586
+ `features_list` and `mp_state_list` should have length equal to the
587
+ number of specs of the Net. Otherwise, `algorithm_index` should be
588
+ between 0 and `length(self.spec) - 1`, meaning only one of the
589
+ algorithms will be processed, and `features_list` and `mp_state_list`
590
+ should have length 1.
591
+
592
+ Returns:
593
+ A 2-tuple consisting of:
594
+ - A 2-tuple with (output predictions, hint predictions)
595
+ for the selected algorithm. Each of these has
596
+ chunk_length x batch_size x ... data, where the first time
597
+ slice contains outputs for the mp_state
598
+ that was passed as input, and the last time slice contains outputs
599
+ for the next-to-last slice of the input features. The outputs that
600
+ correspond to the final time slice of the input features will be
601
+ calculated when the next chunk is processed, using the data in the
602
+ mp_state returned here (see below). If `init_mp_state` is True,
603
+ we return None instead of the 2-tuple.
604
+ - The mp_state (message-passing state) for the next chunk of data
605
+ of the selected algorithm. If `init_mp_state` is True, we return
606
+ initial mp states for all the algorithms.
607
+ """
608
+ if algorithm_index == -1:
609
+ algorithm_indices = range(len(features_list))
610
+ else:
611
+ algorithm_indices = [algorithm_index]
612
+ assert not init_mp_state # init state only allowed with all algorithms
613
+ assert len(algorithm_indices) == len(features_list)
614
+ assert len(algorithm_indices) == len(mp_state_list)
615
+
616
+ self.encoders, self.decoders = self._construct_encoders_decoders()
617
+ self.processor = self.processor_factory(self.hidden_dim)
618
+ # Optionally construct LSTM.
619
+ if self.use_lstm:
620
+ self.lstm = hk.LSTM(
621
+ hidden_size=self.hidden_dim,
622
+ name='processor_lstm')
623
+ lstm_init = self.lstm.initial_state
624
+ else:
625
+ self.lstm = None
626
+ lstm_init = lambda x: 0
627
+
628
+ if init_mp_state:
629
+ output_mp_states = []
630
+ for algorithm_index, features, mp_state in zip(
631
+ algorithm_indices, features_list, mp_state_list):
632
+ inputs = features.inputs
633
+ hints = features.hints
634
+ batch_size, nb_nodes = _data_dimensions_chunked(features)
635
+
636
+ if self.use_lstm:
637
+ lstm_state = lstm_init(batch_size * nb_nodes)
638
+ lstm_state = jax.tree_util.tree_map(
639
+ lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]),
640
+ lstm_state)
641
+ mp_state.lstm_state = lstm_state
642
+ mp_state.inputs = jax.tree_util.tree_map(lambda x: x[0], inputs)
643
+ mp_state.hints = jax.tree_util.tree_map(lambda x: x[0], hints)
644
+ mp_state.is_first = jnp.zeros(batch_size, dtype=int)
645
+ mp_state.hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
646
+ next_is_first = jnp.ones(batch_size, dtype=int)
647
+
648
+ mp_state, _ = self._msg_passing_step(
649
+ mp_state,
650
+ (mp_state.inputs, mp_state.hints, next_is_first),
651
+ repred=repred,
652
+ init_mp_state=True,
653
+ batch_size=batch_size,
654
+ nb_nodes=nb_nodes,
655
+ spec=self.spec[algorithm_index],
656
+ encs=self.encoders[algorithm_index],
657
+ decs=self.decoders[algorithm_index],
658
+ )
659
+ output_mp_states.append(mp_state)
660
+ return None, output_mp_states
661
+
662
+ for algorithm_index, features, mp_state in zip(
663
+ algorithm_indices, features_list, mp_state_list):
664
+ inputs = features.inputs
665
+ hints = features.hints
666
+ is_first = features.is_first
667
+ batch_size, nb_nodes = _data_dimensions_chunked(features)
668
+
669
+ scan_fn = functools.partial(
670
+ self._msg_passing_step,
671
+ repred=repred,
672
+ init_mp_state=False,
673
+ batch_size=batch_size,
674
+ nb_nodes=nb_nodes,
675
+ spec=self.spec[algorithm_index],
676
+ encs=self.encoders[algorithm_index],
677
+ decs=self.decoders[algorithm_index],
678
+ )
679
+
680
+ mp_state, scan_output = hk.scan(
681
+ scan_fn,
682
+ mp_state,
683
+ (inputs, hints, is_first),
684
+ )
685
+
686
+ # We only return the last algorithm's output and state. That's because
687
+ # the output only matters when a single algorithm is processed; the case
688
+ # `algorithm_index==-1` (meaning all algorithms should be processed)
689
+ # is used only to init parameters.
690
+ return (scan_output.output_preds, scan_output.hint_preds), mp_state
691
+
692
+
693
+ def _data_dimensions(features: _Features) -> Tuple[int, int]:
694
+ """Returns (batch_size, nb_nodes)."""
695
+ for inp in features.inputs:
696
+ if inp.location in [_Location.NODE, _Location.EDGE]:
697
+ return inp.data.shape[:2]
698
+ assert False
699
+
700
+
701
+ def _data_dimensions_chunked(features: _FeaturesChunked) -> Tuple[int, int]:
702
+ """Returns (batch_size, nb_nodes)."""
703
+ for inp in features.inputs:
704
+ if inp.location in [_Location.NODE, _Location.EDGE]:
705
+ return inp.data.shape[1:3]
706
+ assert False
707
+
708
+
709
+ def _expand_to(x: _Array, y: _Array) -> _Array:
710
+ while len(y.shape) > len(x.shape):
711
+ x = jnp.expand_dims(x, -1)
712
+ return x
713
+
714
+
715
+ def _is_not_done_broadcast(lengths, i, tensor):
716
+ is_not_done = (lengths > i + 1) * 1.0
717
+ while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars
718
+ is_not_done = jnp.expand_dims(is_not_done, -1)
719
+ return is_not_done
benchmarks/CLRS/env/probing.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Probing utilities.
17
+
18
+ The dataflow for an algorithm is represented by `(stage, loc, type, data)`
19
+ "probes" that are valid under that algorithm's spec (see `specs.py`).
20
+
21
+ When constructing probes, it is convenient to represent these fields in a nested
22
+ format (`ProbesDict`) to facilate efficient contest-based look-up.
23
+
24
+ """
25
+
26
+ import functools
27
+ from typing import Dict, List, Tuple, Union
28
+
29
+ import attr
30
+ from clrs._src import specs
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import numpy as np
34
+ import tensorflow as tf
35
+
36
+
37
+ _Location = specs.Location
38
+ _Stage = specs.Stage
39
+ _Type = specs.Type
40
+ _OutputClass = specs.OutputClass
41
+
42
+ _Array = np.ndarray
43
+ _Data = Union[_Array, List[_Array]]
44
+ _DataOrType = Union[_Data, str]
45
+
46
+ ProbesDict = Dict[
47
+ str, Dict[str, Dict[str, Dict[str, _DataOrType]]]]
48
+
49
+
50
+ def _convert_to_str(element):
51
+ if isinstance(element, tf.Tensor):
52
+ return element.numpy().decode('utf-8')
53
+ elif isinstance(element, (np.ndarray, bytes)):
54
+ return element.decode('utf-8')
55
+ else:
56
+ return element
57
+
58
+
59
+ # First anotation makes this object jax.jit/pmap friendly, second one makes this
60
+ # tf.data.Datasets friendly.
61
+ @jax.tree_util.register_pytree_node_class
62
+ @attr.define
63
+ class DataPoint:
64
+ """Describes a data point."""
65
+
66
+ _name: str
67
+ _location: str
68
+ _type_: str
69
+ data: _Array
70
+
71
+ @property
72
+ def name(self):
73
+ return _convert_to_str(self._name)
74
+
75
+ @property
76
+ def location(self):
77
+ return _convert_to_str(self._location)
78
+
79
+ @property
80
+ def type_(self):
81
+ return _convert_to_str(self._type_)
82
+
83
+ def __repr__(self):
84
+ s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t'
85
+ return s + f'type={self.type_},\tdata=Array{self.data.shape})'
86
+
87
+ def tree_flatten(self):
88
+ data = (self.data,)
89
+ meta = (self.name, self.location, self.type_)
90
+ return data, meta
91
+
92
+ @classmethod
93
+ def tree_unflatten(cls, meta, data):
94
+ name, location, type_ = meta
95
+ subdata, = data
96
+ return DataPoint(name, location, type_, subdata)
97
+
98
+
99
+ class ProbeError(Exception):
100
+ pass
101
+
102
+
103
+ def initialize(spec: specs.Spec) -> ProbesDict:
104
+ """Initializes an empty `ProbesDict` corresponding with the provided spec."""
105
+ probes = dict()
106
+ for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
107
+ probes[stage] = {}
108
+ for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
109
+ probes[stage][loc] = {}
110
+
111
+ for name in spec:
112
+ stage, loc, t = spec[name]
113
+ probes[stage][loc][name] = {}
114
+ probes[stage][loc][name]['data'] = []
115
+ probes[stage][loc][name]['type_'] = t
116
+ # Pytype thinks initialize() returns a ProbesDict with a str for all final
117
+ # values instead of _DataOrType.
118
+ return probes # pytype: disable=bad-return-type
119
+
120
+
121
+ def push(probes: ProbesDict, stage: str, next_probe):
122
+ """Pushes a probe into an existing `ProbesDict`."""
123
+ for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
124
+ for name in probes[stage][loc]:
125
+ if name not in next_probe:
126
+ raise ProbeError(f'Missing probe for {name}.')
127
+ if isinstance(probes[stage][loc][name]['data'], _Array):
128
+ raise ProbeError('Attemping to push to finalized `ProbesDict`.')
129
+ # Pytype thinks initialize() returns a ProbesDict with a str for all final
130
+ # values instead of _DataOrType.
131
+ probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error
132
+
133
+
134
+ def finalize(probes: ProbesDict):
135
+ """Finalizes a `ProbesDict` by stacking/squeezing `data` field."""
136
+ for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
137
+ for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
138
+ for name in probes[stage][loc]:
139
+ if isinstance(probes[stage][loc][name]['data'], _Array):
140
+ raise ProbeError('Attemping to re-finalize a finalized `ProbesDict`.')
141
+ if stage == _Stage.HINT:
142
+ # Hints are provided for each timestep. Stack them here.
143
+ probes[stage][loc][name]['data'] = np.stack(
144
+ probes[stage][loc][name]['data'])
145
+ else:
146
+ # Only one instance of input/output exist. Remove leading axis.
147
+ probes[stage][loc][name]['data'] = np.squeeze(
148
+ np.array(probes[stage][loc][name]['data']))
149
+
150
+
151
+ def split_stages(
152
+ probes: ProbesDict,
153
+ spec: specs.Spec,
154
+ ) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]:
155
+ """Splits contents of `ProbesDict` into `DataPoint`s by stage."""
156
+
157
+ inputs = []
158
+ outputs = []
159
+ hints = []
160
+
161
+ for name in spec:
162
+ stage, loc, t = spec[name]
163
+
164
+ if stage not in probes:
165
+ raise ProbeError(f'Missing stage {stage}.')
166
+ if loc not in probes[stage]:
167
+ raise ProbeError(f'Missing location {loc}.')
168
+ if name not in probes[stage][loc]:
169
+ raise ProbeError(f'Missing probe {name}.')
170
+ if 'type_' not in probes[stage][loc][name]:
171
+ raise ProbeError(f'Probe {name} missing attribute `type_`.')
172
+ if 'data' not in probes[stage][loc][name]:
173
+ raise ProbeError(f'Probe {name} missing attribute `data`.')
174
+ if t != probes[stage][loc][name]['type_']:
175
+ raise ProbeError(f'Probe {name} of incorrect type {t}.')
176
+
177
+ data = probes[stage][loc][name]['data']
178
+ if not isinstance(probes[stage][loc][name]['data'], _Array):
179
+ raise ProbeError((f'Invalid `data` for probe "{name}". ' +
180
+ 'Did you forget to call `probing.finalize`?'))
181
+
182
+ if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]:
183
+ # pytype: disable=attribute-error
184
+ if not ((data == 0) | (data == 1) | (data == -1)).all():
185
+ raise ProbeError(f'0|1|-1 `data` for probe "{name}"')
186
+ # pytype: enable=attribute-error
187
+ if t in [_Type.MASK_ONE, _Type.CATEGORICAL
188
+ ] and not np.all(np.sum(np.abs(data), -1) == 1):
189
+ raise ProbeError(f'Expected one-hot `data` for probe "{name}"')
190
+
191
+ dim_to_expand = 1 if stage == _Stage.HINT else 0
192
+ data_point = DataPoint(name=name, location=loc, type_=t,
193
+ data=np.expand_dims(data, dim_to_expand))
194
+
195
+ if stage == _Stage.INPUT:
196
+ inputs.append(data_point)
197
+ elif stage == _Stage.OUTPUT:
198
+ outputs.append(data_point)
199
+ else:
200
+ hints.append(data_point)
201
+
202
+ return inputs, outputs, hints
203
+
204
+
205
+ # pylint: disable=invalid-name
206
+
207
+
208
+ def array(A_pos: np.ndarray) -> np.ndarray:
209
+ """Constructs an `array` probe."""
210
+ probe = np.arange(A_pos.shape[0])
211
+ for i in range(1, A_pos.shape[0]):
212
+ probe[A_pos[i]] = A_pos[i - 1]
213
+ return probe
214
+
215
+
216
+ def array_cat(A: np.ndarray, n: int) -> np.ndarray:
217
+ """Constructs an `array_cat` probe."""
218
+ assert n > 0
219
+ probe = np.zeros((A.shape[0], n))
220
+ for i in range(A.shape[0]):
221
+ probe[i, A[i]] = 1
222
+ return probe
223
+
224
+
225
+ def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray:
226
+ """Constructs a `heap` probe."""
227
+ assert heap_size > 0
228
+ probe = np.arange(A_pos.shape[0])
229
+ for i in range(1, heap_size):
230
+ probe[A_pos[i]] = A_pos[(i - 1) // 2]
231
+ return probe
232
+
233
+
234
+ def graph(A: np.ndarray) -> np.ndarray:
235
+ """Constructs a `graph` probe."""
236
+ probe = (A != 0) * 1.0
237
+ probe = ((A + np.eye(A.shape[0])) != 0) * 1.0
238
+ return probe
239
+
240
+
241
+ def mask_one(i: int, n: int) -> np.ndarray:
242
+ """Constructs a `mask_one` probe."""
243
+ assert n > i
244
+ probe = np.zeros(n)
245
+ probe[i] = 1
246
+ return probe
247
+
248
+
249
+ def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
250
+ """Constructs a `strings_id` probe."""
251
+ probe_T = np.zeros(T_pos.shape[0])
252
+ probe_P = np.ones(P_pos.shape[0])
253
+ return np.concatenate([probe_T, probe_P])
254
+
255
+
256
+ def strings_pair(pair_probe: np.ndarray) -> np.ndarray:
257
+ """Constructs a `strings_pair` probe."""
258
+ n = pair_probe.shape[0]
259
+ m = pair_probe.shape[1]
260
+ probe_ret = np.zeros((n + m, n + m))
261
+ for i in range(0, n):
262
+ for j in range(0, m):
263
+ probe_ret[i, j + n] = pair_probe[i, j]
264
+ return probe_ret
265
+
266
+
267
+ def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray:
268
+ """Constructs a `strings_pair_cat` probe."""
269
+ assert nb_classes > 0
270
+ n = pair_probe.shape[0]
271
+ m = pair_probe.shape[1]
272
+
273
+ # Add an extra class for 'this cell left blank.'
274
+ probe_ret = np.zeros((n + m, n + m, nb_classes + 1))
275
+ for i in range(0, n):
276
+ for j in range(0, m):
277
+ probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE
278
+
279
+ # Fill the blank cells.
280
+ for i_1 in range(0, n):
281
+ for i_2 in range(0, n):
282
+ probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED
283
+ for j_1 in range(0, m):
284
+ for x in range(0, n + m):
285
+ probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED
286
+ return probe_ret
287
+
288
+
289
+ def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray,
290
+ pi: np.ndarray) -> np.ndarray:
291
+ """Constructs a `strings_pi` probe."""
292
+ probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
293
+ for j in range(P_pos.shape[0]):
294
+ probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]]
295
+ return probe
296
+
297
+
298
+ def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
299
+ """Constructs a `strings_pos` probe."""
300
+ probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0]
301
+ probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0]
302
+ return np.concatenate([probe_T, probe_P])
303
+
304
+
305
+ def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
306
+ """Constructs a `strings_pred` probe."""
307
+ probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
308
+ for i in range(1, T_pos.shape[0]):
309
+ probe[T_pos[i]] = T_pos[i - 1]
310
+ for j in range(1, P_pos.shape[0]):
311
+ probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1]
312
+ return probe
313
+
314
+
315
+ @functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)')
316
+ def predecessor_to_cyclic_predecessor_and_first(
317
+ pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
318
+ """Converts predecessor pointers to cyclic predecessor + first node mask.
319
+
320
+ This function assumes that the pointers represent a linear order of the nodes
321
+ (akin to a linked list), where each node points to its predecessor and the
322
+ first node points to itself. It returns the same pointers, except that
323
+ the first node points to the last, and a mask_one marking the first node.
324
+
325
+ Example:
326
+ ```
327
+ pointers = [2, 1, 1]
328
+ P = [[0, 0, 1],
329
+ [1, 0, 0],
330
+ [0, 1, 0]],
331
+ M = [0, 1, 0]
332
+ ```
333
+
334
+ Args:
335
+ pointers: array of shape [N] containing pointers. The pointers are assumed
336
+ to describe a linear order such that `pointers[i]` is the predecessor
337
+ of node `i`.
338
+
339
+ Returns:
340
+ Permutation pointers `P` of shape [N] and one-hot vector `M` of shape [N].
341
+ """
342
+ nb_nodes = pointers.shape[-1]
343
+ pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
344
+ # Find the index of the last node: it's the node that no other node points to.
345
+ last = pointers_one_hot.sum(-2).argmin()
346
+ # Find the first node: should be the only one pointing to itself.
347
+ first = pointers_one_hot.diagonal().argmax()
348
+ mask = jax.nn.one_hot(first, nb_nodes)
349
+ pointers_one_hot += mask[..., None] * jax.nn.one_hot(last, nb_nodes)
350
+ pointers_one_hot -= mask[..., None] * mask
351
+ return pointers_one_hot, mask
benchmarks/CLRS/env/probing_test.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `probing.py`."""
17
+
18
+ from absl.testing import absltest
19
+
20
+ from clrs._src import probing
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+
25
+ # pylint: disable=invalid-name
26
+
27
+
28
+ class ProbingTest(absltest.TestCase):
29
+
30
+ def test_array(self):
31
+ A_pos = np.array([1, 2, 0, 4, 3])
32
+ expected = np.array([2, 1, 1, 4, 0])
33
+ out = probing.array(A_pos)
34
+ np.testing.assert_array_equal(expected, out)
35
+
36
+ def test_array_cat(self):
37
+ A = np.array([2, 1, 0, 1, 1])
38
+ expected = np.array([
39
+ [0, 0, 1],
40
+ [0, 1, 0],
41
+ [1, 0, 0],
42
+ [0, 1, 0],
43
+ [0, 1, 0]
44
+ ])
45
+ out = probing.array_cat(A, 3)
46
+ np.testing.assert_array_equal(expected, out)
47
+
48
+ def test_heap(self):
49
+ A_pos = np.array([1, 3, 5, 0, 7, 4, 2, 6])
50
+ expected = np.array([3, 1, 2, 1, 5, 1, 6, 3])
51
+ out = probing.heap(A_pos, heap_size=6)
52
+ np.testing.assert_array_equal(expected, out)
53
+
54
+ def test_graph(self):
55
+ G = np.array([
56
+ [0.0, 7.0, -1.0, -3.9, 7.452],
57
+ [0.0, 0.0, 133.0, 0.0, 9.3],
58
+ [0.5, 0.1, 0.22, 0.55, 0.666],
59
+ [7.0, 6.1, 0.2, 0.0, 0.0],
60
+ [0.0, 3.0, 0.0, 1.0, 0.5]
61
+ ])
62
+ expected = np.array([
63
+ [1.0, 1.0, 1.0, 1.0, 1.0],
64
+ [0.0, 1.0, 1.0, 0.0, 1.0],
65
+ [1.0, 1.0, 1.0, 1.0, 1.0],
66
+ [1.0, 1.0, 1.0, 1.0, 0.0],
67
+ [0.0, 1.0, 0.0, 1.0, 1.0]
68
+ ])
69
+ out = probing.graph(G)
70
+ np.testing.assert_array_equal(expected, out)
71
+
72
+ def test_mask_one(self):
73
+ expected = np.array([0, 0, 0, 1, 0])
74
+ out = probing.mask_one(3, 5)
75
+ np.testing.assert_array_equal(expected, out)
76
+
77
+ def test_strings_id(self):
78
+ T_pos = np.array([0, 1, 2, 3, 4])
79
+ P_pos = np.array([0, 1, 2])
80
+ expected = np.array([0, 0, 0, 0, 0, 1, 1, 1])
81
+ out = probing.strings_id(T_pos, P_pos)
82
+ np.testing.assert_array_equal(expected, out)
83
+
84
+ def test_strings_pair(self):
85
+ pair_probe = np.array([
86
+ [0.5, 3.1, 9.1, 7.3],
87
+ [1.0, 0.0, 8.0, 9.3],
88
+ [0.1, 5.0, 0.0, 1.2]
89
+ ])
90
+ expected = np.array([
91
+ [0.0, 0.0, 0.0, 0.5, 3.1, 9.1, 7.3],
92
+ [0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 9.3],
93
+ [0.0, 0.0, 0.0, 0.1, 5.0, 0.0, 1.2],
94
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
95
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
96
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
97
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
98
+ ])
99
+ out = probing.strings_pair(pair_probe)
100
+ np.testing.assert_equal(expected, out)
101
+
102
+ def test_strings_pair_cat(self):
103
+ pair_probe = np.array([
104
+ [0, 2, 1],
105
+ [2, 2, 0]
106
+ ])
107
+ expected = np.array([
108
+ [
109
+ [0, 0, 0, -1],
110
+ [0, 0, 0, -1],
111
+ [1, 0, 0, 0],
112
+ [0, 0, 1, 0],
113
+ [0, 1, 0, 0],
114
+ ],
115
+ [
116
+ [0, 0, 0, -1],
117
+ [0, 0, 0, -1],
118
+ [0, 0, 1, 0],
119
+ [0, 0, 1, 0],
120
+ [1, 0, 0, 0],
121
+ ],
122
+ [
123
+ [0, 0, 0, -1],
124
+ [0, 0, 0, -1],
125
+ [0, 0, 0, -1],
126
+ [0, 0, 0, -1],
127
+ [0, 0, 0, -1],
128
+ ],
129
+ [
130
+ [0, 0, 0, -1],
131
+ [0, 0, 0, -1],
132
+ [0, 0, 0, -1],
133
+ [0, 0, 0, -1],
134
+ [0, 0, 0, -1],
135
+ ],
136
+ [
137
+ [0, 0, 0, -1],
138
+ [0, 0, 0, -1],
139
+ [0, 0, 0, -1],
140
+ [0, 0, 0, -1],
141
+ [0, 0, 0, -1],
142
+ ],
143
+ ])
144
+ out = probing.strings_pair_cat(pair_probe, 3)
145
+ np.testing.assert_equal(expected, out)
146
+
147
+ def test_strings_pi(self):
148
+ T_pos = np.array([0, 1, 2, 3, 4, 5])
149
+ P_pos = np.array([0, 1, 2, 3])
150
+ pi = np.array([3, 1, 0, 2])
151
+ expected = np.array(
152
+ [0, 1, 2, 3, 4, 5, 9, 7, 6, 8]
153
+ )
154
+ out = probing.strings_pi(T_pos, P_pos, pi)
155
+ np.testing.assert_array_equal(expected, out)
156
+
157
+ def test_strings_pos(self):
158
+ T_pos = np.array([0, 1, 2, 3, 4])
159
+ P_pos = np.array([0, 1, 2, 3])
160
+ expected = np.array(
161
+ [0.0, 0.2, 0.4, 0.6, 0.8,
162
+ 0.0, 0.25, 0.5, 0.75]
163
+ )
164
+ out = probing.strings_pos(T_pos, P_pos)
165
+ np.testing.assert_array_equal(expected, out)
166
+
167
+ def test_strings_pred(self):
168
+ T_pos = np.array([0, 1, 2, 3, 4])
169
+ P_pos = np.array([0, 1, 2])
170
+ expected = np.array([0, 0, 1, 2, 3, 5, 5, 6])
171
+ out = probing.strings_pred(T_pos, P_pos)
172
+ np.testing.assert_array_equal(expected, out)
173
+
174
+
175
+ class PermutationsTest(absltest.TestCase):
176
+
177
+ def test_pointers_to_permutation(self):
178
+ pointers = jnp.array([2, 1, 1])
179
+ perm, first = probing.predecessor_to_cyclic_predecessor_and_first(pointers)
180
+ np.testing.assert_array_equal(
181
+ perm, np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]))
182
+ np.testing.assert_array_equal(first, np.array([0, 1, 0]))
183
+
184
+ def test_pointers_to_permutation_already_sorted(self):
185
+ pointers = jnp.array([0, 0, 1, 2, 3, 4])
186
+ perm, first = probing.predecessor_to_cyclic_predecessor_and_first(pointers)
187
+ np.testing.assert_array_equal(perm, np.roll(np.eye(6), 1, 0))
188
+ np.testing.assert_array_equal(first, np.eye(6)[0])
189
+
190
+
191
+ if __name__ == "__main__":
192
+ absltest.main()
benchmarks/CLRS/env/processors.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """JAX implementation of baseline processor networks."""
17
+
18
+ import abc
19
+ from typing import Any, Callable, List, Optional, Tuple
20
+
21
+ import chex
22
+ import haiku as hk
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ _Array = chex.Array
29
+ _Fn = Callable[..., Any]
30
+ BIG_NUMBER = 1e6
31
+ PROCESSOR_TAG = 'clrs_processor'
32
+
33
+
34
+ class Processor(hk.Module):
35
+ """Processor abstract base class."""
36
+
37
+ def __init__(self, name: str):
38
+ if not name.endswith(PROCESSOR_TAG):
39
+ name = name + '_' + PROCESSOR_TAG
40
+ super().__init__(name=name)
41
+
42
+ @abc.abstractmethod
43
+ def __call__(
44
+ self,
45
+ node_fts: _Array,
46
+ edge_fts: _Array,
47
+ graph_fts: _Array,
48
+ adj_mat: _Array,
49
+ hidden: _Array,
50
+ **kwargs,
51
+ ) -> Tuple[_Array, Optional[_Array]]:
52
+ """Processor inference step.
53
+
54
+ Args:
55
+ node_fts: Node features.
56
+ edge_fts: Edge features.
57
+ graph_fts: Graph features.
58
+ adj_mat: Graph adjacency matrix.
59
+ hidden: Hidden features.
60
+ **kwargs: Extra kwargs.
61
+
62
+ Returns:
63
+ Output of processor inference step as a 2-tuple of (node, edge)
64
+ embeddings. The edge embeddings can be None.
65
+ """
66
+ pass
67
+
68
+ @property
69
+ def inf_bias(self):
70
+ return False
71
+
72
+ @property
73
+ def inf_bias_edge(self):
74
+ return False
75
+
76
+
77
+ class GAT(Processor):
78
+ """Graph Attention Network (Velickovic et al., ICLR 2018)."""
79
+
80
+ def __init__(
81
+ self,
82
+ out_size: int,
83
+ nb_heads: int,
84
+ activation: Optional[_Fn] = jax.nn.relu,
85
+ residual: bool = True,
86
+ use_ln: bool = False,
87
+ name: str = 'gat_aggr',
88
+ ):
89
+ super().__init__(name=name)
90
+ self.out_size = out_size
91
+ self.nb_heads = nb_heads
92
+ if out_size % nb_heads != 0:
93
+ raise ValueError('The number of attention heads must divide the width!')
94
+ self.head_size = out_size // nb_heads
95
+ self.activation = activation
96
+ self.residual = residual
97
+ self.use_ln = use_ln
98
+
99
+ def __call__( # pytype: disable=signature-mismatch # numpy-scalars
100
+ self,
101
+ node_fts: _Array,
102
+ edge_fts: _Array,
103
+ graph_fts: _Array,
104
+ adj_mat: _Array,
105
+ hidden: _Array,
106
+ **unused_kwargs,
107
+ ) -> _Array:
108
+ """GAT inference step."""
109
+
110
+ b, n, _ = node_fts.shape
111
+ assert edge_fts.shape[:-1] == (b, n, n)
112
+ assert graph_fts.shape[:-1] == (b,)
113
+ assert adj_mat.shape == (b, n, n)
114
+
115
+ z = jnp.concatenate([node_fts, hidden], axis=-1)
116
+ m = hk.Linear(self.out_size)
117
+ skip = hk.Linear(self.out_size)
118
+
119
+ bias_mat = (adj_mat - 1.0) * 1e9
120
+ bias_mat = jnp.tile(bias_mat[..., None],
121
+ (1, 1, 1, self.nb_heads)) # [B, N, N, H]
122
+ bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
123
+
124
+ a_1 = hk.Linear(self.nb_heads)
125
+ a_2 = hk.Linear(self.nb_heads)
126
+ a_e = hk.Linear(self.nb_heads)
127
+ a_g = hk.Linear(self.nb_heads)
128
+
129
+ values = m(z) # [B, N, H*F]
130
+ values = jnp.reshape(
131
+ values,
132
+ values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
133
+ values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
134
+
135
+ att_1 = jnp.expand_dims(a_1(z), axis=-1)
136
+ att_2 = jnp.expand_dims(a_2(z), axis=-1)
137
+ att_e = a_e(edge_fts)
138
+ att_g = jnp.expand_dims(a_g(graph_fts), axis=-1)
139
+
140
+ logits = (
141
+ jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1]
142
+ jnp.transpose(att_2, (0, 2, 3, 1)) + # + [B, H, 1, N]
143
+ jnp.transpose(att_e, (0, 3, 1, 2)) + # + [B, H, N, N]
144
+ jnp.expand_dims(att_g, axis=-1) # + [B, H, 1, 1]
145
+ ) # = [B, H, N, N]
146
+ coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1)
147
+ ret = jnp.matmul(coefs, values) # [B, H, N, F]
148
+ ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
149
+ ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
150
+
151
+ if self.residual:
152
+ ret += skip(z)
153
+
154
+ if self.activation is not None:
155
+ ret = self.activation(ret)
156
+
157
+ if self.use_ln:
158
+ ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
159
+ ret = ln(ret)
160
+
161
+ return ret, None # pytype: disable=bad-return-type # numpy-scalars
162
+
163
+
164
+ class GATFull(GAT):
165
+ """Graph Attention Network with full adjacency matrix."""
166
+
167
+ def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
168
+ adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
169
+ adj_mat = jnp.ones_like(adj_mat)
170
+ return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
171
+
172
+
173
+ class GATv2(Processor):
174
+ """Graph Attention Network v2 (Brody et al., ICLR 2022)."""
175
+
176
+ def __init__(
177
+ self,
178
+ out_size: int,
179
+ nb_heads: int,
180
+ mid_size: Optional[int] = None,
181
+ activation: Optional[_Fn] = jax.nn.relu,
182
+ residual: bool = True,
183
+ use_ln: bool = False,
184
+ name: str = 'gatv2_aggr',
185
+ ):
186
+ super().__init__(name=name)
187
+ if mid_size is None:
188
+ self.mid_size = out_size
189
+ else:
190
+ self.mid_size = mid_size
191
+ self.out_size = out_size
192
+ self.nb_heads = nb_heads
193
+ if out_size % nb_heads != 0:
194
+ raise ValueError('The number of attention heads must divide the width!')
195
+ self.head_size = out_size // nb_heads
196
+ if self.mid_size % nb_heads != 0:
197
+ raise ValueError('The number of attention heads must divide the message!')
198
+ self.mid_head_size = self.mid_size // nb_heads
199
+ self.activation = activation
200
+ self.residual = residual
201
+ self.use_ln = use_ln
202
+
203
+ def __call__( # pytype: disable=signature-mismatch # numpy-scalars
204
+ self,
205
+ node_fts: _Array,
206
+ edge_fts: _Array,
207
+ graph_fts: _Array,
208
+ adj_mat: _Array,
209
+ hidden: _Array,
210
+ **unused_kwargs,
211
+ ) -> _Array:
212
+ """GATv2 inference step."""
213
+
214
+ b, n, _ = node_fts.shape
215
+ assert edge_fts.shape[:-1] == (b, n, n)
216
+ assert graph_fts.shape[:-1] == (b,)
217
+ assert adj_mat.shape == (b, n, n)
218
+
219
+ z = jnp.concatenate([node_fts, hidden], axis=-1)
220
+ m = hk.Linear(self.out_size)
221
+ skip = hk.Linear(self.out_size)
222
+
223
+ bias_mat = (adj_mat - 1.0) * 1e9
224
+ bias_mat = jnp.tile(bias_mat[..., None],
225
+ (1, 1, 1, self.nb_heads)) # [B, N, N, H]
226
+ bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
227
+
228
+ w_1 = hk.Linear(self.mid_size)
229
+ w_2 = hk.Linear(self.mid_size)
230
+ w_e = hk.Linear(self.mid_size)
231
+ w_g = hk.Linear(self.mid_size)
232
+
233
+ a_heads = []
234
+ for _ in range(self.nb_heads):
235
+ a_heads.append(hk.Linear(1))
236
+
237
+ values = m(z) # [B, N, H*F]
238
+ values = jnp.reshape(
239
+ values,
240
+ values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
241
+ values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
242
+
243
+ pre_att_1 = w_1(z)
244
+ pre_att_2 = w_2(z)
245
+ pre_att_e = w_e(edge_fts)
246
+ pre_att_g = w_g(graph_fts)
247
+
248
+ pre_att = (
249
+ jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F]
250
+ jnp.expand_dims(pre_att_2, axis=2) + # + [B, N, 1, H*F]
251
+ pre_att_e + # + [B, N, N, H*F]
252
+ jnp.expand_dims(pre_att_g, axis=(1, 2)) # + [B, 1, 1, H*F]
253
+ ) # = [B, N, N, H*F]
254
+
255
+ pre_att = jnp.reshape(
256
+ pre_att,
257
+ pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size)
258
+ ) # [B, N, N, H, F]
259
+
260
+ pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) # [B, H, N, N, F]
261
+
262
+ # This part is not very efficient, but we agree to keep it this way to
263
+ # enhance readability, assuming `nb_heads` will not be large.
264
+ logit_heads = []
265
+ for head in range(self.nb_heads):
266
+ logit_heads.append(
267
+ jnp.squeeze(
268
+ a_heads[head](jax.nn.leaky_relu(pre_att[:, head])),
269
+ axis=-1)
270
+ ) # [B, N, N]
271
+
272
+ logits = jnp.stack(logit_heads, axis=1) # [B, H, N, N]
273
+
274
+ coefs = jax.nn.softmax(logits + bias_mat, axis=-1)
275
+ ret = jnp.matmul(coefs, values) # [B, H, N, F]
276
+ ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
277
+ ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
278
+
279
+ if self.residual:
280
+ ret += skip(z)
281
+
282
+ if self.activation is not None:
283
+ ret = self.activation(ret)
284
+
285
+ if self.use_ln:
286
+ ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
287
+ ret = ln(ret)
288
+
289
+ return ret, None # pytype: disable=bad-return-type # numpy-scalars
290
+
291
+
292
+ class GATv2Full(GATv2):
293
+ """Graph Attention Network v2 with full adjacency matrix."""
294
+
295
+ def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
296
+ adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
297
+ adj_mat = jnp.ones_like(adj_mat)
298
+ return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
299
+
300
+
301
+ def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts):
302
+ """Triplet messages, as done by Dudzik and Velickovic (2022)."""
303
+ t_1 = hk.Linear(nb_triplet_fts)
304
+ t_2 = hk.Linear(nb_triplet_fts)
305
+ t_3 = hk.Linear(nb_triplet_fts)
306
+ t_e_1 = hk.Linear(nb_triplet_fts)
307
+ t_e_2 = hk.Linear(nb_triplet_fts)
308
+ t_e_3 = hk.Linear(nb_triplet_fts)
309
+ t_g = hk.Linear(nb_triplet_fts)
310
+
311
+ tri_1 = t_1(z)
312
+ tri_2 = t_2(z)
313
+ tri_3 = t_3(z)
314
+ tri_e_1 = t_e_1(edge_fts)
315
+ tri_e_2 = t_e_2(edge_fts)
316
+ tri_e_3 = t_e_3(edge_fts)
317
+ tri_g = t_g(graph_fts)
318
+
319
+ return (
320
+ jnp.expand_dims(tri_1, axis=(2, 3)) + # (B, N, 1, 1, H)
321
+ jnp.expand_dims(tri_2, axis=(1, 3)) + # + (B, 1, N, 1, H)
322
+ jnp.expand_dims(tri_3, axis=(1, 2)) + # + (B, 1, 1, N, H)
323
+ jnp.expand_dims(tri_e_1, axis=3) + # + (B, N, N, 1, H)
324
+ jnp.expand_dims(tri_e_2, axis=2) + # + (B, N, 1, N, H)
325
+ jnp.expand_dims(tri_e_3, axis=1) + # + (B, 1, N, N, H)
326
+ jnp.expand_dims(tri_g, axis=(1, 2, 3)) # + (B, 1, 1, 1, H)
327
+ ) # = (B, N, N, N, H)
328
+
329
+
330
+ class PGN(Processor):
331
+ """Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
332
+
333
+ def __init__(
334
+ self,
335
+ out_size: int,
336
+ mid_size: Optional[int] = None,
337
+ mid_act: Optional[_Fn] = None,
338
+ activation: Optional[_Fn] = jax.nn.relu,
339
+ reduction: _Fn = jnp.max,
340
+ msgs_mlp_sizes: Optional[List[int]] = None,
341
+ use_ln: bool = False,
342
+ use_triplets: bool = False,
343
+ nb_triplet_fts: int = 8,
344
+ gated: bool = False,
345
+ name: str = 'mpnn_aggr',
346
+ ):
347
+ super().__init__(name=name)
348
+ if mid_size is None:
349
+ self.mid_size = out_size
350
+ else:
351
+ self.mid_size = mid_size
352
+ self.out_size = out_size
353
+ self.mid_act = mid_act
354
+ self.activation = activation
355
+ self.reduction = reduction
356
+ self._msgs_mlp_sizes = msgs_mlp_sizes
357
+ self.use_ln = use_ln
358
+ self.use_triplets = use_triplets
359
+ self.nb_triplet_fts = nb_triplet_fts
360
+ self.gated = gated
361
+
362
+ def __call__( # pytype: disable=signature-mismatch # numpy-scalars
363
+ self,
364
+ node_fts: _Array,
365
+ edge_fts: _Array,
366
+ graph_fts: _Array,
367
+ adj_mat: _Array,
368
+ hidden: _Array,
369
+ **unused_kwargs,
370
+ ) -> _Array:
371
+ """MPNN inference step."""
372
+
373
+ b, n, _ = node_fts.shape
374
+ assert edge_fts.shape[:-1] == (b, n, n)
375
+ assert graph_fts.shape[:-1] == (b,)
376
+ assert adj_mat.shape == (b, n, n)
377
+
378
+ z = jnp.concatenate([node_fts, hidden], axis=-1)
379
+ m_1 = hk.Linear(self.mid_size)
380
+ m_2 = hk.Linear(self.mid_size)
381
+ m_e = hk.Linear(self.mid_size)
382
+ m_g = hk.Linear(self.mid_size)
383
+
384
+ o1 = hk.Linear(self.out_size)
385
+ o2 = hk.Linear(self.out_size)
386
+
387
+ msg_1 = m_1(z)
388
+ msg_2 = m_2(z)
389
+ msg_e = m_e(edge_fts)
390
+ msg_g = m_g(graph_fts)
391
+
392
+ tri_msgs = None
393
+
394
+ if self.use_triplets:
395
+ # Triplet messages, as done by Dudzik and Velickovic (2022)
396
+ triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts)
397
+
398
+ o3 = hk.Linear(self.out_size)
399
+ tri_msgs = o3(jnp.max(triplets, axis=1)) # (B, N, N, H)
400
+
401
+ if self.activation is not None:
402
+ tri_msgs = self.activation(tri_msgs)
403
+
404
+ msgs = (
405
+ jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) +
406
+ msg_e + jnp.expand_dims(msg_g, axis=(1, 2)))
407
+
408
+ if self._msgs_mlp_sizes is not None:
409
+ msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs))
410
+
411
+ if self.mid_act is not None:
412
+ msgs = self.mid_act(msgs)
413
+
414
+ if self.reduction == jnp.mean:
415
+ msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
416
+ msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True)
417
+ elif self.reduction == jnp.max:
418
+ maxarg = jnp.where(jnp.expand_dims(adj_mat, -1),
419
+ msgs,
420
+ -BIG_NUMBER)
421
+ msgs = jnp.max(maxarg, axis=1)
422
+ else:
423
+ msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
424
+
425
+ h_1 = o1(z)
426
+ h_2 = o2(msgs)
427
+
428
+ ret = h_1 + h_2
429
+
430
+ if self.activation is not None:
431
+ ret = self.activation(ret)
432
+
433
+ if self.use_ln:
434
+ ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
435
+ ret = ln(ret)
436
+
437
+ if self.gated:
438
+ gate1 = hk.Linear(self.out_size)
439
+ gate2 = hk.Linear(self.out_size)
440
+ gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3))
441
+ gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs))))
442
+ ret = ret * gate + hidden * (1-gate)
443
+
444
+ return ret, tri_msgs # pytype: disable=bad-return-type # numpy-scalars
445
+
446
+
447
+ class DeepSets(PGN):
448
+ """Deep Sets (Zaheer et al., NeurIPS 2017)."""
449
+
450
+ def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
451
+ adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
452
+ assert adj_mat.ndim == 3
453
+ adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1])
454
+ return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
455
+
456
+
457
+ class MPNN(PGN):
458
+ """Message-Passing Neural Network (Gilmer et al., ICML 2017)."""
459
+
460
+ def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
461
+ adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
462
+ adj_mat = jnp.ones_like(adj_mat)
463
+ return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
464
+
465
+
466
+ class PGNMask(PGN):
467
+ """Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
468
+
469
+ @property
470
+ def inf_bias(self):
471
+ return True
472
+
473
+ @property
474
+ def inf_bias_edge(self):
475
+ return True
476
+
477
+
478
+ class MemNetMasked(Processor):
479
+ """Implementation of End-to-End Memory Networks.
480
+
481
+ Inspired by the description in https://arxiv.org/abs/1503.08895.
482
+ """
483
+
484
+ def __init__(
485
+ self,
486
+ vocab_size: int,
487
+ sentence_size: int,
488
+ linear_output_size: int,
489
+ embedding_size: int = 16,
490
+ memory_size: Optional[int] = 128,
491
+ num_hops: int = 1,
492
+ nonlin: Callable[[Any], Any] = jax.nn.relu,
493
+ apply_embeddings: bool = True,
494
+ init_func: hk.initializers.Initializer = jnp.zeros,
495
+ use_ln: bool = False,
496
+ name: str = 'memnet') -> None:
497
+ """Constructor.
498
+
499
+ Args:
500
+ vocab_size: the number of words in the dictionary (each story, query and
501
+ answer come contain symbols coming from this dictionary).
502
+ sentence_size: the dimensionality of each memory.
503
+ linear_output_size: the dimensionality of the output of the last layer
504
+ of the model.
505
+ embedding_size: the dimensionality of the latent space to where all
506
+ memories are projected.
507
+ memory_size: the number of memories provided.
508
+ num_hops: the number of layers in the model.
509
+ nonlin: non-linear transformation applied at the end of each layer.
510
+ apply_embeddings: flag whether to aply embeddings.
511
+ init_func: initialization function for the biases.
512
+ use_ln: whether to use layer normalisation in the model.
513
+ name: the name of the model.
514
+ """
515
+ super().__init__(name=name)
516
+ self._vocab_size = vocab_size
517
+ self._embedding_size = embedding_size
518
+ self._sentence_size = sentence_size
519
+ self._memory_size = memory_size
520
+ self._linear_output_size = linear_output_size
521
+ self._num_hops = num_hops
522
+ self._nonlin = nonlin
523
+ self._apply_embeddings = apply_embeddings
524
+ self._init_func = init_func
525
+ self._use_ln = use_ln
526
+ # Encoding part: i.e. "I" of the paper.
527
+ self._encodings = _position_encoding(sentence_size, embedding_size)
528
+
529
+ def __call__( # pytype: disable=signature-mismatch # numpy-scalars
530
+ self,
531
+ node_fts: _Array,
532
+ edge_fts: _Array,
533
+ graph_fts: _Array,
534
+ adj_mat: _Array,
535
+ hidden: _Array,
536
+ **unused_kwargs,
537
+ ) -> _Array:
538
+ """MemNet inference step."""
539
+
540
+ del hidden
541
+ node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]],
542
+ axis=1)
543
+ edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None],
544
+ ((0, 0), (0, 1), (0, 1), (0, 0)))
545
+ nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts,
546
+ edge_fts_padded)
547
+
548
+ # Broadcast hidden state corresponding to graph features across the nodes.
549
+ nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:]
550
+ return nxt_hidden, None # pytype: disable=bad-return-type # numpy-scalars
551
+
552
+ def _apply(self, queries: _Array, stories: _Array) -> _Array:
553
+ """Apply Memory Network to the queries and stories.
554
+
555
+ Args:
556
+ queries: Tensor of shape [batch_size, sentence_size].
557
+ stories: Tensor of shape [batch_size, memory_size, sentence_size].
558
+
559
+ Returns:
560
+ Tensor of shape [batch_size, vocab_size].
561
+ """
562
+ if self._apply_embeddings:
563
+ query_biases = hk.get_parameter(
564
+ 'query_biases',
565
+ shape=[self._vocab_size - 1, self._embedding_size],
566
+ init=self._init_func)
567
+ stories_biases = hk.get_parameter(
568
+ 'stories_biases',
569
+ shape=[self._vocab_size - 1, self._embedding_size],
570
+ init=self._init_func)
571
+ memory_biases = hk.get_parameter(
572
+ 'memory_contents',
573
+ shape=[self._memory_size, self._embedding_size],
574
+ init=self._init_func)
575
+ output_biases = hk.get_parameter(
576
+ 'output_biases',
577
+ shape=[self._vocab_size - 1, self._embedding_size],
578
+ init=self._init_func)
579
+
580
+ nil_word_slot = jnp.zeros([1, self._embedding_size])
581
+
582
+ # This is "A" in the paper.
583
+ if self._apply_embeddings:
584
+ stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0)
585
+ memory_embeddings = jnp.take(
586
+ stories_biases, stories.reshape([-1]).astype(jnp.int32),
587
+ axis=0).reshape(list(stories.shape) + [self._embedding_size])
588
+ memory_embeddings = jnp.pad(
589
+ memory_embeddings,
590
+ ((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]),
591
+ (0, 0), (0, 0)))
592
+ memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases
593
+ else:
594
+ memory = stories
595
+
596
+ # This is "B" in the paper. Also, when there are no queries (only
597
+ # sentences), then there these lines are substituted by
598
+ # query_embeddings = 0.1.
599
+ if self._apply_embeddings:
600
+ query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0)
601
+ query_embeddings = jnp.take(
602
+ query_biases, queries.reshape([-1]).astype(jnp.int32),
603
+ axis=0).reshape(list(queries.shape) + [self._embedding_size])
604
+ # This is "u" in the paper.
605
+ query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1)
606
+ else:
607
+ query_input_embedding = queries
608
+
609
+ # This is "C" in the paper.
610
+ if self._apply_embeddings:
611
+ output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0)
612
+ output_embeddings = jnp.take(
613
+ output_biases, stories.reshape([-1]).astype(jnp.int32),
614
+ axis=0).reshape(list(stories.shape) + [self._embedding_size])
615
+ output_embeddings = jnp.pad(
616
+ output_embeddings,
617
+ ((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]),
618
+ (0, 0), (0, 0)))
619
+ output = jnp.sum(output_embeddings * self._encodings, 2)
620
+ else:
621
+ output = stories
622
+
623
+ intermediate_linear = hk.Linear(self._embedding_size, with_bias=False)
624
+
625
+ # Output_linear is "H".
626
+ output_linear = hk.Linear(self._linear_output_size, with_bias=False)
627
+
628
+ for hop_number in range(self._num_hops):
629
+ query_input_embedding_transposed = jnp.transpose(
630
+ jnp.expand_dims(query_input_embedding, -1), [0, 2, 1])
631
+
632
+ # Calculate probabilities.
633
+ probs = jax.nn.softmax(
634
+ jnp.sum(memory * query_input_embedding_transposed, 2))
635
+
636
+ # Calculate output of the layer by multiplying by C.
637
+ transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1])
638
+ transposed_output_embeddings = jnp.transpose(output, [0, 2, 1])
639
+
640
+ # This is "o" in the paper.
641
+ layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2)
642
+
643
+ # Finally the answer
644
+ if hop_number == self._num_hops - 1:
645
+ # Please note that in the TF version we apply the final linear layer
646
+ # in all hops and this results in shape mismatches.
647
+ output_layer = output_linear(query_input_embedding + layer_output)
648
+ else:
649
+ output_layer = intermediate_linear(query_input_embedding + layer_output)
650
+
651
+ query_input_embedding = output_layer
652
+ if self._nonlin:
653
+ output_layer = self._nonlin(output_layer)
654
+
655
+ # This linear here is "W".
656
+ ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer)
657
+
658
+ if self._use_ln:
659
+ ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
660
+ ret = ln(ret)
661
+
662
+ return ret
663
+
664
+
665
+ class MemNetFull(MemNetMasked):
666
+ """Memory Networks with full adjacency matrix."""
667
+
668
+ def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
669
+ adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
670
+ adj_mat = jnp.ones_like(adj_mat)
671
+ return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
672
+
673
+
674
+ ProcessorFactory = Callable[[int], Processor]
675
+
676
+
677
+ def get_processor_factory(kind: str,
678
+ use_ln: bool,
679
+ nb_triplet_fts: int,
680
+ nb_heads: Optional[int] = None) -> ProcessorFactory:
681
+ """Returns a processor factory.
682
+
683
+ Args:
684
+ kind: One of the available types of processor.
685
+ use_ln: Whether the processor passes the output through a layernorm layer.
686
+ nb_triplet_fts: How many triplet features to compute.
687
+ nb_heads: Number of attention heads for GAT processors.
688
+ Returns:
689
+ A callable that takes an `out_size` parameter (equal to the hidden
690
+ dimension of the network) and returns a processor instance.
691
+ """
692
+ def _factory(out_size: int):
693
+ if kind == 'deepsets':
694
+ processor = DeepSets(
695
+ out_size=out_size,
696
+ msgs_mlp_sizes=[out_size, out_size],
697
+ use_ln=use_ln,
698
+ use_triplets=False,
699
+ nb_triplet_fts=0
700
+ )
701
+ elif kind == 'gat':
702
+ processor = GAT(
703
+ out_size=out_size,
704
+ nb_heads=nb_heads,
705
+ use_ln=use_ln,
706
+ )
707
+ elif kind == 'gat_full':
708
+ processor = GATFull(
709
+ out_size=out_size,
710
+ nb_heads=nb_heads,
711
+ use_ln=use_ln
712
+ )
713
+ elif kind == 'gatv2':
714
+ processor = GATv2(
715
+ out_size=out_size,
716
+ nb_heads=nb_heads,
717
+ use_ln=use_ln
718
+ )
719
+ elif kind == 'gatv2_full':
720
+ processor = GATv2Full(
721
+ out_size=out_size,
722
+ nb_heads=nb_heads,
723
+ use_ln=use_ln
724
+ )
725
+ elif kind == 'memnet_full':
726
+ processor = MemNetFull(
727
+ vocab_size=out_size,
728
+ sentence_size=out_size,
729
+ linear_output_size=out_size,
730
+ )
731
+ elif kind == 'memnet_masked':
732
+ processor = MemNetMasked(
733
+ vocab_size=out_size,
734
+ sentence_size=out_size,
735
+ linear_output_size=out_size,
736
+ )
737
+ elif kind == 'mpnn':
738
+ processor = MPNN(
739
+ out_size=out_size,
740
+ msgs_mlp_sizes=[out_size, out_size],
741
+ use_ln=use_ln,
742
+ use_triplets=False,
743
+ nb_triplet_fts=0,
744
+ )
745
+ elif kind == 'pgn':
746
+ processor = PGN(
747
+ out_size=out_size,
748
+ msgs_mlp_sizes=[out_size, out_size],
749
+ use_ln=use_ln,
750
+ use_triplets=False,
751
+ nb_triplet_fts=0,
752
+ )
753
+ elif kind == 'pgn_mask':
754
+ processor = PGNMask(
755
+ out_size=out_size,
756
+ msgs_mlp_sizes=[out_size, out_size],
757
+ use_ln=use_ln,
758
+ use_triplets=False,
759
+ nb_triplet_fts=0,
760
+ )
761
+ elif kind == 'triplet_mpnn':
762
+ processor = MPNN(
763
+ out_size=out_size,
764
+ msgs_mlp_sizes=[out_size, out_size],
765
+ use_ln=use_ln,
766
+ use_triplets=True,
767
+ nb_triplet_fts=nb_triplet_fts,
768
+ )
769
+ elif kind == 'triplet_pgn':
770
+ processor = PGN(
771
+ out_size=out_size,
772
+ msgs_mlp_sizes=[out_size, out_size],
773
+ use_ln=use_ln,
774
+ use_triplets=True,
775
+ nb_triplet_fts=nb_triplet_fts,
776
+ )
777
+ elif kind == 'triplet_pgn_mask':
778
+ processor = PGNMask(
779
+ out_size=out_size,
780
+ msgs_mlp_sizes=[out_size, out_size],
781
+ use_ln=use_ln,
782
+ use_triplets=True,
783
+ nb_triplet_fts=nb_triplet_fts,
784
+ )
785
+ elif kind == 'gpgn':
786
+ processor = PGN(
787
+ out_size=out_size,
788
+ msgs_mlp_sizes=[out_size, out_size],
789
+ use_ln=use_ln,
790
+ use_triplets=False,
791
+ nb_triplet_fts=nb_triplet_fts,
792
+ gated=True,
793
+ )
794
+ elif kind == 'gpgn_mask':
795
+ processor = PGNMask(
796
+ out_size=out_size,
797
+ msgs_mlp_sizes=[out_size, out_size],
798
+ use_ln=use_ln,
799
+ use_triplets=False,
800
+ nb_triplet_fts=nb_triplet_fts,
801
+ gated=True,
802
+ )
803
+ elif kind == 'gmpnn':
804
+ processor = MPNN(
805
+ out_size=out_size,
806
+ msgs_mlp_sizes=[out_size, out_size],
807
+ use_ln=use_ln,
808
+ use_triplets=False,
809
+ nb_triplet_fts=nb_triplet_fts,
810
+ gated=True,
811
+ )
812
+ elif kind == 'triplet_gpgn':
813
+ processor = PGN(
814
+ out_size=out_size,
815
+ msgs_mlp_sizes=[out_size, out_size],
816
+ use_ln=use_ln,
817
+ use_triplets=True,
818
+ nb_triplet_fts=nb_triplet_fts,
819
+ gated=True,
820
+ )
821
+ elif kind == 'triplet_gpgn_mask':
822
+ processor = PGNMask(
823
+ out_size=out_size,
824
+ msgs_mlp_sizes=[out_size, out_size],
825
+ use_ln=use_ln,
826
+ use_triplets=True,
827
+ nb_triplet_fts=nb_triplet_fts,
828
+ gated=True,
829
+ )
830
+ elif kind == 'triplet_gmpnn':
831
+ processor = MPNN(
832
+ out_size=out_size,
833
+ msgs_mlp_sizes=[out_size, out_size],
834
+ use_ln=use_ln,
835
+ use_triplets=True,
836
+ nb_triplet_fts=nb_triplet_fts,
837
+ gated=True,
838
+ )
839
+ else:
840
+ raise ValueError('Unexpected processor kind ' + kind)
841
+
842
+ return processor
843
+
844
+ return _factory
845
+
846
+
847
+ def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray:
848
+ """Position Encoding described in section 4.1 [1]."""
849
+ encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
850
+ ls = sentence_size + 1
851
+ le = embedding_size + 1
852
+ for i in range(1, le):
853
+ for j in range(1, ls):
854
+ encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
855
+ encoding = 1 + 4 * encoding / embedding_size / sentence_size
856
+ return np.transpose(encoding)
benchmarks/CLRS/env/processors_test.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for processors.py."""
17
+
18
+ from absl.testing import absltest
19
+ import chex
20
+ from clrs._src import processors
21
+ import haiku as hk
22
+ import jax.numpy as jnp
23
+
24
+
25
+ class MemnetTest(absltest.TestCase):
26
+
27
+ def test_simple_run_and_check_shapes(self):
28
+
29
+ batch_size = 64
30
+ vocab_size = 177
31
+ embedding_size = 64
32
+ sentence_size = 11
33
+ memory_size = 320
34
+ linear_output_size = 128
35
+ num_hops = 2
36
+ use_ln = True
37
+
38
+ def forward_fn(queries, stories):
39
+ model = processors.MemNetFull(
40
+ vocab_size=vocab_size,
41
+ embedding_size=embedding_size,
42
+ sentence_size=sentence_size,
43
+ memory_size=memory_size,
44
+ linear_output_size=linear_output_size,
45
+ num_hops=num_hops,
46
+ use_ln=use_ln)
47
+ return model._apply(queries, stories)
48
+
49
+ forward = hk.transform(forward_fn)
50
+
51
+ queries = jnp.ones([batch_size, sentence_size], dtype=jnp.int32)
52
+ stories = jnp.ones([batch_size, memory_size, sentence_size],
53
+ dtype=jnp.int32)
54
+
55
+ key = hk.PRNGSequence(42)
56
+ params = forward.init(next(key), queries, stories)
57
+
58
+ model_output = forward.apply(params, None, queries, stories)
59
+ chex.assert_shape(model_output, [batch_size, vocab_size])
60
+ chex.assert_type(model_output, jnp.float32)
61
+
62
+
63
+ if __name__ == '__main__':
64
+ absltest.main()
benchmarks/CLRS/env/samplers.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Sampling utilities."""
17
+
18
+ import abc
19
+ import collections
20
+ import inspect
21
+ import types
22
+
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+ from absl import logging
25
+
26
+ from clrs._src import algorithms
27
+ from clrs._src import probing
28
+ from clrs._src import specs
29
+ import jax
30
+ import numpy as np
31
+
32
+
33
+ _Array = np.ndarray
34
+ _DataPoint = probing.DataPoint
35
+ Trajectory = List[_DataPoint]
36
+ Trajectories = List[Trajectory]
37
+
38
+
39
+ Algorithm = Callable[..., Any]
40
+ Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
41
+ FeaturesChunked = collections.namedtuple(
42
+ 'Features', ['inputs', 'hints', 'is_first', 'is_last'])
43
+ Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
44
+
45
+ # CLRS-30 baseline spec.
46
+ CLRS30 = types.MappingProxyType({
47
+ 'train': {
48
+ 'num_samples': 1000,
49
+ 'length': 16,
50
+ 'seed': 1,
51
+ },
52
+ 'val': {
53
+ 'num_samples': 32,
54
+ 'length': 16,
55
+ 'seed': 2,
56
+ },
57
+ 'test': {
58
+ 'num_samples': 32,
59
+ 'length': 64,
60
+ 'seed': 3,
61
+ },
62
+ })
63
+
64
+
65
+ class Sampler(abc.ABC):
66
+ """Sampler abstract base class."""
67
+
68
+ def __init__(
69
+ self,
70
+ algorithm: Algorithm,
71
+ spec: specs.Spec,
72
+ num_samples: int,
73
+ *args,
74
+ seed: Optional[int] = None,
75
+ **kwargs,
76
+ ):
77
+ """Initializes a `Sampler`.
78
+
79
+ Args:
80
+ algorithm: The algorithm to sample from
81
+ spec: The algorithm spec.
82
+ num_samples: Number of algorithm unrolls to sample. If positive, all the
83
+ samples will be generated in the constructor, and at each call
84
+ of the `next` method a batch will be randomly selected among them.
85
+ If -1, samples are generated on the fly with each call to `next`.
86
+ *args: Algorithm args.
87
+ seed: RNG seed.
88
+ **kwargs: Algorithm kwargs.
89
+ """
90
+
91
+ # Use `RandomState` to ensure deterministic sampling across Numpy versions.
92
+ self._rng = np.random.RandomState(seed)
93
+ self._spec = spec
94
+ self._num_samples = num_samples
95
+ self._algorithm = algorithm
96
+ self._args = args
97
+ self._kwargs = kwargs
98
+
99
+ if num_samples < 0:
100
+ logging.warning('Sampling dataset on-the-fly, unlimited samples.')
101
+ # Just get an initial estimate of max hint length
102
+ self.max_steps = -1
103
+ for _ in range(1000):
104
+ data = self._sample_data(*args, **kwargs)
105
+ _, probes = algorithm(*data)
106
+ _, _, hint = probing.split_stages(probes, spec)
107
+ for dp in hint:
108
+ assert dp.data.shape[1] == 1 # batching axis
109
+ if dp.data.shape[0] > self.max_steps:
110
+ self.max_steps = dp.data.shape[0]
111
+ else:
112
+ logging.info('Creating a dataset with %i samples.', num_samples)
113
+ (self._inputs, self._outputs, self._hints,
114
+ self._lengths) = self._make_batch(num_samples, spec, 0, algorithm, *args,
115
+ **kwargs)
116
+
117
+ def _make_batch(self, num_samples: int, spec: specs.Spec, min_length: int,
118
+ algorithm: Algorithm, *args, **kwargs):
119
+ """Generate a batch of data."""
120
+ inputs = []
121
+ outputs = []
122
+ hints = []
123
+
124
+ for _ in range(num_samples):
125
+ data = self._sample_data(*args, **kwargs)
126
+ _, probes = algorithm(*data)
127
+ inp, outp, hint = probing.split_stages(probes, spec)
128
+ inputs.append(inp)
129
+ outputs.append(outp)
130
+ hints.append(hint)
131
+ if len(hints) % 1000 == 0:
132
+ logging.info('%i samples created', len(hints))
133
+
134
+ # Batch and pad trajectories to max(T).
135
+ inputs = _batch_io(inputs)
136
+ outputs = _batch_io(outputs)
137
+ hints, lengths = _batch_hints(hints, min_length)
138
+ return inputs, outputs, hints, lengths
139
+
140
+ def next(self, batch_size: Optional[int] = None) -> Feedback:
141
+ """Subsamples trajectories from the pre-generated dataset.
142
+
143
+ Args:
144
+ batch_size: Optional batch size. If `None`, returns entire dataset.
145
+
146
+ Returns:
147
+ Subsampled trajectories.
148
+ """
149
+ if batch_size:
150
+ if self._num_samples < 0: # generate on the fly
151
+ inputs, outputs, hints, lengths = self._make_batch(
152
+ batch_size, self._spec, self.max_steps,
153
+ self._algorithm, *self._args, **self._kwargs)
154
+ if hints[0].data.shape[0] > self.max_steps:
155
+ logging.warning('Increasing hint lengh from %i to %i',
156
+ self.max_steps, hints[0].data.shape[0])
157
+ self.max_steps = hints[0].data.shape[0]
158
+ else:
159
+ if batch_size > self._num_samples:
160
+ raise ValueError(
161
+ f'Batch size {batch_size} > dataset size {self._num_samples}.')
162
+
163
+ # Returns a fixed-size random batch.
164
+ indices = self._rng.choice(self._num_samples, (batch_size,),
165
+ replace=True)
166
+ inputs = _subsample_data(self._inputs, indices, axis=0)
167
+ outputs = _subsample_data(self._outputs, indices, axis=0)
168
+ hints = _subsample_data(self._hints, indices, axis=1)
169
+ lengths = self._lengths[indices]
170
+
171
+ else:
172
+ # Returns the full dataset.
173
+ assert self._num_samples >= 0
174
+ inputs = self._inputs
175
+ hints = self._hints
176
+ lengths = self._lengths
177
+ outputs = self._outputs
178
+
179
+ return Feedback(Features(inputs, hints, lengths), outputs)
180
+
181
+ @abc.abstractmethod
182
+ def _sample_data(self, length: int, *args, **kwargs) -> List[_Array]:
183
+ pass
184
+
185
+ def _random_sequence(self, length, low=0.0, high=1.0):
186
+ """Random sequence."""
187
+ return self._rng.uniform(low=low, high=high, size=(length,))
188
+
189
+ def _random_string(self, length, chars=4):
190
+ """Random string."""
191
+ return self._rng.randint(0, high=chars, size=(length,))
192
+
193
+ def _random_er_graph(self, nb_nodes, p=0.5, directed=False, acyclic=False,
194
+ weighted=False, low=0.0, high=1.0):
195
+ """Random Erdos-Renyi graph."""
196
+
197
+ mat = self._rng.binomial(1, p, size=(nb_nodes, nb_nodes))
198
+ if not directed:
199
+ mat *= np.transpose(mat)
200
+ elif acyclic:
201
+ mat = np.triu(mat, k=1)
202
+ p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
203
+ mat = mat[p, :][:, p]
204
+ if weighted:
205
+ weights = self._rng.uniform(low=low, high=high, size=(nb_nodes, nb_nodes))
206
+ if not directed:
207
+ weights *= np.transpose(weights)
208
+ weights = np.sqrt(weights + 1e-3) # Add epsilon to protect underflow
209
+ mat = mat.astype(float) * weights
210
+ return mat
211
+
212
+ def _random_community_graph(self, nb_nodes, k=4, p=0.5, eps=0.01,
213
+ directed=False, acyclic=False, weighted=False,
214
+ low=0.0, high=1.0):
215
+ """Random perturbed k-community graph."""
216
+ mat = np.zeros((nb_nodes, nb_nodes))
217
+ if k > nb_nodes:
218
+ raise ValueError(f'Cannot generate graph of too many ({k}) communities.')
219
+ los, his = [], []
220
+ lo = 0
221
+ for i in range(k):
222
+ if i == k - 1:
223
+ hi = nb_nodes
224
+ else:
225
+ hi = lo + nb_nodes // k
226
+ mat[lo:hi, lo:hi] = self._random_er_graph(
227
+ hi - lo, p=p, directed=directed,
228
+ acyclic=acyclic, weighted=weighted,
229
+ low=low, high=high)
230
+ los.append(lo)
231
+ his.append(hi)
232
+ lo = hi
233
+ toggle = self._random_er_graph(nb_nodes, p=eps, directed=directed,
234
+ acyclic=acyclic, weighted=weighted,
235
+ low=low, high=high)
236
+
237
+ # Prohibit closing new cycles
238
+ for i in range(k):
239
+ for j in range(i):
240
+ toggle[los[i]:his[i], los[j]:his[j]] *= 0
241
+
242
+ mat = np.where(toggle > 0.0, (1.0 - (mat > 0.0)) * toggle, mat)
243
+ p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
244
+ mat = mat[p, :][:, p]
245
+ return mat
246
+
247
+ def _random_bipartite_graph(self, n, m, p=0.25):
248
+ """Random bipartite graph-based flow network."""
249
+ nb_nodes = n + m + 2
250
+ s = 0
251
+ t = n + m + 1
252
+ mat = np.zeros((nb_nodes, nb_nodes))
253
+ mat[s, 1:n+1] = 1.0 # supersource
254
+ mat[n+1:n+m+1, t] = 1.0 # supersink
255
+ mat[1:n+1, n+1:n+m+1] = self._rng.binomial(1, p, size=(n, m))
256
+ return mat
257
+
258
+
259
+ def build_sampler(
260
+ name: str,
261
+ num_samples: int,
262
+ *args,
263
+ seed: Optional[int] = None,
264
+ **kwargs,
265
+ ) -> Tuple[Sampler, specs.Spec]:
266
+ """Builds a sampler. See `Sampler` documentation."""
267
+
268
+ if name not in specs.SPECS or name not in SAMPLERS:
269
+ raise NotImplementedError(f'No implementation of algorithm {name}.')
270
+ spec = specs.SPECS[name]
271
+ algorithm = getattr(algorithms, name)
272
+ sampler_class = SAMPLERS[name]
273
+ # Ignore kwargs not accepted by the sampler.
274
+ sampler_args = inspect.signature(sampler_class._sample_data).parameters # pylint:disable=protected-access
275
+ clean_kwargs = {k: kwargs[k] for k in kwargs if k in sampler_args}
276
+ if set(clean_kwargs) != set(kwargs):
277
+ logging.warning('Ignoring kwargs %s when building sampler class %s',
278
+ set(kwargs).difference(clean_kwargs), sampler_class)
279
+ sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
280
+ *args, **clean_kwargs)
281
+ return sampler, spec
282
+
283
+
284
+ class SortingSampler(Sampler):
285
+ """Sorting sampler. Generates a random sequence of U[0, 1]."""
286
+
287
+ def _sample_data(
288
+ self,
289
+ length: int,
290
+ low: float = 0.,
291
+ high: float = 1.,
292
+ ):
293
+ arr = self._random_sequence(length=length, low=low, high=high)
294
+ return [arr]
295
+
296
+
297
+ class SearchSampler(Sampler):
298
+ """Search sampler. Generates a random sequence and target (of U[0, 1])."""
299
+
300
+ def _sample_data(
301
+ self,
302
+ length: int,
303
+ low: float = 0.,
304
+ high: float = 1.,
305
+ ):
306
+ arr = self._random_sequence(length=length, low=low, high=high)
307
+ arr.sort()
308
+ x = self._rng.uniform(low=low, high=high)
309
+ return [x, arr]
310
+
311
+
312
+ class MaxSubarraySampler(Sampler):
313
+ """Maximum subarray sampler. Generates a random sequence of U[-1, 1]."""
314
+
315
+ def _sample_data(
316
+ self,
317
+ length: int,
318
+ low: float = -1.,
319
+ high: float = 1.,
320
+ ):
321
+ arr = self._random_sequence(length=length, low=low, high=high)
322
+ return [arr]
323
+
324
+
325
+ class LCSSampler(Sampler):
326
+ """Longest Common Subsequence sampler. Generates two random ATCG strings."""
327
+
328
+ def _sample_data(
329
+ self,
330
+ length: int,
331
+ length_2: Optional[int] = None,
332
+ chars: int = 4,
333
+ ):
334
+ if length_2 is None:
335
+ # Assume provided length is total length.
336
+ length_2 = length // 2
337
+ length -= length_2
338
+ a = self._random_string(length=length, chars=chars)
339
+ b = self._random_string(length=length_2, chars=chars)
340
+ return [a, b]
341
+
342
+
343
+ class OptimalBSTSampler(Sampler):
344
+ """Optimal BST sampler. Samples array of probabilities, splits it into two."""
345
+
346
+ def _sample_data(
347
+ self,
348
+ length: int,
349
+ ):
350
+ tot_length = length + (length + 1)
351
+ arr = self._random_sequence(length=tot_length, low=0.0, high=1.0)
352
+ arr /= np.sum(arr)
353
+ p = arr[:length]
354
+ q = arr[length:]
355
+ return [p, q]
356
+
357
+
358
+ class ActivitySampler(Sampler):
359
+ """Activity sampler. Samples start and finish times from U[0, 1]."""
360
+
361
+ def _sample_data(
362
+ self,
363
+ length: int,
364
+ low: float = 0.,
365
+ high: float = 1.,
366
+ ):
367
+ arr_1 = self._random_sequence(length=length, low=low, high=high)
368
+ arr_2 = self._random_sequence(length=length, low=low, high=high)
369
+ return [np.minimum(arr_1, arr_2), np.maximum(arr_1, arr_2)]
370
+
371
+
372
+ class TaskSampler(Sampler):
373
+ """Task sampler. Samples deadlines (integers) and values (U[0, 1])."""
374
+
375
+ def _sample_data(
376
+ self,
377
+ length: int,
378
+ max_deadline: Optional[int] = None,
379
+ low: float = 0.,
380
+ high: float = 1.,
381
+ ):
382
+ if max_deadline is None:
383
+ max_deadline = length
384
+ d = self._random_string(length=length, chars=max_deadline) + 1
385
+ w = self._random_sequence(length=length, low=low, high=high)
386
+ return [d, w]
387
+
388
+
389
+ class DfsSampler(Sampler):
390
+ """DFS sampler."""
391
+
392
+ def _sample_data(
393
+ self,
394
+ length: int,
395
+ p: Tuple[float, ...] = (0.5,),
396
+ ):
397
+ graph = self._random_er_graph(
398
+ nb_nodes=length, p=self._rng.choice(p),
399
+ directed=True, acyclic=False, weighted=False)
400
+ return [graph]
401
+
402
+
403
+ class BfsSampler(Sampler):
404
+ """BFS sampler."""
405
+
406
+ def _sample_data(
407
+ self,
408
+ length: int,
409
+ p: Tuple[float, ...] = (0.5,),
410
+ ):
411
+ graph = self._random_er_graph(
412
+ nb_nodes=length, p=self._rng.choice(p),
413
+ directed=False, acyclic=False, weighted=False)
414
+ source_node = self._rng.choice(length)
415
+ return [graph, source_node]
416
+
417
+
418
+ class TopoSampler(Sampler):
419
+ """Topological Sorting sampler."""
420
+
421
+ def _sample_data(
422
+ self,
423
+ length: int,
424
+ p: Tuple[float, ...] = (0.5,),
425
+ ):
426
+ graph = self._random_er_graph(
427
+ nb_nodes=length, p=self._rng.choice(p),
428
+ directed=True, acyclic=True, weighted=False)
429
+ return [graph]
430
+
431
+
432
+ class ArticulationSampler(Sampler):
433
+ """Articulation Point sampler."""
434
+
435
+ def _sample_data(
436
+ self,
437
+ length: int,
438
+ p: Tuple[float, ...] = (0.2,),
439
+ ):
440
+ graph = self._random_er_graph(
441
+ nb_nodes=length, p=self._rng.choice(p), directed=False,
442
+ acyclic=False, weighted=False)
443
+ return [graph]
444
+
445
+
446
+ class MSTSampler(Sampler):
447
+ """MST sampler for Kruskal's algorithm."""
448
+
449
+ def _sample_data(
450
+ self,
451
+ length: int,
452
+ p: Tuple[float, ...] = (0.2,), # lower p to account for class imbalance
453
+ low: float = 0.,
454
+ high: float = 1.,
455
+ ):
456
+ graph = self._random_er_graph(
457
+ nb_nodes=length,
458
+ p=self._rng.choice(p),
459
+ directed=False,
460
+ acyclic=False,
461
+ weighted=True,
462
+ low=low,
463
+ high=high)
464
+ return [graph]
465
+
466
+
467
+ class BellmanFordSampler(Sampler):
468
+ """Bellman-Ford sampler."""
469
+
470
+ def _sample_data(
471
+ self,
472
+ length: int,
473
+ p: Tuple[float, ...] = (0.5,),
474
+ low: float = 0.,
475
+ high: float = 1.,
476
+ ):
477
+ graph = self._random_er_graph(
478
+ nb_nodes=length,
479
+ p=self._rng.choice(p),
480
+ directed=False,
481
+ acyclic=False,
482
+ weighted=True,
483
+ low=low,
484
+ high=high)
485
+ source_node = self._rng.choice(length)
486
+ return [graph, source_node]
487
+
488
+
489
+ class DAGPathSampler(Sampler):
490
+ """Sampler for DAG shortest paths."""
491
+
492
+ def _sample_data(
493
+ self,
494
+ length: int,
495
+ p: Tuple[float, ...] = (0.5,),
496
+ low: float = 0.,
497
+ high: float = 1.,
498
+ ):
499
+ graph = self._random_er_graph(
500
+ nb_nodes=length,
501
+ p=self._rng.choice(p),
502
+ directed=True,
503
+ acyclic=True,
504
+ weighted=True,
505
+ low=low,
506
+ high=high)
507
+ source_node = self._rng.choice(length)
508
+ return [graph, source_node]
509
+
510
+
511
+ class FloydWarshallSampler(Sampler):
512
+ """Sampler for all-pairs shortest paths."""
513
+
514
+ def _sample_data(
515
+ self,
516
+ length: int,
517
+ p: Tuple[float, ...] = (0.5,),
518
+ low: float = 0.,
519
+ high: float = 1.,
520
+ ):
521
+ graph = self._random_er_graph(
522
+ nb_nodes=length,
523
+ p=self._rng.choice(p),
524
+ directed=False,
525
+ acyclic=False,
526
+ weighted=True,
527
+ low=low,
528
+ high=high)
529
+ return [graph]
530
+
531
+
532
+ class SccSampler(Sampler):
533
+ """Sampler for strongly connected component (SCC) tasks."""
534
+
535
+ def _sample_data(
536
+ self,
537
+ length: int,
538
+ k: int = 4,
539
+ p: Tuple[float, ...] = (0.5,),
540
+ eps: float = 0.01,
541
+ ):
542
+ graph = self._random_community_graph(
543
+ nb_nodes=length, k=k, p=self._rng.choice(p), eps=eps,
544
+ directed=True, acyclic=False, weighted=False)
545
+ return [graph]
546
+
547
+
548
+ class BipartiteSampler(Sampler):
549
+ """Sampler for bipartite matching-based flow networks."""
550
+
551
+ def _sample_data(
552
+ self,
553
+ length: int,
554
+ length_2: Optional[int] = None,
555
+ p: Tuple[float, ...] = (0.3,),
556
+ ):
557
+ if length_2 is None:
558
+ # Assume provided length is total length.
559
+ length_2 = length // 2
560
+ length -= length_2
561
+ graph = self._random_bipartite_graph(n=length, m=length_2,
562
+ p=self._rng.choice(p))
563
+ return [graph, length, length_2, 0, length + length_2 + 1]
564
+
565
+
566
+ class MatcherSampler(Sampler):
567
+ """String matching sampler; embeds needle in a random haystack."""
568
+
569
+ def _sample_data(
570
+ self,
571
+ length: int, # length of haystack + needle, i.e., total number of nodes
572
+ length_needle: Optional[int] = None,
573
+ chars: int = 4,
574
+ ):
575
+ if length_needle is None:
576
+ if length < 5:
577
+ length_needle = 1
578
+ else:
579
+ length_needle = length // 5
580
+ elif length_needle < 0: # randomize needle length
581
+ length_needle = self._rng.randint(1, high=1 - length_needle)
582
+ length_haystack = length - length_needle
583
+ needle = self._random_string(length=length_needle, chars=chars)
584
+ haystack = self._random_string(length=length_haystack, chars=chars)
585
+ embed_pos = self._rng.choice(length_haystack - length_needle)
586
+ haystack[embed_pos:embed_pos + length_needle] = needle
587
+ return [haystack, needle]
588
+
589
+
590
+ class SegmentsSampler(Sampler):
591
+ """Two-segment sampler of points from (U[0, 1], U[0, 1])."""
592
+
593
+ def _sample_data(self, length: int, low: float = 0., high: float = 1.):
594
+ del length # There are exactly four endpoints.
595
+
596
+ # Quick CCW check (ignoring collinearity) for rejection sampling
597
+ def ccw(x_a, y_a, x_b, y_b, x_c, y_c):
598
+ return (y_c - y_a) * (x_b - x_a) > (y_b - y_a) * (x_c - x_a)
599
+ def intersect(xs, ys):
600
+ return ccw(xs[0], ys[0], xs[2], ys[2], xs[3], ys[3]) != ccw(
601
+ xs[1], ys[1], xs[2], ys[2], xs[3], ys[3]) and ccw(
602
+ xs[0], ys[0], xs[1], ys[1], xs[2], ys[2]) != ccw(
603
+ xs[0], ys[0], xs[1], ys[1], xs[3], ys[3])
604
+
605
+ # Decide (with uniform probability) should this sample intersect
606
+ coin_flip = self._rng.binomial(1, 0.5)
607
+
608
+ xs = self._random_sequence(length=4, low=low, high=high)
609
+ ys = self._random_sequence(length=4, low=low, high=high)
610
+
611
+ while intersect(xs, ys) != coin_flip:
612
+ xs = self._random_sequence(length=4, low=low, high=high)
613
+ ys = self._random_sequence(length=4, low=low, high=high)
614
+
615
+ return [xs, ys]
616
+
617
+
618
+ class ConvexHullSampler(Sampler):
619
+ """Convex hull sampler of points over a disk of radius r."""
620
+
621
+ def _sample_data(self, length: int, origin_x: float = 0.,
622
+ origin_y: float = 0., radius: float = 2.):
623
+
624
+ thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi)
625
+ rs = radius * np.sqrt(
626
+ self._random_sequence(length=length, low=0.0, high=1.0))
627
+
628
+ xs = rs * np.cos(thetas) + origin_x
629
+ ys = rs * np.sin(thetas) + origin_y
630
+
631
+ return [xs, ys]
632
+
633
+
634
+ SAMPLERS = {
635
+ 'insertion_sort': SortingSampler,
636
+ 'bubble_sort': SortingSampler,
637
+ 'heapsort': SortingSampler,
638
+ 'quicksort': SortingSampler,
639
+ 'quickselect': SortingSampler,
640
+ 'minimum': SortingSampler,
641
+ 'binary_search': SearchSampler,
642
+ 'find_maximum_subarray': MaxSubarraySampler,
643
+ 'find_maximum_subarray_kadane': MaxSubarraySampler,
644
+ 'matrix_chain_order': SortingSampler,
645
+ 'lcs_length': LCSSampler,
646
+ 'optimal_bst': OptimalBSTSampler,
647
+ 'activity_selector': ActivitySampler,
648
+ 'task_scheduling': TaskSampler,
649
+ 'dfs': DfsSampler,
650
+ 'topological_sort': TopoSampler,
651
+ 'strongly_connected_components': SccSampler,
652
+ 'articulation_points': ArticulationSampler,
653
+ 'bridges': ArticulationSampler,
654
+ 'bfs': BfsSampler,
655
+ 'mst_kruskal': MSTSampler,
656
+ 'mst_prim': BellmanFordSampler,
657
+ 'bellman_ford': BellmanFordSampler,
658
+ 'dag_shortest_paths': DAGPathSampler,
659
+ 'dijkstra': BellmanFordSampler,
660
+ 'floyd_warshall': FloydWarshallSampler,
661
+ 'bipartite_matching': BipartiteSampler,
662
+ 'naive_string_matcher': MatcherSampler,
663
+ 'kmp_matcher': MatcherSampler,
664
+ 'segments_intersect': SegmentsSampler,
665
+ 'graham_scan': ConvexHullSampler,
666
+ 'jarvis_march': ConvexHullSampler,
667
+ }
668
+
669
+
670
+ def _batch_io(traj_io: Trajectories) -> Trajectory:
671
+ """Batches a trajectory of input/output samples along the time axis per probe.
672
+
673
+ Args:
674
+ traj_io: An i/o trajectory of `DataPoint`s indexed by time then probe.
675
+
676
+ Returns:
677
+ A |num probes| list of `DataPoint`s with the time axis stacked into `data`.
678
+ """
679
+
680
+ assert traj_io # non-empty
681
+ for sample_io in traj_io:
682
+ for i, dp in enumerate(sample_io):
683
+ assert dp.data.shape[0] == 1 # batching axis
684
+ assert traj_io[0][i].name == dp.name
685
+
686
+ return jax.tree_util.tree_map(lambda *x: np.concatenate(x), *traj_io)
687
+
688
+
689
+ def _batch_hints(
690
+ traj_hints: Trajectories, min_steps: int) -> Tuple[Trajectory, List[int]]:
691
+ """Batches a trajectory of hints samples along the time axis per probe.
692
+
693
+ Unlike i/o, hints have a variable-length time dimension. Before batching, each
694
+ trajectory is padded to the maximum trajectory length.
695
+
696
+ Args:
697
+ traj_hints: A hint trajectory of `DataPoints`s indexed by time then probe
698
+ min_steps: Hints will be padded at least to this length - if any hint is
699
+ longer than this, the greater length will be used.
700
+
701
+ Returns:
702
+ A |num probes| list of `DataPoint`s with the time axis stacked into `data`,
703
+ and a |sample| list containing the length of each trajectory.
704
+ """
705
+
706
+ max_steps = min_steps
707
+ assert traj_hints # non-empty
708
+ for sample_hint in traj_hints:
709
+ for dp in sample_hint:
710
+ assert dp.data.shape[1] == 1 # batching axis
711
+ if dp.data.shape[0] > max_steps:
712
+ max_steps = dp.data.shape[0]
713
+ time_and_batch = (max_steps, len(traj_hints))
714
+
715
+ # Create zero-filled space for the batched hints, then copy each hint
716
+ # up to the corresponding length.
717
+ batched_traj = jax.tree_util.tree_map(
718
+ lambda x: np.zeros(time_and_batch + x.shape[2:]),
719
+ traj_hints[0])
720
+ hint_lengths = np.zeros(len(traj_hints))
721
+
722
+ for sample_idx, cur_sample in enumerate(traj_hints):
723
+ for i in range(len(cur_sample)):
724
+ assert batched_traj[i].name == cur_sample[i].name
725
+ cur_data = cur_sample[i].data
726
+ cur_length = cur_data.shape[0]
727
+ batched_traj[i].data[:cur_length, sample_idx:sample_idx+1] = cur_data
728
+ if i > 0:
729
+ assert hint_lengths[sample_idx] == cur_length
730
+ else:
731
+ hint_lengths[sample_idx] = cur_length
732
+ return batched_traj, hint_lengths
733
+
734
+
735
+ def _subsample_data(
736
+ trajectory: Trajectory,
737
+ idx: List[int],
738
+ axis: int = 0,
739
+ ) -> Trajectory:
740
+ """New `Trajectory` where each `DataPoint`'s data is subsampled along axis."""
741
+ sampled_traj = []
742
+ for dp in trajectory:
743
+ sampled_data = np.take(dp.data, idx, axis=axis)
744
+ sampled_traj.append(
745
+ probing.DataPoint(dp.name, dp.location, dp.type_, sampled_data))
746
+ return sampled_traj
747
+
748
+
749
+ def _preprocess_permutations(probes, enforce_permutations):
750
+ """Replace should-be permutations with proper permutation pointer + mask."""
751
+ output = []
752
+ for x in probes:
753
+ if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
754
+ output.append(x)
755
+ continue
756
+ assert x.location == specs.Location.NODE
757
+ if enforce_permutations:
758
+ new_x, mask = probing.predecessor_to_cyclic_predecessor_and_first(x.data)
759
+ output.append(
760
+ probing.DataPoint(
761
+ name=x.name,
762
+ location=x.location,
763
+ type_=specs.Type.PERMUTATION_POINTER,
764
+ data=new_x))
765
+ output.append(
766
+ probing.DataPoint(
767
+ name=x.name + '_mask',
768
+ location=x.location,
769
+ type_=specs.Type.MASK_ONE,
770
+ data=mask))
771
+ else:
772
+ output.append(probing.DataPoint(name=x.name, location=x.location,
773
+ type_=specs.Type.POINTER, data=x.data))
774
+ return output
775
+
776
+
777
+ def process_permutations(spec, sample_iterator, enforce_permutations):
778
+ """Replace should-be permutations with proper permutation pointer + mask."""
779
+ def _iterate():
780
+ while True:
781
+ feedback = next(sample_iterator)
782
+ features = feedback.features
783
+ inputs = _preprocess_permutations(features.inputs, enforce_permutations)
784
+ hints = _preprocess_permutations(features.hints, enforce_permutations)
785
+ outputs = _preprocess_permutations(feedback.outputs, enforce_permutations)
786
+ features = features._replace(inputs=tuple(inputs),
787
+ hints=tuple(hints))
788
+ feedback = feedback._replace(features=features,
789
+ outputs=outputs)
790
+ yield feedback
791
+
792
+ new_spec = {}
793
+ for k in spec:
794
+ if (spec[k][1] == specs.Location.NODE and
795
+ spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
796
+ if enforce_permutations:
797
+ new_spec[k] = (spec[k][0], spec[k][1], specs.Type.PERMUTATION_POINTER)
798
+ new_spec[k + '_mask'] = (spec[k][0], spec[k][1], specs.Type.MASK_ONE)
799
+ else:
800
+ new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
801
+ else:
802
+ new_spec[k] = spec[k]
803
+
804
+ return new_spec, _iterate()
805
+
806
+
807
+ def process_pred_as_input(spec, sample_iterator):
808
+ """Move pred_h hint to pred input."""
809
+ def _iterate():
810
+ while True:
811
+ feedback = next(sample_iterator)
812
+ features = feedback.features
813
+ pred_h = [h for h in features.hints if h.name == 'pred_h']
814
+ if pred_h:
815
+ assert len(pred_h) == 1
816
+ pred_h = pred_h[0]
817
+ hints = [h for h in features.hints if h.name != 'pred_h']
818
+ for i in range(len(features.lengths)):
819
+ assert np.sum(np.abs(pred_h.data[1:int(features.lengths[i]), i] -
820
+ pred_h.data[0, i])) == 0.0
821
+ inputs = tuple(features.inputs) + (
822
+ probing.DataPoint(name='pred', location=pred_h.location,
823
+ type_=pred_h.type_, data=pred_h.data[0]),)
824
+ features = features._replace(inputs=tuple(inputs),
825
+ hints=tuple(hints))
826
+ feedback = feedback._replace(features=features)
827
+ yield feedback
828
+
829
+ new_spec = {}
830
+ for k in spec:
831
+ if k == 'pred_h':
832
+ assert spec[k] == (specs.Stage.HINT, specs.Location.NODE,
833
+ specs.Type.POINTER)
834
+ new_spec['pred'] = (specs.Stage.INPUT, specs.Location.NODE,
835
+ specs.Type.POINTER)
836
+ else:
837
+ new_spec[k] = spec[k]
838
+
839
+ return new_spec, _iterate()
840
+
841
+
842
+ def process_random_pos(sample_iterator, rng):
843
+ """Randomize the `pos` input from a sampler.
844
+
845
+ The `pos` input is, by default, a scalar uniformly spaced between 0 and 1
846
+ across the nodes. The exception are string algorithms (naive_string_matcher,
847
+ kmp_string_matcher and lcs_length), where the `pos` sequence is split into
848
+ needle and haystack (or first and second string, for lcs_length). Here
849
+ we replace the uniformly spaced `pos` with an ordered sequence of random
850
+ scalars, or, for string algorithms, two ordered sequences of random scalars.
851
+
852
+ Args:
853
+ sample_iterator: An iterator producing samples with non-random `pos` inputs.
854
+ rng: Numpy random generator
855
+ Returns:
856
+ An iterator returning the samples with randomized `pos` inputs.
857
+ """
858
+ def _iterate():
859
+ while True:
860
+ feedback = next(sample_iterator)
861
+ inputs = feedback.features.inputs
862
+ pos, = [x for x in inputs if x.name == 'pos']
863
+ batch_size, num_nodes = pos.data.shape
864
+ unsorted = rng.uniform(size=(batch_size, num_nodes))
865
+ new_pos = []
866
+ for i in range(batch_size): # we check one example at a time.
867
+ # We find if there are splits in the pos sequence, marked by zeros.
868
+ # We know there will always be at least 1 zero, if there's no split.
869
+ split, = np.where(pos.data[i] == 0)
870
+ split = np.concatenate([split, [num_nodes]])
871
+ # We construct the randomized pos by sorting the random values in each
872
+ # split and concatenating them.
873
+ new_pos.append(
874
+ np.concatenate([np.sort(unsorted[i, split[j]:split[j+1]])
875
+ for j in range(len(split) - 1)]))
876
+ pos.data = np.array(new_pos)
877
+ inputs = [(pos if x.name == 'pos' else x) for x in inputs]
878
+ features = feedback.features._replace(inputs=inputs)
879
+ feedback = feedback._replace(features=features)
880
+ yield feedback
881
+
882
+ return _iterate()
benchmarks/CLRS/env/samplers_test.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for `samplers.py`."""
17
+
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+
21
+ import chex
22
+ from clrs._src import probing
23
+ from clrs._src import samplers
24
+ from clrs._src import specs
25
+ import jax
26
+ import numpy as np
27
+
28
+
29
+ class SamplersTest(parameterized.TestCase):
30
+
31
+ @parameterized.parameters(*specs.CLRS_30_ALGS)
32
+ def test_sampler_determinism(self, name):
33
+ num_samples = 3
34
+ num_nodes = 10
35
+ sampler, _ = samplers.build_sampler(name, num_samples, num_nodes)
36
+
37
+ np.random.seed(47) # Set seed
38
+ feedback = sampler.next()
39
+ expected = feedback.outputs[0].data.copy()
40
+
41
+ np.random.seed(48) # Set a different seed
42
+ feedback = sampler.next()
43
+ actual = feedback.outputs[0].data.copy()
44
+
45
+ # Validate that datasets are the same.
46
+ np.testing.assert_array_equal(expected, actual)
47
+
48
+ @parameterized.parameters(*specs.CLRS_30_ALGS)
49
+ def test_sampler_batch_determinism(self, name):
50
+ num_samples = 10
51
+ batch_size = 5
52
+ num_nodes = 10
53
+ seed = 0
54
+ sampler_1, _ = samplers.build_sampler(
55
+ name, num_samples, num_nodes, seed=seed)
56
+ sampler_2, _ = samplers.build_sampler(
57
+ name, num_samples, num_nodes, seed=seed)
58
+
59
+ feedback_1 = sampler_1.next(batch_size)
60
+ feedback_2 = sampler_2.next(batch_size)
61
+
62
+ # Validate that datasets are the same.
63
+ jax.tree_util.tree_map(np.testing.assert_array_equal, feedback_1,
64
+ feedback_2)
65
+
66
+ def test_end_to_end(self):
67
+ num_samples = 7
68
+ num_nodes = 3
69
+ sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes)
70
+ feedback = sampler.next()
71
+
72
+ inputs = feedback.features.inputs
73
+ self.assertLen(inputs, 4)
74
+ self.assertEqual(inputs[0].name, "pos")
75
+ self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes))
76
+
77
+ outputs = feedback.outputs
78
+ self.assertLen(outputs, 1)
79
+ self.assertEqual(outputs[0].name, "pi")
80
+ self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes))
81
+
82
+ def test_batch_size(self):
83
+ num_samples = 7
84
+ num_nodes = 3
85
+ sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes)
86
+
87
+ # Full-batch.
88
+ feedback = sampler.next()
89
+ for dp in feedback.features.inputs: # [B, ...]
90
+ self.assertEqual(dp.data.shape[0], num_samples)
91
+
92
+ for dp in feedback.outputs: # [B, ...]
93
+ self.assertEqual(dp.data.shape[0], num_samples)
94
+
95
+ for dp in feedback.features.hints: # [T, B, ...]
96
+ self.assertEqual(dp.data.shape[1], num_samples)
97
+
98
+ self.assertLen(feedback.features.lengths, num_samples)
99
+
100
+ # Specified batch.
101
+ batch_size = 5
102
+ feedback = sampler.next(batch_size)
103
+
104
+ for dp in feedback.features.inputs: # [B, ...]
105
+ self.assertEqual(dp.data.shape[0], batch_size)
106
+
107
+ for dp in feedback.outputs: # [B, ...]
108
+ self.assertEqual(dp.data.shape[0], batch_size)
109
+
110
+ for dp in feedback.features.hints: # [T, B, ...]
111
+ self.assertEqual(dp.data.shape[1], batch_size)
112
+
113
+ self.assertLen(feedback.features.lengths, batch_size)
114
+
115
+ def test_batch_io(self):
116
+ sample = [
117
+ probing.DataPoint(
118
+ name="x",
119
+ location=specs.Location.NODE,
120
+ type_=specs.Type.SCALAR,
121
+ data=np.zeros([1, 3]),
122
+ ),
123
+ probing.DataPoint(
124
+ name="y",
125
+ location=specs.Location.EDGE,
126
+ type_=specs.Type.MASK,
127
+ data=np.zeros([1, 3, 3]),
128
+ ),
129
+ ]
130
+
131
+ trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()]
132
+ batched = samplers._batch_io(trajectory)
133
+
134
+ np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3]))
135
+ np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3]))
136
+
137
+ def test_batch_hint(self):
138
+ sample0 = [
139
+ probing.DataPoint(
140
+ name="x",
141
+ location=specs.Location.NODE,
142
+ type_=specs.Type.MASK,
143
+ data=np.zeros([2, 1, 3]),
144
+ ),
145
+ probing.DataPoint(
146
+ name="y",
147
+ location=specs.Location.NODE,
148
+ type_=specs.Type.POINTER,
149
+ data=np.zeros([2, 1, 3]),
150
+ ),
151
+ ]
152
+
153
+ sample1 = [
154
+ probing.DataPoint(
155
+ name="x",
156
+ location=specs.Location.NODE,
157
+ type_=specs.Type.MASK,
158
+ data=np.zeros([1, 1, 3]),
159
+ ),
160
+ probing.DataPoint(
161
+ name="y",
162
+ location=specs.Location.NODE,
163
+ type_=specs.Type.POINTER,
164
+ data=np.zeros([1, 1, 3]),
165
+ ),
166
+ ]
167
+
168
+ trajectory = [sample0, sample1]
169
+ batched, lengths = samplers._batch_hints(trajectory, 0)
170
+
171
+ np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3]))
172
+ np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3]))
173
+ np.testing.assert_array_equal(lengths, np.array([2, 1]))
174
+
175
+ batched, lengths = samplers._batch_hints(trajectory, 5)
176
+
177
+ np.testing.assert_array_equal(batched[0].data, np.zeros([5, 2, 3]))
178
+ np.testing.assert_array_equal(batched[1].data, np.zeros([5, 2, 3]))
179
+ np.testing.assert_array_equal(lengths, np.array([2, 1]))
180
+
181
+ def test_padding(self):
182
+ lens = np.random.choice(10, (10,), replace=True) + 1
183
+ trajectory = []
184
+ for len_ in lens:
185
+ trajectory.append([
186
+ probing.DataPoint(
187
+ name="x",
188
+ location=specs.Location.NODE,
189
+ type_=specs.Type.MASK,
190
+ data=np.ones([len_, 1, 3]),
191
+ )
192
+ ])
193
+
194
+ batched, lengths = samplers._batch_hints(trajectory, 0)
195
+ np.testing.assert_array_equal(lengths, lens)
196
+
197
+ for i in range(len(lens)):
198
+ ones = batched[0].data[:lens[i], i, :]
199
+ zeros = batched[0].data[lens[i]:, i, :]
200
+ np.testing.assert_array_equal(ones, np.ones_like(ones))
201
+ np.testing.assert_array_equal(zeros, np.zeros_like(zeros))
202
+
203
+
204
+ class ProcessRandomPosTest(parameterized.TestCase):
205
+
206
+ @parameterized.parameters(["insertion_sort", "naive_string_matcher"])
207
+ def test_random_pos(self, algorithm_name):
208
+ batch_size, length = 12, 10
209
+ def _make_sampler():
210
+ sampler, _ = samplers.build_sampler(
211
+ algorithm_name,
212
+ seed=0,
213
+ num_samples=100,
214
+ length=length,
215
+ )
216
+ while True:
217
+ yield sampler.next(batch_size)
218
+ sampler_1 = _make_sampler()
219
+ sampler_2 = _make_sampler()
220
+ sampler_2 = samplers.process_random_pos(sampler_2, np.random.RandomState(0))
221
+
222
+ batch_without_rand_pos = next(sampler_1)
223
+ batch_with_rand_pos = next(sampler_2)
224
+ pos_idx = [x.name for x in batch_without_rand_pos.features.inputs].index(
225
+ "pos")
226
+ fixed_pos = batch_without_rand_pos.features.inputs[pos_idx]
227
+ rand_pos = batch_with_rand_pos.features.inputs[pos_idx]
228
+ self.assertEqual(rand_pos.location, specs.Location.NODE)
229
+ self.assertEqual(rand_pos.type_, specs.Type.SCALAR)
230
+ self.assertEqual(rand_pos.data.shape, (batch_size, length))
231
+ self.assertEqual(rand_pos.data.shape, fixed_pos.data.shape)
232
+ self.assertEqual(rand_pos.type_, fixed_pos.type_)
233
+ self.assertEqual(rand_pos.location, fixed_pos.location)
234
+
235
+ assert (rand_pos.data.std(axis=0) > 1e-3).all()
236
+ assert (fixed_pos.data.std(axis=0) < 1e-9).all()
237
+ if "string" in algorithm_name:
238
+ expected = np.concatenate([np.arange(4*length//5)/(4*length//5),
239
+ np.arange(length//5)/(length//5)])
240
+ else:
241
+ expected = np.arange(length)/length
242
+ np.testing.assert_array_equal(
243
+ fixed_pos.data, np.broadcast_to(expected, (batch_size, length)))
244
+
245
+ batch_with_rand_pos.features.inputs[pos_idx] = fixed_pos
246
+ chex.assert_trees_all_equal(batch_with_rand_pos, batch_without_rand_pos)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ absltest.main()
benchmarks/CLRS/env/specs.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Algorithm specs.
17
+
18
+ The "spec" of each algorithm is a static set of `(stage, loc, type)`-tuples.
19
+
20
+ - `stage`: One of either an `input`, `output` or `hint`
21
+ - `location`: Each datum is associated with either the `node`, `edge` or `graph`
22
+ - `type`: Either a `scalar`, `categorical`, `mask`, `mask_one` or `pointer`
23
+
24
+ The dataflow for an algorithm is represented by `(stage, loc, type, data)`
25
+ "probes" that are valid under that algorithm's spec. It contains a single
26
+ snapshot for each `input` and `output` and a time-series of intermediate
27
+ algorithmic states (`hint`).
28
+
29
+ At minimum, each node contains a `pos` probe that serves as a unique index e.g.
30
+ for representing sequential data where appropriate
31
+ """
32
+
33
+ import types
34
+ from typing import Dict, Tuple
35
+
36
+
37
+ class Stage:
38
+ INPUT = 'input'
39
+ OUTPUT = 'output'
40
+ HINT = 'hint'
41
+
42
+
43
+ class Location:
44
+ NODE = 'node'
45
+ EDGE = 'edge'
46
+ GRAPH = 'graph'
47
+
48
+
49
+ class Type:
50
+ SCALAR = 'scalar'
51
+ CATEGORICAL = 'categorical'
52
+ MASK = 'mask'
53
+ MASK_ONE = 'mask_one'
54
+ POINTER = 'pointer'
55
+ SHOULD_BE_PERMUTATION = 'should_be_permutation'
56
+ PERMUTATION_POINTER = 'permutation_pointer'
57
+ SOFT_POINTER = 'soft_pointer'
58
+
59
+
60
+ class OutputClass:
61
+ POSITIVE = 1
62
+ NEGATIVE = 0
63
+ MASKED = -1
64
+
65
+ Spec = Dict[str, Tuple[str, str, str]]
66
+
67
+ CLRS_30_ALGS = [
68
+ 'articulation_points',
69
+ 'activity_selector',
70
+ 'bellman_ford',
71
+ 'bfs',
72
+ 'binary_search',
73
+ 'bridges',
74
+ 'bubble_sort',
75
+ 'dag_shortest_paths',
76
+ 'dfs',
77
+ 'dijkstra',
78
+ 'find_maximum_subarray_kadane',
79
+ 'floyd_warshall',
80
+ 'graham_scan',
81
+ 'heapsort',
82
+ 'insertion_sort',
83
+ 'jarvis_march',
84
+ 'kmp_matcher',
85
+ 'lcs_length',
86
+ 'matrix_chain_order',
87
+ 'minimum',
88
+ 'mst_kruskal',
89
+ 'mst_prim',
90
+ 'naive_string_matcher',
91
+ 'optimal_bst',
92
+ 'quickselect',
93
+ 'quicksort',
94
+ 'segments_intersect',
95
+ 'strongly_connected_components',
96
+ 'task_scheduling',
97
+ 'topological_sort',
98
+ ]
99
+
100
+
101
+ ALGO_IDX_INPUT_NAME = 'algo_idx'
102
+
103
+ # Algorithms have varying numbers of signals they are evaluated on.
104
+ # To compensate for that, we issue more samples for those who use a small
105
+ # number of signals.
106
+ CLRS_30_ALGS_SETTINGS = {alg: {'num_samples_multiplier': 1}
107
+ for alg in CLRS_30_ALGS}
108
+ CLRS_30_ALGS_SETTINGS['find_maximum_subarray_kadane'][
109
+ 'num_samples_multiplier'] = 32
110
+ for alg in ['quickselect', 'minimum', 'binary_search', 'naive_string_matcher',
111
+ 'kmp_matcher', 'segments_intersect']:
112
+ CLRS_30_ALGS_SETTINGS[alg]['num_samples_multiplier'] = 64
113
+
114
+
115
+ SPECS = types.MappingProxyType({
116
+ 'insertion_sort': {
117
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
118
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
119
+ 'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
120
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
121
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
122
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
123
+ },
124
+ 'bubble_sort': {
125
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
126
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
127
+ 'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
128
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
129
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
130
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
131
+ },
132
+ 'heapsort': {
133
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
134
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
135
+ 'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
136
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
137
+ 'parent': (Stage.HINT, Location.NODE, Type.POINTER),
138
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
139
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
140
+ 'largest': (Stage.HINT, Location.NODE, Type.MASK_ONE),
141
+ 'heap_size': (Stage.HINT, Location.NODE, Type.MASK_ONE),
142
+ 'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
143
+ },
144
+ 'quicksort': {
145
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
146
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
147
+ 'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
148
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
149
+ 'p': (Stage.HINT, Location.NODE, Type.MASK_ONE),
150
+ 'r': (Stage.HINT, Location.NODE, Type.MASK_ONE),
151
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
152
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
153
+ },
154
+ 'quickselect': {
155
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
156
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
157
+ 'median': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
158
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
159
+ 'p': (Stage.HINT, Location.NODE, Type.MASK_ONE),
160
+ 'r': (Stage.HINT, Location.NODE, Type.MASK_ONE),
161
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
162
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
163
+ 'i_rank': (Stage.HINT, Location.GRAPH, Type.SCALAR),
164
+ 'target': (Stage.HINT, Location.GRAPH, Type.SCALAR)
165
+ },
166
+ 'minimum': {
167
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
168
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
169
+ 'min': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
170
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
171
+ 'min_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
172
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE)
173
+ },
174
+ 'binary_search': {
175
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
176
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
177
+ 'target': (Stage.INPUT, Location.GRAPH, Type.SCALAR),
178
+ 'return': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
179
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
180
+ 'low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
181
+ 'high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
182
+ 'mid': (Stage.HINT, Location.NODE, Type.MASK_ONE)
183
+ },
184
+ 'find_maximum_subarray': {
185
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
186
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
187
+ 'start': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
188
+ 'end': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
189
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
190
+ 'low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
191
+ 'high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
192
+ 'mid': (Stage.HINT, Location.NODE, Type.MASK_ONE),
193
+ 'left_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
194
+ 'left_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
195
+ 'left_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
196
+ 'right_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
197
+ 'right_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
198
+ 'right_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
199
+ 'cross_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
200
+ 'cross_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
201
+ 'cross_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
202
+ 'ret_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
203
+ 'ret_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
204
+ 'ret_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
205
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
206
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
207
+ 'sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
208
+ 'left_x_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
209
+ 'right_x_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
210
+ 'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
211
+ },
212
+ 'find_maximum_subarray_kadane': {
213
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
214
+ 'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
215
+ 'start': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
216
+ 'end': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
217
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
218
+ 'best_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
219
+ 'best_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
220
+ 'best_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
221
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
222
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
223
+ 'sum': (Stage.HINT, Location.GRAPH, Type.SCALAR)
224
+ },
225
+ 'matrix_chain_order': {
226
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
227
+ 'p': (Stage.INPUT, Location.NODE, Type.SCALAR),
228
+ 's': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
229
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
230
+ 'm': (Stage.HINT, Location.EDGE, Type.SCALAR),
231
+ 's_h': (Stage.HINT, Location.EDGE, Type.POINTER),
232
+ 'msk': (Stage.HINT, Location.EDGE, Type.MASK)
233
+ },
234
+ 'lcs_length': {
235
+ 'string': (Stage.INPUT, Location.NODE, Type.MASK),
236
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
237
+ 'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
238
+ 'b': (Stage.OUTPUT, Location.EDGE, Type.CATEGORICAL),
239
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
240
+ 'b_h': (Stage.HINT, Location.EDGE, Type.CATEGORICAL),
241
+ 'c': (Stage.HINT, Location.EDGE, Type.SCALAR)
242
+ },
243
+ 'optimal_bst': {
244
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
245
+ 'p': (Stage.INPUT, Location.NODE, Type.SCALAR),
246
+ 'q': (Stage.INPUT, Location.NODE, Type.SCALAR),
247
+ 'root': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
248
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
249
+ 'root_h': (Stage.HINT, Location.EDGE, Type.POINTER),
250
+ 'e': (Stage.HINT, Location.EDGE, Type.SCALAR),
251
+ 'w': (Stage.HINT, Location.EDGE, Type.SCALAR),
252
+ 'msk': (Stage.HINT, Location.EDGE, Type.MASK)
253
+ },
254
+ 'activity_selector': {
255
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
256
+ 's': (Stage.INPUT, Location.NODE, Type.SCALAR),
257
+ 'f': (Stage.INPUT, Location.NODE, Type.SCALAR),
258
+ 'selected': (Stage.OUTPUT, Location.NODE, Type.MASK),
259
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
260
+ 'selected_h': (Stage.HINT, Location.NODE, Type.MASK),
261
+ 'm': (Stage.HINT, Location.NODE, Type.MASK_ONE),
262
+ 'k': (Stage.HINT, Location.NODE, Type.MASK_ONE)
263
+ },
264
+ 'task_scheduling': {
265
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
266
+ 'd': (Stage.INPUT, Location.NODE, Type.SCALAR),
267
+ 'w': (Stage.INPUT, Location.NODE, Type.SCALAR),
268
+ 'selected': (Stage.OUTPUT, Location.NODE, Type.MASK),
269
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
270
+ 'selected_h': (Stage.HINT, Location.NODE, Type.MASK),
271
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
272
+ 't': (Stage.HINT, Location.GRAPH, Type.SCALAR)
273
+ },
274
+ 'dfs': {
275
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
276
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
277
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
278
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
279
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
280
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
281
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
282
+ 'f': (Stage.HINT, Location.NODE, Type.SCALAR),
283
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
284
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
285
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
286
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
287
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
288
+ 'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
289
+ },
290
+ 'topological_sort': {
291
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
292
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
293
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
294
+ 'topo': (Stage.OUTPUT, Location.NODE, Type.POINTER),
295
+ 'topo_head': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
296
+ 'topo_h': (Stage.HINT, Location.NODE, Type.POINTER),
297
+ 'topo_head_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
298
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
299
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
300
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
301
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
302
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
303
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE)
304
+ },
305
+ 'strongly_connected_components': {
306
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
307
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
308
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
309
+ 'scc_id': (Stage.OUTPUT, Location.NODE, Type.POINTER),
310
+ 'scc_id_h': (Stage.HINT, Location.NODE, Type.POINTER),
311
+ 'A_t': (Stage.HINT, Location.EDGE, Type.MASK),
312
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
313
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
314
+ 'f': (Stage.HINT, Location.NODE, Type.SCALAR),
315
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
316
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
317
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
318
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
319
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
320
+ 'time': (Stage.HINT, Location.GRAPH, Type.SCALAR),
321
+ 'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
322
+ },
323
+ 'articulation_points': {
324
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
325
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
326
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
327
+ 'is_cut': (Stage.OUTPUT, Location.NODE, Type.MASK),
328
+ 'is_cut_h': (Stage.HINT, Location.NODE, Type.MASK),
329
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
330
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
331
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
332
+ 'f': (Stage.HINT, Location.NODE, Type.SCALAR),
333
+ 'low': (Stage.HINT, Location.NODE, Type.SCALAR),
334
+ 'child_cnt': (Stage.HINT, Location.NODE, Type.SCALAR),
335
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
336
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
337
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
338
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
339
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
340
+ 'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
341
+ },
342
+ 'bridges': {
343
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
344
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
345
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
346
+ 'is_bridge': (Stage.OUTPUT, Location.EDGE, Type.MASK),
347
+ 'is_bridge_h': (Stage.HINT, Location.EDGE, Type.MASK),
348
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
349
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
350
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
351
+ 'f': (Stage.HINT, Location.NODE, Type.SCALAR),
352
+ 'low': (Stage.HINT, Location.NODE, Type.SCALAR),
353
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
354
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
355
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
356
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
357
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
358
+ 'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
359
+ },
360
+ 'bfs': {
361
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
362
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
363
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
364
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
365
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
366
+ 'reach_h': (Stage.HINT, Location.NODE, Type.MASK),
367
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER)
368
+ },
369
+ 'mst_kruskal': {
370
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
371
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
372
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
373
+ 'in_mst': (Stage.OUTPUT, Location.EDGE, Type.MASK),
374
+ 'in_mst_h': (Stage.HINT, Location.EDGE, Type.MASK),
375
+ 'pi': (Stage.HINT, Location.NODE, Type.POINTER),
376
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
377
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
378
+ 'root_u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
379
+ 'root_v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
380
+ 'mask_u': (Stage.HINT, Location.NODE, Type.MASK),
381
+ 'mask_v': (Stage.HINT, Location.NODE, Type.MASK),
382
+ 'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
383
+ },
384
+ 'mst_prim': {
385
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
386
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
387
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
388
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
389
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
390
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
391
+ 'key': (Stage.HINT, Location.NODE, Type.SCALAR),
392
+ 'mark': (Stage.HINT, Location.NODE, Type.MASK),
393
+ 'in_queue': (Stage.HINT, Location.NODE, Type.MASK),
394
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE)
395
+ },
396
+ 'bellman_ford': {
397
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
398
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
399
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
400
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
401
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
402
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
403
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
404
+ 'msk': (Stage.HINT, Location.NODE, Type.MASK)
405
+ },
406
+ 'dag_shortest_paths': {
407
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
408
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
409
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
410
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
411
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
412
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
413
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
414
+ 'mark': (Stage.HINT, Location.NODE, Type.MASK),
415
+ 'topo_h': (Stage.HINT, Location.NODE, Type.POINTER),
416
+ 'topo_head_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
417
+ 'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
418
+ 's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
419
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
420
+ 'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
421
+ 's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
422
+ 'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
423
+ },
424
+ 'dijkstra': {
425
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
426
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
427
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
428
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
429
+ 'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
430
+ 'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
431
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
432
+ 'mark': (Stage.HINT, Location.NODE, Type.MASK),
433
+ 'in_queue': (Stage.HINT, Location.NODE, Type.MASK),
434
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE)
435
+ },
436
+ 'floyd_warshall': {
437
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
438
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
439
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
440
+ 'Pi': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
441
+ 'Pi_h': (Stage.HINT, Location.EDGE, Type.POINTER),
442
+ 'D': (Stage.HINT, Location.EDGE, Type.SCALAR),
443
+ 'msk': (Stage.HINT, Location.EDGE, Type.MASK),
444
+ 'k': (Stage.HINT, Location.NODE, Type.MASK_ONE)
445
+ },
446
+ 'bipartite_matching': {
447
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
448
+ 'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
449
+ 'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
450
+ 's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
451
+ 't': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
452
+ 'in_matching': (Stage.OUTPUT, Location.EDGE, Type.MASK),
453
+ 'in_matching_h': (Stage.HINT, Location.EDGE, Type.MASK),
454
+ 'A_h': (Stage.HINT, Location.EDGE, Type.SCALAR),
455
+ 'adj_h': (Stage.HINT, Location.EDGE, Type.MASK),
456
+ 'd': (Stage.HINT, Location.NODE, Type.SCALAR),
457
+ 'msk': (Stage.HINT, Location.NODE, Type.MASK),
458
+ 'pi': (Stage.HINT, Location.NODE, Type.POINTER),
459
+ 'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
460
+ 'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
461
+ },
462
+ 'naive_string_matcher': {
463
+ 'string': (Stage.INPUT, Location.NODE, Type.MASK),
464
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
465
+ 'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
466
+ 'match': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
467
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
468
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
469
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
470
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
471
+ },
472
+ 'kmp_matcher': {
473
+ 'string': (Stage.INPUT, Location.NODE, Type.MASK),
474
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
475
+ 'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
476
+ 'match': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
477
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
478
+ 'pi': (Stage.HINT, Location.NODE, Type.POINTER),
479
+ 'is_reset': (Stage.HINT, Location.NODE, Type.MASK),
480
+ 'k': (Stage.HINT, Location.NODE, Type.MASK_ONE),
481
+ 'k_reset': (Stage.HINT, Location.GRAPH, Type.MASK),
482
+ 'q': (Stage.HINT, Location.NODE, Type.MASK_ONE),
483
+ 'q_reset': (Stage.HINT, Location.GRAPH, Type.MASK),
484
+ 's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
485
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
486
+ 'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
487
+ },
488
+ 'segments_intersect': {
489
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
490
+ 'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
491
+ 'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
492
+ 'intersect': (Stage.OUTPUT, Location.GRAPH, Type.MASK),
493
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
494
+ 'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
495
+ 'k': (Stage.HINT, Location.NODE, Type.MASK_ONE),
496
+ 'dir': (Stage.HINT, Location.NODE, Type.SCALAR),
497
+ 'on_seg': (Stage.HINT, Location.NODE, Type.MASK)
498
+ },
499
+ 'graham_scan': {
500
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
501
+ 'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
502
+ 'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
503
+ 'in_hull': (Stage.OUTPUT, Location.NODE, Type.MASK),
504
+ 'best': (Stage.HINT, Location.NODE, Type.MASK_ONE),
505
+ 'atans': (Stage.HINT, Location.NODE, Type.SCALAR),
506
+ 'in_hull_h': (Stage.HINT, Location.NODE, Type.MASK),
507
+ 'stack_prev': (Stage.HINT, Location.NODE, Type.POINTER),
508
+ 'last_stack': (Stage.HINT, Location.NODE, Type.MASK_ONE),
509
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
510
+ 'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
511
+ },
512
+ 'jarvis_march': {
513
+ 'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
514
+ 'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
515
+ 'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
516
+ 'in_hull': (Stage.OUTPUT, Location.NODE, Type.MASK),
517
+ 'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
518
+ 'in_hull_h': (Stage.HINT, Location.NODE, Type.MASK),
519
+ 'best': (Stage.HINT, Location.NODE, Type.MASK_ONE),
520
+ 'last_point': (Stage.HINT, Location.NODE, Type.MASK_ONE),
521
+ 'endpoint': (Stage.HINT, Location.NODE, Type.MASK_ONE),
522
+ 'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
523
+ 'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
524
+ }
525
+ })
benchmarks/CLRS/env/train.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Run training of one or more algorithmic tasks from CLRS."""
17
+ import os
18
+ # disable logging until training starts
19
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
20
+
21
+ import functools
22
+ import os
23
+ import shutil
24
+ from typing import Any, Dict, List
25
+
26
+ from absl import app
27
+ from absl import flags
28
+ from absl import logging
29
+ # disable logging until training starts
30
+ logging.set_verbosity(logging.ERROR)
31
+
32
+ import clrs
33
+ import jax
34
+ import numpy as np
35
+ import requests
36
+ import tensorflow as tf
37
+ from baselines import BaselineModel, BaselineModelChunked
38
+ import pickle
39
+ import copy
40
+
41
+ flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.')
42
+ flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'],
43
+ 'Which training sizes to use. A size of -1 means '
44
+ 'use the benchmark dataset.')
45
+ flags.DEFINE_integer('length_needle', -8,
46
+ 'Length of needle for training and validation '
47
+ '(not testing) in string matching algorithms. '
48
+ 'A negative value randomizes the length for each sample '
49
+ 'between 1 and the opposite of the value. '
50
+ 'A value of 0 means use always 1/4 of the length of '
51
+ 'the haystack (the default sampler behavior).')
52
+ flags.DEFINE_integer('seed', 42, 'Random seed to set')
53
+
54
+ flags.DEFINE_boolean('random_pos', True,
55
+ 'Randomize the pos input common to all algos.')
56
+ flags.DEFINE_boolean('enforce_permutations', True,
57
+ 'Whether to enforce permutation-type node pointers.')
58
+ flags.DEFINE_boolean('enforce_pred_as_input', True,
59
+ 'Whether to change pred_h hints into pred inputs.')
60
+ flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.')
61
+ flags.DEFINE_boolean('chunked_training', False,
62
+ 'Whether to use chunking for training.')
63
+ flags.DEFINE_integer('chunk_length', 16,
64
+ 'Time chunk length used for training (if '
65
+ '`chunked_training` is True.')
66
+ flags.DEFINE_integer('train_steps', 500, 'Number of training iterations.')
67
+ flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).')
68
+ flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).')
69
+ flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).')
70
+
71
+ flags.DEFINE_integer('hidden_size', 128,
72
+ 'Number of hidden units of the model.')
73
+ flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')
74
+ flags.DEFINE_integer('nb_msg_passing_steps', 1,
75
+ 'Number of message passing steps to run per hint.')
76
+ flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
77
+ flags.DEFINE_float('grad_clip_max_norm', 1.0,
78
+ 'Gradient clipping by norm. 0.0 disables grad clipping')
79
+ flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
80
+ flags.DEFINE_float('hint_teacher_forcing', 0.0,
81
+ 'Probability that ground-truth teacher hints are encoded '
82
+ 'during training instead of predicted hints. Only '
83
+ 'pertinent in encoded_decoded modes.')
84
+ flags.DEFINE_enum('hint_mode', 'encoded_decoded',
85
+ ['encoded_decoded', 'decoded_only', 'none'],
86
+ 'How should hints be used? Note, each mode defines a '
87
+ 'separate task, with various difficulties. `encoded_decoded` '
88
+ 'requires the model to explicitly materialise hint sequences '
89
+ 'and therefore is hardest, but also most aligned to the '
90
+ 'underlying algorithmic rule. Hence, `encoded_decoded` '
91
+ 'should be treated as the default mode for our benchmark. '
92
+ 'In `decoded_only`, hints are only used for defining '
93
+ 'reconstruction losses. Often, this will perform well, but '
94
+ 'note that we currently do not make any efforts to '
95
+ 'counterbalance the various hint losses. Hence, for certain '
96
+ 'tasks, the best performance will now be achievable with no '
97
+ 'hint usage at all (`none`).')
98
+ flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
99
+ 'How to process predicted hints when fed back as inputs.'
100
+ 'In soft mode, we use softmaxes for categoricals, pointers '
101
+ 'and mask_one, and sigmoids for masks. '
102
+ 'In hard mode, we use argmax instead of softmax, and hard '
103
+ 'thresholding of masks. '
104
+ 'In hard_on_eval mode, soft mode is '
105
+ 'used for training and hard mode is used for evaluation.')
106
+ flags.DEFINE_boolean('use_ln', True,
107
+ 'Whether to use layer normalisation in the processor.')
108
+ flags.DEFINE_boolean('use_lstm', False,
109
+ 'Whether to insert an LSTM after message passing.')
110
+ flags.DEFINE_integer('nb_triplet_fts', 8,
111
+ 'How many triplet features to compute?')
112
+
113
+ flags.DEFINE_enum('encoder_init', 'xavier_on_scalars',
114
+ ['default', 'xavier_on_scalars'],
115
+ 'Initialiser to use for the encoders.')
116
+ flags.DEFINE_enum('processor_type', 'triplet_gmpnn',
117
+ ['deepsets', 'mpnn', 'pgn', 'pgn_mask',
118
+ 'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask',
119
+ 'gat', 'gatv2', 'gat_full', 'gatv2_full',
120
+ 'gpgn', 'gpgn_mask', 'gmpnn',
121
+ 'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'],
122
+ 'Processor type to use as the network P.')
123
+
124
+ flags.DEFINE_string('checkpoint_path', './checkpoints',
125
+ 'Path in which checkpoints are saved.')
126
+ flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
127
+ 'Path in which dataset is stored.')
128
+ flags.DEFINE_boolean('freeze_processor', False,
129
+ 'Whether to freeze the processor of the model.')
130
+
131
+ FLAGS = flags.FLAGS
132
+
133
+
134
+ PRED_AS_INPUT_ALGOS = [
135
+ 'binary_search',
136
+ 'minimum',
137
+ 'find_maximum_subarray',
138
+ 'find_maximum_subarray_kadane',
139
+ 'matrix_chain_order',
140
+ 'lcs_length',
141
+ 'optimal_bst',
142
+ 'activity_selector',
143
+ 'task_scheduling',
144
+ 'naive_string_matcher',
145
+ 'kmp_matcher',
146
+ 'jarvis_march']
147
+
148
+
149
+ def unpack(v):
150
+ try:
151
+ return v.item() # DeviceArray
152
+ except (AttributeError, ValueError):
153
+ return v
154
+
155
+
156
+ def _iterate_sampler(sampler, batch_size):
157
+ while True:
158
+ yield sampler.next(batch_size)
159
+
160
+
161
+ def _maybe_download_dataset(dataset_path):
162
+ """Download CLRS30 dataset if needed."""
163
+ dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder())
164
+ if os.path.isdir(dataset_folder):
165
+ logging.info('Dataset found at %s. Skipping download.', dataset_folder)
166
+ return dataset_folder
167
+ logging.info('Dataset not found in %s. Downloading...', dataset_folder)
168
+
169
+ clrs_url = clrs.get_dataset_gcp_url()
170
+ request = requests.get(clrs_url, allow_redirects=True)
171
+ clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url))
172
+ os.makedirs(dataset_folder)
173
+ open(clrs_file, 'wb').write(request.content)
174
+ shutil.unpack_archive(clrs_file, extract_dir=dataset_folder)
175
+ os.remove(clrs_file)
176
+ return dataset_folder
177
+
178
+
179
+ def make_sampler(length: int,
180
+ rng: Any,
181
+ algorithm: str,
182
+ split: str,
183
+ batch_size: int,
184
+ multiplier: int,
185
+ randomize_pos: bool,
186
+ enforce_pred_as_input: bool,
187
+ enforce_permutations: bool,
188
+ chunked: bool,
189
+ chunk_length: int,
190
+ sampler_kwargs: Dict[str, Any]):
191
+ """Create a sampler with given options.
192
+
193
+ Args:
194
+ length: Size of samples (i.e., number of nodes in the graph).
195
+ A length of -1 will mean that the benchmark
196
+ dataset (for the given split) is used. Positive sizes will instantiate
197
+ samplers of the corresponding size.
198
+ rng: Numpy random state.
199
+ algorithm: The name of the algorithm to sample from.
200
+ split: 'train', 'val' or 'test'.
201
+ batch_size: Samples per batch.
202
+ multiplier: Integer multiplier for the number of samples in the dataset,
203
+ only used for positive sizes. Negative multiplier means infinite samples.
204
+ randomize_pos: Whether to randomize the `pos` input.
205
+ enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs.
206
+ enforce_permutations: Whether to enforce permutation pointers.
207
+ chunked: Whether to chunk the dataset.
208
+ chunk_length: Unroll length of chunks, if `chunked` is True.
209
+ sampler_kwargs: Extra args passed to the sampler.
210
+ Returns:
211
+ A sampler (iterator), the number of samples in the iterator (negative
212
+ if infinite samples), and the spec.
213
+ """
214
+ if length < 0: # load from file
215
+ dataset_folder = _maybe_download_dataset(FLAGS.dataset_path)
216
+ sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder,
217
+ algorithm=algorithm,
218
+ batch_size=batch_size,
219
+ split=split)
220
+ sampler = sampler.as_numpy_iterator()
221
+ else:
222
+ num_samples = clrs.CLRS30[split]['num_samples'] * multiplier
223
+ sampler, spec = clrs.build_sampler(
224
+ algorithm,
225
+ seed=rng.randint(2**32),
226
+ num_samples=num_samples,
227
+ length=length,
228
+ **sampler_kwargs,
229
+ )
230
+ sampler = _iterate_sampler(sampler, batch_size)
231
+
232
+ if randomize_pos:
233
+ sampler = clrs.process_random_pos(sampler, rng)
234
+ if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS:
235
+ spec, sampler = clrs.process_pred_as_input(spec, sampler)
236
+ spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations)
237
+ if chunked:
238
+ sampler = clrs.chunkify(sampler, chunk_length)
239
+ return sampler, num_samples, spec
240
+
241
+
242
+ def make_multi_sampler(sizes, rng, **kwargs):
243
+ """Create a sampler with cycling sample sizes."""
244
+ ss = []
245
+ tot_samples = 0
246
+ for length in sizes:
247
+ sampler, num_samples, spec = make_sampler(length, rng, **kwargs)
248
+ ss.append(sampler)
249
+ tot_samples += num_samples
250
+
251
+ def cycle_samplers():
252
+ while True:
253
+ for s in ss:
254
+ yield next(s)
255
+ return cycle_samplers(), tot_samples, spec
256
+
257
+
258
+ def _concat(dps, axis):
259
+ return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)
260
+
261
+
262
+ def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
263
+ """Collect batches of output and hint preds and evaluate them."""
264
+ processed_samples = 0
265
+ preds = []
266
+ outputs = []
267
+ while processed_samples < sample_count:
268
+ feedback = next(sampler)
269
+ batch_size = feedback.outputs[0].data.shape[0]
270
+ outputs.append(feedback.outputs)
271
+ new_rng_key, rng_key = jax.random.split(rng_key)
272
+ cur_preds, _ = predict_fn(new_rng_key, feedback.features)
273
+ preds.append(cur_preds)
274
+ processed_samples += batch_size
275
+ outputs = _concat(outputs, axis=0)
276
+ preds = _concat(preds, axis=0)
277
+ out = clrs.evaluate(outputs, preds)
278
+ if extras:
279
+ out.update(extras)
280
+ return {k: unpack(v) for k, v in out.items()}
281
+
282
+
283
+ def create_samplers(rng, train_lengths: List[int]):
284
+ """Create all the samplers."""
285
+ train_samplers = []
286
+ val_samplers = []
287
+ val_sample_counts = []
288
+ test_samplers = []
289
+ test_sample_counts = []
290
+ spec_list = []
291
+
292
+ for algo_idx, algorithm in enumerate(FLAGS.algorithms):
293
+ # Make full dataset pipeline run on CPU (including prefetching).
294
+ with tf.device('/cpu:0'):
295
+
296
+ if algorithm in ['naive_string_matcher', 'kmp_matcher']:
297
+ # Fixed haystack + needle; variability will be in needle
298
+ # Still, for chunked training, we maintain as many samplers
299
+ # as train lengths, since, for each length there is a separate state,
300
+ # and we must keep the 1:1 relationship between states and samplers.
301
+ max_length = max(train_lengths)
302
+ if max_length > 0: # if < 0, we are using the benchmark data
303
+ max_length = (max_length * 5) // 4
304
+ train_lengths = [max_length]
305
+ if FLAGS.chunked_training:
306
+ train_lengths = train_lengths * len(train_lengths)
307
+
308
+ logging.info('Creating samplers for algo %s', algorithm)
309
+
310
+ p = tuple([0.1 + 0.1 * i for i in range(9)])
311
+ if p and algorithm in ['articulation_points', 'bridges',
312
+ 'mst_kruskal', 'bipartite_matching']:
313
+ # Choose a lower connection probability for the above algorithms,
314
+ # otherwise trajectories are very long
315
+ p = tuple(np.array(p) / 2)
316
+ length_needle = FLAGS.length_needle
317
+ sampler_kwargs = dict(p=p, length_needle=length_needle)
318
+ if length_needle == 0:
319
+ sampler_kwargs.pop('length_needle')
320
+
321
+ common_sampler_args = dict(
322
+ algorithm=FLAGS.algorithms[algo_idx],
323
+ rng=rng,
324
+ enforce_pred_as_input=FLAGS.enforce_pred_as_input,
325
+ enforce_permutations=FLAGS.enforce_permutations,
326
+ chunk_length=FLAGS.chunk_length,
327
+ )
328
+
329
+ train_args = dict(sizes=train_lengths,
330
+ split='train',
331
+ batch_size=FLAGS.batch_size,
332
+ multiplier=-1,
333
+ randomize_pos=FLAGS.random_pos,
334
+ chunked=FLAGS.chunked_training,
335
+ sampler_kwargs=sampler_kwargs,
336
+ **common_sampler_args)
337
+ train_sampler, _, spec = make_multi_sampler(**train_args)
338
+
339
+ mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
340
+ val_args = dict(sizes=[np.amax(train_lengths)],
341
+ split='val',
342
+ batch_size=32,
343
+ multiplier=2 * mult,
344
+ randomize_pos=FLAGS.random_pos,
345
+ chunked=False,
346
+ sampler_kwargs=sampler_kwargs,
347
+ **common_sampler_args)
348
+ val_sampler, val_samples, spec = make_multi_sampler(**val_args)
349
+
350
+ test_args = dict(sizes=[-1],
351
+ split='test',
352
+ batch_size=32,
353
+ multiplier=2 * mult,
354
+ randomize_pos=False,
355
+ chunked=False,
356
+ sampler_kwargs={},
357
+ **common_sampler_args)
358
+ test_sampler, test_samples, spec = make_multi_sampler(**test_args)
359
+
360
+ spec_list.append(spec)
361
+ train_samplers.append(train_sampler)
362
+ val_samplers.append(val_sampler)
363
+ val_sample_counts.append(val_samples)
364
+ test_samplers.append(test_sampler)
365
+ test_sample_counts.append(test_samples)
366
+
367
+ return (train_samplers,
368
+ val_samplers, val_sample_counts,
369
+ test_samplers, test_sample_counts,
370
+ spec_list)
371
+
372
+
373
+ def main(unused_argv):
374
+ if FLAGS.hint_mode == 'encoded_decoded':
375
+ encode_hints = True
376
+ decode_hints = True
377
+ elif FLAGS.hint_mode == 'decoded_only':
378
+ encode_hints = False
379
+ decode_hints = True
380
+ elif FLAGS.hint_mode == 'none':
381
+ encode_hints = False
382
+ decode_hints = False
383
+ else:
384
+ raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
385
+
386
+ train_lengths = [int(x) for x in FLAGS.train_lengths]
387
+
388
+ rng = np.random.RandomState(FLAGS.seed)
389
+ rng_key = jax.random.PRNGKey(rng.randint(2**32))
390
+
391
+ # Create samplers
392
+ (train_samplers,
393
+ val_samplers, val_sample_counts,
394
+ test_samplers, test_sample_counts,
395
+ spec_list) = create_samplers(rng, train_lengths)
396
+
397
+ processor_factory = clrs.get_processor_factory(
398
+ FLAGS.processor_type,
399
+ use_ln=FLAGS.use_ln,
400
+ nb_triplet_fts=FLAGS.nb_triplet_fts,
401
+ nb_heads=FLAGS.nb_heads
402
+ )
403
+ model_params = dict(
404
+ processor_factory=processor_factory,
405
+ hidden_dim=FLAGS.hidden_size,
406
+ encode_hints=encode_hints,
407
+ decode_hints=decode_hints,
408
+ encoder_init=FLAGS.encoder_init,
409
+ use_lstm=FLAGS.use_lstm,
410
+ learning_rate=FLAGS.learning_rate,
411
+ grad_clip_max_norm=FLAGS.grad_clip_max_norm,
412
+ checkpoint_path=FLAGS.checkpoint_path,
413
+ freeze_processor=FLAGS.freeze_processor,
414
+ dropout_prob=FLAGS.dropout_prob,
415
+ hint_teacher_forcing=FLAGS.hint_teacher_forcing,
416
+ hint_repred_mode=FLAGS.hint_repred_mode,
417
+ nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
418
+ )
419
+
420
+ # save spec_list and model_params; do not change or delete!!
421
+ if not os.path.exists(FLAGS.checkpoint_path):
422
+ os.makedirs(FLAGS.checkpoint_path)
423
+
424
+ with open(os.path.join(FLAGS.checkpoint_path, 'spec_list.pkl'), 'wb') as f:
425
+ pickle.dump(spec_list, f)
426
+ model_params_save = copy.deepcopy(model_params)
427
+ model_params_save["processor_factory"] = (FLAGS.processor_type, FLAGS.use_ln, FLAGS.nb_triplet_fts, FLAGS.nb_heads)
428
+ with open(os.path.join(FLAGS.checkpoint_path, 'model_params.pkl'), 'wb') as f:
429
+ pickle.dump(model_params_save, f)
430
+
431
+ eval_model = BaselineModel(
432
+ spec=spec_list,
433
+ dummy_trajectory=[next(t) for t in val_samplers],
434
+ **model_params
435
+ )
436
+ if FLAGS.chunked_training:
437
+ train_model = BaselineModelChunked(
438
+ spec=spec_list,
439
+ dummy_trajectory=[next(t) for t in train_samplers],
440
+ **model_params
441
+ )
442
+ else:
443
+ train_model = eval_model
444
+
445
+ # Training loop.
446
+ best_score = -1.0
447
+ current_train_items = [0] * len(FLAGS.algorithms)
448
+ step = 0
449
+ next_eval = 0
450
+ # Make sure scores improve on first step, but not overcome best score
451
+ # until all algos have had at least one evaluation.
452
+ val_scores = [-99999.9] * len(FLAGS.algorithms)
453
+ length_idx = 0
454
+
455
+ while step < FLAGS.train_steps:
456
+ feedback_list = [next(t) for t in train_samplers]
457
+
458
+ # Initialize model.
459
+ if step == 0:
460
+ all_features = [f.features for f in feedback_list]
461
+ if FLAGS.chunked_training:
462
+ # We need to initialize the model with samples of all lengths for
463
+ # all algorithms. Also, we need to make sure that the order of these
464
+ # sample sizes is the same as the order of the actual training sizes.
465
+ all_length_features = [all_features] + [
466
+ [next(t).features for t in train_samplers]
467
+ for _ in range(len(train_lengths))]
468
+ train_model.init(all_length_features[:-1], FLAGS.seed + 1)
469
+ else:
470
+ train_model.init(all_features, FLAGS.seed + 1)
471
+
472
+ # Training step.
473
+ # enable logging now that we have initialized the model
474
+ logging.set_verbosity(logging.INFO)
475
+ for algo_idx in range(len(train_samplers)):
476
+ feedback = feedback_list[algo_idx]
477
+ rng_key, new_rng_key = jax.random.split(rng_key)
478
+ if FLAGS.chunked_training:
479
+ # In chunked training, we must indicate which training length we are
480
+ # using, so the model uses the correct state.
481
+ length_and_algo_idx = (length_idx, algo_idx)
482
+ else:
483
+ # In non-chunked training, all training lengths can be treated equally,
484
+ # since there is no state to maintain between batches.
485
+ length_and_algo_idx = algo_idx
486
+ cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
487
+ rng_key = new_rng_key
488
+
489
+ if FLAGS.chunked_training:
490
+ examples_in_chunk = np.sum(feedback.features.is_last).item()
491
+ else:
492
+ examples_in_chunk = len(feedback.features.lengths)
493
+ current_train_items[algo_idx] += examples_in_chunk
494
+ if step % FLAGS.log_every == 0:
495
+ logging.info('Algo %s step %i current loss %f, current_train_items %i.',
496
+ FLAGS.algorithms[algo_idx], step,
497
+ cur_loss, current_train_items[algo_idx])
498
+
499
+ # Periodically evaluate model
500
+ if step >= next_eval:
501
+ eval_model.params = train_model.params
502
+ for algo_idx in range(len(train_samplers)):
503
+ common_extras = {'examples_seen': current_train_items[algo_idx],
504
+ 'step': step,
505
+ 'algorithm': FLAGS.algorithms[algo_idx]}
506
+
507
+ # Validation info.
508
+ new_rng_key, rng_key = jax.random.split(rng_key)
509
+ val_stats = collect_and_eval(
510
+ val_samplers[algo_idx],
511
+ functools.partial(eval_model.predict, algorithm_index=algo_idx),
512
+ val_sample_counts[algo_idx],
513
+ new_rng_key,
514
+ extras=common_extras)
515
+ logging.info('(val) algo %s step %d: %s',
516
+ FLAGS.algorithms[algo_idx], step, val_stats)
517
+ val_scores[algo_idx] = val_stats['score']
518
+
519
+ next_eval += FLAGS.eval_every
520
+
521
+ # If best total score, update best checkpoint.
522
+ # Also save a best checkpoint on the first step.
523
+ msg = (f'best avg val score was '
524
+ f'{best_score/len(FLAGS.algorithms):.3f}, '
525
+ f'current avg val score is {np.mean(val_scores):.3f}, '
526
+ f'val scores are: ')
527
+ msg += ', '.join(
528
+ ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
529
+ if (sum(val_scores) > best_score) or step == 0:
530
+ best_score = sum(val_scores)
531
+ logging.info('Checkpointing best model, %s', msg)
532
+ train_model.save_model('best.pkl')
533
+ else:
534
+ logging.info('Not saving new best model, %s', msg)
535
+
536
+ step += 1
537
+ length_idx = (length_idx + 1) % len(train_lengths)
538
+
539
+ logging.info('Restoring best model from checkpoint...')
540
+ eval_model.restore_model('best.pkl', only_load_processor=False)
541
+
542
+ for algo_idx in range(len(train_samplers)):
543
+ common_extras = {'examples_seen': current_train_items[algo_idx],
544
+ 'step': step,
545
+ 'algorithm': FLAGS.algorithms[algo_idx]}
546
+
547
+ new_rng_key, rng_key = jax.random.split(rng_key)
548
+ test_stats = collect_and_eval(
549
+ test_samplers[algo_idx],
550
+ functools.partial(eval_model.predict, algorithm_index=algo_idx),
551
+ test_sample_counts[algo_idx],
552
+ new_rng_key,
553
+ extras=common_extras)
554
+ logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats)
555
+
556
+ logging.info('Done!')
557
+
558
+
559
+ if __name__ == '__main__':
560
+ app.run(main)
benchmarks/CLRS/scripts/eval.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Run training of one or more algorithmic tasks from CLRS."""
17
+
18
+ import os
19
+ # disable logging until training starts
20
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
21
+
22
+ import functools
23
+ import os
24
+ import shutil
25
+ from typing import Any, Dict, List
26
+
27
+ from absl import app
28
+ from absl import flags
29
+ from absl import logging
30
+ # disable logging until training starts
31
+ logging.set_verbosity(logging.ERROR)
32
+
33
+ import clrs
34
+ import jax
35
+ import numpy as np
36
+ import requests
37
+ import tensorflow as tf
38
+ import sys
39
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../env")))
40
+ from baselines import BaselineModel, BaselineModelChunked
41
+ import pickle
42
+
43
+ flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.')
44
+ flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'],
45
+ 'Which training sizes to use. A size of -1 means '
46
+ 'use the benchmark dataset.')
47
+ flags.DEFINE_integer('length_needle', -8,
48
+ 'Length of needle for training and validation '
49
+ '(not testing) in string matching algorithms. '
50
+ 'A negative value randomizes the length for each sample '
51
+ 'between 1 and the opposite of the value. '
52
+ 'A value of 0 means use always 1/4 of the length of '
53
+ 'the haystack (the default sampler behavior).')
54
+ flags.DEFINE_integer('seed', 42, 'Random seed to set')
55
+
56
+ flags.DEFINE_boolean('random_pos', True,
57
+ 'Randomize the pos input common to all algos.')
58
+ flags.DEFINE_boolean('enforce_permutations', True,
59
+ 'Whether to enforce permutation-type node pointers.')
60
+ flags.DEFINE_boolean('enforce_pred_as_input', True,
61
+ 'Whether to change pred_h hints into pred inputs.')
62
+ flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.')
63
+ flags.DEFINE_boolean('chunked_training', False,
64
+ 'Whether to use chunking for training.')
65
+ flags.DEFINE_integer('chunk_length', 16,
66
+ 'Time chunk length used for training (if '
67
+ '`chunked_training` is True.')
68
+ flags.DEFINE_integer('train_steps', 1000, 'Number of training iterations.')
69
+ flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).')
70
+ flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).')
71
+ flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).')
72
+
73
+ flags.DEFINE_integer('hidden_size', 128,
74
+ 'Number of hidden units of the model.')
75
+ flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')
76
+ flags.DEFINE_integer('nb_msg_passing_steps', 1,
77
+ 'Number of message passing steps to run per hint.')
78
+ flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
79
+ flags.DEFINE_float('grad_clip_max_norm', 1.0,
80
+ 'Gradient clipping by norm. 0.0 disables grad clipping')
81
+ flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
82
+ flags.DEFINE_float('hint_teacher_forcing', 0.0,
83
+ 'Probability that ground-truth teacher hints are encoded '
84
+ 'during training instead of predicted hints. Only '
85
+ 'pertinent in encoded_decoded modes.')
86
+ flags.DEFINE_enum('hint_mode', 'encoded_decoded',
87
+ ['encoded_decoded', 'decoded_only', 'none'],
88
+ 'How should hints be used? Note, each mode defines a '
89
+ 'separate task, with various difficulties. `encoded_decoded` '
90
+ 'requires the model to explicitly materialise hint sequences '
91
+ 'and therefore is hardest, but also most aligned to the '
92
+ 'underlying algorithmic rule. Hence, `encoded_decoded` '
93
+ 'should be treated as the default mode for our benchmark. '
94
+ 'In `decoded_only`, hints are only used for defining '
95
+ 'reconstruction losses. Often, this will perform well, but '
96
+ 'note that we currently do not make any efforts to '
97
+ 'counterbalance the various hint losses. Hence, for certain '
98
+ 'tasks, the best performance will now be achievable with no '
99
+ 'hint usage at all (`none`).')
100
+ flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
101
+ 'How to process predicted hints when fed back as inputs.'
102
+ 'In soft mode, we use softmaxes for categoricals, pointers '
103
+ 'and mask_one, and sigmoids for masks. '
104
+ 'In hard mode, we use argmax instead of softmax, and hard '
105
+ 'thresholding of masks. '
106
+ 'In hard_on_eval mode, soft mode is '
107
+ 'used for training and hard mode is used for evaluation.')
108
+ flags.DEFINE_boolean('use_ln', True,
109
+ 'Whether to use layer normalisation in the processor.')
110
+ flags.DEFINE_boolean('use_lstm', False,
111
+ 'Whether to insert an LSTM after message passing.')
112
+ flags.DEFINE_integer('nb_triplet_fts', 8,
113
+ 'How many triplet features to compute?')
114
+
115
+ flags.DEFINE_enum('encoder_init', 'xavier_on_scalars',
116
+ ['default', 'xavier_on_scalars'],
117
+ 'Initialiser to use for the encoders.')
118
+ flags.DEFINE_enum('processor_type', 'triplet_gmpnn',
119
+ ['deepsets', 'mpnn', 'pgn', 'pgn_mask',
120
+ 'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask',
121
+ 'gat', 'gatv2', 'gat_full', 'gatv2_full',
122
+ 'gpgn', 'gpgn_mask', 'gmpnn',
123
+ 'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'],
124
+ 'Processor type to use as the network P.')
125
+
126
+ flags.DEFINE_string('checkpoint_path', '../env/checkpoints',
127
+ 'Path in which checkpoints are saved.')
128
+ flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
129
+ 'Path in which dataset is stored.')
130
+ flags.DEFINE_boolean('freeze_processor', False,
131
+ 'Whether to freeze the processor of the model.')
132
+
133
+ FLAGS = flags.FLAGS
134
+
135
+
136
+ PRED_AS_INPUT_ALGOS = [
137
+ 'binary_search',
138
+ 'minimum',
139
+ 'find_maximum_subarray',
140
+ 'find_maximum_subarray_kadane',
141
+ 'matrix_chain_order',
142
+ 'lcs_length',
143
+ 'optimal_bst',
144
+ 'activity_selector',
145
+ 'task_scheduling',
146
+ 'naive_string_matcher',
147
+ 'kmp_matcher',
148
+ 'jarvis_march']
149
+
150
+
151
+ def unpack(v):
152
+ try:
153
+ return v.item() # DeviceArray
154
+ except (AttributeError, ValueError):
155
+ return v
156
+
157
+
158
+ def _iterate_sampler(sampler, batch_size):
159
+ while True:
160
+ yield sampler.next(batch_size)
161
+
162
+
163
+ def _maybe_download_dataset(dataset_path):
164
+ """Download CLRS30 dataset if needed."""
165
+ dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder())
166
+ if os.path.isdir(dataset_folder):
167
+ logging.info('Dataset found at %s. Skipping download.', dataset_folder)
168
+ return dataset_folder
169
+ logging.info('Dataset not found in %s. Downloading...', dataset_folder)
170
+
171
+ clrs_url = clrs.get_dataset_gcp_url()
172
+ request = requests.get(clrs_url, allow_redirects=True)
173
+ clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url))
174
+ os.makedirs(dataset_folder)
175
+ open(clrs_file, 'wb').write(request.content)
176
+ shutil.unpack_archive(clrs_file, extract_dir=dataset_folder)
177
+ os.remove(clrs_file)
178
+ return dataset_folder
179
+
180
+
181
+ def make_sampler(length: int,
182
+ rng: Any,
183
+ algorithm: str,
184
+ split: str,
185
+ batch_size: int,
186
+ multiplier: int,
187
+ randomize_pos: bool,
188
+ enforce_pred_as_input: bool,
189
+ enforce_permutations: bool,
190
+ chunked: bool,
191
+ chunk_length: int,
192
+ sampler_kwargs: Dict[str, Any]):
193
+ """Create a sampler with given options.
194
+
195
+ Args:
196
+ length: Size of samples (i.e., number of nodes in the graph).
197
+ A length of -1 will mean that the benchmark
198
+ dataset (for the given split) is used. Positive sizes will instantiate
199
+ samplers of the corresponding size.
200
+ rng: Numpy random state.
201
+ algorithm: The name of the algorithm to sample from.
202
+ split: 'train', 'val' or 'test'.
203
+ batch_size: Samples per batch.
204
+ multiplier: Integer multiplier for the number of samples in the dataset,
205
+ only used for positive sizes. Negative multiplier means infinite samples.
206
+ randomize_pos: Whether to randomize the `pos` input.
207
+ enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs.
208
+ enforce_permutations: Whether to enforce permutation pointers.
209
+ chunked: Whether to chunk the dataset.
210
+ chunk_length: Unroll length of chunks, if `chunked` is True.
211
+ sampler_kwargs: Extra args passed to the sampler.
212
+ Returns:
213
+ A sampler (iterator), the number of samples in the iterator (negative
214
+ if infinite samples), and the spec.
215
+ """
216
+ if length < 0: # load from file
217
+ dataset_folder = _maybe_download_dataset(FLAGS.dataset_path)
218
+ sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder,
219
+ algorithm=algorithm,
220
+ batch_size=batch_size,
221
+ split=split)
222
+ sampler = sampler.as_numpy_iterator()
223
+ else:
224
+ num_samples = clrs.CLRS30[split]['num_samples'] * multiplier
225
+ sampler, spec = clrs.build_sampler(
226
+ algorithm,
227
+ seed=rng.randint(2**32),
228
+ num_samples=num_samples,
229
+ length=length,
230
+ **sampler_kwargs,
231
+ )
232
+ sampler = _iterate_sampler(sampler, batch_size)
233
+
234
+ if randomize_pos:
235
+ sampler = clrs.process_random_pos(sampler, rng)
236
+ if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS:
237
+ spec, sampler = clrs.process_pred_as_input(spec, sampler)
238
+ spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations)
239
+ if chunked:
240
+ sampler = clrs.chunkify(sampler, chunk_length)
241
+ return sampler, num_samples, spec
242
+
243
+
244
+ def make_multi_sampler(sizes, rng, **kwargs):
245
+ """Create a sampler with cycling sample sizes."""
246
+ ss = []
247
+ tot_samples = 0
248
+ for length in sizes:
249
+ sampler, num_samples, spec = make_sampler(length, rng, **kwargs)
250
+ ss.append(sampler)
251
+ tot_samples += num_samples
252
+
253
+ def cycle_samplers():
254
+ while True:
255
+ for s in ss:
256
+ yield next(s)
257
+ return cycle_samplers(), tot_samples, spec
258
+
259
+
260
+ def _concat(dps, axis):
261
+ return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)
262
+
263
+
264
+ def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
265
+ """Collect batches of output and hint preds and evaluate them."""
266
+ processed_samples = 0
267
+ preds = []
268
+ outputs = []
269
+ while processed_samples < sample_count:
270
+ feedback = next(sampler)
271
+ batch_size = feedback.outputs[0].data.shape[0]
272
+ outputs.append(feedback.outputs)
273
+ new_rng_key, rng_key = jax.random.split(rng_key)
274
+ cur_preds, _ = predict_fn(new_rng_key, feedback.features)
275
+ preds.append(cur_preds)
276
+ processed_samples += batch_size
277
+ outputs = _concat(outputs, axis=0)
278
+ preds = _concat(preds, axis=0)
279
+ out = clrs.evaluate(outputs, preds)
280
+ if extras:
281
+ out.update(extras)
282
+ return {k: unpack(v) for k, v in out.items()}
283
+
284
+
285
+ def create_samplers(rng, train_lengths: List[int]):
286
+ """Create all the samplers."""
287
+ train_samplers = []
288
+ val_samplers = []
289
+ val_sample_counts = []
290
+ test_samplers = []
291
+ test_sample_counts = []
292
+ spec_list = []
293
+
294
+ for algo_idx, algorithm in enumerate(FLAGS.algorithms):
295
+ # Make full dataset pipeline run on CPU (including prefetching).
296
+ with tf.device('/cpu:0'):
297
+
298
+ if algorithm in ['naive_string_matcher', 'kmp_matcher']:
299
+ # Fixed haystack + needle; variability will be in needle
300
+ # Still, for chunked training, we maintain as many samplers
301
+ # as train lengths, since, for each length there is a separate state,
302
+ # and we must keep the 1:1 relationship between states and samplers.
303
+ max_length = max(train_lengths)
304
+ if max_length > 0: # if < 0, we are using the benchmark data
305
+ max_length = (max_length * 5) // 4
306
+ train_lengths = [max_length]
307
+ if FLAGS.chunked_training:
308
+ train_lengths = train_lengths * len(train_lengths)
309
+
310
+ logging.info('Creating samplers for algo %s', algorithm)
311
+
312
+ p = tuple([0.1 + 0.1 * i for i in range(9)])
313
+ if p and algorithm in ['articulation_points', 'bridges',
314
+ 'mst_kruskal', 'bipartite_matching']:
315
+ # Choose a lower connection probability for the above algorithms,
316
+ # otherwise trajectories are very long
317
+ p = tuple(np.array(p) / 2)
318
+ length_needle = FLAGS.length_needle
319
+ sampler_kwargs = dict(p=p, length_needle=length_needle)
320
+ if length_needle == 0:
321
+ sampler_kwargs.pop('length_needle')
322
+
323
+ common_sampler_args = dict(
324
+ algorithm=FLAGS.algorithms[algo_idx],
325
+ rng=rng,
326
+ enforce_pred_as_input=FLAGS.enforce_pred_as_input,
327
+ enforce_permutations=FLAGS.enforce_permutations,
328
+ chunk_length=FLAGS.chunk_length,
329
+ )
330
+
331
+ train_args = dict(sizes=train_lengths,
332
+ split='train',
333
+ batch_size=FLAGS.batch_size,
334
+ multiplier=-1,
335
+ randomize_pos=FLAGS.random_pos,
336
+ chunked=FLAGS.chunked_training,
337
+ sampler_kwargs=sampler_kwargs,
338
+ **common_sampler_args)
339
+ train_sampler, _, spec = make_multi_sampler(**train_args)
340
+
341
+ mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
342
+ val_args = dict(sizes=[np.amax(train_lengths)],
343
+ split='val',
344
+ batch_size=32,
345
+ multiplier=2 * mult,
346
+ randomize_pos=FLAGS.random_pos,
347
+ chunked=False,
348
+ sampler_kwargs=sampler_kwargs,
349
+ **common_sampler_args)
350
+ val_sampler, val_samples, spec = make_multi_sampler(**val_args)
351
+
352
+ test_args = dict(sizes=[-1],
353
+ split='test',
354
+ batch_size=32,
355
+ multiplier=2 * mult,
356
+ randomize_pos=False,
357
+ chunked=False,
358
+ sampler_kwargs={},
359
+ **common_sampler_args)
360
+ test_sampler, test_samples, spec = make_multi_sampler(**test_args)
361
+
362
+ spec_list.append(spec)
363
+ train_samplers.append(train_sampler)
364
+ val_samplers.append(val_sampler)
365
+ val_sample_counts.append(val_samples)
366
+ test_samplers.append(test_sampler)
367
+ test_sample_counts.append(test_samples)
368
+
369
+ return (train_samplers,
370
+ val_samplers, val_sample_counts,
371
+ test_samplers, test_sample_counts,
372
+ spec_list)
373
+
374
+
375
+ def get_score(submission_folder):
376
+ FLAGS(["eval.py"])
377
+ if FLAGS.hint_mode == 'encoded_decoded':
378
+ encode_hints = True
379
+ decode_hints = True
380
+ elif FLAGS.hint_mode == 'decoded_only':
381
+ encode_hints = False
382
+ decode_hints = True
383
+ elif FLAGS.hint_mode == 'none':
384
+ encode_hints = False
385
+ decode_hints = False
386
+ else:
387
+ raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
388
+
389
+ train_lengths = [int(x) for x in FLAGS.train_lengths]
390
+
391
+ rng = np.random.RandomState(FLAGS.seed)
392
+ rng_key = jax.random.PRNGKey(rng.randint(2**32))
393
+
394
+ checkpoint_path = os.path.join(submission_folder, 'checkpoints')
395
+
396
+ spec_list = pickle.load(open(os.path.join(checkpoint_path, 'spec_list.pkl'), 'rb'))
397
+
398
+ # Create samplers
399
+ (train_samplers,
400
+ val_samplers, val_sample_counts,
401
+ test_samplers, test_sample_counts,
402
+ spec_list) = create_samplers(rng, train_lengths)
403
+
404
+ # load spec_list
405
+ model_params = pickle.load(open(os.path.join(checkpoint_path, 'model_params.pkl'), 'rb'))
406
+ processor_type, use_ln, nb_triplet_fts, nb_heads = model_params["processor_factory"]
407
+ model_params["processor_factory"] = clrs.get_processor_factory(
408
+ processor_type,
409
+ use_ln=use_ln,
410
+ nb_triplet_fts=nb_triplet_fts,
411
+ nb_heads=nb_heads
412
+ )
413
+ model_params["checkpoint_path"]=checkpoint_path
414
+
415
+ eval_model = BaselineModel(
416
+ spec=spec_list,
417
+ dummy_trajectory=[next(t) for t in val_samplers],
418
+ **model_params
419
+ )
420
+
421
+ feedback_list = [next(t) for t in train_samplers]
422
+
423
+ # Initialize model.
424
+ all_features = [f.features for f in feedback_list]
425
+ eval_model.init(all_features, FLAGS.seed + 1)
426
+
427
+
428
+ logging.set_verbosity(logging.INFO)
429
+
430
+ logging.info('Restoring best model from checkpoint...')
431
+ eval_model.restore_model('best.pkl', only_load_processor=False)
432
+
433
+ for algo_idx in range(len(train_samplers)):
434
+ new_rng_key, rng_key = jax.random.split(rng_key)
435
+ val_stats = collect_and_eval(
436
+ val_samplers[algo_idx],
437
+ functools.partial(eval_model.predict, algorithm_index=algo_idx),
438
+ val_sample_counts[algo_idx],
439
+ new_rng_key,
440
+ extras = {})
441
+ # logging.info('(val) algo %s: %s', FLAGS.algorithms[algo_idx], val_stats)
442
+
443
+ new_rng_key, rng_key = jax.random.split(rng_key)
444
+ test_stats = collect_and_eval(
445
+ test_samplers[algo_idx],
446
+ functools.partial(eval_model.predict, algorithm_index=algo_idx),
447
+ test_sample_counts[algo_idx],
448
+ new_rng_key,
449
+ extras = {})
450
+ # logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats)
451
+ return test_stats['score']
452
+
453
+ if __name__ == '__main__':
454
+ app.run(get_score)
benchmarks/CLRS/scripts/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py>=0.13.0
2
+ attrs>=21.4.0
3
+ chex>=0.0.8
4
+ dm-haiku>=0.0.4
5
+ jax>=0.2.18
6
+ jaxlib>=0.1.69
7
+ numpy>=1.21.1
8
+ opt-einsum>=3.3.0
9
+ optax>=0.0.9
10
+ six>=1.16.0
11
+ tensorflow>=2.9.0
12
+ tfds-nightly==4.5.2.dev202204190046
13
+ toolz>=0.11.1
benchmarks/CLRS/scripts/research_problem.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Improve the baseline model performance on the task floyd_warshall in The CLRS Algorithmic Reasoning Benchmark. The dataset description is available in data_description.txt, and the baseline model architecture description is available in baseline_model_description.txt. To run the baseline model, execute train.py. Note that the core message passing function of the baseline model is implemented in function get_triplet_msgs (L301 in processors.py). You can modify this function to improve the baseline model performance. You can also modify other parts of the baseline model and training script to improve its performance, as long as the final model is still loadable by calling BaselineModel class as in L415 in train.py.
2
+
3
+ When you submit your final answer, you will be evaluated on the performance of the checkpoint checkpoints/best.pkl saved by train.py. Note that the final model must still be loadable by calling BaselineModel class as in L415 in train.py and with the saved spec_list.pkl and model_params.pkl.
benchmarks/CLRS/scripts/source_code.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://github.com/deepmind/clrs/blob/master/clrs/examples/run.py
benchmarks/amp-parkinsons-disease-progression-prediction/env/data_description.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dataset Description
2
+ The goal of this competition is to predict the course of Parkinson's disease (PD) using protein abundance data. The complete set of proteins involved in PD remains an open research question and any proteins that have predictive value are likely worth investigating further. The core of the dataset consists of protein abundance values derived from mass spectrometry readings of cerebrospinal fluid (CSF) samples gathered from several hundred patients. Each patient contributed several samples over the course of multiple years while they also took assessments of PD severity.
3
+
4
+ This is a time-series code competition: you will receive test set data and make predictions with a time-series API. See the evaluation_details.txt for details.
5
+
6
+ Files
7
+ train_peptides.csv Mass spectrometry data at the peptide level. Peptides are the component subunits of proteins.
8
+
9
+ visit_id - ID code for the visit.
10
+ visit_month - The month of the visit, relative to the first visit by the patient.
11
+ patient_id - An ID code for the patient.
12
+ UniProt - The UniProt ID code for the associated protein. There are often several peptides per protein.
13
+ Peptide - The sequence of amino acids included in the peptide. See this table for the relevant codes. Some rare annotations may not be included in the table. The test set may include peptides not found in the train set.
14
+ PeptideAbundance - The frequency of the amino acid in the sample.
15
+ train_proteins.csv Protein expression frequencies aggregated from the peptide level data.
16
+
17
+ visit_id - ID code for the visit.
18
+ visit_month - The month of the visit, relative to the first visit by the patient.
19
+ patient_id - An ID code for the patient.
20
+ UniProt - The UniProt ID code for the associated protein. There are often several peptides per protein. The test set may include proteins not found in the train set.
21
+ NPX - Normalized protein expression. The frequency of the protein's occurrence in the sample. May not have a 1:1 relationship with the component peptides as some proteins contain repeated copies of a given peptide.
22
+ train_clinical_data.csv
23
+
24
+ visit_id - ID code for the visit.
25
+ visit_month - The month of the visit, relative to the first visit by the patient.
26
+ patient_id - An ID code for the patient.
27
+ updrs_[1-4] - The patient's score for part N of the Unified Parkinson's Disease Rating Scale. Higher numbers indicate more severe symptoms. Each sub-section covers a distinct category of symptoms, such as mood and behavior for Part 1 and motor functions for Part 3.
28
+ upd23b_clinical_state_on_medication - Whether or not the patient was taking medication such as Levodopa during the UPDRS assessment. Expected to mainly affect the scores for Part 3 (motor function). These medications wear off fairly quickly (on the order of one day) so it's common for patients to take the motor function exam twice in a single month, both with and without medication.
29
+ supplemental_clinical_data.csv Clinical records without any associated CSF samples. This data is intended to provide additional context about the typical progression of Parkinsons. Uses the same columns as train_clinical_data.csv.
30
+
31
+ example_test_files/ Data intended to illustrate how the API functions. Includes the same columns delivered by the API (ie no updrs columns).
32
+
33
+ public_timeseries_testing_util.py A file for running custom API tests.
benchmarks/amp-parkinsons-disease-progression-prediction/env/evaluation_details.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Submissions are evaluated on SMAPE between forecasts and actual values. We define SMAPE = 0 when the actual and predicted values are both 0.
2
+
3
+ For each patient visit where a protein/peptide sample was taken you will need to estimate both their UPDRS scores for that visit and predict their scores for any potential visits 6, 12, and 24 months later. Predictions for any visits that didn't ultimately take place are ignored.
4
+
5
+ You must submit to this competition using the provided python time-series API, which ensures that models do not peek forward in time. To use the API, follow this template in Kaggle Notebooks:
6
+
7
+ from public_timeseries_testing_util import MockApi
8
+ env = MockApi.make_env() # initialize the environment
9
+ iter_test = env.iter_test() # an iterator which loops over the test files
10
+ for (test, test_peptides, test_proteins, sample_submission) in iter_test:
11
+ sample_prediction_df['rating'] = np.arange(len(sample_prediction)) # make your predictions here
12
+ env.predict(sample_prediction_df) # register your predictions
benchmarks/amp-parkinsons-disease-progression-prediction/env/public_timeseries_testing_util.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ An unlocked version of the timeseries API intended for testing alternate inputs.
3
+ Mirrors the production timeseries API in the crucial respects, but won't be as fast.
4
+
5
+ ONLY works afer the first three variables in MockAPI.__init__ are populated.
6
+ '''
7
+
8
+ from typing import Sequence, Tuple
9
+
10
+ import pandas as pd
11
+
12
+
13
+ class MockApi:
14
+ def __init__(self):
15
+ '''
16
+ YOU MUST UPDATE THE FIRST THREE LINES of this method.
17
+ They've been intentionally left in an invalid state.
18
+
19
+ Variables to set:
20
+ input_paths: a list of two or more paths to the csv files to be served
21
+ group_id_column: the column that identifies which groups of rows the API should serve.
22
+ A call to iter_test serves all rows of all dataframes with the current group ID value.
23
+ export_group_id_column: if true, the dataframes iter_test serves will include the group_id_column values.
24
+ '''
25
+ self.input_paths: Sequence[str] = [
26
+ 'example_test_files/test.csv',
27
+ 'example_test_files/test_peptides.csv',
28
+ 'example_test_files/test_proteins.csv',
29
+ 'example_test_files/sample_submission.csv',
30
+ ]
31
+ self.group_id_column: str = 'visit_month'
32
+ self.export_group_id_column: bool = True
33
+ # iter_test is only designed to support at least two dataframes, such as test and sample_submission
34
+ assert len(self.input_paths) >= 2
35
+
36
+ self._status = 'initialized'
37
+ self.predictions = []
38
+
39
+ def iter_test(self) -> Tuple[pd.DataFrame]:
40
+ '''
41
+ Loads all of the dataframes specified in self.input_paths,
42
+ then yields all rows in those dataframes that equal the current self.group_id_column value.
43
+ '''
44
+ if self._status != 'initialized':
45
+
46
+ raise Exception('WARNING: the real API can only iterate over `iter_test()` once.')
47
+
48
+ dataframes = []
49
+ for pth in self.input_paths:
50
+ dataframes.append(pd.read_csv(pth, low_memory=False))
51
+ group_order = dataframes[0][self.group_id_column].drop_duplicates().tolist()
52
+ dataframes = [df.set_index(self.group_id_column) for df in dataframes]
53
+
54
+ for group_id in group_order:
55
+ self._status = 'prediction_needed'
56
+ current_data = []
57
+ for df in dataframes:
58
+ try:
59
+ cur_df = df.loc[group_id].copy()
60
+ # returning single line dataframes from df.loc requires special handling
61
+ if not isinstance(cur_df, pd.DataFrame):
62
+ cur_df = pd.DataFrame({a: b for a, b in zip(cur_df.index.values, cur_df.values)}, index=[group_id])
63
+ cur_df = cur_df.index.rename(self.group_id_column)
64
+ except KeyError:
65
+ cur_df = df.loc[[]].copy()
66
+ cur_df = cur_df.reset_index(drop=not(self.export_group_id_column))
67
+ current_data.append(cur_df)
68
+ yield tuple(current_data)
69
+
70
+ while self._status != 'prediction_received':
71
+ print('You must call `predict()` successfully before you can continue with `iter_test()`', flush=True)
72
+ yield None
73
+
74
+ with open('submission.csv', 'w') as f_open:
75
+ pd.concat(self.predictions).to_csv(f_open, index=False)
76
+ self._status = 'finished'
77
+
78
+ def predict(self, user_predictions: pd.DataFrame):
79
+ '''
80
+ Accepts and stores the user's predictions and unlocks iter_test once that is done
81
+ '''
82
+ if self._status == 'finished':
83
+ raise Exception('You have already made predictions for the full test set.')
84
+ if self._status != 'prediction_needed':
85
+ raise Exception('You must get the next test sample from `iter_test()` first.')
86
+ if not isinstance(user_predictions, pd.DataFrame):
87
+ raise Exception('You must provide a DataFrame.')
88
+
89
+ self.predictions.append(user_predictions)
90
+ self._status = 'prediction_received'
91
+
92
+
93
+ def make_env():
94
+ return MockApi()
benchmarks/amp-parkinsons-disease-progression-prediction/env/train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ from public_timeseries_testing_util import MockApi
6
+ from sklearn.metrics import make_scorer
7
+ from sklearn.model_selection import KFold, GroupKFold, cross_val_score
8
+ from sklearn.utils import check_consistent_length
9
+
10
+ # Define the metric
11
+ def smapep1(y_true, y_pred):
12
+ """SMAPE of y+1, a nonnegative float, smaller is better
13
+
14
+ Parameters: y_true, y_pred: array-like
15
+
16
+ Returns 100 for 100 % error.
17
+ y_true may have missing values.
18
+ """
19
+ check_consistent_length(y_true, y_pred)
20
+ y_true = np.array(y_true, copy=False).ravel()
21
+ y_pred = np.array(y_pred, copy=False).ravel()
22
+ y_true, y_pred = y_true[np.isfinite(y_true)], y_pred[np.isfinite(y_true)]
23
+ if (y_true < 0).any(): raise ValueError('y_true < 0')
24
+ if (y_pred < 0).any(): raise ValueError('y_pred < 0')
25
+ denominator = (y_true + y_pred) / 2 + 1
26
+ ape = np.abs(y_pred - y_true) / denominator
27
+ return np.average(ape) * 100
28
+
29
+ # The scorer returns nonpositive values so that greater is better.
30
+ # It will be used as an argument to cross_val_score
31
+ smapep1_scorer = make_scorer(smapep1, greater_is_better=False)
32
+
33
+ def get_predictions(my_train, model):
34
+
35
+ # Forecast
36
+ my_train = my_train.fillna(0)
37
+ result = pd.DataFrame(columns = ['prediction_id', 'rating'])
38
+ final = []
39
+ target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
40
+
41
+ for u in target:
42
+
43
+ # Predict
44
+ X = my_train["visit_month"]
45
+
46
+ predict = model[u].predict(X.values.reshape(-1, 1)).tolist()
47
+ complete_result = my_train[["visit_id",'visit_month']].values.tolist()
48
+ for index in range(len(complete_result)):
49
+ complete_result[index].extend(predict[index])
50
+ temp = pd.DataFrame(complete_result,
51
+ columns = ["visit_id",'visit_month',u +'_plus_0_months',
52
+ u +'_plus_6_months',
53
+ u +'_plus_12_months',
54
+ u +'_plus_24_months'])
55
+ temp = temp.melt( id_vars=["visit_id",'visit_month'],
56
+ value_vars=[ u +'_plus_0_months' , u +'_plus_6_months',
57
+ u +'_plus_12_months',u +"_plus_24_months"],
58
+ value_name = 'rating')
59
+ temp['prediction_id'] = temp['visit_id'] + '_' + temp['variable']
60
+
61
+ final.append(temp[['prediction_id','rating']])
62
+ final = pd.concat(final)
63
+ final = final.drop_duplicates(subset=['prediction_id', 'rating'])
64
+ return final
65
+
66
+ if __name__ == "__main__":
67
+
68
+
69
+
70
+ target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
71
+ data_proteins = pd.read_csv('train_proteins.csv')
72
+ data_clinical = pd.read_csv('train_clinical_data.csv')
73
+ data_peptides = pd.read_csv('train_peptides.csv')
74
+ data_supplemental = pd.read_csv('supplemental_clinical_data.csv')
75
+ merged_data = pd.concat([data_clinical, data_supplemental])
76
+
77
+ ## TODO: data cleaning and feature engineering
78
+ # Right now, we only use the month data and the target data
79
+ id_list = merged_data['patient_id'].unique().tolist()
80
+ data_for_train = {}
81
+ for u in target:
82
+ final = []
83
+ for id_ in id_list:
84
+ infor_of_id = merged_data[merged_data['patient_id'] == id_]
85
+ month_per_id = infor_of_id.visit_month.tolist()
86
+ for month in month_per_id:
87
+ check = [month, id_]
88
+ for plus in [0,6,12,24]:
89
+ if month + plus in month_per_id :
90
+ month_value = infor_of_id[infor_of_id.visit_month == month+plus][u].values[0]
91
+ if month_value != np.nan:
92
+ check.append(month_value)
93
+ if len(check) == 6:
94
+ final.append(check)
95
+ check = pd.DataFrame(final,columns = ['month', 'patient_id',u+'+0',u+'+6',u+'+12',u+'+24'])
96
+ data_for_train[u] = check.dropna()
97
+
98
+
99
+ ## train model
100
+ model = {}
101
+ overall_score = []
102
+ target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
103
+
104
+ for i, u in enumerate(target):
105
+
106
+ # Train data
107
+ X = data_for_train[u]['month']
108
+ y = data_for_train[u].iloc[:,2:6]
109
+ trained = RandomForestRegressor().fit(X.values.reshape(-1, 1), y)
110
+ # Save model
111
+ model[u] = trained
112
+
113
+ ## cross validation and print results
114
+ print('Cross-validation scores')
115
+
116
+ cvs = cross_val_score(RandomForestRegressor(),
117
+ X=X.values.reshape(-1, 1), y=y,
118
+ groups=data_for_train[u]['patient_id'],
119
+ scoring=smapep1_scorer,
120
+ cv=GroupKFold(n_splits=8),
121
+ error_score='raise')
122
+ print([f'updrs_{i}:'], -cvs.round(1), -cvs.mean().round(1))
123
+ overall_score.append(-cvs)
124
+ print(f'Overall cv score of the group model: {np.array(overall_score).mean():.2f}')
125
+
126
+
127
+
128
+ ## save to submission.csv file for the test set by using this following API call
129
+
130
+ env = MockApi()
131
+ iter_test = env.iter_test() # an iterator which loops over the test files
132
+
133
+ # The API will deliver four dataframes in this specific order:
134
+ for iteration, (test_clinical_data, test_peptides, test_proteins, sample_submission) in enumerate(iter_test):
135
+ # TODO - make your predictions here by modifying 'rating' sample_submission dataframe
136
+ pred = get_predictions(test_clinical_data, model).round(0)
137
+
138
+ for index in sample_submission['prediction_id']:
139
+ sample_submission.loc[sample_submission['prediction_id']==index, 'rating'] = pred[pred['prediction_id']==index]['rating'].values
140
+
141
+ env.predict(sample_submission) # register your predictions
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/eval.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import sys
3
+ import os
4
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../env"))
5
+ from importlib import reload
6
+ import train
7
+ reload(train)
8
+ import pandas as pd
9
+ from train import smapep1, check_consistent_length
10
+
11
+
12
+ def get_score(submission_folder = "../env"):
13
+ submission_path = os.path.join(submission_folder, "submission.csv")
14
+ solution = pd.read_csv(os.path.join(os.path.dirname(__file__), "answer.csv"))
15
+ submission = pd.read_csv(submission_path)
16
+
17
+ s = smapep1(solution["rating"], submission["rating"])
18
+ return s
19
+
20
+ if __name__ == "__main__":
21
+ print(get_score())
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/prepare.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import pandas as pd
3
+ import random
4
+ import os
5
+
6
+ taskname = "amp-parkinsons-disease-progression-prediction"
7
+ download_dir = "../env"
8
+
9
+ input(f"Consent to the competition at https://www.kaggle.com/competitions/{taskname}/data; Press any key after you have accepted the rules online.")
10
+
11
+ subprocess.run(["kaggle", "competitions", "download", "-c", taskname], cwd=download_dir)
12
+ subprocess.run(["unzip", "-n", f"{taskname}.zip"], cwd=download_dir)
13
+ subprocess.run(["rm", f"{taskname}.zip"], cwd=download_dir)
14
+ subprocess.run(["rm", "-r", "amp_pd_peptide"], cwd=download_dir)
15
+ subprocess.run(["rm", "-r", "amp_pd_peptide_310"], cwd=download_dir)
16
+
17
+ # ## split train to train and test in env
18
+
19
+ data_proteins = pd.read_csv(f'{download_dir}/train_proteins.csv')
20
+ data_clinical = pd.read_csv(f'{download_dir}/train_clinical_data.csv')
21
+ data_peptides = pd.read_csv(f'{download_dir}/train_peptides.csv')
22
+ data_supplemental = pd.read_csv(f'{download_dir}/supplemental_clinical_data.csv')
23
+
24
+ random.seed(42)
25
+ patient_id = data_clinical['patient_id'].unique()
26
+ test_patient_id = random.sample(patient_id.tolist(), 2)
27
+ train_patient_id = [x for x in patient_id if x not in test_patient_id]
28
+
29
+ data_proteins[data_proteins['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_proteins.csv', index=False)
30
+ data_clinical[data_clinical['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_clinical_data.csv', index=False)
31
+ data_peptides[data_peptides['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_peptides.csv', index=False)
32
+ data_supplemental[data_supplemental['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/supplemental_clinical_data.csv', index=False)
33
+
34
+ data_proteins[data_proteins['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_proteins.csv', index=False)
35
+ data_peptides[data_peptides['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_peptides.csv', index=False)
36
+ test_clinical = data_clinical[data_clinical['patient_id'].isin(test_patient_id)]
37
+
38
+
39
+ # Create test.csv
40
+ temp_list = []
41
+ for i in range(1, 5):
42
+ temp = test_clinical.copy()
43
+ temp['level_3'] = i
44
+ temp['updrs_test'] = f'updrs_{i}'
45
+ temp_list.append(temp)
46
+ mock_train = pd.concat(temp_list)
47
+ mock_train['row_id'] = (mock_train[['patient_id', 'visit_month', 'level_3']]
48
+ .apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}"), axis=1))
49
+ mock_train[['visit_id', 'patient_id', 'visit_month','row_id', 'updrs_test']].to_csv(f'{download_dir}/example_test_files/test.csv', index=False)
50
+
51
+ # Create sample_submission.csv
52
+ temp_list = []
53
+ for wait in [0, 6, 12, 24]:
54
+ temp = mock_train.copy()
55
+ temp['wait'] = wait
56
+ temp_list.append(temp)
57
+ y = pd.concat(temp_list)
58
+ y = y[y.visit_month + y.wait <= 108]
59
+ y['prediction_id'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']]
60
+ .apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}_plus_{r.wait}_months"), axis=1))
61
+
62
+ def get_rating(row):
63
+ rating = test_clinical[test_clinical["visit_id"] == f'{row.patient_id}_{int(row.visit_month) + int(row.wait) }' ][f'updrs_{row.level_3}']
64
+ if len(rating) == 0:
65
+ return None
66
+ return rating.item()
67
+
68
+ y['rating'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']].apply(get_rating, axis=1))
69
+ y = y.dropna()
70
+ y[['prediction_id', 'rating', 'visit_month']].to_csv(f'answer.csv', index=False)
71
+
72
+ y['rating'] = 0
73
+ y[['prediction_id', 'rating', 'visit_month']].to_csv(f'{download_dir}/example_test_files/sample_submission.csv', index=False)
74
+
75
+
76
+
77
+
78
+
79
+
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/read_only_files.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ example_test_files/*
2
+ ./supplemental_clinical_data.csv
3
+ ./train_clinical_data.csv
4
+ ./train_peptide.csv
5
+ ./train_protein.csv
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/research_problem.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Go through the data_description.txt file to understand the data and the machine learning task. You can summarize it in your research logs to keep track of what all you have to do.
2
+ Then fill in the provided train.py script to train a model and iterate over different models or feature selections to get a better performance (for SMAPE score the lower is better). Finally, you should submit the predictions of your best model for the test set as a submission.csv as described in the evaluation_details.txt file.
3
+ Never try to read any csv files directly. Do not forget to execute the changes you made to check for performance.
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/source_code.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ https://www.kaggle.com/code/dangkhanhle/test-model
2
+ https://www.kaggle.com/code/ambrosm/pdpp-linear-and-isotonic-groups/notebook
benchmarks/babylm/env/babyLM_for_hf.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datasets
3
+
4
+ _CITATION = """
5
+ """
6
+
7
+ _DESCRIPTION = """\
8
+ BabyLM data
9
+ """
10
+ _HOMEPAGE = "https://babylm.github.io/"
11
+ _LICENSE = "????"
12
+ _DATA_URL = "./babylm_data"
13
+
14
+
15
+ class babyLMConfig(datasets.BuilderConfig):
16
+ """BuilderConfig for babyLM."""
17
+
18
+ def __init__(self, data_url, **kwargs):
19
+ """BuilderConfig for babyLM
20
+ Args:
21
+ data_url: `string`, url to the dataset (word or raw level)
22
+ **kwargs: keyword arguments forwarded to super.
23
+ """
24
+ super().__init__(
25
+ version=datasets.Version(
26
+ "1.0.0",
27
+ ),
28
+ **kwargs,
29
+ )
30
+ self.data_url = data_url
31
+
32
+
33
+ class babyLM(datasets.GeneratorBasedBuilder):
34
+ """TODO: Short description of dataset dataset."""
35
+ DATA_SOURCES = [
36
+ 'aochildes', 'bnc_spoken', 'cbt', 'children_stories',
37
+ 'gutenberg', 'open_subtitles', 'qed', 'simple_wikipedia',
38
+ 'switchboard', 'wikipedia']
39
+ VERSION = datasets.Version("0.0.0")
40
+ BUILDER_CONFIGS = [
41
+ babyLMConfig(
42
+ name="babyLM-10M",
43
+ data_url=os.path.join(_DATA_URL, 'babylm_10M'),
44
+ description="Raw level dataset: the raw tokens before the addition of <unk> tokens. 10M tokens.",
45
+ ),
46
+ babyLMConfig(
47
+ name="babyLM-100M",
48
+ data_url=os.path.join(_DATA_URL, 'babylm_100M'),
49
+ description="Raw level dataset: the raw tokens before the addition of <unk> tokens. 100M tokens.",
50
+ ),
51
+ ]
52
+
53
+ def _info(self):
54
+ return datasets.DatasetInfo(
55
+ # This is the description that will appear on the datasets page.
56
+ description=_DESCRIPTION,
57
+ # datasets.features.FeatureConnectors
58
+ features=datasets.Features(
59
+ {
60
+ "text": datasets.Value("string")
61
+ # These are the features of your dataset like images, labels ...
62
+ }
63
+ ),
64
+ # If there's a common (input, target) tuple from the features,
65
+ # specify them here. They'll be used if as_supervised=True in
66
+ # builder.as_dataset.
67
+ supervised_keys=None,
68
+ homepage=_HOMEPAGE,
69
+ license=_LICENSE,
70
+ citation=_CITATION,
71
+ )
72
+
73
+ def _split_generators(self, dl_manager):
74
+ """Returns SplitGenerators."""
75
+ ret_list = [
76
+ datasets.SplitGenerator(
77
+ name=datasets.Split.TEST,
78
+ gen_kwargs={"data_folder": os.path.join(_DATA_URL, "babylm_test"), "split": "test"},
79
+ ),
80
+ datasets.SplitGenerator(
81
+ name=datasets.Split.VALIDATION,
82
+ gen_kwargs={"data_folder": os.path.join(_DATA_URL, "babylm_dev"), "split": "dev"},
83
+ ),
84
+ datasets.SplitGenerator(
85
+ name=datasets.Split.TRAIN,
86
+ gen_kwargs={"data_folder": self.config.data_url, "split": "train"},
87
+ ),
88
+ ]
89
+ return ret_list
90
+
91
+ def _generate_examples(self, data_folder, split):
92
+ """Yields examples."""
93
+ all_data_files = [
94
+ os.path.join(data_folder, f'{source}.{split}')
95
+ for source in self.DATA_SOURCES]
96
+ all_lines = []
97
+ for data_file in all_data_files:
98
+ with open(data_file, encoding="utf-8") as f:
99
+ all_lines.extend(f.readlines())
100
+ for idx, row in enumerate(all_lines):
101
+ if row.strip():
102
+ yield idx, {"text": row}
103
+ else:
104
+ yield idx, {"text": ""}
benchmarks/babylm/env/train.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ # disable logging until training starts
28
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
29
+ import sys
30
+ from dataclasses import dataclass, field
31
+ from itertools import chain
32
+ from typing import Optional
33
+
34
+ import datasets
35
+ import evaluate
36
+ import torch
37
+ from datasets import load_dataset
38
+
39
+ import transformers
40
+ from transformers import (
41
+ CONFIG_MAPPING,
42
+ MODEL_FOR_CAUSAL_LM_MAPPING,
43
+ AutoConfig,
44
+ AutoModelForCausalLM,
45
+ AutoTokenizer,
46
+ HfArgumentParser,
47
+ Trainer,
48
+ TrainingArguments,
49
+ default_data_collator,
50
+ is_torch_tpu_available,
51
+ set_seed,
52
+ )
53
+ from transformers.testing_utils import CaptureLogger
54
+ from transformers.trainer_utils import get_last_checkpoint
55
+ from transformers.utils import check_min_version, send_example_telemetry
56
+ from transformers.utils.versions import require_version
57
+
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
64
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
65
+
66
+
67
+ @dataclass
68
+ class ModelArguments:
69
+ """
70
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
71
+ """
72
+
73
+ model_name_or_path: Optional[str] = field(
74
+ default=None,
75
+ metadata={
76
+ "help": (
77
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
78
+ )
79
+ },
80
+ )
81
+ model_type: Optional[str] = field(
82
+ default="gpt2",
83
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
84
+ )
85
+ config_overrides: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": (
89
+ "Override some existing default config settings when a model is trained from scratch. Example: "
90
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
91
+ )
92
+ },
93
+ )
94
+ config_name: Optional[str] = field(
95
+ default="gpt2", metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
+ )
97
+ tokenizer_name: Optional[str] = field(
98
+ default="gpt2", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
+ )
100
+ cache_dir: Optional[str] = field(
101
+ default=None,
102
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
103
+ )
104
+ use_fast_tokenizer: bool = field(
105
+ default=True,
106
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
107
+ )
108
+ model_revision: str = field(
109
+ default="main",
110
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
111
+ )
112
+ use_auth_token: bool = field(
113
+ default=False,
114
+ metadata={
115
+ "help": (
116
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
117
+ "with private models)."
118
+ )
119
+ },
120
+ )
121
+ torch_dtype: Optional[str] = field(
122
+ default=None,
123
+ metadata={
124
+ "help": (
125
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
126
+ "dtype will be automatically derived from the model's weights."
127
+ ),
128
+ "choices": ["auto", "bfloat16", "float16", "float32"],
129
+ },
130
+ )
131
+ low_cpu_mem_usage: bool = field(
132
+ default=False,
133
+ metadata={
134
+ "help": (
135
+ "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
136
+ "set True will benefit LLM loading time and RAM consumption."
137
+ )
138
+ },
139
+ )
140
+
141
+ def __post_init__(self):
142
+ if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
143
+ raise ValueError(
144
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class DataTrainingArguments:
150
+ """
151
+ Arguments pertaining to what data we are going to input our model for training and eval.
152
+ """
153
+
154
+ dataset_name: Optional[str] = field(
155
+ default="babyLM_for_hf.py", metadata={"help": "The name of the dataset to use (via the datasets library)."}
156
+ )
157
+ dataset_config_name: Optional[str] = field(
158
+ default="babyLM-10M", metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
159
+ )
160
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
161
+ validation_file: Optional[str] = field(
162
+ default=None,
163
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
164
+ )
165
+ max_train_samples: Optional[int] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": (
169
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
170
+ "value if set."
171
+ )
172
+ },
173
+ )
174
+ max_eval_samples: Optional[int] = field(
175
+ default=200,
176
+ metadata={
177
+ "help": (
178
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
179
+ "value if set."
180
+ )
181
+ },
182
+ )
183
+ streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
184
+ block_size: Optional[int] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": (
188
+ "Optional input sequence length after tokenization. "
189
+ "The training dataset will be truncated in block of this size for training. "
190
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
191
+ )
192
+ },
193
+ )
194
+ overwrite_cache: bool = field(
195
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
196
+ )
197
+ validation_split_percentage: Optional[int] = field(
198
+ default=5,
199
+ metadata={
200
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
201
+ },
202
+ )
203
+ preprocessing_num_workers: Optional[int] = field(
204
+ default=None,
205
+ metadata={"help": "The number of processes to use for the preprocessing."},
206
+ )
207
+ keep_linebreaks: bool = field(
208
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
209
+ )
210
+
211
+ def __post_init__(self):
212
+ if self.streaming:
213
+ require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
214
+
215
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
216
+ raise ValueError("Need either a dataset name or a training/validation file.")
217
+ else:
218
+ if self.train_file is not None:
219
+ extension = self.train_file.split(".")[-1]
220
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
221
+ if self.validation_file is not None:
222
+ extension = self.validation_file.split(".")[-1]
223
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
224
+
225
+
226
+ def main():
227
+ # See all possible arguments in src/transformers/training_args.py
228
+ # or by passing the --help flag to this script.
229
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
230
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
231
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
232
+ # If we pass only one argument to the script and it's the path to a json file,
233
+ # let's parse it to get our arguments.
234
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
235
+ else:
236
+ if "--output_dir" not in sys.argv:
237
+ sys.argv.append("--output_dir")
238
+ sys.argv.append("./output")
239
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
240
+
241
+ # by default we do both training and evaluation
242
+ training_args.do_train = True if not "--do_train" in sys.argv else training_args.do_train
243
+ training_args.do_eval = True if not "--do_eval" in sys.argv else training_args.do_eval
244
+ training_args.overwrite_output_dir = True if not "--overwrite_output_dir" in sys.argv else training_args.overwrite_output_dir
245
+ training_args.report_to = [] if not "--report_to" in sys.argv else training_args.report_to
246
+ training_args.log_level = "critical" if not "--log_level" in sys.argv else training_args.log_level
247
+ training_args.num_train_epochs = 1 if not "--num_train_epochs" in sys.argv else training_args.num_train_epochs
248
+ training_args.evaluation_strategy = "steps" if not "--evaluation_strategy" in sys.argv else training_args.evaluation_strategy
249
+ training_args.eval_steps = 0.2 if not "--eval_steps" in sys.argv else training_args.eval_steps
250
+ training_args.per_device_train_batch_size = 16 if not "--per_device_train_batch_size" in sys.argv else training_args.per_device_train_batch_size
251
+
252
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
253
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
254
+ send_example_telemetry("run_clm", model_args, data_args)
255
+
256
+ # Setup logging
257
+ logging.basicConfig(
258
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
259
+ datefmt="%m/%d/%Y %H:%M:%S",
260
+ handlers=[logging.StreamHandler(sys.stdout)],
261
+ )
262
+
263
+ if training_args.should_log:
264
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
265
+ transformers.utils.logging.set_verbosity_info()
266
+
267
+ log_level = training_args.get_process_log_level()
268
+ logger.setLevel(log_level)
269
+ datasets.utils.logging.set_verbosity(log_level)
270
+ transformers.utils.logging.set_verbosity(log_level)
271
+ transformers.utils.logging.enable_default_handler()
272
+ transformers.utils.logging.enable_explicit_format()
273
+
274
+ # Log on each process the small summary:
275
+ logger.warning(
276
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
277
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
278
+ )
279
+ logger.info(f"Training/evaluation parameters {training_args}")
280
+
281
+ # Detecting last checkpoint.
282
+ last_checkpoint = None
283
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
284
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
285
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
286
+ raise ValueError(
287
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
288
+ "Use --overwrite_output_dir to overcome."
289
+ )
290
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
291
+ logger.info(
292
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
293
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
294
+ )
295
+
296
+ # Set seed before initializing model.
297
+ set_seed(training_args.seed)
298
+
299
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
300
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
301
+ # (the dataset will be downloaded automatically from the datasets Hub).
302
+ #
303
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
304
+ # 'text' is found. You can easily tweak this behavior (see below).
305
+ #
306
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
307
+ # download the dataset.
308
+ if data_args.dataset_name is not None:
309
+ # Downloading and loading a dataset from the hub.
310
+ raw_datasets = load_dataset(
311
+ data_args.dataset_name,
312
+ data_args.dataset_config_name,
313
+ cache_dir=model_args.cache_dir,
314
+ use_auth_token=True if model_args.use_auth_token else None,
315
+ streaming=data_args.streaming,
316
+ )
317
+ if "validation" not in raw_datasets.keys():
318
+ raw_datasets["validation"] = load_dataset(
319
+ data_args.dataset_name,
320
+ data_args.dataset_config_name,
321
+ split=f"train[:{data_args.validation_split_percentage}%]",
322
+ cache_dir=model_args.cache_dir,
323
+ use_auth_token=True if model_args.use_auth_token else None,
324
+ streaming=data_args.streaming,
325
+ )
326
+ raw_datasets["train"] = load_dataset(
327
+ data_args.dataset_name,
328
+ data_args.dataset_config_name,
329
+ split=f"train[{data_args.validation_split_percentage}%:]",
330
+ cache_dir=model_args.cache_dir,
331
+ use_auth_token=True if model_args.use_auth_token else None,
332
+ streaming=data_args.streaming,
333
+ )
334
+ else:
335
+ data_files = {}
336
+ dataset_args = {}
337
+ if data_args.train_file is not None:
338
+ data_files["train"] = data_args.train_file
339
+ if data_args.validation_file is not None:
340
+ data_files["validation"] = data_args.validation_file
341
+ extension = (
342
+ data_args.train_file.split(".")[-1]
343
+ if data_args.train_file is not None
344
+ else data_args.validation_file.split(".")[-1]
345
+ )
346
+ if extension == "txt":
347
+ extension = "text"
348
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
349
+ raw_datasets = load_dataset(
350
+ extension,
351
+ data_files=data_files,
352
+ cache_dir=model_args.cache_dir,
353
+ use_auth_token=True if model_args.use_auth_token else None,
354
+ **dataset_args,
355
+ )
356
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
357
+ if "validation" not in raw_datasets.keys():
358
+ raw_datasets["validation"] = load_dataset(
359
+ extension,
360
+ data_files=data_files,
361
+ split=f"train[:{data_args.validation_split_percentage}%]",
362
+ cache_dir=model_args.cache_dir,
363
+ use_auth_token=True if model_args.use_auth_token else None,
364
+ **dataset_args,
365
+ )
366
+ raw_datasets["train"] = load_dataset(
367
+ extension,
368
+ data_files=data_files,
369
+ split=f"train[{data_args.validation_split_percentage}%:]",
370
+ cache_dir=model_args.cache_dir,
371
+ use_auth_token=True if model_args.use_auth_token else None,
372
+ **dataset_args,
373
+ )
374
+
375
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
376
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
377
+
378
+ # Load pretrained model and tokenizer
379
+ #
380
+ # Distributed training:
381
+ # The .from_pretrained methods guarantee that only one local process can concurrently
382
+ # download model & vocab.
383
+
384
+ config_kwargs = {
385
+ "cache_dir": model_args.cache_dir,
386
+ "revision": model_args.model_revision,
387
+ "use_auth_token": True if model_args.use_auth_token else None,
388
+ }
389
+ if model_args.config_name:
390
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
391
+ elif model_args.model_name_or_path:
392
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
393
+ else:
394
+ config = CONFIG_MAPPING[model_args.model_type]()
395
+ logger.warning("You are instantiating a new config instance from scratch.")
396
+ if model_args.config_overrides is not None:
397
+ logger.info(f"Overriding config: {model_args.config_overrides}")
398
+ config.update_from_string(model_args.config_overrides)
399
+ logger.info(f"New config: {config}")
400
+
401
+ tokenizer_kwargs = {
402
+ "cache_dir": model_args.cache_dir,
403
+ "use_fast": model_args.use_fast_tokenizer,
404
+ "revision": model_args.model_revision,
405
+ "use_auth_token": True if model_args.use_auth_token else None,
406
+ }
407
+ if model_args.tokenizer_name:
408
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
409
+ elif model_args.model_name_or_path:
410
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
411
+ else:
412
+ raise ValueError(
413
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
414
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
415
+ )
416
+
417
+ if model_args.model_name_or_path:
418
+ torch_dtype = (
419
+ model_args.torch_dtype
420
+ if model_args.torch_dtype in ["auto", None]
421
+ else getattr(torch, model_args.torch_dtype)
422
+ )
423
+ model = AutoModelForCausalLM.from_pretrained(
424
+ model_args.model_name_or_path,
425
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
426
+ config=config,
427
+ cache_dir=model_args.cache_dir,
428
+ revision=model_args.model_revision,
429
+ use_auth_token=True if model_args.use_auth_token else None,
430
+ torch_dtype=torch_dtype,
431
+ low_cpu_mem_usage=model_args.low_cpu_mem_usage,
432
+ )
433
+ else:
434
+ model = AutoModelForCausalLM.from_config(config)
435
+ n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
436
+ logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
437
+
438
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
439
+ # on a small vocab and want a smaller embedding size, remove this test.
440
+ embedding_size = model.get_input_embeddings().weight.shape[0]
441
+ if len(tokenizer) > embedding_size:
442
+ model.resize_token_embeddings(len(tokenizer))
443
+
444
+ # Preprocessing the datasets.
445
+ # First we tokenize all the texts.
446
+ if training_args.do_train:
447
+ column_names = list(raw_datasets["train"].features)
448
+ else:
449
+ column_names = list(raw_datasets["validation"].features)
450
+ text_column_name = "text" if "text" in column_names else column_names[0]
451
+
452
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
453
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
454
+
455
+ def tokenize_function(examples):
456
+ with CaptureLogger(tok_logger) as cl:
457
+ output = tokenizer(examples[text_column_name])
458
+ # clm input could be much much longer than block_size
459
+ if "Token indices sequence length is longer than the" in cl.out:
460
+ tok_logger.warning(
461
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
462
+ " before being passed to the model."
463
+ )
464
+ return output
465
+
466
+ with training_args.main_process_first(desc="dataset map tokenization"):
467
+ if not data_args.streaming:
468
+ tokenized_datasets = raw_datasets.map(
469
+ tokenize_function,
470
+ batched=True,
471
+ num_proc=data_args.preprocessing_num_workers,
472
+ remove_columns=column_names,
473
+ load_from_cache_file=not data_args.overwrite_cache,
474
+ desc="Running tokenizer on dataset",
475
+ )
476
+ else:
477
+ tokenized_datasets = raw_datasets.map(
478
+ tokenize_function,
479
+ batched=True,
480
+ remove_columns=column_names,
481
+ )
482
+
483
+ if data_args.block_size is None:
484
+ block_size = tokenizer.model_max_length
485
+ if block_size > 1024:
486
+ logger.warning(
487
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
488
+ " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
489
+ " override this default with `--block_size xxx`."
490
+ )
491
+ block_size = 1024
492
+ else:
493
+ if data_args.block_size > tokenizer.model_max_length:
494
+ logger.warning(
495
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
496
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
497
+ )
498
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
499
+
500
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
501
+ def group_texts(examples):
502
+ # Concatenate all texts.
503
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
504
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
505
+ # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
506
+ # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
507
+ total_length = (total_length // block_size) * block_size
508
+ # Split by chunks of max_len.
509
+ result = {
510
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
511
+ for k, t in concatenated_examples.items()
512
+ }
513
+ result["labels"] = result["input_ids"].copy()
514
+ return result
515
+
516
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
517
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
518
+ # to preprocess.
519
+ #
520
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
521
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
522
+
523
+ with training_args.main_process_first(desc="grouping texts together"):
524
+ if not data_args.streaming:
525
+ lm_datasets = tokenized_datasets.map(
526
+ group_texts,
527
+ batched=True,
528
+ num_proc=data_args.preprocessing_num_workers,
529
+ load_from_cache_file=not data_args.overwrite_cache,
530
+ desc=f"Grouping texts in chunks of {block_size}",
531
+ )
532
+ else:
533
+ lm_datasets = tokenized_datasets.map(
534
+ group_texts,
535
+ batched=True,
536
+ )
537
+
538
+ if training_args.do_train:
539
+ if "train" not in tokenized_datasets:
540
+ raise ValueError("--do_train requires a train dataset")
541
+ train_dataset = lm_datasets["train"]
542
+ if data_args.max_train_samples is not None:
543
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
544
+ train_dataset = train_dataset.select(range(max_train_samples))
545
+
546
+ if training_args.do_eval:
547
+ if "validation" not in tokenized_datasets:
548
+ raise ValueError("--do_eval requires a validation dataset")
549
+ eval_dataset = lm_datasets["validation"]
550
+ if data_args.max_eval_samples is not None:
551
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
552
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
553
+
554
+ def preprocess_logits_for_metrics(logits, labels):
555
+ if isinstance(logits, tuple):
556
+ # Depending on the model and config, logits may contain extra tensors,
557
+ # like past_key_values, but logits always come first
558
+ logits = logits[0]
559
+ return logits.argmax(dim=-1)
560
+
561
+ metric = evaluate.load("accuracy")
562
+
563
+ def compute_metrics(eval_preds):
564
+ preds, labels = eval_preds
565
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
566
+ # by preprocess_logits_for_metrics but we need to shift the labels
567
+ labels = labels[:, 1:].reshape(-1)
568
+ preds = preds[:, :-1].reshape(-1)
569
+ return metric.compute(predictions=preds, references=labels)
570
+
571
+ # Initialize our Trainer
572
+ trainer = Trainer(
573
+ model=model,
574
+ args=training_args,
575
+ train_dataset=train_dataset if training_args.do_train else None,
576
+ eval_dataset=eval_dataset if training_args.do_eval else None,
577
+ tokenizer=tokenizer,
578
+ # Data collator will default to DataCollatorWithPadding, so we change it.
579
+ data_collator=default_data_collator,
580
+ compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
581
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
582
+ if training_args.do_eval and not is_torch_tpu_available()
583
+ else None,
584
+ )
585
+
586
+ transformers.utils.logging.set_verbosity(transformers.utils.logging.WARNING)
587
+
588
+ # Training
589
+ if training_args.do_train:
590
+ checkpoint = None
591
+ if training_args.resume_from_checkpoint is not None:
592
+ checkpoint = training_args.resume_from_checkpoint
593
+ elif last_checkpoint is not None:
594
+ checkpoint = last_checkpoint
595
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
596
+ trainer.save_model() # Saves the tokenizer too for easy upload
597
+
598
+ metrics = train_result.metrics
599
+
600
+ max_train_samples = (
601
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
602
+ )
603
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
604
+
605
+ trainer.log_metrics("train", metrics)
606
+ trainer.save_metrics("train", metrics)
607
+ trainer.save_state()
608
+
609
+ # Evaluation
610
+ if training_args.do_eval:
611
+ logger.info("*** Evaluate ***")
612
+
613
+ metrics = trainer.evaluate()
614
+
615
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
616
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
617
+ try:
618
+ perplexity = math.exp(metrics["eval_loss"])
619
+ except OverflowError:
620
+ perplexity = float("inf")
621
+ metrics["perplexity"] = perplexity
622
+
623
+ trainer.log_metrics("eval", metrics)
624
+ trainer.save_metrics("eval", metrics)
625
+
626
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
627
+ if data_args.dataset_name is not None:
628
+ kwargs["dataset_tags"] = data_args.dataset_name
629
+ if data_args.dataset_config_name is not None:
630
+ kwargs["dataset_args"] = data_args.dataset_config_name
631
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
632
+ else:
633
+ kwargs["dataset"] = data_args.dataset_name
634
+
635
+ if training_args.push_to_hub:
636
+ trainer.push_to_hub(**kwargs)
637
+ else:
638
+ trainer.create_model_card(**kwargs)
639
+
640
+ if __name__ == "__main__":
641
+ main()
benchmarks/babylm/scripts/eval.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ # disable logging until training starts
28
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
29
+ import sys
30
+ from dataclasses import dataclass, field
31
+ from itertools import chain
32
+ from typing import Optional
33
+
34
+ import datasets
35
+ import evaluate
36
+ import torch
37
+ from datasets import load_dataset
38
+
39
+ import transformers
40
+ from transformers import (
41
+ CONFIG_MAPPING,
42
+ MODEL_FOR_CAUSAL_LM_MAPPING,
43
+ AutoConfig,
44
+ AutoModelForCausalLM,
45
+ AutoTokenizer,
46
+ HfArgumentParser,
47
+ Trainer,
48
+ TrainingArguments,
49
+ default_data_collator,
50
+ is_torch_tpu_available,
51
+ set_seed,
52
+ )
53
+ from transformers.testing_utils import CaptureLogger
54
+ from transformers.trainer_utils import get_last_checkpoint
55
+ from transformers.utils import check_min_version, send_example_telemetry
56
+ from transformers.utils.versions import require_version
57
+
58
+ from transformers import AutoModel, AutoTokenizer
59
+ from datasets import load_dataset
60
+ from transformers.testing_utils import CaptureLogger
61
+
62
+ from itertools import chain
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ def get_score(submission_folder = "../env"):
68
+ training_args = TrainingArguments("test_trainer")
69
+ training_args.report_to = []
70
+ raw_datasets = load_dataset(submission_folder + "/babyLM_for_hf.py", "babyLM-10M", split="test")
71
+ model = AutoModelForCausalLM.from_pretrained(submission_folder + "/output/")
72
+ tokenizer = AutoTokenizer.from_pretrained(submission_folder + "/output/")
73
+
74
+ # Preprocessing the datasets.
75
+ # First we tokenize all the texts.
76
+ column_names = list(raw_datasets.features)
77
+ text_column_name = "text" if "text" in column_names else column_names[0]
78
+
79
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
80
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
81
+
82
+ def tokenize_function(examples):
83
+ with CaptureLogger(tok_logger) as cl:
84
+ output = tokenizer(examples[text_column_name])
85
+ # clm input could be much much longer than block_size
86
+ if "Token indices sequence length is longer than the" in cl.out:
87
+ tok_logger.warning(
88
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
89
+ " before being passed to the model."
90
+ )
91
+ return output
92
+
93
+ with training_args.main_process_first(desc="dataset map tokenization"):
94
+ # if not data_args.streaming:
95
+ # tokenized_datasets = raw_datasets.map(
96
+ # tokenize_function,
97
+ # batched=True,
98
+ # num_proc=data_args.preprocessing_num_workers,
99
+ # remove_columns=column_names,
100
+ # load_from_cache_file=not data_args.overwrite_cache,
101
+ # desc="Running tokenizer on dataset",
102
+ # )
103
+ # else:
104
+ tokenized_datasets = raw_datasets.map(
105
+ tokenize_function,
106
+ batched=True,
107
+ remove_columns=column_names,
108
+ )
109
+
110
+ if True:
111
+ block_size = tokenizer.model_max_length
112
+ if block_size > 1024:
113
+ logger.warning(
114
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
115
+ " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
116
+ " override this default with `--block_size xxx`."
117
+ )
118
+ block_size = 1024
119
+ else:
120
+ if data_args.block_size > tokenizer.model_max_length:
121
+ logger.warning(
122
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
123
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
124
+ )
125
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
126
+
127
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
128
+ def group_texts(examples):
129
+ # Concatenate all texts.
130
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
131
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
132
+ # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
133
+ # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
134
+ total_length = (total_length // block_size) * block_size
135
+ # Split by chunks of max_len.
136
+ result = {
137
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
138
+ for k, t in concatenated_examples.items()
139
+ }
140
+ result["labels"] = result["input_ids"].copy()
141
+ return result
142
+
143
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
144
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
145
+ # to preprocess.
146
+ #
147
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
148
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
149
+
150
+ with training_args.main_process_first(desc="grouping texts together"):
151
+ # if not data_args.streaming:
152
+ # lm_datasets = tokenized_datasets.map(
153
+ # group_texts,
154
+ # batched=True,
155
+ # num_proc=data_args.preprocessing_num_workers,
156
+ # load_from_cache_file=not data_args.overwrite_cache,
157
+ # desc=f"Grouping texts in chunks of {block_size}",
158
+ # )
159
+ # else:
160
+ lm_datasets = tokenized_datasets.map(
161
+ group_texts,
162
+ batched=True,
163
+ )
164
+ eval_dataset = lm_datasets
165
+
166
+ def preprocess_logits_for_metrics(logits, labels):
167
+ if isinstance(logits, tuple):
168
+ # Depending on the model and config, logits may contain extra tensors,
169
+ # like past_key_values, but logits always come first
170
+ logits = logits[0]
171
+ return logits.argmax(dim=-1)
172
+
173
+ metric = evaluate.load("accuracy")
174
+
175
+ def compute_metrics(eval_preds):
176
+ preds, labels = eval_preds
177
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
178
+ # by preprocess_logits_for_metrics but we need to shift the labels
179
+ labels = labels[:, 1:].reshape(-1)
180
+ preds = preds[:, :-1].reshape(-1)
181
+ return metric.compute(predictions=preds, references=labels)
182
+
183
+ # Initialize our Trainer
184
+ trainer = Trainer(
185
+ model=model,
186
+ args=training_args,
187
+ train_dataset=None,
188
+ eval_dataset=eval_dataset,
189
+ tokenizer=tokenizer,
190
+ # Data collator will default to DataCollatorWithPadding, so we change it.
191
+ data_collator=default_data_collator,
192
+ compute_metrics=compute_metrics,
193
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
194
+ )
195
+
196
+ transformers.utils.logging.set_verbosity(transformers.utils.logging.WARNING)
197
+
198
+ # Evaluation
199
+ metrics = trainer.evaluate()
200
+
201
+ max_eval_samples = len(eval_dataset)
202
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
203
+ try:
204
+ perplexity = math.exp(metrics["eval_loss"])
205
+ except OverflowError:
206
+ perplexity = float("inf")
207
+ metrics["perplexity"] = perplexity
208
+
209
+ return perplexity
210
+
211
+ if __name__ == "__main__":
212
+ print(get_score())
benchmarks/babylm/scripts/prepare.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import pandas as pd
3
+
4
+ taskname = "babylm"
5
+ download_dir = "../env"
6
+
7
+ subprocess.run(["wget", "https://github.com/babylm/babylm.github.io/raw/main/babylm_data.zip"], cwd=download_dir)
8
+ subprocess.run(["unzip", "-n", f"babylm_data.zip"], cwd=download_dir)
9
+ subprocess.run(["rm", f"babylm_data.zip"], cwd=download_dir)
10
+ subprocess.run(["rm", "-rf", f"babylm_data/babylm_100M"], cwd=download_dir)
11
+ subprocess.run(["rm", "-rf", f"__MACOSX"], cwd=download_dir)
benchmarks/babylm/scripts/read_only_files.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ babylm_data/*
2
+ __MACOSX/*
benchmarks/babylm/scripts/research_problem.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Improve the baseline model performance on the babyLM Benchmark.
2
+
3
+ Summary: This shared task challenges community members to train a language model **from scratch** on the same amount of linguistic data available to a child. Submissions should be implemented in Huggingface's Transformers library and will be evaluated on a shared pipeline. This shared task is co-sponsored by CMCL and CoNLL.
4
+
5
+ To run the baseline model, execute train.py. It will train a standard gpt2 model on the babyLM data. The final model will be saved to output/ folder.
6
+
7
+ When you submit your final answer, you will be evaluated on the performance of the checkpoint saved in the output folder. It will be evaluated on a held-out test set.
benchmarks/bibtex-generation/env/arxiv_API_reference.txt ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arXiv API User's Manual
2
+ Please review the Terms of Use for arXiv APIs before using the arXiv API.
3
+
4
+ Table of Contents
5
+ 1. Preface
6
+ 2. API QuickStart
7
+ 3. Structure of the API
8
+ 3.1. Calling the API
9
+ 3.1.1. Query Interface
10
+ 3.1.1.1. search_query and id_list logic
11
+ 3.1.1.2. start and max_results paging
12
+ 3.1.1.3. sort order for return results
13
+ 3.2. The API Response
14
+ 3.3. Outline of an Atom feed
15
+ 3.3.1. Feed Metadata
16
+ 3.3.1.1. <title>, <id>, <link> and <updated>
17
+ 3.3.1.2. OpenSearch Extension Elements
18
+ 3.3.2. Entry Metadata
19
+ 3.3.2.1. <title>, <id>, <published>, and <updated>
20
+ 3.3.2.1. <summary>, <author> and <category>
21
+ 3.3.2.3. <link>'s
22
+ 3.3.2.4. <arxiv> extension elements
23
+ 3.4. Errors
24
+ 4. Examples
25
+ 4.1. Simple Examples
26
+ 4.1.1. Perl
27
+ 4.1.2. Python
28
+ 4.1.3. Ruby
29
+ 4.1.4. PHP
30
+ 4.2. Detailed Parsing Examples
31
+ 5. Appendices
32
+ 5.1. Details of Query Construction
33
+ 5.1.1. A Note on Article Versions
34
+ 5.2. Details of Atom Results Returned
35
+ 5.3. Subject Classifications
36
+
37
+
38
+ 1. Preface
39
+ The arXiv API allows programmatic access to the hundreds of thousands of e-prints hosted on arXiv.org.
40
+
41
+ This manual is meant to provide an introduction to using the API, as well as documentation describing its details, and as such is meant to be read by both beginning and advanced users. To get a flavor for how the API works, see the API Quickstart. For more detailed information, see Structure of the API.
42
+
43
+ For examples of using the API from several popular programming languages including perl, python and ruby, see the Examples section.
44
+
45
+ Finally, the Appendices contain an explanation of all input parameters to the API, as well as the output format.
46
+
47
+
48
+ 2. API QuickStart
49
+ The easiest place to start with the API is by accessing it through a web browser. For examples of accessing the API through common programming languages, see the Examples section.
50
+
51
+ Most everyone that has read or submitted e-prints on the arXiv is familiar with the arXiv human web interface. These HTML pages can be accessed by opening up your web browser, and entering the following url in your web browser
52
+
53
+ http://arxiv.org
54
+
55
+ From there, the article listings can be browsed by clicking on one of the many links, or you can search for articles using the search box in the upper right hand side of the page. For example, if I wanted to search for articles that contain the word electron in the title or abstract, I would type electron in the search box, and click Go. If you follow my example, you will see something like this: a web page listing the title and authors of each result, with links to the abstract page, pdf, etc.
56
+
57
+ In its simplest form, the API can be used in exactly the same way. However, it uses a few shortcuts so there is less clicking involved. For example, you can see the same search results for electron by entering the url
58
+
59
+ http://export.arxiv.org/api/query?search_query=all:electron.
60
+
61
+ Alternatively, you can search for articles that contain electron AND proton with the API by entering
62
+
63
+ http://export.arxiv.org/api/query?search_query=all:electron+AND+all:proton
64
+
65
+ What you see will look different from the HTML interface, but it contains the same information as the search done with the human interface. The reason why the results look different is that the API returns results in the Atom 1.0 format, and not HTML. Since Atom is defined as an XML grammar, it is much easier to digest for programs than HTML. The API is not intended to be used inside a web browser by itself, but this is a particularly simple way to debug a program that does use the API.
66
+
67
+ You might notice that your web browser has asked you if you want to “subscribe to this feed” after you enter the API url. This is because Atom is one of the formats used by web sites to syndicate their content. These feeds are usually read with feed reader software, and are what is generated by the existing arXiv rss feeds. The current arXiv feeds only give you updates on new papers within the category you specify. One immediately useful thing to do with the API then is to generate your own feed, based on a custom query!
68
+
69
+ To learn more about how to construct custom search queries with the API, see the appendix on the details of query construction. To learn about what information is returned by the API, see the section on the API response. To learn more about writing programs to call the API, and digest the responses, we suggest starting with the section on Structure of the API.
70
+
71
+
72
+ 3. Structure of the API
73
+ In this section, we'll go over some of the details of interacting with the API. A diagram of a typical API call is shown below:
74
+
75
+ Example: A typical API call
76
+
77
+
78
+ Request from url: http://export.arxiv.org/api/query (1)
79
+ with parameters: search_query=all:electron
80
+ .
81
+ .
82
+ .
83
+ API server processes the request and sends the response
84
+ .
85
+ .
86
+ .
87
+ Response received by client. (2)
88
+ The request can be made via HTTP GET, in which the parameters are encoded in the url, or via an HTTP POST in which the parameters are encoded in the HTTP request header. Most client libraries support both methods.
89
+
90
+ If all goes well, the HTTP header will show a 200 OK status, and the response body will contain the Atom response content as shown in the example response.
91
+
92
+
93
+ 3.1. Calling the API
94
+ As mentioned above, the API can be called with an HTTP request of type GET or POST. For our purposes, the main difference is that the parameters are included in the url for a GET request, but not for the POST request. Thus if the parameters list is unusually long, a POST request might be preferred.
95
+
96
+ The parameters for each of the API methods are explained below. For each method, the base url is
97
+
98
+
99
+ http://export.arxiv.org/api/{method_name}?{parameters}
100
+
101
+ 3.1.1. Query Interface
102
+ The API query interface has method_name=query. The table below outlines the parameters that can be passed to the query interface. Parameters are separated with the & sign in the constructed url's.
103
+
104
+ query
105
+ parameters type defaults required
106
+ search_query string None No
107
+ id_list comma-delimited string None No
108
+ start int 0 No
109
+ max_results int 10 No
110
+
111
+ 3.1.1.1. SEARCH_QUERY AND ID_LIST LOGIC
112
+ We have already seen the use of search_query in the quickstart section. The search_query takes a string that represents a search query used to find articles. The construction of search_query is described in the search query construction appendix. The id_list contains a comma-delimited list of arXiv id's.
113
+
114
+ The logic of these two parameters is as follows:
115
+
116
+ If only search_query is given (id_list is blank or not given), then the API will return results for each article that matches the search query.
117
+
118
+ If only id_list is given (search_query is blank or not given), then the API will return results for each article in id_list.
119
+
120
+ If BOTH search_query and id_list are given, then the API will return each article in id_list that matches search_query. This allows the API to act as a results filter.
121
+
122
+ This is summarized in the following table:
123
+
124
+ search_query present id_list present API returns
125
+ yes no articles that match search_query
126
+ no yes articles that are in id_list
127
+ yes yes articles in id_list that also match search_query
128
+
129
+ 3.1.1.2. START AND MAX_RESULTS PAGING
130
+ Many times there are hundreds of results for an API query. Rather than download information about all the results at once, the API offers a paging mechanism through start and max_results that allows you to download chucks of the result set at a time. Within the total results set, start defines the index of the first returned result, using 0-based indexing. max_results is the number of results returned by the query. For example, if wanted to step through the results of a search_query of all:electron, we would construct the urls:
131
+
132
+
133
+ http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=10 (1)
134
+ http://export.arxiv.org/api/query?search_query=all:electron&start=10&max_results=10 (2)
135
+ http://export.arxiv.org/api/query?search_query=all:electron&start=20&max_results=10 (3)
136
+ Get results 0-9
137
+
138
+ Get results 10-19
139
+
140
+ Get results 20-29
141
+
142
+ Detailed examples of how to perform paging in a variety of programming languages can be found in the examples section.
143
+
144
+ In cases where the API needs to be called multiple times in a row, we encourage you to play nice and incorporate a 3 second delay in your code. The detailed examples below illustrate how to do this in a variety of languages.
145
+
146
+ Because of speed limitations in our implementation of the API, the maximum number of results returned from a single call (max_results) is limited to 30000 in slices of at most 2000 at a time, using the max_results and start query parameters. For example to retrieve matches 6001-8000: http://export.arxiv.org/api/query?search_query=all:electron&start=6000&max_results=8000
147
+
148
+ Large result sets put considerable load on the server and also take a long time to render. We recommend to refine queries which return more than 1,000 results, or at least request smaller slices. For bulk metadata harvesting or set information, etc., the OAI-PMH interface is more suitable. A request with max_results >30,000 will result in an HTTP 400 error code with appropriate explanation. A request for 30000 results will typically take a little over 2 minutes to return a response of over 15MB. Requests for fewer results are much faster and correspondingly smaller.
149
+
150
+
151
+ 3.1.1.3. SORT ORDER FOR RETURN RESULTS
152
+ There are two options for for the result set to the API search, sortBy and sortOrder.
153
+
154
+ sortBy can be "relevance", "lastUpdatedDate", "submittedDate"
155
+
156
+ sortOrder can be either "ascending" or "descending"
157
+
158
+ A sample query using these new parameters looks like:
159
+
160
+
161
+ http://export.arxiv.org/api/query?search_query=ti:"electron thermal conductivity"&sortBy=lastUpdatedDate&sortOrder=ascending
162
+
163
+ 3.2. The API Response
164
+ Everything returned by the API in the body of the HTTP responses is Atom 1.0, including errors. Atom is a grammar of XML that is popular in the world of content syndication, and is very similar to RSS for this purpose. Typically web sites with dynamic content such as news sites and blogs will publish their content as Atom or RSS feeds. However, Atom is a general format that embodies the concept of a list of items, and thus is well-suited to returning the arXiv search results.
165
+
166
+
167
+ 3.3. Outline of an Atom feed
168
+ In this section we will discuss the contents of the Atom documents returned by the API. To see the full explanation of the Atom 1.0 format, please see the Atom specification.
169
+
170
+ An API response consists of an Atom <feed> element which contains metadata about the API call performed, as well as child <entry> elements which embody the metadata for each of the returned results. Below we explain each of the elements and attributes. We will base our discussion on the sample results feed discussed in the examples section.
171
+
172
+ You may notice that the results from the API are ordered differently that the results given by the HTML arXiv search interface. The HTML interface automatically sorts results in descending order based on the date of their submission, while the API returns results according to relevancy from the internal search engine. Thus when debugging a search query, we encourage you to use the API within a web browser, rather than the HTML search interface. If you want sorting by date, you can always do this within your programs by reading the <published> tag for each entry as explained below.
173
+
174
+
175
+ 3.3.1. Feed Metadata
176
+ Every response will contain the line:
177
+
178
+
179
+ <?xml version="1.0" encoding="utf-8"?>
180
+ to signify that we are receiving XML 1.0 with a UTF-8 encoding. Following that line will be a line indicating that we are receiving an Atom feed:
181
+
182
+
183
+ <feed xmlns="http://www.w3.org/2005/Atom"
184
+ xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/"
185
+ xmlns:arxiv="http://arxiv.org/schemas/atom">
186
+ You will notice that three XML namespaces are defined. The default namespace signifies that we are dealing with Atom 1.0. The other two namespaces define extensions to Atom that we describe below.
187
+
188
+
189
+ 3.3.1.1. <TITLE>, <ID>, <LINK> AND <UPDATED>
190
+ The <title> element gives the title for the feed:
191
+
192
+
193
+ <title xmlns="http://www.w3.org/2005/Atom">
194
+ ArXiv Query: search_query=all:electron&amp;id_list=&amp;start=0&amp;max_results=1
195
+ </title>
196
+ The title contains a canonicalized version of the query used to call the API. The canonicalization includes all parameters, using their defaults if they were not included, and always puts them in the order search_query,id_list,start,max_results, even if they were specified in a different order in the actual query.
197
+
198
+ The <id> element serves as a unique id for this query, and is useful if you are writing a program such as a feed reader that wants to keep track of all the feeds requested in the past. This id can then be used as a key in a database.
199
+
200
+
201
+ <id xmlns="http://www.w3.org/2005/Atom">
202
+ http://arxiv.org/api/cHxbiOdZaP56ODnBPIenZhzg5f8
203
+ </id>
204
+ The id is guaranteed to be unique for each query.
205
+
206
+ The <link> element provides a URL that can be used to retrieve this feed again.
207
+
208
+
209
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=all:electron&amp;id_list=&amp;start=0&amp;max_results=1" rel="self" type="application/atom+xml"/>
210
+ Note that the url in the link represents the canonicalized version of the query. The <link> provides a GET requestable url, even if the original request was done via POST.
211
+
212
+ The <updated> element provides the last time the contents of the feed were last updated:
213
+
214
+
215
+ <updated xmlns="http://www.w3.org/2005/Atom">2007-10-08T00:00:00-04:00</updated>
216
+ Because the arXiv submission process works on a 24 hour submission cycle, new articles are only available to the API on the midnight after the articles were processed. The <updated> tag thus reflects the midnight of the day that you are calling the API. This is very important - search results do not change until new articles are added. Therefore there is no need to call the API more than once in a day for the same query. Please cache your results. This primarily applies to production systems, and of course you are free to play around with the API while you are developing your program!
217
+
218
+
219
+ 3.3.1.2. OPENSEARCH EXTENSION ELEMENTS
220
+ There are several extension elements defined in the OpenSearch namespace
221
+
222
+
223
+ http://a9.com/-/spec/opensearch/1.1/
224
+ OpenSearch is a lightweight technology that acts in a similar way as the Web Services Description Language. The OpenSearch elements we have included allow OpenSearch enabled clients to digest our results. Such clients often include search result aggregators and browser pluggins that allow searching from a variety of sources.
225
+
226
+ The OpenSearch extension elements can still be useful to you even if you are not writing one of these applications. The <opensearch:totalResults> element lists how many results are in the result set for the query:
227
+
228
+
229
+ <opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
230
+ 1000
231
+ </opensearch:totalResults>
232
+ This can be very useful when implementing paging of search results. The other two elements <opensearch:startIndex>, and <opensearch:itemsPerPage> are analogous to start, and max_results discussed above.
233
+
234
+
235
+ <opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
236
+ 0
237
+ </opensearch:startIndex>
238
+ <opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
239
+ 1
240
+ </opensearch:itemsPerPage>
241
+
242
+ 3.3.2. Entry Metadata
243
+ If there are no errors, the <feed> element contains 0 or more child <entry> elements with each <entry> representing an article in the returned results set. As explained in the errors section, if there are errors, a single <entry> element representing the error is returned. Below the element description describes the elements for <entry>'s representing arXiv articles. For a general discussion of arXiv metadata, see the arXiv metadata explanation.
244
+
245
+
246
+ 3.3.2.1. <TITLE>, <ID>, <PUBLISHED>, AND <UPDATED>
247
+ The <title> element contains the title of the article returned:
248
+
249
+
250
+ <title xmlns="http://www.w3.org/2005/Atom">
251
+ Multi-Electron Production at High Transverse Momenta in ep Collisions at HERA
252
+ </title>
253
+ The <id> element contains a url that resolves to the abstract page for that article:
254
+
255
+
256
+ <id xmlns="http://www.w3.org/2005/Atom">
257
+ http://arxiv.org/abs/hep-ex/0307015
258
+ </id>
259
+ If you want only the arXiv id for the article, you can remove the leading http://arxiv.org/abs/ in the <id>.
260
+
261
+ The <published> tag contains the date in which the first version of this article was submitted and processed. The <updated> element contains the date on which the retrieved article was submitted and processed. If the version is version 1, then <published> == <updated>, otherwise they are different. In the example below, the article retrieved was version 2, so <updated> and <published> are different (see the original query).
262
+
263
+
264
+ <published xmlns="http://www.w3.org/2005/Atom">
265
+ 2007-02-27T16:02:02-05:00
266
+ </published>
267
+ <updated xmlns="http://www.w3.org/2005/Atom">
268
+ 2007-06-25T17:09:59-04:00
269
+ </updated>
270
+
271
+ 3.3.2.2. <SUMMARY>, <AUTHOR> AND <CATEGORY>
272
+ The <summary> element contains the abstract for the article:
273
+
274
+
275
+ <summary xmlns="http://www.w3.org/2005/Atom">
276
+ Multi-electron production is studied at high electron transverse momentum
277
+ in positron- and electron-proton collisions using the H1 detector at HERA.
278
+ The data correspond to an integrated luminosity of 115 pb-1. Di-electron
279
+ and tri-electron event yields are measured. Cross sections are derived in
280
+ a restricted phase space region dominated by photon-photon collisions. In
281
+ general good agreement is found with the Standard Model predictions.
282
+ However, for electron pair invariant masses above 100 GeV, three
283
+ di-electron events and three tri-electron events are observed, compared to
284
+ Standard Model expectations of 0.30 \pm 0.04 and 0.23 \pm 0.04,
285
+ respectively.
286
+ </summary>
287
+ There is one <author> element for each author of the paper in order of authorship. Each <author> element has a <name> sub-element which contains the name of the author.
288
+
289
+
290
+ <author xmlns="http://www.w3.org/2005/Atom">
291
+ <name xmlns="http://www.w3.org/2005/Atom">H1 Collaboration</name>
292
+ </author>
293
+ If author affiliation is present, it is included as an <arxiv:affiliation> subelement of the <author> element as discussed below.
294
+
295
+ The <category> element is used to describe either an arXiv, ACM, or MSC classification. See the arXiv metadata explanation for more details about these classifications. The <category> element has two attributes, scheme, which is the categorization scheme, and term which is the term used in the categorization. Here is an example from the query http://export.arxiv.org/api/query?id_list=cs/9901002v1
296
+
297
+
298
+ <category xmlns="http://www.w3.org/2005/Atom" term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
299
+ <category xmlns="http://www.w3.org/2005/Atom" term="cs.AI" scheme="http://arxiv.org/schemas/atom"/>
300
+ <category xmlns="http://www.w3.org/2005/Atom" term="I.2.6" scheme="http://arxiv.org/schemas/atom"/>
301
+ Note that in this example, there are 3 category elements, one for each category. The first two correspond to arXiv categories, and the last one to an ACM category. See <arxiv> extension elements below for information on how to identify the arXiv primary category.
302
+
303
+
304
+ 3.3.2.3. <LINK>'S
305
+ For each entry, there are up to three <link> elements, distinguished by their rel and title attributes. The table below summarizes what these links refer to
306
+
307
+ rel title refers to always present
308
+ alternate - abstract page yes
309
+ related pdf pdf yes
310
+ related doi resolved doi no
311
+ For example:
312
+
313
+
314
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/abs/hep-ex/0307015v1" rel="alternate" type="text/html"/>
315
+ <link xmlns="http://www.w3.org/2005/Atom" title="pdf" href="http://arxiv.org/pdf/hep-ex/0307015v1" rel="related" type="application/pdf"/>
316
+ <link xmlns="http://www.w3.org/2005/Atom" title="doi" href="http://dx.doi.org/10.1529/biophysj.104.047340" rel="related"/>
317
+
318
+ 3.3.2.4. <ARXIV> EXTENSION ELEMENTS
319
+ There are several pieces of arXiv metadata that are not able to be mapped onto the standard Atom specification. We have therefore defined several extension elements which live in the arxiv namespace
320
+
321
+
322
+ http://arxiv.org/schemas/atom
323
+ The arXiv classification system supports multiple <category> tags, as well as a primary classification. The primary classification is a replica of an Atom <category> tag, except it has the name <arxiv:primary_category>. For example, from the query http://export.arxiv.org/api/query?id_list=cs/9901002v1, we have
324
+
325
+
326
+ <arxiv:primary_category xmlns:arxiv="http://arxiv.org/schemas/atom" term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
327
+ signifying that cs.LG is the primary arXiv classification for this e-print.
328
+
329
+ The <arxiv:comment> element contains the typical author comments found on most arXiv articles:
330
+
331
+
332
+ <arxiv:comment xmlns:arxiv="http://arxiv.org/schemas/atom">
333
+ 23 pages, 8 figures and 4 tables
334
+ </arxiv:comment>
335
+ If the author has supplied affiliation information, then this is included as an <arxiv:affiliation> subelement of the standard Atom <author> element. For example, from the query http://export.arxiv.org/api/query?id_list=0710.5765v1, we have
336
+
337
+
338
+ <author>
339
+ <name>G. G. Kacprzak</name>
340
+ <arxiv:affiliation xmlns:arxiv="http://arxiv.org/schemas/atom">NMSU</arxiv:affiliation>
341
+ </author>
342
+ If the author has provided a journal reference for the article, then there will be a <arxiv:journal_ref> element with this information:
343
+
344
+
345
+ <arxiv:journal_ref xmlns:arxiv="http://arxiv.org/schemas/atom">
346
+ Eur.Phys.J. C31 (2003) 17-29
347
+ </arxiv:journal_ref>
348
+ If the author has provided a DOI for the article, then there will be a <arxiv:doi> element with this information:
349
+
350
+
351
+ <arxiv:doi xmlns:arxiv="http://arxiv.org/schemas/atom">
352
+ 10.1529/biophysj.104.047340
353
+ </arxiv:doi>
354
+
355
+ 3.4. Errors
356
+ Errors are returned as Atom feeds with a single entry representing the error. The <summary> for the error contains a helpful error message, and the <link> element contains a url to a more detailed explanation of the message.
357
+
358
+ For example, the API call http://export.arxiv.org/api/query?id_list=1234.12345 contains a malformed id, and results in the error
359
+
360
+
361
+ <?xml version="1.0" encoding="utf-8"?>
362
+ <feed xmlns="http://www.w3.org/2005/Atom" xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
363
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=&amp;id_list=1234.12345" rel="self" type="application/atom+xml"/>
364
+ <title xmlns="http://www.w3.org/2005/Atom">ArXiv Query: search_query=&amp;id_list=1234.12345</title>
365
+ <id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/kvuntZ8c9a4Eq5CF7KY03nMug+Q</id>
366
+ <updated xmlns="http://www.w3.org/2005/Atom">2007-10-12T00:00:00-04:00</updated>
367
+ <opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:totalResults>
368
+ <opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">0</opensearch:startIndex>
369
+
370
+ <opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:itemsPerPage>
371
+ <entry xmlns="http://www.w3.org/2005/Atom">
372
+ <id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/errors#incorrect_id_format_for_1234.12345</id>
373
+ <title xmlns="http://www.w3.org/2005/Atom">Error</title>
374
+ <summary xmlns="http://www.w3.org/2005/Atom">incorrect id format for 1234.12345</summary>
375
+ <updated xmlns="http://www.w3.org/2005/Atom">2007-10-12T00:00:00-04:00</updated>
376
+
377
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/errors#incorrect_id_format_for_1234.12345" rel="alternate" type="text/html"/>
378
+ <author xmlns="http://www.w3.org/2005/Atom">
379
+ <name xmlns="http://www.w3.org/2005/Atom">arXiv api core</name>
380
+ </author>
381
+ </entry>
382
+ </feed>
383
+ The following table gives information on errors that might occur.
384
+
385
+ Sample query Error Explanation
386
+ http://export.arxiv.org/api/query?start=not_an_int start must be an integer
387
+ http://export.arxiv.org/api/query?start=-1 start must be >= 0
388
+ http://export.arxiv.org/api/query?max_results=not_an_int max_results must be an integer
389
+ http://export.arxiv.org/api/query?max_results=-1 max_results must be >= 0
390
+ http://export.arxiv.org/api/query?id_list=1234.1234 malformed id - see arxiv identifier explanation
391
+ http://export.arxiv.org/api/query?id_list=cond—mat/0709123 malformed id - see arxiv identifier explanation
392
+
393
+ 4. Examples
394
+ Once you have familiarized yourself with the API, you should be able to easily write programs that call the API automatically. Most programming languages, if not all, have libraries that allow you to make HTTP requests. Since Atom is growing, not all languages have libraries that support Atom parsing, so most of the programming effort will be in digesting the responses you receive. The languages that we know of that can easily handle calling the api via HTTP and parsing the results include:
395
+
396
+ Perl (via LWP) (example)
397
+
398
+ Python (via urllib) (example)
399
+
400
+ Ruby (via uri and net::http) (example)
401
+
402
+ PHP (via file_get_contents()) (example)
403
+
404
+
405
+ 4.1. Simple Examples
406
+ Below we include code snippets for these languages that perform the bare minimum functionality - calling the api and printing the raw Atom results. If your favorite language is not up here, write us with an example, and we'll be glad to post it!
407
+
408
+ All of the simple examples produce an output which looks like:
409
+
410
+ Example: A Typical Atom Response
411
+
412
+
413
+ <?xml version="1.0" encoding="utf-8"?>
414
+ <feed xmlns="http://www.w3.org/2005/Atom" xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/" xmlns:arxiv="http://arxiv.org/schemas/atom">
415
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=all:electron&amp;id_list=&amp;start=0&amp;max_results=1" rel="self" type="application/atom+xml"/>
416
+ <title xmlns="http://www.w3.org/2005/Atom">ArXiv Query: search_query=all:electron&amp;id_list=&amp;start=0&amp;max_results=1</title>
417
+ <id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/cHxbiOdZaP56ODnBPIenZhzg5f8</id>
418
+ <updated xmlns="http://www.w3.org/2005/Atom">2007-10-08T00:00:00-04:00</updated>
419
+ <opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1000</opensearch:totalResults>
420
+ <opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">0</opensearch:startIndex>
421
+ <opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:itemsPerPage>
422
+ <entry xmlns="http://www.w3.org/2005/Atom" xmlns:arxiv="http://arxiv.org/schemas/atom">
423
+ <id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/abs/hep-ex/0307015</id>
424
+ <published xmlns="http://www.w3.org/2005/Atom">2003-07-07T13:46:39-04:00</published>
425
+ <updated xmlns="http://www.w3.org/2005/Atom">2003-07-07T13:46:39-04:00</updated>
426
+ <title xmlns="http://www.w3.org/2005/Atom">Multi-Electron Production at High Transverse Momenta in ep Collisions at
427
+ HERA</title>
428
+ <summary xmlns="http://www.w3.org/2005/Atom"> Multi-electron production is studied at high electron transverse momentum in
429
+ positron- and electron-proton collisions using the H1 detector at HERA. The
430
+ data correspond to an integrated luminosity of 115 pb-1. Di-electron and
431
+ tri-electron event yields are measured. Cross sections are derived in a
432
+ restricted phase space region dominated by photon-photon collisions. In general
433
+ good agreement is found with the Standard Model predictions. However, for
434
+ electron pair invariant masses above 100 GeV, three di-electron events and
435
+ three tri-electron events are observed, compared to Standard Model expectations
436
+ of 0.30 \pm 0.04 and 0.23 \pm 0.04, respectively.
437
+ </summary>
438
+ <author xmlns="http://www.w3.org/2005/Atom">
439
+ <name xmlns="http://www.w3.org/2005/Atom">H1 Collaboration</name>
440
+ </author>
441
+ <arxiv:comment xmlns:arxiv="http://arxiv.org/schemas/atom">23 pages, 8 figures and 4 tables</arxiv:comment>
442
+ <arxiv:journal_ref xmlns:arxiv="http://arxiv.org/schemas/atom">Eur.Phys.J. C31 (2003) 17-29</arxiv:journal_ref>
443
+ <link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/abs/hep-ex/0307015v1" rel="alternate" type="text/html"/>
444
+ <link xmlns="http://www.w3.org/2005/Atom" title="pdf" href="http://arxiv.org/pdf/hep-ex/0307015v1" rel="related" type="application/pdf"/>
445
+ <arxiv:primary_category xmlns:arxiv="http://arxiv.org/schemas/atom" term="hep-ex" scheme="http://arxiv.org/schemas/atom"/>
446
+ <category term="hep-ex" scheme="http://arxiv.org/schemas/atom"/>
447
+ </entry>
448
+ </feed>
449
+
450
+ 4.1.1. Perl
451
+ LWP is in the default perl installation on most platforms. It can be downloaded and installed from CPAN. Sample code to produce the above output is:
452
+
453
+
454
+ use LWP;
455
+ use strict;
456
+
457
+ my $url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1';
458
+ my $browser = LWP::UserAgent->new();
459
+ my $response = $browser->get($url);
460
+ print $response->content();
461
+
462
+ 4.1.2. Python
463
+ The urllib module is part of the python standard library, and is included in any default installation of python. Sample code to produce the above output in Python 2.7 is:
464
+
465
+
466
+ import urllib
467
+ url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1'
468
+ data = urllib.urlopen(url).read()
469
+ print data
470
+ wheras in Python 3 an example would be:
471
+
472
+
473
+ import urllib.request as libreq
474
+ with libreq.urlopen('http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1') as url:
475
+ r = url.read()
476
+ print(r)
477
+
478
+ 4.1.3. Ruby
479
+ The net/http and uri modules are part of the ruby standard library, and are included in any default installation of ruby. Sample code to produce the above output is:
480
+
481
+
482
+ require 'net/http'
483
+ require 'uri'
484
+ url = URI.parse('http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1')
485
+ res = Net::HTTP.get_response(url)
486
+ print res.body
487
+
488
+ 4.1.4. PHP
489
+ The file_get_contents() function is part of the PHP core language:
490
+
491
+
492
+ <?php
493
+ $url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1';
494
+ $response = file_get_contents($url);
495
+ print_r($response);
496
+ ?>
497
+
498
+ 4.2. Detailed Parsing Examples
499
+ The examples above don't cover how to parse the Atom results returned to extract the information you might be interested in. They also don't cover how to do more advanced programming of the API to perform such tasks as downloading chunks of the full results list one page at a time. The table below contains links to more detailed examples for each of the languages above, as well as to the libraries used to parse Atom.
500
+
501
+ Language Library Parsing Example Paging Example
502
+ Perl XML::Atom parsing paging
503
+ Python feedparser parsing paging
504
+ Ruby feedtools parsing paging
505
+ PHP SimplePie parsing paging
506
+
507
+ 5. Appendices
508
+
509
+ 5.1. Details of Query Construction
510
+ As outlined in the Structure of the API section, the interface to the API is quite simple. This simplicity, combined with search_query construction, and result set filtering through id_list makes the API a powerful tool for harvesting data from the arXiv. In this section, we outline the possibilities for constructing search_query's to retrieve our desired article lists. We outlined how to use the id_list parameter to filter results sets in search_query and id_list logic.
511
+
512
+ In the arXiv search engine, each article is divided up into a number of fields that can individually be searched. For example, the titles of an article can be searched, as well as the author list, abstracts, comments and journal reference. To search one of these fields, we simply prepend the field prefix followed by a colon to our search term. For example, suppose we wanted to find all articles by the author Adrian Del Maestro. We could construct the following query
513
+
514
+ http://export.arxiv.org/api/query?search_query=au:del_maestro
515
+
516
+ This returns nine results. The following table lists the field prefixes for all the fields that can be searched.
517
+
518
+ prefix explanation
519
+ ti Title
520
+ au Author
521
+ abs Abstract
522
+ co Comment
523
+ jr Journal Reference
524
+ cat Subject Category
525
+ rn Report Number
526
+ id Id (use id_list instead)
527
+ all All of the above
528
+ Note: The id_list parameter should be used rather than search_query=id:xxx to properly handle article versions. In addition, note that all: searches in each of the fields simultaneously.
529
+
530
+ The API allows advanced query construction by combining these search fields with Boolean operators. For example, suppose we want to find all articles by the author Adrian DelMaestro that also contain the word checkerboard in the title. We could construct the following query, using the AND operator:
531
+
532
+ http://export.arxiv.org/api/query?search_query=au:del_maestro+AND+ti:checkerboard
533
+
534
+ As expected, this query picked out the one of the nine previous results with checkerboard in the title. Note that we included + signs in the urls to the API. In a url, a + sign encodes a space, which is useful since spaces are not allowed in url's. It is always a good idea to escape the characters in your url's, which is a common feature in most programming libraries that deal with url's. Note that the <title> of the returned feed has spaces in the query constructed. It is a good idea to look at <title> to see if you have escaped your url correctly.
535
+
536
+ The following table lists the three possible Boolean operators.
537
+
538
+ AND
539
+ OR
540
+ ANDNOT
541
+ The ANDNOT Boolean operator is particularly useful, as it allows us to filter search results based on certain fields. For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that did not contain the word checkerboard, we could construct the following query:
542
+
543
+ http://export.arxiv.org/api/query?search_query=au:del_maestro+ANDNOT+ti:checkerboard
544
+
545
+ As expected, this query returns eight results.
546
+
547
+ Finally, even more complex queries can be used by using parentheses for grouping the Boolean expressions. To include parentheses in in a url, use %28 for a left-parens (, and %29 for a right-parens ). For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that did not contain the words checkerboard, OR Pyrochore, we could construct the following query:
548
+
549
+ http://export.arxiv.org/api/query?search_query=au:del_maestro+ANDNOT+%28ti:checkerboard+OR+ti:Pyrochlore%29
550
+
551
+ This query returns three results. Notice that the <title> element displays the parenthesis correctly meaning that we used the correct url escaping.
552
+
553
+ So far we have only used single words as the field terms to search for. You can include entire phrases by enclosing the phrase in double quotes, escaped by %22. For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that contain quantum criticality, we could construct the following query:
554
+
555
+ http://export.arxiv.org/api/query?search_query=au:del_maestro+AND+ti:%22quantum+criticality%22
556
+
557
+ This query returns one result, and notice that the feed <title> contains double quotes as expected. The table below lists the two grouping operators used in the API.
558
+
559
+ symbol encoding explanation
560
+ ( ) %28 %29 Used to group Boolean expressions for Boolean operator precedence.
561
+ double quotes %22 %22 Used to group multiple words into phrases to search a particular field.
562
+ space + Used to extend a search_query to include multiple fields.
563
+
564
+ 5.1.1. A Note on Article Versions
565
+ Each arXiv article has a version associated with it. The first time an article is posted, it is given a version number of 1. When subsequent corrections are made to an article, it is resubmitted, and the version number is incremented. At any time, any version of an article may be retrieved.
566
+
567
+ When using the API, if you want to retrieve the latest version of an article, you may simply enter the arxiv id in the id_list parameter. If you want to retrieve information about a specific version, you can do this by appending vn to the id, where n is the version number you are interested in.
568
+
569
+ For example, to retrieve the latest version of cond-mat/0207270, you could use the query http://export.arxiv.org/api/query?id_list=cond-mat/0207270. To retrieve the very first version of this article, you could use the query http://export.arxiv.org/api/query?id_list=cond-mat/0207270v1
570
+
571
+
572
+ 5.2. Details of Atom Results Returned
573
+ The following table lists each element of the returned Atom results. For a more detailed explanation see Outline of an Atom Feed.
574
+
575
+ element explanation
576
+ feed elements
577
+ <title> The title of the feed containing a canonicalized query string.
578
+ <id> A unique id assigned to this query.
579
+ <updated> The last time search results for this query were updated. Set to midnight of the current day.
580
+ <link> A url that will retrieve this feed via a GET request.
581
+ <opensearch:totalResults> The total number of search results for this query.
582
+ <opensearch:startIndex> The 0-based index of the first returned result in the total results list.
583
+ <opensearch:itemsPerPage> The number of results returned.
584
+ entry elements
585
+ <title> The title of the article.
586
+ <id> A url http://arxiv.org/abs/id
587
+ <published> The date that version 1 of the article was submitted.
588
+ <updated> The date that the retrieved version of the article was submitted. Same as <published> if the retrieved version is version 1.
589
+ <summary> The article abstract.
590
+ <author> One for each author. Has child element <name> containing the author name.
591
+ <link> Can be up to 3 given url's associated with this article.
592
+ <category> The arXiv or ACM or MSC category for an article if present.
593
+ <arxiv:primary_category> The primary arXiv category.
594
+ <arxiv:comment> The authors comment if present.
595
+ <arxiv:affiliation> The author's affiliation included as a subelement of <author> if present.
596
+ <arxiv:journal_ref> A journal reference if present.
597
+ <arxiv:doi> A url for the resolved DOI to an external resource if present.
598
+ 5.3. Subject Classifications
599
+ For the complete list of arXiv subject classifications, please visit the taxonomy page.
benchmarks/bibtex-generation/env/bibtex_generation.py ADDED
File without changes
benchmarks/bibtex-generation/env/claude_example.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import anthropic
3
+
4
+ client = anthropic.Client(open("claude_api_key.txt").read().strip())
5
+ response = client.completion(
6
+ prompt=f"{anthropic.HUMAN_PROMPT} How many toes do dogs have?{anthropic.AI_PROMPT}",
7
+ stop_sequences = [anthropic.HUMAN_PROMPT],
8
+ model="claude-v1",
9
+ max_tokens_to_sample=100,
10
+ )
11
+ print(response)